diff --git a/ktransformers/server/api/openai/endpoints/chat.py b/ktransformers/server/api/openai/endpoints/chat.py index f84538a..356637c 100644 --- a/ktransformers/server/api/openai/endpoints/chat.py +++ b/ktransformers/server/api/openai/endpoints/chat.py @@ -28,13 +28,13 @@ async def chat_completion(request:Request,create:ChatCompletionCreate): if create.stream: async def inner(): chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time())) - async for token in interface.inference(input_message,id): + async for token in interface.inference(input_message,id,create.temperature,create.top_p,create.repetition_penalty): chunk.set_token(token) yield chunk return chat_stream_response(request,inner()) else: comp = ChatCompletionObject(id=id,object='chat.completion',created=int(time())) comp.usage = Usage(completion_tokens=1, prompt_tokens=1, total_tokens=2) - async for token in interface.inference(input_message,id): + async for token in interface.inference(input_message,id,create.temperature,create.top_p,create.repetition_penalty): comp.append_token(token) return comp diff --git a/ktransformers/server/schemas/endpoints/chat.py b/ktransformers/server/schemas/endpoints/chat.py index 5c4dc4e..b929c4b 100644 --- a/ktransformers/server/schemas/endpoints/chat.py +++ b/ktransformers/server/schemas/endpoints/chat.py @@ -25,6 +25,9 @@ class ChatCompletionCreate(BaseModel): messages: List[Message] model : str stream : bool = False + temperature: Optional[float] + top_p: Optional[float] + repetition_penalty: Optional[float] def get_tokenizer_messages(self): return [m.to_tokenizer_message() for m in self.messages]