diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 89ad80cc..c176e8c3 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -151,7 +151,23 @@ MODEL_PLATFORMS = [ "image_models": [], "multimodal_models": [], }, - + { + "platform_name": "ollama", + "platform_type": "ollama", + "api_base_url": "http://{host}:{port}/v1", + "api_key": "sk-", + "api_concurrencies": 5, + "llm_models": [ + # Qwen API,其它更多模型请参考https://ollama.com/library + "qwen:7b", + ], + "embed_models": [ + # 必须升级ollama到0.1.29以上,低版本向量服务有问题 + "nomic-embed-text" + ], + "image_models": [], + "multimodal_models": [], + }, # { # "platform_name": "loom", # "platform_type": "loom", diff --git a/server/api_server/api_schemas.py b/server/api_server/api_schemas.py index 8400b8ba..9423bd62 100644 --- a/server/api_server/api_schemas.py +++ b/server/api_server/api_schemas.py @@ -24,6 +24,9 @@ class OpenAIBaseInput(BaseModel): extra_body: Optional[Dict] = None timeout: Optional[float] = None + class Config: + extra = "allow" + class OpenAIChatInput(OpenAIBaseInput): messages: List[ChatCompletionMessageParam] diff --git a/server/api_server/kb_routes.py b/server/api_server/kb_routes.py index c482bba9..e08db0c5 100644 --- a/server/api_server/kb_routes.py +++ b/server/api_server/kb_routes.py @@ -11,7 +11,6 @@ from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_do search_docs, update_info) from server.knowledge_base.kb_summary_api import (summary_file_to_vector_store, recreate_summary_vector_store, summary_doc_ids_to_vector_store) -from server.knowledge_base.model.kb_document_model import DocumentWithVSId from server.utils import BaseResponse, ListResponse @@ -38,7 +37,7 @@ kb_router.get("/list_files", )(list_files) kb_router.post("/search_docs", - response_model=List[DocumentWithVSId], + response_model=List[dict], summary="搜索知识库" )(search_docs) diff --git a/server/api_server/openai_routes.py b/server/api_server/openai_routes.py index 64638a7d..6ef61988 100644 --- a/server/api_server/openai_routes.py +++ b/server/api_server/openai_routes.py @@ -98,7 +98,12 @@ async def create_chat_completions( body: OpenAIChatInput, ): async with acquire_model_client(body.model) as client: - return await openai_request(client.chat.completions.create, body) + result = await openai_request(client.chat.completions.create, body) + # result["related_docs"] = ["doc1"] + # result["choices"][0]["message"]["related_docs"] = ["doc1"] + # print(result) + # breakpoint() + return result @openai_router.post("/completions") @@ -115,9 +120,9 @@ async def create_embeddings( request: Request, body: OpenAIEmbeddingsInput, ): - params = body.dict(exclude_unset=True) + params = body.model_dump(exclude_unset=True) client = get_OpenAIClient(model_name=body.model) - return (await client.embeddings.create(**params)).dict() + return (await client.embeddings.create(**params)).model_dump() @openai_router.post("/images/generations") diff --git a/server/chat/utils.py b/server/chat/utils.py index 89a7762b..d51004ab 100644 --- a/server/chat/utils.py +++ b/server/chat/utils.py @@ -1,5 +1,5 @@ from functools import lru_cache -from server.pydantic_v1 import BaseModel, Field +from server.pydantic_v2 import BaseModel, Field from langchain.prompts.chat import ChatMessagePromptTemplate from configs import logger, log_verbose from typing import List, Tuple, Dict, Union diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 8b200254..683fb041 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -29,7 +29,7 @@ def search_docs( ge=0, le=1), file_name: str = Body("", description="文件名称,支持 sql 通配符"), metadata: dict = Body({}, description="根据 metadata 进行过滤,仅支持一级键"), -) -> List[DocumentWithVSId]: +) -> List[Dict]: kb = KBServiceFactory.get_service_by_name(knowledge_base_name) data = [] if kb is not None: @@ -41,7 +41,7 @@ def search_docs( for d in data: if "vector" in d.metadata: del d.metadata["vector"] - return data + return [x.dict() for x in data] def list_files( diff --git a/server/utils.py b/server/utils.py index d4490d0c..5f3a64da 100644 --- a/server/utils.py +++ b/server/utils.py @@ -202,6 +202,7 @@ def get_Embeddings( local_wrap: bool = False, # use local wrapped api ) -> Embeddings: from langchain_community.embeddings.openai import OpenAIEmbeddings + from langchain_community.embeddings import OllamaEmbeddings from server.localai_embeddings import LocalAIEmbeddings # TODO: fork of lc pr #17154 model_info = get_model_info(model_name=embed_model) @@ -220,6 +221,10 @@ def get_Embeddings( ) if model_info.get("platform_type") == "openai": return OpenAIEmbeddings(**params) + elif model_info.get("platform_type") == "ollama": + return OllamaEmbeddings(base_url=model_info.get("api_base_url").replace('/v1',''), + model=embed_model, + ) else: return LocalAIEmbeddings(**params) except Exception as e: