diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 9be56953..25fcf12b 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -6,9 +6,9 @@ import os MODEL_ROOT_PATH = "" # 选用的 Embedding 名称 -EMBEDDING_MODEL = "bge-large-zh" +EMBEDDING_MODEL = "bge-large-zh-v1.5" -# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。 +# Embedding 模型运行设备。设为"auto"会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。 EMBEDDING_DEVICE = "auto" # 选用的reranker模型 @@ -26,50 +26,33 @@ EMBEDDING_MODEL_OUTPUT_PATH = "output" # 在这里,我们使用目前主流的两个离线模型,其中,chatglm3-6b 为默认加载模型。 # 如果你的显存不足,可使用 Qwen-1_8B-Chat, 该模型 FP16 仅需 3.8G显存。 -# chatglm3-6b输出角色标签<|user|>及自问自答的问题详见项目wiki->常见问题->Q20. - -LLM_MODELS = ["chatglm3-6b", "zhipu-api", "openai-api"] # "Qwen-1_8B-Chat", - -# AgentLM模型的名称 (可以不指定,指定之后就锁定进入Agent之后的Chain的模型,不指定就是LLM_MODELS[0]) +LLM_MODELS = ["zhipu-api"] Agent_MODEL = None -# LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。 -LLM_DEVICE = "auto" +# LLM 模型运行设备。设为"auto"会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。 +LLM_DEVICE = "cuda" -# 历史对话轮数 HISTORY_LEN = 3 -# 大模型最长支持的长度,如果不填写,则使用模型默认的最大长度,如果填写,则为用户设定的最大长度 -MAX_TOKENS = None +MAX_TOKENS = 2048 -# LLM通用对话参数 TEMPERATURE = 0.7 -# TOP_P = 0.95 # ChatOpenAI暂不支持该参数 ONLINE_LLM_MODEL = { - # 线上模型。请在server_config中为每个在线API设置不同的端口 - "openai-api": { - "model_name": "gpt-3.5-turbo", + "model_name": "gpt-4", "api_base_url": "https://api.openai.com/v1", "api_key": "", "openai_proxy": "", }, - # 获取api_key请前往https://makersuite.google.com/或者google cloud,使用前先确认网络正常,使用代理请在项目启动(python startup.py -a)环境内设置https_proxy环境变量 - "gemini-api": { - "api_key": "", - "provider": "GeminiWorker", - }, - - # 具体注册及api key获取请前往 http://open.bigmodel.cn + # 智谱AI API,具体注册及api key获取请前往 http://open.bigmodel.cn "zhipu-api": { "api_key": "", - "version": "chatglm_turbo", # 可选包括 "chatglm_turbo" + "version": "glm-4", "provider": "ChatGLMWorker", }, - # 具体注册及api key获取请前往 https://api.minimax.chat/ "minimax-api": { "group_id": "", @@ -78,7 +61,6 @@ ONLINE_LLM_MODEL = { "provider": "MiniMaxWorker", }, - # 具体注册及api key获取请前往 https://xinghuo.xfyun.cn/ "xinghuo-api": { "APPID": "", @@ -99,8 +81,8 @@ ONLINE_LLM_MODEL = { # 火山方舟 API,文档参考 https://www.volcengine.com/docs/82379 "fangzhou-api": { - "version": "chatglm-6b-model", # 当前支持 "chatglm-6b-model", 更多的见文档模型支持列表中方舟部分。 - "version_url": "", # 可以不填写version,直接填写在方舟申请模型发布的API地址 + "version": "chatglm-6b-model", + "version_url": "", "api_key": "", "secret_key": "", "provider": "FangZhouWorker", @@ -108,15 +90,15 @@ ONLINE_LLM_MODEL = { # 阿里云通义千问 API,文档参考 https://help.aliyun.com/zh/dashscope/developer-reference/api-details "qwen-api": { - "version": "qwen-turbo", # 可选包括 "qwen-turbo", "qwen-plus" - "api_key": "", # 请在阿里云控制台模型服务灵积API-KEY管理页面创建 + "version": "qwen-max", + "api_key": "", "provider": "QwenWorker", - "embed_model": "text-embedding-v1" # embedding 模型名称 + "embed_model": "text-embedding-v1" # embedding 模型名称 }, # 百川 API,申请方式请参考 https://www.baichuan-ai.com/home#api-enter "baichuan-api": { - "version": "Baichuan2-53B", # 当前支持 "Baichuan2-53B", 见官方文档。 + "version": "Baichuan2-53B", "api_key": "", "secret_key": "", "provider": "BaiChuanWorker", @@ -138,6 +120,11 @@ ONLINE_LLM_MODEL = { "secret_key": "", "provider": "TianGongWorker", }, + # Gemini API (开发组未测试,由社群提供,只支持pro)https://makersuite.google.com/或者google cloud,使用前先确认网络正常,使用代理请在项目启动(python startup.py -a)环境内设置https_proxy环境变量 + "gemini-api": { + "api_key": "", + "provider": "GeminiWorker", + } } @@ -149,6 +136,7 @@ ONLINE_LLM_MODEL = { # - GanymedeNil/text2vec-large-chinese # - text2vec-large-chinese # 2.2 如果以上本地路径不存在,则使用huggingface模型 + MODEL_PATH = { "embed_model": { "ernie-tiny": "nghuyong/ernie-3.0-nano-zh", @@ -167,7 +155,7 @@ MODEL_PATH = { "bge-large-zh": "BAAI/bge-large-zh", "bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct", "bge-base-zh-v1.5": "BAAI/bge-base-zh-v1.5", - "bge-large-zh-v1.5": "BAAI/bge-large-zh-v1.5", + "bge-large-zh-v1.5": "/share/home/zyx/Models/bge-large-zh-v1.5", "piccolo-base-zh": "sensenova/piccolo-base-zh", "piccolo-large-zh": "sensenova/piccolo-large-zh", "nlp_gte_sentence-embedding_chinese-large": "damo/nlp_gte_sentence-embedding_chinese-large", @@ -175,55 +163,55 @@ MODEL_PATH = { }, "llm_model": { - # 以下部分模型并未完全测试,仅根据fastchat和vllm模型的模型列表推定支持 "chatglm2-6b": "THUDM/chatglm2-6b", "chatglm2-6b-32k": "THUDM/chatglm2-6b-32k", - "chatglm3-6b": "THUDM/chatglm3-6b", "chatglm3-6b-32k": "THUDM/chatglm3-6b-32k", - "chatglm3-6b-base": "THUDM/chatglm3-6b-base", - "Qwen-1_8B": "Qwen/Qwen-1_8B", - "Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat", - "Qwen-1_8B-Chat-Int8": "Qwen/Qwen-1_8B-Chat-Int8", - "Qwen-1_8B-Chat-Int4": "Qwen/Qwen-1_8B-Chat-Int4", + "Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf", + "Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf", + "Llama-2-70b-chat-hf": "meta-llama/Llama-2-70b-chat-hf", - "Qwen-7B": "Qwen/Qwen-7B", + "Qwen-1_8B-Chat": "/media/checkpoint/Qwen-1_8B-Chat", "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat", - - "Qwen-14B": "Qwen/Qwen-14B", "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat", - - "Qwen-14B-Chat-Int8": "Qwen/Qwen-14B-Chat-Int8", - # 在新版的transformers下需要手动修改模型的config.json文件,在quantization_config字典中 - # 增加`disable_exllama:true` 字段才能启动qwen的量化模型 - "Qwen-14B-Chat-Int4": "Qwen/Qwen-14B-Chat-Int4", - - "Qwen-72B": "Qwen/Qwen-72B", "Qwen-72B-Chat": "Qwen/Qwen-72B-Chat", - "Qwen-72B-Chat-Int8": "Qwen/Qwen-72B-Chat-Int8", - "Qwen-72B-Chat-Int4": "Qwen/Qwen-72B-Chat-Int4", - "baichuan2-13b": "baichuan-inc/Baichuan2-13B-Chat", - "baichuan2-7b": "baichuan-inc/Baichuan2-7B-Chat", - - "baichuan-7b": "baichuan-inc/Baichuan-7B", - "baichuan-13b": "baichuan-inc/Baichuan-13B", + "baichuan-7b-chat": "baichuan-inc/Baichuan-7B-Chat", "baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat", - - "aquila-7b": "BAAI/Aquila-7B", - "aquilachat-7b": "BAAI/AquilaChat-7B", + "baichuan2-7b-chat": "baichuan-inc/Baichuan2-7B-Chat", + "baichuan2-13b-chat": "baichuan-inc/Baichuan2-13B-Chat", "internlm-7b": "internlm/internlm-7b", "internlm-chat-7b": "internlm/internlm-chat-7b", + "internlm2-chat-7b": "internlm/internlm2-chat-7b", + "internlm2-chat-20b": "internlm/internlm2-chat-20b", + + "BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat", + "BlueLM-7B-Chat-32k": "vivo-ai/BlueLM-7B-Chat-32k", + + "Yi-34B-Chat": "https://huggingface.co/01-ai/Yi-34B-Chat", + + "agentlm-7b": "THUDM/agentlm-7b", + "agentlm-13b": "THUDM/agentlm-13b", + "agentlm-70b": "THUDM/agentlm-70b", "falcon-7b": "tiiuae/falcon-7b", "falcon-40b": "tiiuae/falcon-40b", "falcon-rw-7b": "tiiuae/falcon-rw-7b", + "aquila-7b": "BAAI/Aquila-7B", + "aquilachat-7b": "BAAI/AquilaChat-7B", + "open_llama_13b": "openlm-research/open_llama_13b", + "vicuna-13b-v1.5": "lmsys/vicuna-13b-v1.5", + "koala": "young-geng/koala", + "mpt-7b": "mosaicml/mpt-7b", + "mpt-7b-storywriter": "mosaicml/mpt-7b-storywriter", + "mpt-30b": "mosaicml/mpt-30b", + "opt-66b": "facebook/opt-66b", + "opt-iml-max-30b": "facebook/opt-iml-max-30b", "gpt2": "gpt2", "gpt2-xl": "gpt2-xl", - "gpt-j-6b": "EleutherAI/gpt-j-6b", "gpt4all-j": "nomic-ai/gpt4all-j", "gpt-neox-20b": "EleutherAI/gpt-neox-20b", @@ -231,63 +219,50 @@ MODEL_PATH = { "oasst-sft-4-pythia-12b-epoch-3.5": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5", "dolly-v2-12b": "databricks/dolly-v2-12b", "stablelm-tuned-alpha-7b": "stabilityai/stablelm-tuned-alpha-7b", - - "Llama-2-13b-hf": "meta-llama/Llama-2-13b-hf", - "Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf", - "open_llama_13b": "openlm-research/open_llama_13b", - "vicuna-13b-v1.3": "lmsys/vicuna-13b-v1.3", - "koala": "young-geng/koala", - - "mpt-7b": "mosaicml/mpt-7b", - "mpt-7b-storywriter": "mosaicml/mpt-7b-storywriter", - "mpt-30b": "mosaicml/mpt-30b", - "opt-66b": "facebook/opt-66b", - "opt-iml-max-30b": "facebook/opt-iml-max-30b", - - "agentlm-7b": "THUDM/agentlm-7b", - "agentlm-13b": "THUDM/agentlm-13b", - "agentlm-70b": "THUDM/agentlm-70b", - - "Yi-34B-Chat": "01-ai/Yi-34B-Chat", }, - "reranker":{ - "bge-reranker-large":"BAAI/bge-reranker-large", - "bge-reranker-base":"BAAI/bge-reranker-base", - #TODO 增加在线reranker,如cohere + "reranker": { + "bge-reranker-large": "BAAI/bge-reranker-large", + "bge-reranker-base": "BAAI/bge-reranker-base", } } - # 通常情况下不需要更改以下内容 # nltk 模型存储路径 NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data") +# 使用VLLM可能导致模型推理能力下降,无法完成Agent任务 VLLM_MODEL_DICT = { - "aquila-7b": "BAAI/Aquila-7B", - "aquilachat-7b": "BAAI/AquilaChat-7B", - - "baichuan-7b": "baichuan-inc/Baichuan-7B", - "baichuan-13b": "baichuan-inc/Baichuan-13B", - "baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat", - "chatglm2-6b": "THUDM/chatglm2-6b", "chatglm2-6b-32k": "THUDM/chatglm2-6b-32k", "chatglm3-6b": "THUDM/chatglm3-6b", "chatglm3-6b-32k": "THUDM/chatglm3-6b-32k", + "Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf", + "Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf", + "Llama-2-70b-chat-hf": "meta-llama/Llama-2-70b-chat-hf", + + "Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat", + "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat", + "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat", + "Qwen-72B-Chat": "Qwen/Qwen-72B-Chat", + + "baichuan-7b-chat": "baichuan-inc/Baichuan-7B-Chat", + "baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat", + "baichuan2-7b-chat": "baichuan-inc/Baichuan-7B-Chat", + "baichuan2-13b-chat": "baichuan-inc/Baichuan-13B-Chat", + "BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat", "BlueLM-7B-Chat-32k": "vivo-ai/BlueLM-7B-Chat-32k", - # 注意:bloom系列的tokenizer与model是分离的,因此虽然vllm支持,但与fschat框架不兼容 - # "bloom": "bigscience/bloom", - # "bloomz": "bigscience/bloomz", - # "bloomz-560m": "bigscience/bloomz-560m", - # "bloomz-7b1": "bigscience/bloomz-7b1", - # "bloomz-1b7": "bigscience/bloomz-1b7", - "internlm-7b": "internlm/internlm-7b", "internlm-chat-7b": "internlm/internlm-chat-7b", + "internlm2-chat-7b": "internlm/Models/internlm2-chat-7b", + "internlm2-chat-20b": "internlm/Models/internlm2-chat-20b", + + "aquila-7b": "BAAI/Aquila-7B", + "aquilachat-7b": "BAAI/AquilaChat-7B", + "falcon-7b": "tiiuae/falcon-7b", "falcon-40b": "tiiuae/falcon-40b", "falcon-rw-7b": "tiiuae/falcon-rw-7b", @@ -300,8 +275,6 @@ VLLM_MODEL_DICT = { "oasst-sft-4-pythia-12b-epoch-3.5": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5", "dolly-v2-12b": "databricks/dolly-v2-12b", "stablelm-tuned-alpha-7b": "stabilityai/stablelm-tuned-alpha-7b", - "Llama-2-13b-hf": "meta-llama/Llama-2-13b-hf", - "Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf", "open_llama_13b": "openlm-research/open_llama_13b", "vicuna-13b-v1.3": "lmsys/vicuna-13b-v1.3", "koala": "young-geng/koala", @@ -311,37 +284,12 @@ VLLM_MODEL_DICT = { "opt-66b": "facebook/opt-66b", "opt-iml-max-30b": "facebook/opt-iml-max-30b", - "Qwen-1_8B": "Qwen/Qwen-1_8B", - "Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat", - "Qwen-1_8B-Chat-Int8": "Qwen/Qwen-1_8B-Chat-Int8", - "Qwen-1_8B-Chat-Int4": "Qwen/Qwen-1_8B-Chat-Int4", - - "Qwen-7B": "Qwen/Qwen-7B", - "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat", - - "Qwen-14B": "Qwen/Qwen-14B", - "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat", - "Qwen-14B-Chat-Int8": "Qwen/Qwen-14B-Chat-Int8", - "Qwen-14B-Chat-Int4": "Qwen/Qwen-14B-Chat-Int4", - - "Qwen-72B": "Qwen/Qwen-72B", - "Qwen-72B-Chat": "Qwen/Qwen-72B-Chat", - "Qwen-72B-Chat-Int8": "Qwen/Qwen-72B-Chat-Int8", - "Qwen-72B-Chat-Int4": "Qwen/Qwen-72B-Chat-Int4", - - "agentlm-7b": "THUDM/agentlm-7b", - "agentlm-13b": "THUDM/agentlm-13b", - "agentlm-70b": "THUDM/agentlm-70b", - } -# 你认为支持Agent能力的模型,可以在这里添加,添加后不会出现可视化界面的警告 -# 经过我们测试,原生支持Agent的模型仅有以下几个 SUPPORT_AGENT_MODEL = [ "azure-api", "openai-api", "qwen-api", "Qwen", "chatglm3", - "xinghuo-api", ] diff --git a/requirements.txt b/requirements.txt index 5ab39657..db1cbe28 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,7 @@ sentence_transformers==2.2.2 langchain==0.0.354 langchain-experimental==0.0.47 pydantic==1.10.13 -fschat==0.2.34 +fschat==0.2.35 openai~=1.7.1 fastapi~=0.108.0 sse_starlette==1.8.2 @@ -48,8 +48,7 @@ beautifulsoup4~=4.12.2 # for .mhtml files pysrt~=1.1.2 # Online api libs dependencies - -zhipuai==1.0.7 # zhipu +# zhipuAI sdk is not supported on our platform, so use http instead dashscope==1.13.6 # qwen # volcengine>=1.0.119 # fangzhou diff --git a/requirements_api.txt b/requirements_api.txt index 0e2a9eee..60b884eb 100644 --- a/requirements_api.txt +++ b/requirements_api.txt @@ -8,7 +8,7 @@ sentence_transformers==2.2.2 langchain==0.0.354 langchain-experimental==0.0.47 pydantic==1.10.13 -fschat==0.2.34 +fschat==0.2.35 openai~=1.7.1 fastapi~=0.108.0 sse_starlette==1.8.2 @@ -49,7 +49,7 @@ pysrt~=1.1.2 # Online api libs dependencies -zhipuai==1.0.7 # zhipu +# zhipuAI sdk is not supported on our platform, so use http instead dashscope==1.13.6 # qwen # volcengine>=1.0.119 # fangzhou diff --git a/requirements_lite.txt b/requirements_lite.txt index db57273a..4274b649 100644 --- a/requirements_lite.txt +++ b/requirements_lite.txt @@ -3,7 +3,7 @@ langchain==0.0.354 langchain-experimental==0.0.47 pydantic==1.10.13 -fschat==0.2.34 +fschat==0.2.35 openai~=1.7.1 fastapi~=0.108.0 sse_starlette==1.8.2 @@ -36,7 +36,7 @@ pytest # Online api libs dependencies -zhipuai==1.0.7 +# zhipuAI sdk is not supported on our platform, so use http instead dashscope==1.13.6 # volcengine>=1.0.119 diff --git a/server/agent/tools/weather_check.py b/server/agent/tools/weather_check.py index 7e55c7cb..8e9f3c6b 100644 --- a/server/agent/tools/weather_check.py +++ b/server/agent/tools/weather_check.py @@ -20,6 +20,6 @@ def weather(location: str, api_key: str): def weathercheck(location: str): - return weather(location, "S8vrB4U_-c5mvAMiK") + return weather(location, "your keys") class WeatherInput(BaseModel): - location: str = Field(description="City name,include city and county,like '厦门'") + location: str = Field(description="City name,include city and county") diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index f08b62b8..32382929 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -18,13 +18,10 @@ class MilvusKBService(KBService): from pymilvus import Collection return Collection(milvus_name) - # def save_vector_store(self): - # if self.milvus.col: - # self.milvus.col.flush() - def get_doc_by_ids(self, ids: List[str]) -> List[Document]: result = [] if self.milvus.col: + # ids = [int(id) for id in ids] # for milvus if needed #pr 2725 data_list = self.milvus.col.query(expr=f'pk in {ids}', output_fields=["*"]) for data in data_list: text = data.pop("text") diff --git a/server/knowledge_base/kb_service/zilliz_kb_service.py b/server/knowledge_base/kb_service/zilliz_kb_service.py index 5d00a49e..753225a0 100644 --- a/server/knowledge_base/kb_service/zilliz_kb_service.py +++ b/server/knowledge_base/kb_service/zilliz_kb_service.py @@ -16,13 +16,10 @@ class ZillizKBService(KBService): from pymilvus import Collection return Collection(zilliz_name) - # def save_vector_store(self): - # if self.zilliz.col: - # self.zilliz.col.flush() - def get_doc_by_ids(self, ids: List[str]) -> List[Document]: result = [] if self.zilliz.col: + # ids = [int(id) for id in ids] # for zilliz if needed #pr 2725 data_list = self.zilliz.col.query(expr=f'pk in {ids}', output_fields=["*"]) for data in data_list: text = data.pop("text") @@ -50,8 +47,7 @@ class ZillizKBService(KBService): def _load_zilliz(self): zilliz_args = kbs_config.get("zilliz") self.zilliz = Zilliz(embedding_function=EmbeddingsFunAdapter(self.embed_model), - collection_name=self.kb_name, connection_args=zilliz_args) - + collection_name=self.kb_name, connection_args=zilliz_args) def do_init(self): self._load_zilliz() @@ -95,9 +91,7 @@ class ZillizKBService(KBService): if __name__ == '__main__': - from server.db.base import Base, engine Base.metadata.create_all(bind=engine) zillizService = ZillizKBService("test") - diff --git a/server/model_workers/gemini.py b/server/model_workers/gemini.py index 46130212..0cd8e159 100644 --- a/server/model_workers/gemini.py +++ b/server/model_workers/gemini.py @@ -18,7 +18,7 @@ class GeminiWorker(ApiModelWorker): **kwargs, ): kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) - kwargs.setdefault("context_len", 4096) #TODO 16K模型需要改成16384 + kwargs.setdefault("context_len", 4096) super().__init__(**kwargs) def create_gemini_messages(self,messages) -> json: @@ -47,10 +47,10 @@ class GeminiWorker(ApiModelWorker): params.load_config(self.model_names[0]) data = self.create_gemini_messages(messages=params.messages) generationConfig=dict( - temperature = params.temperature, - topK = 1, - topP = 1, - maxOutputTokens = 4096, + temperature=params.temperature, + topK=1, + topP=1, + maxOutputTokens=4096, stopSequences=[] ) diff --git a/server/model_workers/qianfan.py b/server/model_workers/qianfan.py index 2bcce94e..7dd3a355 100644 --- a/server/model_workers/qianfan.py +++ b/server/model_workers/qianfan.py @@ -84,30 +84,6 @@ class QianFanWorker(ApiModelWorker): def do_chat(self, params: ApiChatParams) -> Dict: params.load_config(self.model_names[0]) - # import qianfan - - # comp = qianfan.ChatCompletion(model=params.version, - # endpoint=params.version_url, - # ak=params.api_key, - # sk=params.secret_key,) - # text = "" - # for resp in comp.do(messages=params.messages, - # temperature=params.temperature, - # top_p=params.top_p, - # stream=True): - # if resp.code == 200: - # if chunk := resp.body.get("result"): - # text += chunk - # yield { - # "error_code": 0, - # "text": text - # } - # else: - # yield { - # "error_code": resp.code, - # "text": str(resp.body), - # } - BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \ '/{model_version}?access_token={access_token}' @@ -190,19 +166,19 @@ class QianFanWorker(ApiModelWorker): i = 0 batch_size = 10 while i < len(params.texts): - texts = params.texts[i:i+batch_size] + texts = params.texts[i:i + batch_size] resp = client.post(url, json={"input": texts}).json() if "error_code" in resp: data = { - "code": resp["error_code"], - "msg": resp["error_msg"], - "error": { - "message": resp["error_msg"], - "type": "invalid_request_error", - "param": None, - "code": None, - } - } + "code": resp["error_code"], + "msg": resp["error_msg"], + "error": { + "message": resp["error_msg"], + "type": "invalid_request_error", + "param": None, + "code": None, + } + } self.logger.error(f"请求千帆 API 时发生错误:{data}") return data else: diff --git a/server/model_workers/tiangong.py b/server/model_workers/tiangong.py index 85a763fe..e127ea55 100644 --- a/server/model_workers/tiangong.py +++ b/server/model_workers/tiangong.py @@ -11,16 +11,15 @@ from typing import List, Literal, Dict import requests - class TianGongWorker(ApiModelWorker): def __init__( - self, - *, - controller_addr: str = None, - worker_addr: str = None, - model_names: List[str] = ["tiangong-api"], - version: Literal["SkyChat-MegaVerse"] = "SkyChat-MegaVerse", - **kwargs, + self, + *, + controller_addr: str = None, + worker_addr: str = None, + model_names: List[str] = ["tiangong-api"], + version: Literal["SkyChat-MegaVerse"] = "SkyChat-MegaVerse", + **kwargs, ): kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.setdefault("context_len", 32768) @@ -34,18 +33,18 @@ class TianGongWorker(ApiModelWorker): data = { "messages": params.messages, "model": "SkyChat-MegaVerse" - } - timestamp = str(int(time.time())) - sign_content = params.api_key + params.secret_key + timestamp - sign_result = hashlib.md5(sign_content.encode('utf-8')).hexdigest() - headers={ + } + timestamp = str(int(time.time())) + sign_content = params.api_key + params.secret_key + timestamp + sign_result = hashlib.md5(sign_content.encode('utf-8')).hexdigest() + headers = { "app_key": params.api_key, "timestamp": timestamp, "sign": sign_result, "Content-Type": "application/json", - "stream": "true" # or change to "false" 不处理流式返回内容 + "stream": "true" # or change to "false" 不处理流式返回内容 } - + # 发起请求并获取响应 response = requests.post(url, headers=headers, json=data, stream=True) @@ -56,17 +55,17 @@ class TianGongWorker(ApiModelWorker): # 处理接收到的数据 # print(line.decode('utf-8')) resp = json.loads(line) - if resp["code"] == 200: + if resp["code"] == 200: text += resp['resp_data']['reply'] yield { "error_code": 0, "text": text - } + } else: data = { "error_code": resp["code"], "text": resp["code_msg"] - } + } self.logger.error(f"请求天工 API 时出错:{data}") yield data @@ -85,5 +84,3 @@ class TianGongWorker(ApiModelWorker): sep="\n### ", stop_str="###", ) - - diff --git a/server/model_workers/xinghuo.py b/server/model_workers/xinghuo.py index 72db7389..1e772a33 100644 --- a/server/model_workers/xinghuo.py +++ b/server/model_workers/xinghuo.py @@ -37,7 +37,7 @@ class XingHuoWorker(ApiModelWorker): **kwargs, ): kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) - kwargs.setdefault("context_len", 8000) # TODO: V1模型的最大长度为4000,需要自行修改 + kwargs.setdefault("context_len", 8000) super().__init__(**kwargs) self.version = version diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py index 0005c7d3..552b67cc 100644 --- a/server/model_workers/zhipu.py +++ b/server/model_workers/zhipu.py @@ -4,93 +4,86 @@ from fastchat import conversation as conv import sys from typing import List, Dict, Iterator, Literal from configs import logger, log_verbose +import requests +import jwt +import time +import json + + +def generate_token(apikey: str, exp_seconds: int): + try: + id, secret = apikey.split(".") + except Exception as e: + raise Exception("invalid apikey", e) + + payload = { + "api_key": id, + "exp": int(round(time.time() * 1000)) + exp_seconds * 1000, + "timestamp": int(round(time.time() * 1000)), + } + + return jwt.encode( + payload, + secret, + algorithm="HS256", + headers={"alg": "HS256", "sign_type": "SIGN"}, + ) class ChatGLMWorker(ApiModelWorker): - DEFAULT_EMBED_MODEL = "text_embedding" - def __init__( - self, - *, - model_names: List[str] = ["zhipu-api"], - controller_addr: str = None, - worker_addr: str = None, - version: Literal["chatglm_turbo"] = "chatglm_turbo", - **kwargs, + self, + *, + model_names: List[str] = ["zhipu-api"], + controller_addr: str = None, + worker_addr: str = None, + version: Literal["chatglm_turbo"] = "chatglm_turbo", + **kwargs, ): kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) - kwargs.setdefault("context_len", 32768) + kwargs.setdefault("context_len", 4096) super().__init__(**kwargs) self.version = version def do_chat(self, params: ApiChatParams) -> Iterator[Dict]: - # TODO: 维护request_id - import zhipuai - params.load_config(self.model_names[0]) - zhipuai.api_key = params.api_key - - if log_verbose: - logger.info(f'{self.__class__.__name__}:params: {params}') - - response = zhipuai.model_api.sse_invoke( - model=params.version, - prompt=params.messages, - temperature=params.temperature, - top_p=params.top_p, - incremental=False, - ) - for e in response.events(): - if e.event == "add": - yield {"error_code": 0, "text": e.data} - elif e.event in ["error", "interrupted"]: - data = { - "error_code": 500, - "text": e.data, - "error": { - "message": e.data, - "type": "invalid_request_error", - "param": None, - "code": None, - } - } - self.logger.error(f"请求智谱 API 时发生错误:{data}") - yield data - - def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: - import zhipuai - - params.load_config(self.model_names[0]) - zhipuai.api_key = params.api_key - - embeddings = [] - try: - for t in params.texts: - response = zhipuai.model_api.invoke(model=params.embed_model or self.DEFAULT_EMBED_MODEL, prompt=t) - if response["code"] == 200: - embeddings.append(response["data"]["embedding"]) - else: - self.logger.error(f"请求智谱 API 时发生错误:{response}") - return response # dict with code & msg - except Exception as e: - self.logger.error(f"请求智谱 API 时发生错误:{data}") - data = {"code": 500, "msg": f"对文本向量化时出错:{e}"} - return data - - return {"code": 200, "data": embeddings} + token = generate_token(params.api_key, 60) + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {token}" + } + data = { + "model": params.version, + "messages": params.messages, + "max_tokens": params.max_tokens, + "temperature": params.temperature, + "stream": True + } + url = "https://open.bigmodel.cn/api/paas/v4/chat/completions" + response = requests.post(url, headers=headers, json=data, stream=True) + for chunk in response.iter_lines(): + if chunk: + chunk_str = chunk.decode('utf-8') + json_start_pos = chunk_str.find('{"id"') + if json_start_pos != -1: + json_str = chunk_str[json_start_pos:] + json_data = json.loads(json_str) + for choice in json_data.get('choices', []): + delta = choice.get('delta', {}) + content = delta.get('content', '') + yield {"error_code": 0, "text": content} def get_embeddings(self, params): - # TODO: 支持embeddings + # 临时解决方案,不支持embedding print("embedding") - # print(params) + print(params) def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: - # 这里的是chatglm api的模板,其它API的conv_template需要定制 return conv.Conversation( name=self.model_names[0], - system_message="你是一个聪明的助手,请根据用户的提示来完成任务", + system_message="你是智谱AI小助手,请根据用户的提示来完成任务", messages=[], - roles=["Human", "Assistant", "System"], + roles=["user", "assistant", "system"], sep="\n###", stop_str="###", ) diff --git a/server/utils.py b/server/utils.py index 270c5158..26ef967e 100644 --- a/server/utils.py +++ b/server/utils.py @@ -503,16 +503,12 @@ def set_httpx_config( no_proxy.append(host) os.environ["NO_PROXY"] = ",".join(no_proxy) - # TODO: 简单的清除系统代理不是个好的选择,影响太多。似乎修改代理服务器的bypass列表更好。 - # patch requests to use custom proxies instead of system settings def _get_proxies(): return proxies import urllib.request urllib.request.getproxies = _get_proxies - # 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch - def detect_device() -> Literal["cuda", "mps", "cpu"]: try: