From 175db6710ef33cb63e21fef16668becf694f382d Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Sun, 21 Jan 2024 18:24:33 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=9C=AC=E5=9C=B0fschat?= =?UTF-8?q?=E9=85=8D=E7=BD=AE,pydantic=E5=8D=87=E7=BA=A7=E5=88=B02?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/agent/agent_factory/__init__.py | 2 +- server/agent/agent_factory/agents_registry.py | 6 +++--- server/embeddings_api.py | 6 +++--- startup.py | 3 +-- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/server/agent/agent_factory/__init__.py b/server/agent/agent_factory/__init__.py index a387f103..3b6ef0a3 100644 --- a/server/agent/agent_factory/__init__.py +++ b/server/agent/agent_factory/__init__.py @@ -1,2 +1,2 @@ -from .glm3_agent import create_structured_glm3_chat_agent +# from .glm3_agent import create_structured_glm3_chat_agent from .qwen_agent import create_structured_qwen_chat_agent diff --git a/server/agent/agent_factory/agents_registry.py b/server/agent/agent_factory/agents_registry.py index e2aabee6..be77bd2f 100644 --- a/server/agent/agent_factory/agents_registry.py +++ b/server/agent/agent_factory/agents_registry.py @@ -8,8 +8,7 @@ from langchain_core.messages import SystemMessage from langchain_core.prompts import ChatPromptTemplate from langchain_core.tools import BaseTool -from server.agent.agent_factory import (create_structured_glm3_chat_agent, - create_structured_qwen_chat_agent) +from server.agent.agent_factory import ( create_structured_qwen_chat_agent) def agents_registry( @@ -24,7 +23,8 @@ def agents_registry( # Write any optimized method here. if "glm3" in llm.model_name.lower(): # An optimized method of langchain Agent that uses the glm3 series model - agent = create_structured_glm3_chat_agent(llm=llm, tools=tools) + # agent = create_structured_glm3_chat_agent(llm=llm, tools=tools) + pass elif "qwen" in llm.model_name.lower(): agent = create_structured_qwen_chat_agent(llm=llm, tools=tools) else: diff --git a/server/embeddings_api.py b/server/embeddings_api.py index e907de07..d86189fb 100644 --- a/server/embeddings_api.py +++ b/server/embeddings_api.py @@ -1,6 +1,6 @@ from langchain.docstore.document import Document from configs import EMBEDDING_MODEL, logger -from server.model_workers.base import ApiEmbeddingsParams +# from server.model_workers.base import ApiEmbeddingsParams from server.utils import BaseResponse, get_model_worker_config, list_embed_models, list_online_embed_models from fastapi import Body from fastapi.concurrency import run_in_threadpool @@ -30,8 +30,8 @@ def embed_texts( embed_model = config.get("embed_model") worker = worker_class() if worker_class.can_embedding(): - params = ApiEmbeddingsParams(texts=texts, to_query=to_query, embed_model=embed_model) - resp = worker.do_embeddings(params) + # params = ApiEmbeddingsParams(texts=texts, to_query=to_query) + resp = worker.do_embeddings(None) return BaseResponse(**resp) return BaseResponse(code=500, msg=f"指定的模型 {embed_model} 不支持 Embeddings 功能。") diff --git a/startup.py b/startup.py index a69a5473..aef47b7e 100644 --- a/startup.py +++ b/startup.py @@ -147,7 +147,6 @@ def parse_args() -> argparse.ArgumentParser: def dump_server_info(after_start=False, args=None): import platform import langchain - import fastchat from server.utils import api_address, webui_address print("\n") @@ -155,7 +154,7 @@ def dump_server_info(after_start=False, args=None): print(f"操作系统:{platform.platform()}.") print(f"python版本:{sys.version}") print(f"项目版本:{VERSION}") - print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}") + print(f"langchain版本:{langchain.__version__}") print("\n") print(f"当前使用的分词器:{TEXT_SPLITTER_NAME}")