mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-06 23:15:53 +08:00
dev分支解决pydantic版本冲突问题,增加ollama配置,支持ollama会话和向量接口 (#3508)
* dev分支解决pydantic版本冲突问题,增加ollama配置,支持ollama会话和向量接口 1、因dev版本的pydantic升级到了v2版本,由于在class History(BaseModel)中使用了from server.pydantic_v1,而fastapi的引用已变为pydantic的v2版本,所以fastapi用v2版本去校验用v1版本定义的对象,当会话历史histtory不为空的时候,会报错:TypeError: BaseModel.validate() takes 2 positional arguments but 3 were given。经测试,解方法为在class History(BaseModel)中也使用v2版本即可; 2、配置文件参照其它平台配置,增加了ollama平台相关配置,会话模型用户可根据实际情况自行添加,向量模型目前支持nomic-embed-text(必须升级ollama到0.1.29以上)。 3、因ollama官方只在会话部分对openai api做了兼容,向量api暂未适配,好在langchain官方库支持OllamaEmbeddings,因而在get_Embeddings方法中添加了相关支持代码。 * 修复 pydantic 升级到 v2 后 DocumentWithVsID 和 /v1/embeddings 兼容性问题 --------- Co-authored-by: srszzw <srszzw@163.com> Co-authored-by: liunux4odoo <liunux@qq.com>
This commit is contained in:
parent
51691ee008
commit
9b62b1c72b
@ -151,7 +151,23 @@ MODEL_PLATFORMS = [
|
|||||||
"image_models": [],
|
"image_models": [],
|
||||||
"multimodal_models": [],
|
"multimodal_models": [],
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"platform_name": "ollama",
|
||||||
|
"platform_type": "ollama",
|
||||||
|
"api_base_url": "http://{host}:{port}/v1",
|
||||||
|
"api_key": "sk-",
|
||||||
|
"api_concurrencies": 5,
|
||||||
|
"llm_models": [
|
||||||
|
# Qwen API,其它更多模型请参考https://ollama.com/library
|
||||||
|
"qwen:7b",
|
||||||
|
],
|
||||||
|
"embed_models": [
|
||||||
|
# 必须升级ollama到0.1.29以上,低版本向量服务有问题
|
||||||
|
"nomic-embed-text"
|
||||||
|
],
|
||||||
|
"image_models": [],
|
||||||
|
"multimodal_models": [],
|
||||||
|
},
|
||||||
# {
|
# {
|
||||||
# "platform_name": "loom",
|
# "platform_name": "loom",
|
||||||
# "platform_type": "loom",
|
# "platform_type": "loom",
|
||||||
|
|||||||
@ -24,6 +24,9 @@ class OpenAIBaseInput(BaseModel):
|
|||||||
extra_body: Optional[Dict] = None
|
extra_body: Optional[Dict] = None
|
||||||
timeout: Optional[float] = None
|
timeout: Optional[float] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "allow"
|
||||||
|
|
||||||
|
|
||||||
class OpenAIChatInput(OpenAIBaseInput):
|
class OpenAIChatInput(OpenAIBaseInput):
|
||||||
messages: List[ChatCompletionMessageParam]
|
messages: List[ChatCompletionMessageParam]
|
||||||
|
|||||||
@ -11,7 +11,6 @@ from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_do
|
|||||||
search_docs, update_info)
|
search_docs, update_info)
|
||||||
from server.knowledge_base.kb_summary_api import (summary_file_to_vector_store, recreate_summary_vector_store,
|
from server.knowledge_base.kb_summary_api import (summary_file_to_vector_store, recreate_summary_vector_store,
|
||||||
summary_doc_ids_to_vector_store)
|
summary_doc_ids_to_vector_store)
|
||||||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
|
||||||
from server.utils import BaseResponse, ListResponse
|
from server.utils import BaseResponse, ListResponse
|
||||||
|
|
||||||
|
|
||||||
@ -38,7 +37,7 @@ kb_router.get("/list_files",
|
|||||||
)(list_files)
|
)(list_files)
|
||||||
|
|
||||||
kb_router.post("/search_docs",
|
kb_router.post("/search_docs",
|
||||||
response_model=List[DocumentWithVSId],
|
response_model=List[dict],
|
||||||
summary="搜索知识库"
|
summary="搜索知识库"
|
||||||
)(search_docs)
|
)(search_docs)
|
||||||
|
|
||||||
|
|||||||
@ -98,7 +98,12 @@ async def create_chat_completions(
|
|||||||
body: OpenAIChatInput,
|
body: OpenAIChatInput,
|
||||||
):
|
):
|
||||||
async with acquire_model_client(body.model) as client:
|
async with acquire_model_client(body.model) as client:
|
||||||
return await openai_request(client.chat.completions.create, body)
|
result = await openai_request(client.chat.completions.create, body)
|
||||||
|
# result["related_docs"] = ["doc1"]
|
||||||
|
# result["choices"][0]["message"]["related_docs"] = ["doc1"]
|
||||||
|
# print(result)
|
||||||
|
# breakpoint()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@openai_router.post("/completions")
|
@openai_router.post("/completions")
|
||||||
@ -115,9 +120,9 @@ async def create_embeddings(
|
|||||||
request: Request,
|
request: Request,
|
||||||
body: OpenAIEmbeddingsInput,
|
body: OpenAIEmbeddingsInput,
|
||||||
):
|
):
|
||||||
params = body.dict(exclude_unset=True)
|
params = body.model_dump(exclude_unset=True)
|
||||||
client = get_OpenAIClient(model_name=body.model)
|
client = get_OpenAIClient(model_name=body.model)
|
||||||
return (await client.embeddings.create(**params)).dict()
|
return (await client.embeddings.create(**params)).model_dump()
|
||||||
|
|
||||||
|
|
||||||
@openai_router.post("/images/generations")
|
@openai_router.post("/images/generations")
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from server.pydantic_v1 import BaseModel, Field
|
from server.pydantic_v2 import BaseModel, Field
|
||||||
from langchain.prompts.chat import ChatMessagePromptTemplate
|
from langchain.prompts.chat import ChatMessagePromptTemplate
|
||||||
from configs import logger, log_verbose
|
from configs import logger, log_verbose
|
||||||
from typing import List, Tuple, Dict, Union
|
from typing import List, Tuple, Dict, Union
|
||||||
|
|||||||
@ -29,7 +29,7 @@ def search_docs(
|
|||||||
ge=0, le=1),
|
ge=0, le=1),
|
||||||
file_name: str = Body("", description="文件名称,支持 sql 通配符"),
|
file_name: str = Body("", description="文件名称,支持 sql 通配符"),
|
||||||
metadata: dict = Body({}, description="根据 metadata 进行过滤,仅支持一级键"),
|
metadata: dict = Body({}, description="根据 metadata 进行过滤,仅支持一级键"),
|
||||||
) -> List[DocumentWithVSId]:
|
) -> List[Dict]:
|
||||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||||
data = []
|
data = []
|
||||||
if kb is not None:
|
if kb is not None:
|
||||||
@ -41,7 +41,7 @@ def search_docs(
|
|||||||
for d in data:
|
for d in data:
|
||||||
if "vector" in d.metadata:
|
if "vector" in d.metadata:
|
||||||
del d.metadata["vector"]
|
del d.metadata["vector"]
|
||||||
return data
|
return [x.dict() for x in data]
|
||||||
|
|
||||||
|
|
||||||
def list_files(
|
def list_files(
|
||||||
|
|||||||
@ -202,6 +202,7 @@ def get_Embeddings(
|
|||||||
local_wrap: bool = False, # use local wrapped api
|
local_wrap: bool = False, # use local wrapped api
|
||||||
) -> Embeddings:
|
) -> Embeddings:
|
||||||
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||||
|
from langchain_community.embeddings import OllamaEmbeddings
|
||||||
from server.localai_embeddings import LocalAIEmbeddings # TODO: fork of lc pr #17154
|
from server.localai_embeddings import LocalAIEmbeddings # TODO: fork of lc pr #17154
|
||||||
|
|
||||||
model_info = get_model_info(model_name=embed_model)
|
model_info = get_model_info(model_name=embed_model)
|
||||||
@ -220,6 +221,10 @@ def get_Embeddings(
|
|||||||
)
|
)
|
||||||
if model_info.get("platform_type") == "openai":
|
if model_info.get("platform_type") == "openai":
|
||||||
return OpenAIEmbeddings(**params)
|
return OpenAIEmbeddings(**params)
|
||||||
|
elif model_info.get("platform_type") == "ollama":
|
||||||
|
return OllamaEmbeddings(base_url=model_info.get("api_base_url").replace('/v1',''),
|
||||||
|
model=embed_model,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return LocalAIEmbeddings(**params)
|
return LocalAIEmbeddings(**params)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user