From 9b62b1c72b37eadc7e347518b6ea0243df43f4e3 Mon Sep 17 00:00:00 2001 From: srszzw <741992282@qq.com> Date: Mon, 25 Mar 2024 16:35:45 +0800 Subject: [PATCH] =?UTF-8?q?dev=E5=88=86=E6=94=AF=E8=A7=A3=E5=86=B3pydantic?= =?UTF-8?q?=E7=89=88=E6=9C=AC=E5=86=B2=E7=AA=81=E9=97=AE=E9=A2=98=EF=BC=8C?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0ollama=E9=85=8D=E7=BD=AE=EF=BC=8C=E6=94=AF?= =?UTF-8?q?=E6=8C=81ollama=E4=BC=9A=E8=AF=9D=E5=92=8C=E5=90=91=E9=87=8F?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3=20(#3508)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * dev分支解决pydantic版本冲突问题,增加ollama配置,支持ollama会话和向量接口 1、因dev版本的pydantic升级到了v2版本,由于在class History(BaseModel)中使用了from server.pydantic_v1,而fastapi的引用已变为pydantic的v2版本,所以fastapi用v2版本去校验用v1版本定义的对象,当会话历史histtory不为空的时候,会报错:TypeError: BaseModel.validate() takes 2 positional arguments but 3 were given。经测试,解方法为在class History(BaseModel)中也使用v2版本即可; 2、配置文件参照其它平台配置,增加了ollama平台相关配置,会话模型用户可根据实际情况自行添加,向量模型目前支持nomic-embed-text(必须升级ollama到0.1.29以上)。 3、因ollama官方只在会话部分对openai api做了兼容,向量api暂未适配,好在langchain官方库支持OllamaEmbeddings,因而在get_Embeddings方法中添加了相关支持代码。 * 修复 pydantic 升级到 v2 后 DocumentWithVsID 和 /v1/embeddings 兼容性问题 --------- Co-authored-by: srszzw Co-authored-by: liunux4odoo --- configs/model_config.py.example | 18 +++++++++++++++++- server/api_server/api_schemas.py | 3 +++ server/api_server/kb_routes.py | 3 +-- server/api_server/openai_routes.py | 11 ++++++++--- server/chat/utils.py | 2 +- server/knowledge_base/kb_doc_api.py | 4 ++-- server/utils.py | 5 +++++ 7 files changed, 37 insertions(+), 9 deletions(-) diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 89ad80cc..c176e8c3 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -151,7 +151,23 @@ MODEL_PLATFORMS = [ "image_models": [], "multimodal_models": [], }, - + { + "platform_name": "ollama", + "platform_type": "ollama", + "api_base_url": "http://{host}:{port}/v1", + "api_key": "sk-", + "api_concurrencies": 5, + "llm_models": [ + # Qwen API,其它更多模型请参考https://ollama.com/library + "qwen:7b", + ], + "embed_models": [ + # 必须升级ollama到0.1.29以上,低版本向量服务有问题 + "nomic-embed-text" + ], + "image_models": [], + "multimodal_models": [], + }, # { # "platform_name": "loom", # "platform_type": "loom", diff --git a/server/api_server/api_schemas.py b/server/api_server/api_schemas.py index 8400b8ba..9423bd62 100644 --- a/server/api_server/api_schemas.py +++ b/server/api_server/api_schemas.py @@ -24,6 +24,9 @@ class OpenAIBaseInput(BaseModel): extra_body: Optional[Dict] = None timeout: Optional[float] = None + class Config: + extra = "allow" + class OpenAIChatInput(OpenAIBaseInput): messages: List[ChatCompletionMessageParam] diff --git a/server/api_server/kb_routes.py b/server/api_server/kb_routes.py index c482bba9..e08db0c5 100644 --- a/server/api_server/kb_routes.py +++ b/server/api_server/kb_routes.py @@ -11,7 +11,6 @@ from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_do search_docs, update_info) from server.knowledge_base.kb_summary_api import (summary_file_to_vector_store, recreate_summary_vector_store, summary_doc_ids_to_vector_store) -from server.knowledge_base.model.kb_document_model import DocumentWithVSId from server.utils import BaseResponse, ListResponse @@ -38,7 +37,7 @@ kb_router.get("/list_files", )(list_files) kb_router.post("/search_docs", - response_model=List[DocumentWithVSId], + response_model=List[dict], summary="搜索知识库" )(search_docs) diff --git a/server/api_server/openai_routes.py b/server/api_server/openai_routes.py index 64638a7d..6ef61988 100644 --- a/server/api_server/openai_routes.py +++ b/server/api_server/openai_routes.py @@ -98,7 +98,12 @@ async def create_chat_completions( body: OpenAIChatInput, ): async with acquire_model_client(body.model) as client: - return await openai_request(client.chat.completions.create, body) + result = await openai_request(client.chat.completions.create, body) + # result["related_docs"] = ["doc1"] + # result["choices"][0]["message"]["related_docs"] = ["doc1"] + # print(result) + # breakpoint() + return result @openai_router.post("/completions") @@ -115,9 +120,9 @@ async def create_embeddings( request: Request, body: OpenAIEmbeddingsInput, ): - params = body.dict(exclude_unset=True) + params = body.model_dump(exclude_unset=True) client = get_OpenAIClient(model_name=body.model) - return (await client.embeddings.create(**params)).dict() + return (await client.embeddings.create(**params)).model_dump() @openai_router.post("/images/generations") diff --git a/server/chat/utils.py b/server/chat/utils.py index 89a7762b..d51004ab 100644 --- a/server/chat/utils.py +++ b/server/chat/utils.py @@ -1,5 +1,5 @@ from functools import lru_cache -from server.pydantic_v1 import BaseModel, Field +from server.pydantic_v2 import BaseModel, Field from langchain.prompts.chat import ChatMessagePromptTemplate from configs import logger, log_verbose from typing import List, Tuple, Dict, Union diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 8b200254..683fb041 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -29,7 +29,7 @@ def search_docs( ge=0, le=1), file_name: str = Body("", description="文件名称,支持 sql 通配符"), metadata: dict = Body({}, description="根据 metadata 进行过滤,仅支持一级键"), -) -> List[DocumentWithVSId]: +) -> List[Dict]: kb = KBServiceFactory.get_service_by_name(knowledge_base_name) data = [] if kb is not None: @@ -41,7 +41,7 @@ def search_docs( for d in data: if "vector" in d.metadata: del d.metadata["vector"] - return data + return [x.dict() for x in data] def list_files( diff --git a/server/utils.py b/server/utils.py index d4490d0c..5f3a64da 100644 --- a/server/utils.py +++ b/server/utils.py @@ -202,6 +202,7 @@ def get_Embeddings( local_wrap: bool = False, # use local wrapped api ) -> Embeddings: from langchain_community.embeddings.openai import OpenAIEmbeddings + from langchain_community.embeddings import OllamaEmbeddings from server.localai_embeddings import LocalAIEmbeddings # TODO: fork of lc pr #17154 model_info = get_model_info(model_name=embed_model) @@ -220,6 +221,10 @@ def get_Embeddings( ) if model_info.get("platform_type") == "openai": return OpenAIEmbeddings(**params) + elif model_info.get("platform_type") == "ollama": + return OllamaEmbeddings(base_url=model_info.get("api_base_url").replace('/v1',''), + model=embed_model, + ) else: return LocalAIEmbeddings(**params) except Exception as e: