dev分支解决pydantic版本冲突问题,增加ollama配置,支持ollama会话和向量接口 (#3508)

* dev分支解决pydantic版本冲突问题,增加ollama配置,支持ollama会话和向量接口
1、因dev版本的pydantic升级到了v2版本,由于在class History(BaseModel)中使用了from server.pydantic_v1,而fastapi的引用已变为pydantic的v2版本,所以fastapi用v2版本去校验用v1版本定义的对象,当会话历史histtory不为空的时候,会报错:TypeError: BaseModel.validate() takes 2 positional arguments but 3 were given。经测试,解方法为在class History(BaseModel)中也使用v2版本即可;
2、配置文件参照其它平台配置,增加了ollama平台相关配置,会话模型用户可根据实际情况自行添加,向量模型目前支持nomic-embed-text(必须升级ollama到0.1.29以上)。
3、因ollama官方只在会话部分对openai api做了兼容,向量api暂未适配,好在langchain官方库支持OllamaEmbeddings,因而在get_Embeddings方法中添加了相关支持代码。

* 修复 pydantic 升级到 v2 后 DocumentWithVsID 和 /v1/embeddings 兼容性问题

---------

Co-authored-by: srszzw <srszzw@163.com>
Co-authored-by: liunux4odoo <liunux@qq.com>
This commit is contained in:
srszzw 2024-03-25 16:35:45 +08:00 committed by GitHub
parent 51691ee008
commit 9b62b1c72b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 37 additions and 9 deletions

View File

@ -151,7 +151,23 @@ MODEL_PLATFORMS = [
"image_models": [],
"multimodal_models": [],
},
{
"platform_name": "ollama",
"platform_type": "ollama",
"api_base_url": "http://{host}:{port}/v1",
"api_key": "sk-",
"api_concurrencies": 5,
"llm_models": [
# Qwen API,其它更多模型请参考https://ollama.com/library
"qwen:7b",
],
"embed_models": [
# 必须升级ollama到0.1.29以上,低版本向量服务有问题
"nomic-embed-text"
],
"image_models": [],
"multimodal_models": [],
},
# {
# "platform_name": "loom",
# "platform_type": "loom",

View File

@ -24,6 +24,9 @@ class OpenAIBaseInput(BaseModel):
extra_body: Optional[Dict] = None
timeout: Optional[float] = None
class Config:
extra = "allow"
class OpenAIChatInput(OpenAIBaseInput):
messages: List[ChatCompletionMessageParam]

View File

@ -11,7 +11,6 @@ from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_do
search_docs, update_info)
from server.knowledge_base.kb_summary_api import (summary_file_to_vector_store, recreate_summary_vector_store,
summary_doc_ids_to_vector_store)
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
from server.utils import BaseResponse, ListResponse
@ -38,7 +37,7 @@ kb_router.get("/list_files",
)(list_files)
kb_router.post("/search_docs",
response_model=List[DocumentWithVSId],
response_model=List[dict],
summary="搜索知识库"
)(search_docs)

View File

@ -98,7 +98,12 @@ async def create_chat_completions(
body: OpenAIChatInput,
):
async with acquire_model_client(body.model) as client:
return await openai_request(client.chat.completions.create, body)
result = await openai_request(client.chat.completions.create, body)
# result["related_docs"] = ["doc1"]
# result["choices"][0]["message"]["related_docs"] = ["doc1"]
# print(result)
# breakpoint()
return result
@openai_router.post("/completions")
@ -115,9 +120,9 @@ async def create_embeddings(
request: Request,
body: OpenAIEmbeddingsInput,
):
params = body.dict(exclude_unset=True)
params = body.model_dump(exclude_unset=True)
client = get_OpenAIClient(model_name=body.model)
return (await client.embeddings.create(**params)).dict()
return (await client.embeddings.create(**params)).model_dump()
@openai_router.post("/images/generations")

View File

@ -1,5 +1,5 @@
from functools import lru_cache
from server.pydantic_v1 import BaseModel, Field
from server.pydantic_v2 import BaseModel, Field
from langchain.prompts.chat import ChatMessagePromptTemplate
from configs import logger, log_verbose
from typing import List, Tuple, Dict, Union

View File

@ -29,7 +29,7 @@ def search_docs(
ge=0, le=1),
file_name: str = Body("", description="文件名称,支持 sql 通配符"),
metadata: dict = Body({}, description="根据 metadata 进行过滤,仅支持一级键"),
) -> List[DocumentWithVSId]:
) -> List[Dict]:
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
data = []
if kb is not None:
@ -41,7 +41,7 @@ def search_docs(
for d in data:
if "vector" in d.metadata:
del d.metadata["vector"]
return data
return [x.dict() for x in data]
def list_files(

View File

@ -202,6 +202,7 @@ def get_Embeddings(
local_wrap: bool = False, # use local wrapped api
) -> Embeddings:
from langchain_community.embeddings.openai import OpenAIEmbeddings
from langchain_community.embeddings import OllamaEmbeddings
from server.localai_embeddings import LocalAIEmbeddings # TODO: fork of lc pr #17154
model_info = get_model_info(model_name=embed_model)
@ -220,6 +221,10 @@ def get_Embeddings(
)
if model_info.get("platform_type") == "openai":
return OpenAIEmbeddings(**params)
elif model_info.get("platform_type") == "ollama":
return OllamaEmbeddings(base_url=model_info.get("api_base_url").replace('/v1',''),
model=embed_model,
)
else:
return LocalAIEmbeddings(**params)
except Exception as e: