From 68e7df3a251650995384740a62b678310c7c73c2 Mon Sep 17 00:00:00 2001 From: swu-hyk Date: Wed, 26 Feb 2025 17:05:00 +0800 Subject: [PATCH 1/2] implementation of chat routing for Ollama --- .../server/api/ollama/completions.py | 109 +++++++++++------- 1 file changed, 69 insertions(+), 40 deletions(-) diff --git a/ktransformers/server/api/ollama/completions.py b/ktransformers/server/api/ollama/completions.py index e3a1a51..d0ac17e 100644 --- a/ktransformers/server/api/ollama/completions.py +++ b/ktransformers/server/api/ollama/completions.py @@ -12,8 +12,8 @@ from ktransformers.server.config.config import Config from ktransformers.server.utils.create_interface import get_interface from ktransformers.server.schemas.assistants.streaming import check_link_response from ktransformers.server.backend.base import BackendInterfaceBase -router = APIRouter(prefix='/api') +router = APIRouter(prefix='/api') # https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion class OllamaGenerateCompletionRequest(BaseModel): @@ -40,61 +40,95 @@ class OllamaGenerateCompletionRequest(BaseModel): keep_alive: Optional[str] = Field( "5m", description="Controls how long the model will stay loaded into memory following the request.") - class OllamaGenerationStreamResponse(BaseModel): model: str created_at: str response: str done: bool = Field(...) - class OllamaGenerationResponse(BaseModel): pass - @router.post("/generate", tags=['ollama']) async def generate(request: Request, input: OllamaGenerateCompletionRequest): id = str(uuid4()) - interface: BackendInterfaceBase = get_interface() print(f'COMPLETION INPUT:----\n{input.prompt}\n----') - config = Config() if input.stream: async def inner(): - async for token in interface.inference(input.prompt,id): - d = OllamaGenerationStreamResponse(model=config.model_name,created_at=str(datetime.now()),response=token,done=False) - yield d.model_dump_json()+'\n' - # d = {'model':config.model_name,'created_at':"", 'response':token,'done':False} - # yield f"{json.dumps(d)}\n" - # d = {'model':config.model_name,'created_at':"", 'response':'','done':True} - # yield f"{json.dumps(d)}\n" - d = OllamaGenerationStreamResponse(model=config.model_name,created_at=str(datetime.now()),response='',done=True) - yield d.model_dump_json()+'\n' - return check_link_response(request,inner()) + async for token in interface.inference(input.prompt, id): + d = OllamaGenerationStreamResponse( + model=config.model_name, + created_at=str(datetime.now()), + response=token, + done=False + ) + yield d.model_dump_json() + '\n' + d = OllamaGenerationStreamResponse( + model=config.model_name, + created_at=str(datetime.now()), + response='', + done=True + ) + yield d.model_dump_json() + '\n' + return check_link_response(request, inner()) else: raise NotImplementedError # https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion - +class OllamaChatCompletionMessage(BaseModel): + role: str + content: str class OllamaChatCompletionRequest(BaseModel): - pass - + model: str = Field(..., description="The model name, which is required.") + messages: List[OllamaChatCompletionMessage] = Field( + ..., description="A list of messages to generate a response for.") + stream: bool = Field(True, description="If true, the response will be streamed.") class OllamaChatCompletionStreamResponse(BaseModel): - pass - + model: str + created_at: str + message: str + done: bool = Field(...) class OllamaChatCompletionResponse(BaseModel): pass - @router.post("/chat", tags=['ollama']) async def chat(request: Request, input: OllamaChatCompletionRequest): - raise NotImplementedError + id = str(uuid4()) + interface: BackendInterfaceBase = get_interface() + config = Config() + # 将消息转换为提示字符串 + prompt = "" + for msg in input.messages: + prompt += f"{msg.role}: {msg.content}\n" + prompt += "assistant:" + + if input.stream: + async def inner(): + async for token in interface.inference(prompt, id): + d = OllamaChatCompletionStreamResponse( + model=config.model_name, + created_at=str(datetime.now()), + message=token, + done=False + ) + yield d.model_dump_json() + '\n' + d = OllamaChatCompletionStreamResponse( + model=config.model_name, + created_at=str(datetime.now()), + message='', + done=True + ) + yield d.model_dump_json() + '\n' + return check_link_response(request, inner()) + else: + raise NotImplementedError("Non-streaming chat is not implemented.") # https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models class OllamaModel(BaseModel): @@ -103,9 +137,8 @@ class OllamaModel(BaseModel): size: int # TODO: fill the rest correctly - # mock ollama -@router.get("/tags",tags=['ollama']) +@router.get("/tags", tags=['ollama']) async def tags(): config = Config() # TODO: fill this correctly, although it does not effect Tabby @@ -138,25 +171,21 @@ class OllamaShowResponse(BaseModel): class Config: protected_namespaces = () - - @router.post("/show", tags=['ollama']) async def show(request: Request, input: OllamaShowRequest): config = Config() # TODO: Add more info in config to return, although it does not effect Tabby return OllamaShowResponse( - modelfile = "# Modelfile generated by ...", - parameters = " ", - template = " ", - details = OllamaShowDetial( - parent_model = " ", - format = "gguf", - family = " ", - families = [ - " " - ], - parameter_size = " ", - quantization_level = " " + modelfile="# Modelfile generated by ...", + parameters=" ", + template=" ", + details=OllamaShowDetial( + parent_model=" ", + format="gguf", + family=" ", + families=[" "], + parameter_size=" ", + quantization_level=" " ), - model_info = OllamaModelInfo() + model_info=OllamaModelInfo() ) \ No newline at end of file From ec7e912feed51db8c247e96ea582a9427966134a Mon Sep 17 00:00:00 2001 From: swu-hyk Date: Wed, 26 Feb 2025 19:21:30 +0800 Subject: [PATCH 2/2] modify --- .../server/api/ollama/completions.py | 34 ++++++++++++++++--- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/ktransformers/server/api/ollama/completions.py b/ktransformers/server/api/ollama/completions.py index d0ac17e..0ff6183 100644 --- a/ktransformers/server/api/ollama/completions.py +++ b/ktransformers/server/api/ollama/completions.py @@ -91,8 +91,16 @@ class OllamaChatCompletionRequest(BaseModel): class OllamaChatCompletionStreamResponse(BaseModel): model: str created_at: str - message: str + message: dict done: bool = Field(...) + total_duration: Optional[int] = Field(None, description="Total time spent in nanoseconds") + load_duration: Optional[int] = Field(None, description="Time spent loading model in nanoseconds") + prompt_eval_count: Optional[int] = Field(None, description="Number of tokens in prompt") + prompt_eval_duration: Optional[int] = Field(None, description="Time spent evaluating prompt in nanoseconds") + eval_count: Optional[int] = Field(None, description="Number of tokens generated") + eval_duration: Optional[int] = Field(None, description="Time spent generating response in nanoseconds") + + class OllamaChatCompletionResponse(BaseModel): pass @@ -111,19 +119,37 @@ async def chat(request: Request, input: OllamaChatCompletionRequest): if input.stream: async def inner(): + start_time = time() # 记录开始时间(秒) + eval_count = 0 # 统计生成的 token 数量 + tokens = [] + async for token in interface.inference(prompt, id): d = OllamaChatCompletionStreamResponse( model=config.model_name, created_at=str(datetime.now()), - message=token, + message={"role": "assistant", "content": token}, done=False ) yield d.model_dump_json() + '\n' + # 计算性能数据 + end_time = time() + total_duration = int((end_time - start_time) * 1_000_000_000) # 转换为纳秒 + prompt_eval_count = len(prompt.split()) # 简单估算提示词数量 + eval_duration = total_duration # 假设全部时间用于生成(简化) + prompt_eval_duration = 0 # 假设无单独提示评估时间 + load_duration = 0 # 假设加载时间未知 + d = OllamaChatCompletionStreamResponse( model=config.model_name, created_at=str(datetime.now()), - message='', - done=True + message={}, + done=True, + total_duration=total_duration, + load_duration=load_duration, + prompt_eval_count=prompt_eval_count, + prompt_eval_duration=prompt_eval_duration, + eval_count=eval_count, + eval_duration=eval_duration ) yield d.model_dump_json() + '\n' return check_link_response(request, inner())