From e536e1420d65dbe293d374d1c88c9f0715f9828d Mon Sep 17 00:00:00 2001 From: liam Date: Wed, 12 Feb 2025 11:42:55 +0800 Subject: [PATCH 1/3] :zap: update force_think --- ktransformers/server/args.py | 1 + ktransformers/server/backend/interfaces/transformers.py | 9 +++++++++ ktransformers/server/config/config.py | 1 + 3 files changed, 11 insertions(+) diff --git a/ktransformers/server/args.py b/ktransformers/server/args.py index 38aa20d..660a782 100644 --- a/ktransformers/server/args.py +++ b/ktransformers/server/args.py @@ -90,6 +90,7 @@ class ArgumentParser: # user config parser.add_argument("--user_secret_key", type=str, default=self.cfg.user_secret_key) parser.add_argument("--user_algorithm", type=str, default=self.cfg.user_algorithm) + parser.add_argument("--force_think", type=bool, default=self.cfg.force_think) # web config parser.add_argument("--web_cross_domain", type=bool, default=self.cfg.web_cross_domain) diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index 81fa6e5..fd997b4 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -10,6 +10,7 @@ from transformers import ( BitsAndBytesConfig, ) +from ktransformers.server.config.config import Config from ktransformers.server.schemas.base import ObjectID from ktransformers.server.utils.multi_timer import Profiler import torch @@ -323,10 +324,18 @@ class TransformersInterface(BackendInterfaceBase): #input_ids = torch.tensor([[6366]], device=input_ids.device) else: raise ValueError("local_messages should be List or str") + if Config().force_think: + token_thinks = torch.tensor([self.tokenizer.encode("\\n",add_special_tokens=False)]) + input_ids = torch.cat( + [input_ids, token_thinks], dim=1 + ) self.profiler.pause_timer("tokenize") self.profiler.create_and_start_timer("prefill") + if Config().force_think: + print("\n") + yield "\n" for t in self.prefill(input_ids, self.check_is_new(thread_id)): if t is not None: print(t, end="",flush=True) diff --git a/ktransformers/server/config/config.py b/ktransformers/server/config/config.py index 7ce616b..7dc9921 100644 --- a/ktransformers/server/config/config.py +++ b/ktransformers/server/config/config.py @@ -83,6 +83,7 @@ class Config(metaclass=Singleton): self.user_config: dict = cfg.get("user", {}) self.user_secret_key = self.user_config.get("secret_key", "") self.user_algorithm = self.user_config.get("algorithm", "") + self.user_force_think = self.user_config.get("force_think", False) # model config self.model: dict = cfg.get("model", {}) From 6f3a39be08f30fd7d41fd9aa19c9e8a9a6cec353 Mon Sep 17 00:00:00 2001 From: liam Date: Wed, 12 Feb 2025 12:10:16 +0800 Subject: [PATCH 2/3] :zap: update force_think config --- ktransformers/server/args.py | 2 +- ktransformers/server/backend/interfaces/transformers.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ktransformers/server/args.py b/ktransformers/server/args.py index 660a782..e90ca2f 100644 --- a/ktransformers/server/args.py +++ b/ktransformers/server/args.py @@ -90,7 +90,7 @@ class ArgumentParser: # user config parser.add_argument("--user_secret_key", type=str, default=self.cfg.user_secret_key) parser.add_argument("--user_algorithm", type=str, default=self.cfg.user_algorithm) - parser.add_argument("--force_think", type=bool, default=self.cfg.force_think) + parser.add_argument("--force_think", type=bool, default=self.cfg.user_force_think) # web config parser.add_argument("--web_cross_domain", type=bool, default=self.cfg.web_cross_domain) diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index fd997b4..01a6b84 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -324,7 +324,7 @@ class TransformersInterface(BackendInterfaceBase): #input_ids = torch.tensor([[6366]], device=input_ids.device) else: raise ValueError("local_messages should be List or str") - if Config().force_think: + if Config().user_force_think: token_thinks = torch.tensor([self.tokenizer.encode("\\n",add_special_tokens=False)]) input_ids = torch.cat( [input_ids, token_thinks], dim=1 @@ -333,7 +333,7 @@ class TransformersInterface(BackendInterfaceBase): self.profiler.pause_timer("tokenize") self.profiler.create_and_start_timer("prefill") - if Config().force_think: + if Config().user_force_think: print("\n") yield "\n" for t in self.prefill(input_ids, self.check_is_new(thread_id)): From 4385e85096f315b8e65afdf224e3c49cd058941c Mon Sep 17 00:00:00 2001 From: liam Date: Wed, 12 Feb 2025 12:43:53 +0800 Subject: [PATCH 3/3] :zap: support force thinking --- ktransformers/local_chat.py | 2 +- ktransformers/server/args.py | 1 + ktransformers/server/backend/interfaces/transformers.py | 9 +++++---- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index 4e006b6..676ea67 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -160,7 +160,7 @@ def local_chat( messages, add_generation_prompt=True, return_tensors="pt" ) if force_think: - token_thinks = torch.tensor([tokenizer.encode("\\n",add_special_tokens=False)]) + token_thinks = torch.tensor([tokenizer.encode("\\n",add_special_tokens=False)],device=input_tensor.device) input_tensor = torch.cat( [input_tensor, token_thinks], dim=1 ) diff --git a/ktransformers/server/args.py b/ktransformers/server/args.py index e90ca2f..44fe7d2 100644 --- a/ktransformers/server/args.py +++ b/ktransformers/server/args.py @@ -122,4 +122,5 @@ class ArgumentParser: self.cfg.server_ip = args.host self.cfg.server_port = args.port self.cfg.backend_type = args.type + self.cfg.user_force_think = args.force_think return args diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index 01a6b84..f18581a 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -325,7 +325,7 @@ class TransformersInterface(BackendInterfaceBase): else: raise ValueError("local_messages should be List or str") if Config().user_force_think: - token_thinks = torch.tensor([self.tokenizer.encode("\\n",add_special_tokens=False)]) + token_thinks = torch.tensor([self.tokenizer.encode("\\n",add_special_tokens=False)],device=input_ids.device) input_ids = torch.cat( [input_ids, token_thinks], dim=1 ) @@ -334,8 +334,9 @@ class TransformersInterface(BackendInterfaceBase): self.profiler.create_and_start_timer("prefill") if Config().user_force_think: - print("\n") - yield "\n" + t = "\n" + print(t,end="",flush=True) + yield t for t in self.prefill(input_ids, self.check_is_new(thread_id)): if t is not None: print(t, end="",flush=True) @@ -346,7 +347,7 @@ class TransformersInterface(BackendInterfaceBase): for t in self.generate(): if t is not None: print(t, end="",flush=True) - yield t + yield t print("") self.profiler.pause_timer("decode") self.report_last_time_performance()