mirror of
https://github.com/RYDE-WORK/ktransformers.git
synced 2026-01-23 23:13:38 +08:00
commit
4ae2e81c38
@ -160,7 +160,7 @@ def local_chat(
|
||||
messages, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
if force_think:
|
||||
token_thinks = torch.tensor([tokenizer.encode("<think>\\n",add_special_tokens=False)])
|
||||
token_thinks = torch.tensor([tokenizer.encode("<think>\\n",add_special_tokens=False)],device=input_tensor.device)
|
||||
input_tensor = torch.cat(
|
||||
[input_tensor, token_thinks], dim=1
|
||||
)
|
||||
|
||||
@ -90,6 +90,7 @@ class ArgumentParser:
|
||||
# user config
|
||||
parser.add_argument("--user_secret_key", type=str, default=self.cfg.user_secret_key)
|
||||
parser.add_argument("--user_algorithm", type=str, default=self.cfg.user_algorithm)
|
||||
parser.add_argument("--force_think", type=bool, default=self.cfg.user_force_think)
|
||||
|
||||
# web config
|
||||
parser.add_argument("--web_cross_domain", type=bool, default=self.cfg.web_cross_domain)
|
||||
@ -121,4 +122,5 @@ class ArgumentParser:
|
||||
self.cfg.server_ip = args.host
|
||||
self.cfg.server_port = args.port
|
||||
self.cfg.backend_type = args.type
|
||||
self.cfg.user_force_think = args.force_think
|
||||
return args
|
||||
|
||||
@ -10,6 +10,7 @@ from transformers import (
|
||||
BitsAndBytesConfig,
|
||||
)
|
||||
|
||||
from ktransformers.server.config.config import Config
|
||||
from ktransformers.server.schemas.base import ObjectID
|
||||
from ktransformers.server.utils.multi_timer import Profiler
|
||||
import torch
|
||||
@ -323,10 +324,19 @@ class TransformersInterface(BackendInterfaceBase):
|
||||
#input_ids = torch.tensor([[6366]], device=input_ids.device)
|
||||
else:
|
||||
raise ValueError("local_messages should be List or str")
|
||||
if Config().user_force_think:
|
||||
token_thinks = torch.tensor([self.tokenizer.encode("<think>\\n",add_special_tokens=False)],device=input_ids.device)
|
||||
input_ids = torch.cat(
|
||||
[input_ids, token_thinks], dim=1
|
||||
)
|
||||
|
||||
self.profiler.pause_timer("tokenize")
|
||||
|
||||
self.profiler.create_and_start_timer("prefill")
|
||||
if Config().user_force_think:
|
||||
t = "<think>\n"
|
||||
print(t,end="",flush=True)
|
||||
yield t
|
||||
for t in self.prefill(input_ids, self.check_is_new(thread_id)):
|
||||
if t is not None:
|
||||
print(t, end="",flush=True)
|
||||
@ -337,7 +347,7 @@ class TransformersInterface(BackendInterfaceBase):
|
||||
for t in self.generate():
|
||||
if t is not None:
|
||||
print(t, end="",flush=True)
|
||||
yield t
|
||||
yield t
|
||||
print("")
|
||||
self.profiler.pause_timer("decode")
|
||||
self.report_last_time_performance()
|
||||
|
||||
@ -83,6 +83,7 @@ class Config(metaclass=Singleton):
|
||||
self.user_config: dict = cfg.get("user", {})
|
||||
self.user_secret_key = self.user_config.get("secret_key", "")
|
||||
self.user_algorithm = self.user_config.get("algorithm", "")
|
||||
self.user_force_think = self.user_config.get("force_think", False)
|
||||
|
||||
# model config
|
||||
self.model: dict = cfg.get("model", {})
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user