mirror of
https://github.com/RYDE-WORK/ktransformers.git
synced 2026-01-19 12:43:16 +08:00
Merge pull request #644 from wtdcode/temperature_top_p_from_request
Allow temperature and top_p from /v1/chat/completions
This commit is contained in:
commit
5e3c6b4f97
@ -31,13 +31,13 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
|
|||||||
if create.stream:
|
if create.stream:
|
||||||
async def inner():
|
async def inner():
|
||||||
chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time()))
|
chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time()))
|
||||||
async for token in interface.inference(input_message,id):
|
async for token in interface.inference(input_message,id,create.temperature,create.top_p):
|
||||||
chunk.set_token(token)
|
chunk.set_token(token)
|
||||||
yield chunk
|
yield chunk
|
||||||
return chat_stream_response(request,inner())
|
return chat_stream_response(request,inner())
|
||||||
else:
|
else:
|
||||||
comp = ChatCompletionObject(id=id,object='chat.completion',created=int(time()))
|
comp = ChatCompletionObject(id=id,object='chat.completion',created=int(time()))
|
||||||
comp.usage = Usage(completion_tokens=1, prompt_tokens=1, total_tokens=2)
|
comp.usage = Usage(completion_tokens=1, prompt_tokens=1, total_tokens=2)
|
||||||
async for token in interface.inference(input_message,id):
|
async for token in interface.inference(input_message,id,create.temperature,create.top_p):
|
||||||
comp.append_token(token)
|
comp.append_token(token)
|
||||||
return comp
|
return comp
|
||||||
|
|||||||
@ -20,7 +20,7 @@ async def create_completion(request:Request,create:CompletionCreate):
|
|||||||
|
|
||||||
if create.stream:
|
if create.stream:
|
||||||
async def inner():
|
async def inner():
|
||||||
async for token in interface.inference(create.prompt,id):
|
async for token in interface.inference(create.prompt,id,create.temperature,create.top_p):
|
||||||
d = {'choices':[{'delta':{'content':token}}]}
|
d = {'choices':[{'delta':{'content':token}}]}
|
||||||
yield f"data:{json.dumps(d)}\n\n"
|
yield f"data:{json.dumps(d)}\n\n"
|
||||||
d = {'choices':[{'delta':{'content':''},'finish_reason':''}]}
|
d = {'choices':[{'delta':{'content':''},'finish_reason':''}]}
|
||||||
@ -28,6 +28,6 @@ async def create_completion(request:Request,create:CompletionCreate):
|
|||||||
return stream_response(request,inner())
|
return stream_response(request,inner())
|
||||||
else:
|
else:
|
||||||
comp = CompletionObject(id=id,object='text_completion',created=int(time()))
|
comp = CompletionObject(id=id,object='text_completion',created=int(time()))
|
||||||
async for token in interface.inference(create.prompt,id):
|
async for token in interface.inference(create.prompt,id,create.temperature,create.top_p):
|
||||||
comp.append_token(token)
|
comp.append_token(token)
|
||||||
return comp
|
return comp
|
||||||
|
|||||||
@ -14,9 +14,9 @@ from ktransformers.models.custom_cache import StaticCache
|
|||||||
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
||||||
from ktransformers.local_chat import custom_models, default_optimize_rules
|
from ktransformers.local_chat import custom_models, default_optimize_rules
|
||||||
from ktransformers.util.utils import get_device
|
from ktransformers.util.utils import get_device
|
||||||
|
from typing import Optional
|
||||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
|
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
|
||||||
|
|
||||||
|
|
||||||
warm_uped = False
|
warm_uped = False
|
||||||
|
|
||||||
class KTransformersThreadContext(TransformersThreadContext):
|
class KTransformersThreadContext(TransformersThreadContext):
|
||||||
@ -128,7 +128,7 @@ class KTransformersInterface(TransformersInterface):
|
|||||||
|
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def prefill(self, input_ids: torch.Tensor, is_new: bool):
|
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float], top_p: Optional[float]):
|
||||||
input_ids_length = input_ids.shape[-1]
|
input_ids_length = input_ids.shape[-1]
|
||||||
logger.debug(f"input_ids: {input_ids.shape}")
|
logger.debug(f"input_ids: {input_ids.shape}")
|
||||||
|
|
||||||
@ -203,7 +203,7 @@ class KTransformersInterface(TransformersInterface):
|
|||||||
|
|
||||||
if flashinfer_enabled:
|
if flashinfer_enabled:
|
||||||
MLAWrapperSingleton.reset_buffer()
|
MLAWrapperSingleton.reset_buffer()
|
||||||
self.prepare_logits_wrapper(input_ids, device)
|
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
|
||||||
next_token = self.logits_to_token(logits[0, -1, :])
|
next_token = self.logits_to_token(logits[0, -1, :])
|
||||||
yield self.append_new_tokens(next_token)
|
yield self.append_new_tokens(next_token)
|
||||||
|
|
||||||
@ -212,7 +212,7 @@ class KTransformersInterface(TransformersInterface):
|
|||||||
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
|
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
|
||||||
return torch.tensor([self.seq_length - 1], device=device)
|
return torch.tensor([self.seq_length - 1], device=device)
|
||||||
|
|
||||||
async def inference(self, local_messages, thread_id: str):
|
async def inference(self, local_messages, thread_id: str, temperature: Optional[float], top_p: Optional[float]):
|
||||||
async with self._infer_lock:
|
async with self._infer_lock:
|
||||||
async for v in super().inference(local_messages, thread_id):
|
async for v in super().inference(local_messages, thread_id, temperature, top_p):
|
||||||
yield v
|
yield v
|
||||||
|
|||||||
@ -202,13 +202,17 @@ class TransformersInterface(BackendInterfaceBase):
|
|||||||
self.seq_length += 1
|
self.seq_length += 1
|
||||||
return self.streamer.put(new_tokens)
|
return self.streamer.put(new_tokens)
|
||||||
|
|
||||||
def prepare_logits_wrapper(self, inputs, device):
|
def prepare_logits_wrapper(self, inputs, device, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
||||||
|
if temperature is None:
|
||||||
|
temperature = self.args.temperature
|
||||||
|
if top_p is None:
|
||||||
|
top_p = self.args.top_p
|
||||||
generation_config, model_kwargs = self.model._prepare_generation_config(
|
generation_config, model_kwargs = self.model._prepare_generation_config(
|
||||||
None, max_length=self.args.max_new_tokens,
|
None, max_length=self.args.max_new_tokens,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
top_k=self.args.top_k,
|
top_k=self.args.top_k,
|
||||||
top_p=self.args.top_p,
|
top_p=top_p,
|
||||||
temperature=self.args.temperature,
|
temperature=temperature,
|
||||||
repetition_penalty=self.args.repetition_penalty # change this to modify generate config
|
repetition_penalty=self.args.repetition_penalty # change this to modify generate config
|
||||||
)
|
)
|
||||||
self.inputs = inputs
|
self.inputs = inputs
|
||||||
@ -255,7 +259,7 @@ class TransformersInterface(BackendInterfaceBase):
|
|||||||
return self.logits_to_token(logits)
|
return self.logits_to_token(logits)
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def prefill(self, input_ids: torch.Tensor, is_new: bool):
|
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
||||||
input_ids_length = input_ids.shape[-1]
|
input_ids_length = input_ids.shape[-1]
|
||||||
logger.debug(f"input_ids: {input_ids.shape}")
|
logger.debug(f"input_ids: {input_ids.shape}")
|
||||||
|
|
||||||
@ -323,7 +327,7 @@ class TransformersInterface(BackendInterfaceBase):
|
|||||||
else:
|
else:
|
||||||
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
||||||
|
|
||||||
self.prepare_logits_wrapper(input_ids, device)
|
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
|
||||||
next_token = self.logits_to_token(logits[0, -1, :])
|
next_token = self.logits_to_token(logits[0, -1, :])
|
||||||
yield self.append_new_tokens(next_token)
|
yield self.append_new_tokens(next_token)
|
||||||
|
|
||||||
@ -359,7 +363,7 @@ class TransformersInterface(BackendInterfaceBase):
|
|||||||
self.last_request_id = thread_id
|
self.last_request_id = thread_id
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def inference(self, local_messages, thread_id: str):
|
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
||||||
self.streamer.reset()
|
self.streamer.reset()
|
||||||
self.profiler.create_and_start_timer("tokenize")
|
self.profiler.create_and_start_timer("tokenize")
|
||||||
if isinstance(local_messages, List):
|
if isinstance(local_messages, List):
|
||||||
@ -386,7 +390,7 @@ class TransformersInterface(BackendInterfaceBase):
|
|||||||
print(think, end="",flush=True)
|
print(think, end="",flush=True)
|
||||||
yield think
|
yield think
|
||||||
|
|
||||||
for t in self.prefill(input_ids, self.check_is_new(thread_id)):
|
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
|
||||||
# output think token after prefill done
|
# output think token after prefill done
|
||||||
if t is not None:
|
if t is not None:
|
||||||
print(t, end="",flush=True)
|
print(t, end="",flush=True)
|
||||||
|
|||||||
@ -25,7 +25,9 @@ class ChatCompletionCreate(BaseModel):
|
|||||||
messages: List[Message]
|
messages: List[Message]
|
||||||
model : str
|
model : str
|
||||||
stream : bool = False
|
stream : bool = False
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
|
||||||
def get_tokenizer_messages(self):
|
def get_tokenizer_messages(self):
|
||||||
return [m.to_tokenizer_message() for m in self.messages]
|
return [m.to_tokenizer_message() for m in self.messages]
|
||||||
|
|
||||||
|
|||||||
@ -9,6 +9,8 @@ class CompletionCreate(BaseModel):
|
|||||||
model: str
|
model: str
|
||||||
prompt: str | List[str]
|
prompt: str | List[str]
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
|
||||||
def get_tokenizer_messages(self):
|
def get_tokenizer_messages(self):
|
||||||
if isinstance(self.prompt,List):
|
if isinstance(self.prompt,List):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user