更新GLM 临时解决方案,支持GLM4,版本不兼容,会有bug

This commit is contained in:
zR 2024-01-21 11:48:44 +08:00
parent 0cf65d5933
commit e5b4bb41d8
5 changed files with 94 additions and 130 deletions

View File

@ -29,6 +29,7 @@ EMBEDDING_MODEL_OUTPUT_PATH = "output"
LLM_MODELS = ["zhipu-api"] LLM_MODELS = ["zhipu-api"]
Agent_MODEL = None Agent_MODEL = None
# LLM 模型运行设备。设为"auto"会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。
LLM_DEVICE = "cuda" LLM_DEVICE = "cuda"
HISTORY_LEN = 3 HISTORY_LEN = 3
@ -45,10 +46,10 @@ ONLINE_LLM_MODEL = {
"openai_proxy": "", "openai_proxy": "",
}, },
# 智谱AI API不支持GLM4本版本无法兼容敬请期待0.3.x具体注册及api key获取请前往 http://open.bigmodel.cn # 智谱AI API,具体注册及api key获取请前往 http://open.bigmodel.cn
"zhipu-api": { "zhipu-api": {
"api_key": "", "api_key": "",
"version": "chatglm_turbo", "version": "glm-4",
"provider": "ChatGLMWorker", "provider": "ChatGLMWorker",
}, },
@ -119,7 +120,7 @@ ONLINE_LLM_MODEL = {
"secret_key": "", "secret_key": "",
"provider": "TianGongWorker", "provider": "TianGongWorker",
}, },
# Gemini API (开发组未测试由社群提供只支持pro # Gemini API (开发组未测试由社群提供只支持prohttps://makersuite.google.com/或者google cloud使用前先确认网络正常使用代理请在项目启动python startup.py -a)环境内设置https_proxy环境变量
"gemini-api": { "gemini-api": {
"api_key": "", "api_key": "",
"provider": "GeminiWorker", "provider": "GeminiWorker",
@ -196,7 +197,7 @@ MODEL_PATH = {
"agentlm-70b": "THUDM/agentlm-70b", "agentlm-70b": "THUDM/agentlm-70b",
"falcon-7b": "tiiuae/falcon-7b", "falcon-7b": "tiiuae/falcon-7b",
"falcon-40b": "tiiuae/falcon-40,b", "falcon-40b": "tiiuae/falcon-40b",
"falcon-rw-7b": "tiiuae/falcon-rw-7b", "falcon-rw-7b": "tiiuae/falcon-rw-7b",
"aquila-7b": "BAAI/Aquila-7B", "aquila-7b": "BAAI/Aquila-7B",
@ -291,7 +292,4 @@ SUPPORT_AGENT_MODEL = [
"qwen-api", "qwen-api",
"Qwen", "Qwen",
"chatglm3", "chatglm3",
"xinghuo-api",
"internlm2-chat-7b",
"internlm2-chat-20b"
] ]

View File

@ -84,30 +84,6 @@ class QianFanWorker(ApiModelWorker):
def do_chat(self, params: ApiChatParams) -> Dict: def do_chat(self, params: ApiChatParams) -> Dict:
params.load_config(self.model_names[0]) params.load_config(self.model_names[0])
# import qianfan
# comp = qianfan.ChatCompletion(model=params.version,
# endpoint=params.version_url,
# ak=params.api_key,
# sk=params.secret_key,)
# text = ""
# for resp in comp.do(messages=params.messages,
# temperature=params.temperature,
# top_p=params.top_p,
# stream=True):
# if resp.code == 200:
# if chunk := resp.body.get("result"):
# text += chunk
# yield {
# "error_code": 0,
# "text": text
# }
# else:
# yield {
# "error_code": resp.code,
# "text": str(resp.body),
# }
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \ BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \
'/{model_version}?access_token={access_token}' '/{model_version}?access_token={access_token}'
@ -190,19 +166,19 @@ class QianFanWorker(ApiModelWorker):
i = 0 i = 0
batch_size = 10 batch_size = 10
while i < len(params.texts): while i < len(params.texts):
texts = params.texts[i:i+batch_size] texts = params.texts[i:i + batch_size]
resp = client.post(url, json={"input": texts}).json() resp = client.post(url, json={"input": texts}).json()
if "error_code" in resp: if "error_code" in resp:
data = { data = {
"code": resp["error_code"], "code": resp["error_code"],
"msg": resp["error_msg"], "msg": resp["error_msg"],
"error": { "error": {
"message": resp["error_msg"], "message": resp["error_msg"],
"type": "invalid_request_error", "type": "invalid_request_error",
"param": None, "param": None,
"code": None, "code": None,
} }
} }
self.logger.error(f"请求千帆 API 时发生错误:{data}") self.logger.error(f"请求千帆 API 时发生错误:{data}")
return data return data
else: else:

View File

@ -11,16 +11,15 @@ from typing import List, Literal, Dict
import requests import requests
class TianGongWorker(ApiModelWorker): class TianGongWorker(ApiModelWorker):
def __init__( def __init__(
self, self,
*, *,
controller_addr: str = None, controller_addr: str = None,
worker_addr: str = None, worker_addr: str = None,
model_names: List[str] = ["tiangong-api"], model_names: List[str] = ["tiangong-api"],
version: Literal["SkyChat-MegaVerse"] = "SkyChat-MegaVerse", version: Literal["SkyChat-MegaVerse"] = "SkyChat-MegaVerse",
**kwargs, **kwargs,
): ):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 32768) kwargs.setdefault("context_len", 32768)
@ -38,12 +37,12 @@ class TianGongWorker(ApiModelWorker):
timestamp = str(int(time.time())) timestamp = str(int(time.time()))
sign_content = params.api_key + params.secret_key + timestamp sign_content = params.api_key + params.secret_key + timestamp
sign_result = hashlib.md5(sign_content.encode('utf-8')).hexdigest() sign_result = hashlib.md5(sign_content.encode('utf-8')).hexdigest()
headers={ headers = {
"app_key": params.api_key, "app_key": params.api_key,
"timestamp": timestamp, "timestamp": timestamp,
"sign": sign_result, "sign": sign_result,
"Content-Type": "application/json", "Content-Type": "application/json",
"stream": "true" # or change to "false" 不处理流式返回内容 "stream": "true" # or change to "false" 不处理流式返回内容
} }
# 发起请求并获取响应 # 发起请求并获取响应
@ -61,12 +60,12 @@ class TianGongWorker(ApiModelWorker):
yield { yield {
"error_code": 0, "error_code": 0,
"text": text "text": text
} }
else: else:
data = { data = {
"error_code": resp["code"], "error_code": resp["code"],
"text": resp["code_msg"] "text": resp["code_msg"]
} }
self.logger.error(f"请求天工 API 时出错:{data}") self.logger.error(f"请求天工 API 时出错:{data}")
yield data yield data
@ -85,5 +84,3 @@ class TianGongWorker(ApiModelWorker):
sep="\n### ", sep="\n### ",
stop_str="###", stop_str="###",
) )

View File

@ -37,7 +37,7 @@ class XingHuoWorker(ApiModelWorker):
**kwargs, **kwargs,
): ):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 8000) # TODO: V1模型的最大长度为4000需要自行修改 kwargs.setdefault("context_len", 8000)
super().__init__(**kwargs) super().__init__(**kwargs)
self.version = version self.version = version

View File

@ -4,93 +4,86 @@ from fastchat import conversation as conv
import sys import sys
from typing import List, Dict, Iterator, Literal from typing import List, Dict, Iterator, Literal
from configs import logger, log_verbose from configs import logger, log_verbose
import requests
import jwt
import time
import json
def generate_token(apikey: str, exp_seconds: int):
try:
id, secret = apikey.split(".")
except Exception as e:
raise Exception("invalid apikey", e)
payload = {
"api_key": id,
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
"timestamp": int(round(time.time() * 1000)),
}
return jwt.encode(
payload,
secret,
algorithm="HS256",
headers={"alg": "HS256", "sign_type": "SIGN"},
)
class ChatGLMWorker(ApiModelWorker): class ChatGLMWorker(ApiModelWorker):
DEFAULT_EMBED_MODEL = "text_embedding"
def __init__( def __init__(
self, self,
*, *,
model_names: List[str] = ["zhipu-api"], model_names: List[str] = ["zhipu-api"],
controller_addr: str = None, controller_addr: str = None,
worker_addr: str = None, worker_addr: str = None,
version: Literal["chatglm_turbo"] = "chatglm_turbo", version: Literal["chatglm_turbo"] = "chatglm_turbo",
**kwargs, **kwargs,
): ):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 32768) kwargs.setdefault("context_len", 4096)
super().__init__(**kwargs) super().__init__(**kwargs)
self.version = version self.version = version
def do_chat(self, params: ApiChatParams) -> Iterator[Dict]: def do_chat(self, params: ApiChatParams) -> Iterator[Dict]:
# TODO: 维护request_id
import zhipuai
params.load_config(self.model_names[0]) params.load_config(self.model_names[0])
zhipuai.api_key = params.api_key token = generate_token(params.api_key, 60)
headers = {
if log_verbose: "Content-Type": "application/json",
logger.info(f'{self.__class__.__name__}:params: {params}') "Authorization": f"Bearer {token}"
}
response = zhipuai.model_api.sse_invoke( data = {
model=params.version, "model": params.version,
prompt=params.messages, "messages": params.messages,
temperature=params.temperature, "max_tokens": params.max_tokens,
top_p=params.top_p, "temperature": params.temperature,
incremental=False, "stream": True
) }
for e in response.events(): url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
if e.event == "add": response = requests.post(url, headers=headers, json=data, stream=True)
yield {"error_code": 0, "text": e.data} for chunk in response.iter_lines():
elif e.event in ["error", "interrupted"]: if chunk:
data = { chunk_str = chunk.decode('utf-8')
"error_code": 500, json_start_pos = chunk_str.find('{"id"')
"text": e.data, if json_start_pos != -1:
"error": { json_str = chunk_str[json_start_pos:]
"message": e.data, json_data = json.loads(json_str)
"type": "invalid_request_error", for choice in json_data.get('choices', []):
"param": None, delta = choice.get('delta', {})
"code": None, content = delta.get('content', '')
} yield {"error_code": 0, "text": content}
}
self.logger.error(f"请求智谱 API 时发生错误:{data}")
yield data
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
import zhipuai
params.load_config(self.model_names[0])
zhipuai.api_key = params.api_key
embeddings = []
try:
for t in params.texts:
response = zhipuai.model_api.invoke(model=params.embed_model or self.DEFAULT_EMBED_MODEL, prompt=t)
if response["code"] == 200:
embeddings.append(response["data"]["embedding"])
else:
self.logger.error(f"请求智谱 API 时发生错误:{response}")
return response # dict with code & msg
except Exception as e:
self.logger.error(f"请求智谱 API 时发生错误:{data}")
data = {"code": 500, "msg": f"对文本向量化时出错:{e}"}
return data
return {"code": 200, "data": embeddings}
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings # 临时解决方案不支持embedding
print("embedding") print("embedding")
# print(params) print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# 这里的是chatglm api的模板其它API的conv_template需要定制
return conv.Conversation( return conv.Conversation(
name=self.model_names[0], name=self.model_names[0],
system_message="你是一个聪明的助手,请根据用户的提示来完成任务", system_message="你是智谱AI小助手请根据用户的提示来完成任务",
messages=[], messages=[],
roles=["Human", "Assistant", "System"], roles=["user", "assistant", "system"],
sep="\n###", sep="\n###",
stop_str="###", stop_str="###",
) )