From cce2b55719303cd73531a88c75378242e3c6f92f Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Thu, 18 Jan 2024 17:56:37 +0800 Subject: [PATCH] =?UTF-8?q?=E9=9B=86=E6=88=90openai=20plugins=E6=8F=92?= =?UTF-8?q?=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/chat/file_chat.py | 6 ++++ server/knowledge_base/kb_summary_api.py | 37 +++++++++++++++++++++++++ tests/test_qwen_agent.py | 1 - 3 files changed, 43 insertions(+), 1 deletion(-) diff --git a/server/chat/file_chat.py b/server/chat/file_chat.py index 275371b5..7af3eed9 100644 --- a/server/chat/file_chat.py +++ b/server/chat/file_chat.py @@ -99,6 +99,9 @@ async def file_chat(query: str = Body(..., description="用户输入", examples= "content": "虎头虎脑"}]] ), stream: bool = Body(False, description="流式输出"), + endpoint_host: str = Body(False, description="接入点地址"), + endpoint_host_key: str = Body(False, description="接入点key"), + endpoint_host_proxy: str = Body(False, description="接入点代理地址"), model_name: str = Body(None, description="LLM 模型名称。"), temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), @@ -116,6 +119,9 @@ async def file_chat(query: str = Body(..., description="用户输入", examples= max_tokens = None model = get_ChatOpenAI( + endpoint_host=endpoint_host, + endpoint_host_key=endpoint_host_key, + endpoint_host_proxy=endpoint_host_proxy, model_name=model_name, temperature=temperature, max_tokens=max_tokens, diff --git a/server/knowledge_base/kb_summary_api.py b/server/knowledge_base/kb_summary_api.py index d0d49280..00974d1c 100644 --- a/server/knowledge_base/kb_summary_api.py +++ b/server/knowledge_base/kb_summary_api.py @@ -12,12 +12,16 @@ from server.knowledge_base.kb_summary.summary_chunk import SummaryAdapter from server.utils import wrap_done, get_ChatOpenAI, BaseResponse from server.knowledge_base.model.kb_document_model import DocumentWithVSId + def recreate_summary_vector_store( knowledge_base_name: str = Body(..., examples=["samples"]), allow_empty_kb: bool = Body(True), vs_type: str = Body(DEFAULT_VS_TYPE), embed_model: str = Body(EMBEDDING_MODEL), file_description: str = Body(''), + endpoint_host: str = Body(False, description="接入点地址"), + endpoint_host_key: str = Body(False, description="接入点key"), + endpoint_host_proxy: str = Body(False, description="接入点代理地址"), model_name: str = Body(None, description="LLM 模型名称。"), temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), @@ -25,6 +29,9 @@ def recreate_summary_vector_store( """ 重建单个知识库文件摘要 :param max_tokens: + :param endpoint_host: + :param endpoint_host_key: + :param endpoint_host_proxy: :param model_name: :param temperature: :param file_description: @@ -47,11 +54,17 @@ def recreate_summary_vector_store( kb_summary.create_kb_summary() llm = get_ChatOpenAI( + endpoint_host=endpoint_host, + endpoint_host_key=endpoint_host_key, + endpoint_host_proxy=endpoint_host_proxy, model_name=model_name, temperature=temperature, max_tokens=max_tokens, ) reduce_llm = get_ChatOpenAI( + endpoint_host=endpoint_host, + endpoint_host_key=endpoint_host_key, + endpoint_host_proxy=endpoint_host_proxy, model_name=model_name, temperature=temperature, max_tokens=max_tokens, @@ -99,12 +112,18 @@ def summary_file_to_vector_store( vs_type: str = Body(DEFAULT_VS_TYPE), embed_model: str = Body(EMBEDDING_MODEL), file_description: str = Body(''), + endpoint_host: str = Body(False, description="接入点地址"), + endpoint_host_key: str = Body(False, description="接入点key"), + endpoint_host_proxy: str = Body(False, description="接入点代理地址"), model_name: str = Body(None, description="LLM 模型名称。"), temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), ): """ 单个知识库根据文件名称摘要 + :param endpoint_host: + :param endpoint_host_key: + :param endpoint_host_proxy: :param model_name: :param max_tokens: :param temperature: @@ -127,11 +146,17 @@ def summary_file_to_vector_store( kb_summary.create_kb_summary() llm = get_ChatOpenAI( + endpoint_host=endpoint_host, + endpoint_host_key=endpoint_host_key, + endpoint_host_proxy=endpoint_host_proxy, model_name=model_name, temperature=temperature, max_tokens=max_tokens, ) reduce_llm = get_ChatOpenAI( + endpoint_host=endpoint_host, + endpoint_host_key=endpoint_host_key, + endpoint_host_proxy=endpoint_host_proxy, model_name=model_name, temperature=temperature, max_tokens=max_tokens, @@ -171,6 +196,9 @@ def summary_doc_ids_to_vector_store( vs_type: str = Body(DEFAULT_VS_TYPE), embed_model: str = Body(EMBEDDING_MODEL), file_description: str = Body(''), + endpoint_host: str = Body(False, description="接入点地址"), + endpoint_host_key: str = Body(False, description="接入点key"), + endpoint_host_proxy: str = Body(False, description="接入点代理地址"), model_name: str = Body(None, description="LLM 模型名称。"), temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), @@ -178,6 +206,9 @@ def summary_doc_ids_to_vector_store( """ 单个知识库根据doc_ids摘要 :param knowledge_base_name: + :param endpoint_host: + :param endpoint_host_key: + :param endpoint_host_proxy: :param doc_ids: :param model_name: :param max_tokens: @@ -192,11 +223,17 @@ def summary_doc_ids_to_vector_store( return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data={}) else: llm = get_ChatOpenAI( + endpoint_host=endpoint_host, + endpoint_host_key=endpoint_host_key, + endpoint_host_proxy=endpoint_host_proxy, model_name=model_name, temperature=temperature, max_tokens=max_tokens, ) reduce_llm = get_ChatOpenAI( + endpoint_host=endpoint_host, + endpoint_host_key=endpoint_host_key, + endpoint_host_proxy=endpoint_host_proxy, model_name=model_name, temperature=temperature, max_tokens=max_tokens, diff --git a/tests/test_qwen_agent.py b/tests/test_qwen_agent.py index bac8a4d3..964a3705 100644 --- a/tests/test_qwen_agent.py +++ b/tests/test_qwen_agent.py @@ -8,7 +8,6 @@ from pprint import pprint from langchain.agents import AgentExecutor from langchain_openai.chat_models import ChatOpenAI # from langchain.chat_models.openai import ChatOpenAI -from server.utils import get_ChatOpenAI from server.agent.tools_factory.tools_registry import all_tools from server.agent.agent_factory.qwen_agent import create_structured_qwen_chat_agent from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler