mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 21:37:20 +08:00
dev分支解决pydantic版本冲突问题,增加ollama配置,支持ollama会话和向量接口 (#3508)
* dev分支解决pydantic版本冲突问题,增加ollama配置,支持ollama会话和向量接口 1、因dev版本的pydantic升级到了v2版本,由于在class History(BaseModel)中使用了from server.pydantic_v1,而fastapi的引用已变为pydantic的v2版本,所以fastapi用v2版本去校验用v1版本定义的对象,当会话历史histtory不为空的时候,会报错:TypeError: BaseModel.validate() takes 2 positional arguments but 3 were given。经测试,解方法为在class History(BaseModel)中也使用v2版本即可; 2、配置文件参照其它平台配置,增加了ollama平台相关配置,会话模型用户可根据实际情况自行添加,向量模型目前支持nomic-embed-text(必须升级ollama到0.1.29以上)。 3、因ollama官方只在会话部分对openai api做了兼容,向量api暂未适配,好在langchain官方库支持OllamaEmbeddings,因而在get_Embeddings方法中添加了相关支持代码。 * 修复 pydantic 升级到 v2 后 DocumentWithVsID 和 /v1/embeddings 兼容性问题 --------- Co-authored-by: srszzw <srszzw@163.com> Co-authored-by: liunux4odoo <liunux@qq.com>
This commit is contained in:
parent
51691ee008
commit
9b62b1c72b
@ -151,7 +151,23 @@ MODEL_PLATFORMS = [
|
||||
"image_models": [],
|
||||
"multimodal_models": [],
|
||||
},
|
||||
|
||||
{
|
||||
"platform_name": "ollama",
|
||||
"platform_type": "ollama",
|
||||
"api_base_url": "http://{host}:{port}/v1",
|
||||
"api_key": "sk-",
|
||||
"api_concurrencies": 5,
|
||||
"llm_models": [
|
||||
# Qwen API,其它更多模型请参考https://ollama.com/library
|
||||
"qwen:7b",
|
||||
],
|
||||
"embed_models": [
|
||||
# 必须升级ollama到0.1.29以上,低版本向量服务有问题
|
||||
"nomic-embed-text"
|
||||
],
|
||||
"image_models": [],
|
||||
"multimodal_models": [],
|
||||
},
|
||||
# {
|
||||
# "platform_name": "loom",
|
||||
# "platform_type": "loom",
|
||||
|
||||
@ -24,6 +24,9 @@ class OpenAIBaseInput(BaseModel):
|
||||
extra_body: Optional[Dict] = None
|
||||
timeout: Optional[float] = None
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class OpenAIChatInput(OpenAIBaseInput):
|
||||
messages: List[ChatCompletionMessageParam]
|
||||
|
||||
@ -11,7 +11,6 @@ from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_do
|
||||
search_docs, update_info)
|
||||
from server.knowledge_base.kb_summary_api import (summary_file_to_vector_store, recreate_summary_vector_store,
|
||||
summary_doc_ids_to_vector_store)
|
||||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||
from server.utils import BaseResponse, ListResponse
|
||||
|
||||
|
||||
@ -38,7 +37,7 @@ kb_router.get("/list_files",
|
||||
)(list_files)
|
||||
|
||||
kb_router.post("/search_docs",
|
||||
response_model=List[DocumentWithVSId],
|
||||
response_model=List[dict],
|
||||
summary="搜索知识库"
|
||||
)(search_docs)
|
||||
|
||||
|
||||
@ -98,7 +98,12 @@ async def create_chat_completions(
|
||||
body: OpenAIChatInput,
|
||||
):
|
||||
async with acquire_model_client(body.model) as client:
|
||||
return await openai_request(client.chat.completions.create, body)
|
||||
result = await openai_request(client.chat.completions.create, body)
|
||||
# result["related_docs"] = ["doc1"]
|
||||
# result["choices"][0]["message"]["related_docs"] = ["doc1"]
|
||||
# print(result)
|
||||
# breakpoint()
|
||||
return result
|
||||
|
||||
|
||||
@openai_router.post("/completions")
|
||||
@ -115,9 +120,9 @@ async def create_embeddings(
|
||||
request: Request,
|
||||
body: OpenAIEmbeddingsInput,
|
||||
):
|
||||
params = body.dict(exclude_unset=True)
|
||||
params = body.model_dump(exclude_unset=True)
|
||||
client = get_OpenAIClient(model_name=body.model)
|
||||
return (await client.embeddings.create(**params)).dict()
|
||||
return (await client.embeddings.create(**params)).model_dump()
|
||||
|
||||
|
||||
@openai_router.post("/images/generations")
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from functools import lru_cache
|
||||
from server.pydantic_v1 import BaseModel, Field
|
||||
from server.pydantic_v2 import BaseModel, Field
|
||||
from langchain.prompts.chat import ChatMessagePromptTemplate
|
||||
from configs import logger, log_verbose
|
||||
from typing import List, Tuple, Dict, Union
|
||||
|
||||
@ -29,7 +29,7 @@ def search_docs(
|
||||
ge=0, le=1),
|
||||
file_name: str = Body("", description="文件名称,支持 sql 通配符"),
|
||||
metadata: dict = Body({}, description="根据 metadata 进行过滤,仅支持一级键"),
|
||||
) -> List[DocumentWithVSId]:
|
||||
) -> List[Dict]:
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
data = []
|
||||
if kb is not None:
|
||||
@ -41,7 +41,7 @@ def search_docs(
|
||||
for d in data:
|
||||
if "vector" in d.metadata:
|
||||
del d.metadata["vector"]
|
||||
return data
|
||||
return [x.dict() for x in data]
|
||||
|
||||
|
||||
def list_files(
|
||||
|
||||
@ -202,6 +202,7 @@ def get_Embeddings(
|
||||
local_wrap: bool = False, # use local wrapped api
|
||||
) -> Embeddings:
|
||||
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain_community.embeddings import OllamaEmbeddings
|
||||
from server.localai_embeddings import LocalAIEmbeddings # TODO: fork of lc pr #17154
|
||||
|
||||
model_info = get_model_info(model_name=embed_model)
|
||||
@ -220,6 +221,10 @@ def get_Embeddings(
|
||||
)
|
||||
if model_info.get("platform_type") == "openai":
|
||||
return OpenAIEmbeddings(**params)
|
||||
elif model_info.get("platform_type") == "ollama":
|
||||
return OllamaEmbeddings(base_url=model_info.get("api_base_url").replace('/v1',''),
|
||||
model=embed_model,
|
||||
)
|
||||
else:
|
||||
return LocalAIEmbeddings(**params)
|
||||
except Exception as e:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user