diff --git a/configs/kb_config.py.example b/configs/kb_config.py.example index a857e80b..40165b96 100644 --- a/configs/kb_config.py.example +++ b/configs/kb_config.py.example @@ -1,5 +1,7 @@ import os +# 默认使用的知识库 +DEFAULT_KNOWLEDGE_BASE = "samples" # 默认向量库类型。可选:faiss, milvus(离线) & zilliz(在线), pg. DEFAULT_VS_TYPE = "faiss" @@ -19,6 +21,9 @@ VECTOR_SEARCH_TOP_K = 3 # 知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右 SCORE_THRESHOLD = 1 +# 默认搜索引擎。可选:bing, duckduckgo, metaphor +DEFAULT_SEARCH_ENGINE = "duckduckgo" + # 搜索引擎匹配结题数量 SEARCH_ENGINE_TOP_K = 3 @@ -36,6 +41,10 @@ BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search" # 是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG BING_SUBSCRIPTION_KEY = "" +# metaphor搜索需要KEY +METAPHOR_API_KEY = "" + + # 是否开启中文标题加强,以及标题增强的相关配置 # 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记; # 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。 @@ -47,7 +56,6 @@ KB_INFO = { "知识库名称": "知识库介绍", "samples": "关于本项目issue的解答", } - # 通常情况下不需要更改以下内容 # 知识库默认存储路径 KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base") diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 0111fc5d..78a10e9a 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -44,7 +44,7 @@ MODEL_PATH = { "chatglm2-6b-int4": "THUDM/chatglm2-6b-int4", "chatglm2-6b-32k": "THUDM/chatglm2-6b-32k", - "baichuan2-13b": "baichuan-inc/Baichuan-13B-Chat", + "baichuan2-13b": "baichuan-inc/Baichuan2-13B-Chat", "baichuan2-7b":"baichuan-inc/Baichuan2-7B-Chat", "baichuan-7b": "baichuan-inc/Baichuan-7B", @@ -112,7 +112,8 @@ TEMPERATURE = 0.7 # TOP_P = 0.95 # ChatOpenAI暂不支持该参数 -ONLINE_LLM_MODEL = { +LANGCHAIN_LLM_MODEL = { + # 不需要走Fschat封装的,Langchain直接支持的模型。 # 调用chatgpt时如果报出: urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443): # Max retries exceeded with url: /v1/chat/completions # 则需要将urllib3版本修改为1.25.11 @@ -128,11 +129,29 @@ ONLINE_LLM_MODEL = { # 4.0 seconds as it raised APIConnectionError: Error communicating with OpenAI. # 需要添加代理访问(正常开的代理软件可能会拦截不上)需要设置配置openai_proxy 或者 使用环境遍历OPENAI_PROXY 进行设置 # 比如: "openai_proxy": 'http://127.0.0.1:4780' - "gpt-3.5-turbo": { + + # 这些配置文件的名字不能改动 + "Azure-OpenAI": { + "deployment_name": "your Azure deployment name", + "model_version": "0701", + "openai_api_type": "azure", + "api_base_url": "https://your Azure point.azure.com", + "api_version": "2023-07-01-preview", + "api_key": "your Azure api key", + "openai_proxy": "", + }, + "OpenAI": { + "model_name": "your openai model name(such as gpt-4)", "api_base_url": "https://api.openai.com/v1", "api_key": "your OPENAI_API_KEY", - "openai_proxy": "your OPENAI_PROXY", + "openai_proxy": "", }, + "Anthropic": { + "model_name": "your claude model name(such as claude2-100k)", + "api_key":"your ANTHROPIC_API_KEY", + } +} +ONLINE_LLM_MODEL = { # 线上模型。请在server_config中为每个在线API设置不同的端口 # 具体注册及api key获取请前往 http://open.bigmodel.cn "zhipu-api": { diff --git a/init_database.py b/init_database.py index 9e807a8e..c2cd1d49 100644 --- a/init_database.py +++ b/init_database.py @@ -1,3 +1,5 @@ +import sys +sys.path.append(".") from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, prune_db_docs, prune_folder_files from configs.model_config import NLTK_DATA_PATH import nltk diff --git a/requirements.txt b/requirements.txt index 76312473..3da9c845 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,15 @@ -langchain>=0.0.314 +langchain>=0.0.319 langchain-experimental>=0.0.30 -fschat[model_worker]==0.2.30 -openai +fschat[model_worker]==0.2.31 +xformers>=0.0.22.post4 +openai>=0.28.1 sentence_transformers transformers>=4.34 torch>=2.0.1 # 推荐2.1 torchvision torchaudio -fastapi>=0.103.2 -nltk~=3.8.1 +fastapi>=0.104 +nltk>=3.8.1 uvicorn~=0.23.1 starlette~=0.27.0 pydantic~=1.10.11 @@ -43,7 +44,7 @@ pandas~=2.0.3 streamlit>=1.26.0 streamlit-option-menu>=0.3.6 streamlit-antd-components>=0.1.11 -streamlit-chatbox>=1.1.9 +streamlit-chatbox==1.1.10 streamlit-aggrid>=0.3.4.post3 httpx~=0.24.1 watchdog diff --git a/requirements_api.txt b/requirements_api.txt index af4e7e08..b8a7f6d4 100644 --- a/requirements_api.txt +++ b/requirements_api.txt @@ -1,13 +1,14 @@ -langchain==0.0.313 -langchain-experimental==0.0.30 -fschat[model_worker]==0.2.30 -openai +langchain>=0.0.319 +langchain-experimental>=0.0.30 +fschat[model_worker]==0.2.31 +xformers>=0.0.22.post4 +openai>=0.28.1 sentence_transformers>=2.2.2 transformers>=4.34 -torch>=2.0.1 +torch>=2.1 torchvision torchaudio -fastapi>=0.103.1 +fastapi>=0.104 nltk~=3.8.1 uvicorn~=0.23.1 starlette~=0.27.0 diff --git a/requirements_webui.txt b/requirements_webui.txt index 9caf085a..a36fec49 100644 --- a/requirements_webui.txt +++ b/requirements_webui.txt @@ -1,11 +1,11 @@ numpy~=1.24.4 pandas~=2.0.3 -streamlit>=1.26.0 +streamlit>=1.27.2 streamlit-option-menu>=0.3.6 -streamlit-antd-components>=0.1.11 -streamlit-chatbox>=1.1.9 +streamlit-antd-components>=0.2.3 +streamlit-chatbox==1.1.10 streamlit-aggrid>=0.3.4.post3 -httpx~=0.24.1 -nltk +httpx>=0.25.0 +nltk>=3.8.1 watchdog websockets diff --git a/server/agent/callbacks.py b/server/agent/callbacks.py index 3a82b9c7..49ce9730 100644 --- a/server/agent/callbacks.py +++ b/server/agent/callbacks.py @@ -97,8 +97,16 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): llm_token="", ) self.queue.put_nowait(dumps(self.cur_tool)) - - async def on_chat_model_start(self,serialized: Dict[str, Any], **kwargs: Any, + async def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, ) -> None: self.cur_tool.update( status=Status.start, diff --git a/server/agent/custom_template.py b/server/agent/custom_template.py index 22469c6b..fdac6e21 100644 --- a/server/agent/custom_template.py +++ b/server/agent/custom_template.py @@ -4,7 +4,6 @@ from langchain.prompts import StringPromptTemplate from typing import List from langchain.schema import AgentAction, AgentFinish from server.agent import model_container -begin = False class CustomPromptTemplate(StringPromptTemplate): # The template to use template: str @@ -38,7 +37,7 @@ class CustomOutputParser(AgentOutputParser): def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction: # Check if agent should finish - support_agent = ["gpt","Qwen","qwen-api","baichuan-api"] + support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型 if not any(agent in model_container.MODEL for agent in support_agent) and self.begin: self.begin = False stop_words = ["Observation:"] diff --git a/server/agent/tools/__init__.py b/server/agent/tools/__init__.py index 8bb5cac6..7031b71b 100644 --- a/server/agent/tools/__init__.py +++ b/server/agent/tools/__init__.py @@ -2,7 +2,6 @@ from .search_knowledge_simple import knowledge_search_simple from .search_all_knowledge_once import knowledge_search_once from .search_all_knowledge_more import knowledge_search_more -from .travel_assistant import travel_assistant from .calculate import calculate from .translator import translate from .weather import weathercheck diff --git a/server/chat/agent_chat.py b/server/chat/agent_chat.py index c78add68..5a71478b 100644 --- a/server/chat/agent_chat.py +++ b/server/chat/agent_chat.py @@ -26,8 +26,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), - max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), - # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。 + max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), ): diff --git a/server/chat/chat.py b/server/chat/chat.py index 3ec68558..4402185a 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -22,8 +22,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), - max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), - # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。 + max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), ): diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index c39b147e..19ca871f 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -31,8 +31,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), - max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), - # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。 + max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), ): kb = KBServiceFactory.get_service_by_name(knowledge_base_name) diff --git a/server/chat/openai_chat.py b/server/chat/openai_chat.py index 4a46ddd9..7efb0a8b 100644 --- a/server/chat/openai_chat.py +++ b/server/chat/openai_chat.py @@ -16,7 +16,7 @@ class OpenAiChatMsgIn(BaseModel): messages: List[OpenAiMessage] temperature: float = 0.7 n: int = 1 - max_tokens: int = 1024 + max_tokens: int = None stop: List[str] = [] stream: bool = False presence_penalty: int = 0 diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index 83ed65e4..e1ccaa48 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -1,6 +1,7 @@ from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper -from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, - LLM_MODEL, SEARCH_ENGINE_TOP_K, TEMPERATURE) +from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY, + LLM_MODEL, SEARCH_ENGINE_TOP_K, TEMPERATURE, + TEXT_SPLITTER_NAME, OVERLAP_SIZE) from fastapi import Body from fastapi.responses import StreamingResponse from fastapi.concurrency import run_in_threadpool @@ -11,7 +12,7 @@ from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable import asyncio from langchain.prompts.chat import ChatPromptTemplate -from typing import List, Optional +from typing import List, Optional, Dict from server.chat.utils import History from langchain.docstore.document import Document import json @@ -32,8 +33,49 @@ def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K): return search.results(text, result_len) +def metaphor_search( + text: str, + result_len: int = SEARCH_ENGINE_TOP_K, + splitter_name: str = "SpacyTextSplitter", + chunk_size: int = 500, + chunk_overlap: int = OVERLAP_SIZE, +) -> List[Dict]: + from metaphor_python import Metaphor + from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool + from server.knowledge_base.utils import make_text_splitter + + if not METAPHOR_API_KEY: + return [] + + client = Metaphor(METAPHOR_API_KEY) + search = client.search(text, num_results=result_len, use_autoprompt=True) + contents = search.get_contents().contents + + # metaphor 返回的内容都是长文本,需要分词再检索 + docs = [Document(page_content=x.extract, + metadata={"link": x.url, "title": x.title}) + for x in contents] + text_splitter = make_text_splitter(splitter_name=splitter_name, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap) + splitted_docs = text_splitter.split_documents(docs) + + # 将切分好的文档放入临时向量库,重新筛选出TOP_K个文档 + if len(splitted_docs) > result_len: + vs = memo_faiss_pool.new_vector_store() + vs.add_documents(splitted_docs) + splitted_docs = vs.similarity_search(text, k=result_len, score_threshold=1.0) + + docs = [{"snippet": x.page_content, + "link": x.metadata["link"], + "title": x.metadata["title"]} + for x in splitted_docs] + return docs + + SEARCH_ENGINES = {"bing": bing_search, "duckduckgo": duckduckgo_search, + "metaphor": metaphor_search, } @@ -72,8 +114,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入", stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), - max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), - # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。 + max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), ): if search_engine_name not in SEARCH_ENGINES.keys(): diff --git a/server/knowledge_base/kb_cache/faiss_cache.py b/server/knowledge_base/kb_cache/faiss_cache.py index 801e4a6b..c00ac4f6 100644 --- a/server/knowledge_base/kb_cache/faiss_cache.py +++ b/server/knowledge_base/kb_cache/faiss_cache.py @@ -140,7 +140,7 @@ if __name__ == "__main__": ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings) pprint(ids) elif r == 2: # search docs - docs = vs.similarity_search_with_score(f"{name}", top_k=3, score_threshold=1.0) + docs = vs.similarity_search_with_score(f"{name}", k=3, score_threshold=1.0) pprint(docs) if r == 3: # delete docs logger.warning(f"清除 {vs_name} by {name}") diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 02212c88..c73d0219 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -1,7 +1,5 @@ import os - from transformers import AutoTokenizer - from configs import ( EMBEDDING_MODEL, KB_ROOT_PATH, diff --git a/server/llm_api.py b/server/llm_api.py index dc9ddced..453f1473 100644 --- a/server/llm_api.py +++ b/server/llm_api.py @@ -1,5 +1,5 @@ from fastapi import Body -from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT +from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT,LANGCHAIN_LLM_MODEL from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models, get_httpx_client, get_model_worker_config) @@ -16,7 +16,7 @@ def list_running_models( with get_httpx_client() as client: r = client.post(controller_address + "/list_models") models = r.json()["models"] - data = {m: get_model_worker_config(m) for m in models} + data = {m: get_model_config(m).data for m in models} return BaseResponse(data=data) except Exception as e: logger.error(f'{e.__class__.__name__}: {e}', @@ -52,7 +52,6 @@ def get_model_config( 获取LLM模型配置项(合并后的) ''' config = get_model_worker_config(model_name=model_name) - # 删除ONLINE_MODEL配置中的敏感信息 del_keys = set(["worker_class"]) for k in config: diff --git a/server/model_workers/SparkApi.py b/server/model_workers/SparkApi.py index e1dce6a0..c4e090e8 100644 --- a/server/model_workers/SparkApi.py +++ b/server/model_workers/SparkApi.py @@ -65,7 +65,7 @@ def gen_params(appid, domain,question, temperature): "chat": { "domain": domain, "random_threshold": 0.5, - "max_tokens": 2048, + "max_tokens": None, "auditing": "default", "temperature": temperature, } diff --git a/server/model_workers/baichuan.py b/server/model_workers/baichuan.py index 879a51ec..1c6a6f1d 100644 --- a/server/model_workers/baichuan.py +++ b/server/model_workers/baichuan.py @@ -1,15 +1,15 @@ # import os # import sys # sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) -import requests import json import time import hashlib from server.model_workers.base import ApiModelWorker +from server.utils import get_model_worker_config, get_httpx_client from fastchat import conversation as conv import sys import json -from typing import List, Literal +from typing import List, Literal, Dict from configs import TEMPERATURE @@ -20,29 +20,29 @@ def calculate_md5(input_string): return encrypted -def do_request(): - url = "https://api.baichuan-ai.com/v1/stream/chat" - api_key = "" - secret_key = "" +def request_baichuan_api( + messages: List[Dict[str, str]], + api_key: str = None, + secret_key: str = None, + version: str = "Baichuan2-53B", + temperature: float = TEMPERATURE, + model_name: str = "baichuan-api", +): + config = get_model_worker_config(model_name) + api_key = api_key or config.get("api_key") + secret_key = secret_key or config.get("secret_key") + version = version or config.get("version") + url = "https://api.baichuan-ai.com/v1/stream/chat" data = { - "model": "Baichuan2-53B", - "messages": [ - { - "role": "user", - "content": "世界第一高峰是" - } - ], - "parameters": { - "temperature": 0.1, - "top_k": 10 - } + "model": version, + "messages": messages, + "parameters": {"temperature": temperature} } json_data = json.dumps(data) time_stamp = int(time.time()) signature = calculate_md5(secret_key + json_data + str(time_stamp)) - headers = { "Content-Type": "application/json", "Authorization": "Bearer " + api_key, @@ -52,18 +52,17 @@ def do_request(): "X-BC-Sign-Algo": "MD5", } - response = requests.post(url, data=json_data, headers=headers) - - if response.status_code == 200: - print("请求成功!") - print("响应header:", response.headers) - print("响应body:", response.text) - else: - print("请求失败,状态码:", response.status_code) + with get_httpx_client() as client: + with client.stream("POST", url, headers=headers, json=data) as response: + for line in response.iter_lines(): + if not line.strip(): + continue + resp = json.loads(line) + yield resp class BaiChuanWorker(ApiModelWorker): - BASE_URL = "https://api.baichuan-ai.com/v1/chat" + BASE_URL = "https://api.baichuan-ai.com/v1/stream/chat" SUPPORT_MODELS = ["Baichuan2-53B"] def __init__( @@ -95,54 +94,34 @@ class BaiChuanWorker(ApiModelWorker): self.secret_key = config.get("secret_key") def generate_stream_gate(self, params): - data = { - "model": self.version, - "messages": [ - { - "role": "user", - "content": params["prompt"] - } - ], - "parameters": { - "temperature": params.get("temperature",TEMPERATURE), - "top_k": params.get("top_k",1) - } - } + super().generate_stream_gate(params) - json_data = json.dumps(data) - time_stamp = int(time.time()) - signature = calculate_md5(self.secret_key + json_data + str(time_stamp)) - headers = { - "Content-Type": "application/json", - "Authorization": "Bearer " + self.api_key, - "X-BC-Request-Id": "your requestId", - "X-BC-Timestamp": str(time_stamp), - "X-BC-Signature": signature, - "X-BC-Sign-Algo": "MD5", - } + messages = self.prompt_to_messages(params["prompt"]) - response = requests.post(self.BASE_URL, data=json_data, headers=headers) + text = "" + for resp in request_baichuan_api(messages=messages, + api_key=self.api_key, + secret_key=self.secret_key, + version=self.version, + temperature=params.get("temperature")): + if resp["code"] == 0: + text += resp["data"]["messages"][-1]["content"] + yield json.dumps( + { + "error_code": resp["code"], + "text": text + }, + ensure_ascii=False + ).encode() + b"\0" + else: + yield json.dumps( + { + "error_code": resp["code"], + "text": resp["msg"] + }, + ensure_ascii=False + ).encode() + b"\0" - if response.status_code == 200: - resp = eval(response.text) - yield json.dumps( - { - "error_code": resp["code"], - "text": resp["data"]["messages"][-1]["content"] - }, - ensure_ascii=False - ).encode() + b"\0" - else: - yield json.dumps( - { - "error_code": resp["code"], - "text": resp["msg"] - }, - ensure_ascii=False - ).encode() + b"\0" - - - def get_embeddings(self, params): # TODO: 支持embeddings print("embedding") diff --git a/server/model_workers/base.py b/server/model_workers/base.py index ea141046..2b39bd6b 100644 --- a/server/model_workers/base.py +++ b/server/model_workers/base.py @@ -1,13 +1,13 @@ from configs.basic_config import LOG_PATH import fastchat.constants fastchat.constants.LOGDIR = LOG_PATH -from fastchat.serve.model_worker import BaseModelWorker +from fastchat.serve.base_model_worker import BaseModelWorker import uuid import json import sys from pydantic import BaseModel import fastchat -import threading +import asyncio from typing import Dict, List @@ -40,6 +40,7 @@ class ApiModelWorker(BaseModelWorker): worker_addr=worker_addr, **kwargs) self.context_len = context_len + self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency) self.init_heart_beat() def count_token(self, params): @@ -62,15 +63,6 @@ class ApiModelWorker(BaseModelWorker): print("embedding") print(params) - # workaround to make program exit with Ctrl+c - # it should be deleted after pr is merged by fastchat - def init_heart_beat(self): - self.register_to_controller() - self.heart_beat_thread = threading.Thread( - target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True, - ) - self.heart_beat_thread.start() - # help methods def get_config(self): from server.utils import get_model_worker_config diff --git a/server/utils.py b/server/utils.py index 2f6dfc49..5dff5f1b 100644 --- a/server/utils.py +++ b/server/utils.py @@ -5,12 +5,11 @@ from fastapi import FastAPI from pathlib import Path import asyncio from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE, - MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, - logger, log_verbose, + MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, LANGCHAIN_LLM_MODEL, logger, log_verbose, FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT) import os from concurrent.futures import ThreadPoolExecutor, as_completed -from langchain.chat_models import ChatOpenAI +from langchain.chat_models import ChatOpenAI, AzureChatOpenAI, ChatAnthropic import httpx from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union @@ -40,19 +39,64 @@ def get_ChatOpenAI( verbose: bool = True, **kwargs: Any, ) -> ChatOpenAI: - config = get_model_worker_config(model_name) - model = ChatOpenAI( - streaming=streaming, - verbose=verbose, - callbacks=callbacks, - openai_api_key=config.get("api_key", "EMPTY"), - openai_api_base=config.get("api_base_url", fschat_openai_api_address()), - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - openai_proxy=config.get("openai_proxy"), - **kwargs - ) + ## 以下模型是Langchain原生支持的模型,这些模型不会走Fschat封装 + config_models = list_config_llm_models() + if model_name in config_models.get("langchain", {}): + config = config_models["langchain"][model_name] + if model_name == "Azure-OpenAI": + model = AzureChatOpenAI( + streaming=streaming, + verbose=verbose, + callbacks=callbacks, + deployment_name=config.get("deployment_name"), + model_version=config.get("model_version"), + openai_api_type=config.get("openai_api_type"), + openai_api_base=config.get("api_base_url"), + openai_api_version=config.get("api_version"), + openai_api_key=config.get("api_key"), + openai_proxy=config.get("openai_proxy"), + temperature=temperature, + max_tokens=max_tokens, + ) + + elif model_name == "OpenAI": + model = ChatOpenAI( + streaming=streaming, + verbose=verbose, + callbacks=callbacks, + model_name=config.get("model_name"), + openai_api_base=config.get("api_base_url"), + openai_api_key=config.get("api_key"), + openai_proxy=config.get("openai_proxy"), + temperature=temperature, + max_tokens=max_tokens, + ) + elif model_name == "Anthropic": + model = ChatAnthropic( + streaming=streaming, + verbose=verbose, + callbacks=callbacks, + model_name=config.get("model_name"), + anthropic_api_key=config.get("api_key"), + + ) + ## TODO 支持其他的Langchain原生支持的模型 + else: + ## 非Langchain原生支持的模型,走Fschat封装 + config = get_model_worker_config(model_name) + model = ChatOpenAI( + streaming=streaming, + verbose=verbose, + callbacks=callbacks, + openai_api_key=config.get("api_key", "EMPTY"), + openai_api_base=config.get("api_base_url", fschat_openai_api_address()), + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + openai_proxy=config.get("openai_proxy"), + **kwargs + ) + return model @@ -249,8 +293,9 @@ def MakeFastAPIOffline( redoc_favicon_url=favicon, ) + # 从model_config中获取模型信息 + -# 从model_config中获取模型信息 def list_embed_models() -> List[str]: ''' get names of configured embedding models @@ -266,9 +311,9 @@ def list_config_llm_models() -> Dict[str, Dict]: workers = list(FSCHAT_MODEL_WORKERS) if LLM_MODEL not in workers: workers.insert(0, LLM_MODEL) - return { "local": MODEL_PATH["llm_model"], + "langchain": LANGCHAIN_LLM_MODEL, "online": ONLINE_LLM_MODEL, "worker": workers, } @@ -300,8 +345,9 @@ def get_model_path(model_name: str, type: str = None) -> Optional[str]: return str(path) return path_str # THUDM/chatglm06b + # 从server_config中获取服务信息 + -# 从server_config中获取服务信息 def get_model_worker_config(model_name: str = None) -> dict: ''' 加载model worker的配置项。 @@ -316,6 +362,10 @@ def get_model_worker_config(model_name: str = None) -> dict: config.update(FSCHAT_MODEL_WORKERS.get(model_name, {})) # 在线模型API + if model_name in LANGCHAIN_LLM_MODEL: + config["langchain_model"] = True + config["worker_class"] = "" + if model_name in ONLINE_LLM_MODEL: config["online_api"] = True if provider := config.get("provider"): @@ -389,7 +439,7 @@ def webui_address() -> str: return f"http://{host}:{port}" -def get_prompt_template(type:str,name: str) -> Optional[str]: +def get_prompt_template(type: str, name: str) -> Optional[str]: ''' 从prompt_config中加载模板内容 type: "llm_chat","agent_chat","knowledge_base_chat","search_engine_chat"的其中一种,如果有新功能,应该进行加入。 @@ -459,8 +509,9 @@ def set_httpx_config( import urllib.request urllib.request.getproxies = _get_proxies + # 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch + -# 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch def detect_device() -> Literal["cuda", "mps", "cpu"]: try: import torch @@ -568,6 +619,8 @@ def get_server_configs() -> Dict: 获取configs中的原始配置项,供前端使用 ''' from configs.kb_config import ( + DEFAULT_KNOWLEDGE_BASE, + DEFAULT_SEARCH_ENGINE, DEFAULT_VS_TYPE, CHUNK_SIZE, OVERLAP_SIZE, diff --git a/startup.py b/startup.py index 878e7eca..4bb34496 100644 --- a/startup.py +++ b/startup.py @@ -68,7 +68,9 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI: controller_address: worker_address: - + 对于Langchain支持的模型: + langchain_model:True + 不会使用fschat 对于online_api: online_api:True worker_class: `provider` @@ -78,31 +80,34 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI: """ import fastchat.constants fastchat.constants.LOGDIR = LOG_PATH - from fastchat.serve.model_worker import worker_id, logger import argparse - logger.setLevel(log_level) parser = argparse.ArgumentParser() args = parser.parse_args([]) for k, v in kwargs.items(): setattr(args, k, v) - + if worker_class := kwargs.get("langchain_model"): #Langchian支持的模型不用做操作 + from fastchat.serve.base_model_worker import app + worker = "" # 在线模型API - if worker_class := kwargs.get("worker_class"): - from fastchat.serve.model_worker import app + elif worker_class := kwargs.get("worker_class"): + from fastchat.serve.base_model_worker import app + worker = worker_class(model_names=args.model_names, controller_addr=args.controller_address, worker_addr=args.worker_address) - sys.modules["fastchat.serve.model_worker"].worker = worker + # sys.modules["fastchat.serve.base_model_worker"].worker = worker + sys.modules["fastchat.serve.base_model_worker"].logger.setLevel(log_level) # 本地模型 else: from configs.model_config import VLLM_MODEL_DICT if kwargs["model_names"][0] in VLLM_MODEL_DICT and args.infer_turbo == "vllm": import fastchat.serve.vllm_worker - from fastchat.serve.vllm_worker import VLLMWorker,app + from fastchat.serve.vllm_worker import VLLMWorker, app from vllm import AsyncLLMEngine from vllm.engine.arg_utils import AsyncEngineArgs,EngineArgs + args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加 args.tokenizer_mode = 'auto' args.trust_remote_code= True @@ -126,8 +131,8 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI: args.engine_use_ray = False args.disable_log_requests = False - # 0.2.0 vllm后要加的参数 - args.max_model_len = 8192 # 模型可以处理的最大序列长度。请根据你的大模型设置, + # 0.2.0 vllm后要加的参数, 但是这里不需要 + args.max_model_len = None args.revision = None args.quantization = None args.max_log_len = None @@ -155,10 +160,12 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI: conv_template = args.conv_template, ) sys.modules["fastchat.serve.vllm_worker"].engine = engine - sys.modules["fastchat.serve.vllm_worker"].worker = worker + # sys.modules["fastchat.serve.vllm_worker"].worker = worker + sys.modules["fastchat.serve.vllm_worker"].logger.setLevel(log_level) else: - from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker + from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id + args.gpus = "0" # GPU的编号,如果有多个GPU,可以设置为"0,1,2,3" args.max_gpu_memory = "22GiB" args.num_gpus = 1 # model worker的切分是model并行,这里填写显卡的数量 @@ -221,8 +228,8 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI: ) sys.modules["fastchat.serve.model_worker"].args = args sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config - - sys.modules["fastchat.serve.model_worker"].worker = worker + # sys.modules["fastchat.serve.model_worker"].worker = worker + sys.modules["fastchat.serve.model_worker"].logger.setLevel(log_level) MakeFastAPIOffline(app) app.title = f"FastChat LLM Server ({args.model_names[0]})" diff --git a/tests/online_api/test_baichuan.py b/tests/online_api/test_baichuan.py new file mode 100644 index 00000000..536466ee --- /dev/null +++ b/tests/online_api/test_baichuan.py @@ -0,0 +1,16 @@ +import sys +from pathlib import Path +root_path = Path(__file__).parent.parent.parent +sys.path.append(str(root_path)) + +from server.model_workers.baichuan import request_baichuan_api +from pprint import pprint + + +def test_qwen(): + messages = [{"role": "user", "content": "hello"}] + + for x in request_baichuan_api(messages): + print(type(x)) + pprint(x) + assert x["code"] == 0 \ No newline at end of file diff --git a/webui.py b/webui.py index 776d5e60..85d6cb40 100644 --- a/webui.py +++ b/webui.py @@ -21,13 +21,6 @@ if __name__ == "__main__": } ) - if not chat_box.chat_inited: - running_models = api.list_running_models() - st.toast( - f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n" - f"当前运行中的模型`{running_models}`, 您可以开始提问了." - ) - pages = { "对话": { "icon": "chat", diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index adc60293..8628aac8 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -3,7 +3,8 @@ from webui_pages.utils import * from streamlit_chatbox import * from datetime import datetime import os -from configs import LLM_MODEL, TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES +from configs import (LLM_MODEL, TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES, + DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE,LANGCHAIN_LLM_MODEL) from typing import List, Dict chat_box = ChatBox( @@ -40,7 +41,6 @@ def get_default_llm_model(api: ApiRequest) -> (str, bool): 返回类型为(model_name, is_local_model) ''' running_models = api.list_running_models() - if not running_models: return "", False @@ -50,12 +50,17 @@ def get_default_llm_model(api: ApiRequest) -> (str, bool): local_models = [k for k, v in running_models.items() if not v.get("online_api")] if local_models: return local_models[0], True - - return running_models[0], False + return list(running_models)[0], False def dialogue_page(api: ApiRequest): - chat_box.init_session() + if not chat_box.chat_inited: + default_model = get_default_llm_model(api)[0] + st.toast( + f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n" + f"当前运行的模型`{default_model}`, 您可以开始提问了." + ) + chat_box.init_session() with st.sidebar: # TODO: 对话模型与会话绑定 @@ -74,16 +79,17 @@ def dialogue_page(api: ApiRequest): "搜索引擎问答", "自定义Agent问答", ], - index=3, + index=0, on_change=on_mode_change, key="dialogue_mode", ) def on_llm_change(): - config = api.get_model_config(llm_model) - if not config.get("online_api"): # 只有本地model_worker可以切换模型 - st.session_state["prev_llm_model"] = llm_model - st.session_state["cur_llm_model"] = st.session_state.llm_model + if llm_model: + config = api.get_model_config(llm_model) + if not config.get("online_api"): # 只有本地model_worker可以切换模型 + st.session_state["prev_llm_model"] = llm_model + st.session_state["cur_llm_model"] = st.session_state.llm_model def llm_model_format_func(x): if x in running_models: @@ -91,16 +97,18 @@ def dialogue_page(api: ApiRequest): return x running_models = list(api.list_running_models()) + running_models += LANGCHAIN_LLM_MODEL.keys() available_models = [] config_models = api.list_config_models() worker_models = list(config_models.get("worker", {})) # 仅列出在FSCHAT_MODEL_WORKERS中配置的模型 for m in worker_models: if m not in running_models and m != "default": available_models.append(m) - for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型(如GPT) + for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型 if not v.get("provider") and k not in running_models: - print(k, v) available_models.append(k) + for k, v in config_models.get("langchain", {}).items(): # 列出LANGCHAIN_LLM_MODEL支持的模型 + available_models.append(k) llm_models = running_models + available_models index = llm_models.index(st.session_state.get("cur_llm_model", get_default_llm_model(api)[0])) llm_model = st.selectbox("选择LLM模型:", @@ -111,7 +119,8 @@ def dialogue_page(api: ApiRequest): key="llm_model", ) if (st.session_state.get("prev_llm_model") != llm_model - and not api.get_model_config(llm_model).get("online_api") + and not llm_model in config_models.get("online", {}) + and not llm_model in config_models.get("langchain", {}) and llm_model not in running_models): with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"): prev_model = st.session_state.get("prev_llm_model") @@ -156,9 +165,13 @@ def dialogue_page(api: ApiRequest): if dialogue_mode == "知识库问答": with st.expander("知识库配置", True): kb_list = api.list_knowledge_bases() + index = 0 + if DEFAULT_KNOWLEDGE_BASE in kb_list: + index = kb_list.index(DEFAULT_KNOWLEDGE_BASE) selected_kb = st.selectbox( "请选择知识库:", kb_list, + index=index, on_change=on_kb_change, key="selected_kb", ) @@ -167,11 +180,15 @@ def dialogue_page(api: ApiRequest): elif dialogue_mode == "搜索引擎问答": search_engine_list = api.list_search_engines() + if DEFAULT_SEARCH_ENGINE in search_engine_list: + index = search_engine_list.index(DEFAULT_SEARCH_ENGINE) + else: + index = search_engine_list.index("duckduckgo") if "duckduckgo" in search_engine_list else 0 with st.expander("搜索引擎配置", True): search_engine = st.selectbox( label="请选择搜索引擎", options=search_engine_list, - index=search_engine_list.index("duckduckgo") if "duckduckgo" in search_engine_list else 0, + index=index, ) se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, SEARCH_ENGINE_TOP_K) @@ -210,9 +227,9 @@ def dialogue_page(api: ApiRequest): ]) text = "" ans = "" - support_agent = ["gpt", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型 + support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型 if not any(agent in llm_model for agent in support_agent): - ans += "正在思考... \n\n 该模型并没有进行Agent对齐,无法正常使用Agent功能!\n\n\n请更换 GPT4或Qwen-14B等支持Agent的模型获得更好的体验! \n\n\n" + ans += "正在思考... \n\n 该模型并没有进行Agent对齐,请更换支持Agent的模型获得更好的体验!\n\n\n" chat_box.update_msg(ans, element_index=0, streaming=False) for d in api.agent_chat(prompt, history=history, diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 7b9e161c..8190dba1 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -245,7 +245,7 @@ class ApiRequest: stream: bool = True, model: str = LLM_MODEL, temperature: float = TEMPERATURE, - max_tokens: int = 1024, + max_tokens: int = None, **kwargs: Any, ): ''' @@ -278,7 +278,7 @@ class ApiRequest: stream: bool = True, model: str = LLM_MODEL, temperature: float = TEMPERATURE, - max_tokens: int = 1024, + max_tokens: int = None, prompt_name: str = "default", **kwargs, ): @@ -308,7 +308,7 @@ class ApiRequest: stream: bool = True, model: str = LLM_MODEL, temperature: float = TEMPERATURE, - max_tokens: int = 1024, + max_tokens: int = None, prompt_name: str = "default", ): ''' @@ -340,7 +340,7 @@ class ApiRequest: stream: bool = True, model: str = LLM_MODEL, temperature: float = TEMPERATURE, - max_tokens: int = 1024, + max_tokens: int = None, prompt_name: str = "default", ): ''' @@ -378,7 +378,7 @@ class ApiRequest: stream: bool = True, model: str = LLM_MODEL, temperature: float = TEMPERATURE, - max_tokens: int = 1024, + max_tokens: int = None, prompt_name: str = "default", ): '''