From e5b4bb41d884abdbfbc395394fcc3c2471c68ca7 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Sun, 21 Jan 2024 11:48:44 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0GLM=20=E4=B8=B4=E6=97=B6?= =?UTF-8?q?=E8=A7=A3=E5=86=B3=E6=96=B9=E6=A1=88=EF=BC=8C=E6=94=AF=E6=8C=81?= =?UTF-8?q?GLM4=EF=BC=8C=E7=89=88=E6=9C=AC=E4=B8=8D=E5=85=BC=E5=AE=B9?= =?UTF-8?q?=EF=BC=8C=E4=BC=9A=E6=9C=89bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/model_config.py.example | 12 ++- server/model_workers/qianfan.py | 44 +++-------- server/model_workers/tiangong.py | 37 ++++----- server/model_workers/xinghuo.py | 2 +- server/model_workers/zhipu.py | 129 +++++++++++++++---------------- 5 files changed, 94 insertions(+), 130 deletions(-) diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 824a7d28..25fcf12b 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -29,6 +29,7 @@ EMBEDDING_MODEL_OUTPUT_PATH = "output" LLM_MODELS = ["zhipu-api"] Agent_MODEL = None +# LLM 模型运行设备。设为"auto"会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。 LLM_DEVICE = "cuda" HISTORY_LEN = 3 @@ -45,10 +46,10 @@ ONLINE_LLM_MODEL = { "openai_proxy": "", }, - # 智谱AI API(不支持GLM4,本版本无法兼容,敬请期待0.3.x)具体注册及api key获取请前往 http://open.bigmodel.cn + # 智谱AI API,具体注册及api key获取请前往 http://open.bigmodel.cn "zhipu-api": { "api_key": "", - "version": "chatglm_turbo", + "version": "glm-4", "provider": "ChatGLMWorker", }, @@ -119,7 +120,7 @@ ONLINE_LLM_MODEL = { "secret_key": "", "provider": "TianGongWorker", }, - # Gemini API (开发组未测试,由社群提供,只支持pro) + # Gemini API (开发组未测试,由社群提供,只支持pro)https://makersuite.google.com/或者google cloud,使用前先确认网络正常,使用代理请在项目启动(python startup.py -a)环境内设置https_proxy环境变量 "gemini-api": { "api_key": "", "provider": "GeminiWorker", @@ -196,7 +197,7 @@ MODEL_PATH = { "agentlm-70b": "THUDM/agentlm-70b", "falcon-7b": "tiiuae/falcon-7b", - "falcon-40b": "tiiuae/falcon-40,b", + "falcon-40b": "tiiuae/falcon-40b", "falcon-rw-7b": "tiiuae/falcon-rw-7b", "aquila-7b": "BAAI/Aquila-7B", @@ -291,7 +292,4 @@ SUPPORT_AGENT_MODEL = [ "qwen-api", "Qwen", "chatglm3", - "xinghuo-api", - "internlm2-chat-7b", - "internlm2-chat-20b" ] diff --git a/server/model_workers/qianfan.py b/server/model_workers/qianfan.py index 2bcce94e..7dd3a355 100644 --- a/server/model_workers/qianfan.py +++ b/server/model_workers/qianfan.py @@ -84,30 +84,6 @@ class QianFanWorker(ApiModelWorker): def do_chat(self, params: ApiChatParams) -> Dict: params.load_config(self.model_names[0]) - # import qianfan - - # comp = qianfan.ChatCompletion(model=params.version, - # endpoint=params.version_url, - # ak=params.api_key, - # sk=params.secret_key,) - # text = "" - # for resp in comp.do(messages=params.messages, - # temperature=params.temperature, - # top_p=params.top_p, - # stream=True): - # if resp.code == 200: - # if chunk := resp.body.get("result"): - # text += chunk - # yield { - # "error_code": 0, - # "text": text - # } - # else: - # yield { - # "error_code": resp.code, - # "text": str(resp.body), - # } - BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \ '/{model_version}?access_token={access_token}' @@ -190,19 +166,19 @@ class QianFanWorker(ApiModelWorker): i = 0 batch_size = 10 while i < len(params.texts): - texts = params.texts[i:i+batch_size] + texts = params.texts[i:i + batch_size] resp = client.post(url, json={"input": texts}).json() if "error_code" in resp: data = { - "code": resp["error_code"], - "msg": resp["error_msg"], - "error": { - "message": resp["error_msg"], - "type": "invalid_request_error", - "param": None, - "code": None, - } - } + "code": resp["error_code"], + "msg": resp["error_msg"], + "error": { + "message": resp["error_msg"], + "type": "invalid_request_error", + "param": None, + "code": None, + } + } self.logger.error(f"请求千帆 API 时发生错误:{data}") return data else: diff --git a/server/model_workers/tiangong.py b/server/model_workers/tiangong.py index 85a763fe..e127ea55 100644 --- a/server/model_workers/tiangong.py +++ b/server/model_workers/tiangong.py @@ -11,16 +11,15 @@ from typing import List, Literal, Dict import requests - class TianGongWorker(ApiModelWorker): def __init__( - self, - *, - controller_addr: str = None, - worker_addr: str = None, - model_names: List[str] = ["tiangong-api"], - version: Literal["SkyChat-MegaVerse"] = "SkyChat-MegaVerse", - **kwargs, + self, + *, + controller_addr: str = None, + worker_addr: str = None, + model_names: List[str] = ["tiangong-api"], + version: Literal["SkyChat-MegaVerse"] = "SkyChat-MegaVerse", + **kwargs, ): kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.setdefault("context_len", 32768) @@ -34,18 +33,18 @@ class TianGongWorker(ApiModelWorker): data = { "messages": params.messages, "model": "SkyChat-MegaVerse" - } - timestamp = str(int(time.time())) - sign_content = params.api_key + params.secret_key + timestamp - sign_result = hashlib.md5(sign_content.encode('utf-8')).hexdigest() - headers={ + } + timestamp = str(int(time.time())) + sign_content = params.api_key + params.secret_key + timestamp + sign_result = hashlib.md5(sign_content.encode('utf-8')).hexdigest() + headers = { "app_key": params.api_key, "timestamp": timestamp, "sign": sign_result, "Content-Type": "application/json", - "stream": "true" # or change to "false" 不处理流式返回内容 + "stream": "true" # or change to "false" 不处理流式返回内容 } - + # 发起请求并获取响应 response = requests.post(url, headers=headers, json=data, stream=True) @@ -56,17 +55,17 @@ class TianGongWorker(ApiModelWorker): # 处理接收到的数据 # print(line.decode('utf-8')) resp = json.loads(line) - if resp["code"] == 200: + if resp["code"] == 200: text += resp['resp_data']['reply'] yield { "error_code": 0, "text": text - } + } else: data = { "error_code": resp["code"], "text": resp["code_msg"] - } + } self.logger.error(f"请求天工 API 时出错:{data}") yield data @@ -85,5 +84,3 @@ class TianGongWorker(ApiModelWorker): sep="\n### ", stop_str="###", ) - - diff --git a/server/model_workers/xinghuo.py b/server/model_workers/xinghuo.py index 72db7389..1e772a33 100644 --- a/server/model_workers/xinghuo.py +++ b/server/model_workers/xinghuo.py @@ -37,7 +37,7 @@ class XingHuoWorker(ApiModelWorker): **kwargs, ): kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) - kwargs.setdefault("context_len", 8000) # TODO: V1模型的最大长度为4000,需要自行修改 + kwargs.setdefault("context_len", 8000) super().__init__(**kwargs) self.version = version diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py index 0005c7d3..552b67cc 100644 --- a/server/model_workers/zhipu.py +++ b/server/model_workers/zhipu.py @@ -4,93 +4,86 @@ from fastchat import conversation as conv import sys from typing import List, Dict, Iterator, Literal from configs import logger, log_verbose +import requests +import jwt +import time +import json + + +def generate_token(apikey: str, exp_seconds: int): + try: + id, secret = apikey.split(".") + except Exception as e: + raise Exception("invalid apikey", e) + + payload = { + "api_key": id, + "exp": int(round(time.time() * 1000)) + exp_seconds * 1000, + "timestamp": int(round(time.time() * 1000)), + } + + return jwt.encode( + payload, + secret, + algorithm="HS256", + headers={"alg": "HS256", "sign_type": "SIGN"}, + ) class ChatGLMWorker(ApiModelWorker): - DEFAULT_EMBED_MODEL = "text_embedding" - def __init__( - self, - *, - model_names: List[str] = ["zhipu-api"], - controller_addr: str = None, - worker_addr: str = None, - version: Literal["chatglm_turbo"] = "chatglm_turbo", - **kwargs, + self, + *, + model_names: List[str] = ["zhipu-api"], + controller_addr: str = None, + worker_addr: str = None, + version: Literal["chatglm_turbo"] = "chatglm_turbo", + **kwargs, ): kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) - kwargs.setdefault("context_len", 32768) + kwargs.setdefault("context_len", 4096) super().__init__(**kwargs) self.version = version def do_chat(self, params: ApiChatParams) -> Iterator[Dict]: - # TODO: 维护request_id - import zhipuai - params.load_config(self.model_names[0]) - zhipuai.api_key = params.api_key - - if log_verbose: - logger.info(f'{self.__class__.__name__}:params: {params}') - - response = zhipuai.model_api.sse_invoke( - model=params.version, - prompt=params.messages, - temperature=params.temperature, - top_p=params.top_p, - incremental=False, - ) - for e in response.events(): - if e.event == "add": - yield {"error_code": 0, "text": e.data} - elif e.event in ["error", "interrupted"]: - data = { - "error_code": 500, - "text": e.data, - "error": { - "message": e.data, - "type": "invalid_request_error", - "param": None, - "code": None, - } - } - self.logger.error(f"请求智谱 API 时发生错误:{data}") - yield data - - def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: - import zhipuai - - params.load_config(self.model_names[0]) - zhipuai.api_key = params.api_key - - embeddings = [] - try: - for t in params.texts: - response = zhipuai.model_api.invoke(model=params.embed_model or self.DEFAULT_EMBED_MODEL, prompt=t) - if response["code"] == 200: - embeddings.append(response["data"]["embedding"]) - else: - self.logger.error(f"请求智谱 API 时发生错误:{response}") - return response # dict with code & msg - except Exception as e: - self.logger.error(f"请求智谱 API 时发生错误:{data}") - data = {"code": 500, "msg": f"对文本向量化时出错:{e}"} - return data - - return {"code": 200, "data": embeddings} + token = generate_token(params.api_key, 60) + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {token}" + } + data = { + "model": params.version, + "messages": params.messages, + "max_tokens": params.max_tokens, + "temperature": params.temperature, + "stream": True + } + url = "https://open.bigmodel.cn/api/paas/v4/chat/completions" + response = requests.post(url, headers=headers, json=data, stream=True) + for chunk in response.iter_lines(): + if chunk: + chunk_str = chunk.decode('utf-8') + json_start_pos = chunk_str.find('{"id"') + if json_start_pos != -1: + json_str = chunk_str[json_start_pos:] + json_data = json.loads(json_str) + for choice in json_data.get('choices', []): + delta = choice.get('delta', {}) + content = delta.get('content', '') + yield {"error_code": 0, "text": content} def get_embeddings(self, params): - # TODO: 支持embeddings + # 临时解决方案,不支持embedding print("embedding") - # print(params) + print(params) def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: - # 这里的是chatglm api的模板,其它API的conv_template需要定制 return conv.Conversation( name=self.model_names[0], - system_message="你是一个聪明的助手,请根据用户的提示来完成任务", + system_message="你是智谱AI小助手,请根据用户的提示来完成任务", messages=[], - roles=["Human", "Assistant", "System"], + roles=["user", "assistant", "system"], sep="\n###", stop_str="###", )