mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 13:23:16 +08:00
更新GLM 临时解决方案,支持GLM4,版本不兼容,会有bug
This commit is contained in:
parent
0cf65d5933
commit
e5b4bb41d8
@ -29,6 +29,7 @@ EMBEDDING_MODEL_OUTPUT_PATH = "output"
|
||||
LLM_MODELS = ["zhipu-api"]
|
||||
Agent_MODEL = None
|
||||
|
||||
# LLM 模型运行设备。设为"auto"会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。
|
||||
LLM_DEVICE = "cuda"
|
||||
|
||||
HISTORY_LEN = 3
|
||||
@ -45,10 +46,10 @@ ONLINE_LLM_MODEL = {
|
||||
"openai_proxy": "",
|
||||
},
|
||||
|
||||
# 智谱AI API(不支持GLM4,本版本无法兼容,敬请期待0.3.x)具体注册及api key获取请前往 http://open.bigmodel.cn
|
||||
# 智谱AI API,具体注册及api key获取请前往 http://open.bigmodel.cn
|
||||
"zhipu-api": {
|
||||
"api_key": "",
|
||||
"version": "chatglm_turbo",
|
||||
"version": "glm-4",
|
||||
"provider": "ChatGLMWorker",
|
||||
},
|
||||
|
||||
@ -119,7 +120,7 @@ ONLINE_LLM_MODEL = {
|
||||
"secret_key": "",
|
||||
"provider": "TianGongWorker",
|
||||
},
|
||||
# Gemini API (开发组未测试,由社群提供,只支持pro)
|
||||
# Gemini API (开发组未测试,由社群提供,只支持pro)https://makersuite.google.com/或者google cloud,使用前先确认网络正常,使用代理请在项目启动(python startup.py -a)环境内设置https_proxy环境变量
|
||||
"gemini-api": {
|
||||
"api_key": "",
|
||||
"provider": "GeminiWorker",
|
||||
@ -196,7 +197,7 @@ MODEL_PATH = {
|
||||
"agentlm-70b": "THUDM/agentlm-70b",
|
||||
|
||||
"falcon-7b": "tiiuae/falcon-7b",
|
||||
"falcon-40b": "tiiuae/falcon-40,b",
|
||||
"falcon-40b": "tiiuae/falcon-40b",
|
||||
"falcon-rw-7b": "tiiuae/falcon-rw-7b",
|
||||
|
||||
"aquila-7b": "BAAI/Aquila-7B",
|
||||
@ -291,7 +292,4 @@ SUPPORT_AGENT_MODEL = [
|
||||
"qwen-api",
|
||||
"Qwen",
|
||||
"chatglm3",
|
||||
"xinghuo-api",
|
||||
"internlm2-chat-7b",
|
||||
"internlm2-chat-20b"
|
||||
]
|
||||
|
||||
@ -84,30 +84,6 @@ class QianFanWorker(ApiModelWorker):
|
||||
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
params.load_config(self.model_names[0])
|
||||
# import qianfan
|
||||
|
||||
# comp = qianfan.ChatCompletion(model=params.version,
|
||||
# endpoint=params.version_url,
|
||||
# ak=params.api_key,
|
||||
# sk=params.secret_key,)
|
||||
# text = ""
|
||||
# for resp in comp.do(messages=params.messages,
|
||||
# temperature=params.temperature,
|
||||
# top_p=params.top_p,
|
||||
# stream=True):
|
||||
# if resp.code == 200:
|
||||
# if chunk := resp.body.get("result"):
|
||||
# text += chunk
|
||||
# yield {
|
||||
# "error_code": 0,
|
||||
# "text": text
|
||||
# }
|
||||
# else:
|
||||
# yield {
|
||||
# "error_code": resp.code,
|
||||
# "text": str(resp.body),
|
||||
# }
|
||||
|
||||
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \
|
||||
'/{model_version}?access_token={access_token}'
|
||||
|
||||
@ -190,19 +166,19 @@ class QianFanWorker(ApiModelWorker):
|
||||
i = 0
|
||||
batch_size = 10
|
||||
while i < len(params.texts):
|
||||
texts = params.texts[i:i+batch_size]
|
||||
texts = params.texts[i:i + batch_size]
|
||||
resp = client.post(url, json={"input": texts}).json()
|
||||
if "error_code" in resp:
|
||||
data = {
|
||||
"code": resp["error_code"],
|
||||
"msg": resp["error_msg"],
|
||||
"error": {
|
||||
"message": resp["error_msg"],
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
"code": resp["error_code"],
|
||||
"msg": resp["error_msg"],
|
||||
"error": {
|
||||
"message": resp["error_msg"],
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
self.logger.error(f"请求千帆 API 时发生错误:{data}")
|
||||
return data
|
||||
else:
|
||||
|
||||
@ -11,16 +11,15 @@ from typing import List, Literal, Dict
|
||||
import requests
|
||||
|
||||
|
||||
|
||||
class TianGongWorker(ApiModelWorker):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
model_names: List[str] = ["tiangong-api"],
|
||||
version: Literal["SkyChat-MegaVerse"] = "SkyChat-MegaVerse",
|
||||
**kwargs,
|
||||
self,
|
||||
*,
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
model_names: List[str] = ["tiangong-api"],
|
||||
version: Literal["SkyChat-MegaVerse"] = "SkyChat-MegaVerse",
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 32768)
|
||||
@ -34,18 +33,18 @@ class TianGongWorker(ApiModelWorker):
|
||||
data = {
|
||||
"messages": params.messages,
|
||||
"model": "SkyChat-MegaVerse"
|
||||
}
|
||||
timestamp = str(int(time.time()))
|
||||
sign_content = params.api_key + params.secret_key + timestamp
|
||||
sign_result = hashlib.md5(sign_content.encode('utf-8')).hexdigest()
|
||||
headers={
|
||||
}
|
||||
timestamp = str(int(time.time()))
|
||||
sign_content = params.api_key + params.secret_key + timestamp
|
||||
sign_result = hashlib.md5(sign_content.encode('utf-8')).hexdigest()
|
||||
headers = {
|
||||
"app_key": params.api_key,
|
||||
"timestamp": timestamp,
|
||||
"sign": sign_result,
|
||||
"Content-Type": "application/json",
|
||||
"stream": "true" # or change to "false" 不处理流式返回内容
|
||||
"stream": "true" # or change to "false" 不处理流式返回内容
|
||||
}
|
||||
|
||||
|
||||
# 发起请求并获取响应
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
|
||||
@ -56,17 +55,17 @@ class TianGongWorker(ApiModelWorker):
|
||||
# 处理接收到的数据
|
||||
# print(line.decode('utf-8'))
|
||||
resp = json.loads(line)
|
||||
if resp["code"] == 200:
|
||||
if resp["code"] == 200:
|
||||
text += resp['resp_data']['reply']
|
||||
yield {
|
||||
"error_code": 0,
|
||||
"text": text
|
||||
}
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
"error_code": resp["code"],
|
||||
"text": resp["code_msg"]
|
||||
}
|
||||
}
|
||||
self.logger.error(f"请求天工 API 时出错:{data}")
|
||||
yield data
|
||||
|
||||
@ -85,5 +84,3 @@ class TianGongWorker(ApiModelWorker):
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -37,7 +37,7 @@ class XingHuoWorker(ApiModelWorker):
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 8000) # TODO: V1模型的最大长度为4000,需要自行修改
|
||||
kwargs.setdefault("context_len", 8000)
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
|
||||
@ -4,93 +4,86 @@ from fastchat import conversation as conv
|
||||
import sys
|
||||
from typing import List, Dict, Iterator, Literal
|
||||
from configs import logger, log_verbose
|
||||
import requests
|
||||
import jwt
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def generate_token(apikey: str, exp_seconds: int):
|
||||
try:
|
||||
id, secret = apikey.split(".")
|
||||
except Exception as e:
|
||||
raise Exception("invalid apikey", e)
|
||||
|
||||
payload = {
|
||||
"api_key": id,
|
||||
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
|
||||
"timestamp": int(round(time.time() * 1000)),
|
||||
}
|
||||
|
||||
return jwt.encode(
|
||||
payload,
|
||||
secret,
|
||||
algorithm="HS256",
|
||||
headers={"alg": "HS256", "sign_type": "SIGN"},
|
||||
)
|
||||
|
||||
|
||||
class ChatGLMWorker(ApiModelWorker):
|
||||
DEFAULT_EMBED_MODEL = "text_embedding"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_names: List[str] = ["zhipu-api"],
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
version: Literal["chatglm_turbo"] = "chatglm_turbo",
|
||||
**kwargs,
|
||||
self,
|
||||
*,
|
||||
model_names: List[str] = ["zhipu-api"],
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
version: Literal["chatglm_turbo"] = "chatglm_turbo",
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 32768)
|
||||
kwargs.setdefault("context_len", 4096)
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
def do_chat(self, params: ApiChatParams) -> Iterator[Dict]:
|
||||
# TODO: 维护request_id
|
||||
import zhipuai
|
||||
|
||||
params.load_config(self.model_names[0])
|
||||
zhipuai.api_key = params.api_key
|
||||
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:params: {params}')
|
||||
|
||||
response = zhipuai.model_api.sse_invoke(
|
||||
model=params.version,
|
||||
prompt=params.messages,
|
||||
temperature=params.temperature,
|
||||
top_p=params.top_p,
|
||||
incremental=False,
|
||||
)
|
||||
for e in response.events():
|
||||
if e.event == "add":
|
||||
yield {"error_code": 0, "text": e.data}
|
||||
elif e.event in ["error", "interrupted"]:
|
||||
data = {
|
||||
"error_code": 500,
|
||||
"text": e.data,
|
||||
"error": {
|
||||
"message": e.data,
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
self.logger.error(f"请求智谱 API 时发生错误:{data}")
|
||||
yield data
|
||||
|
||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
||||
import zhipuai
|
||||
|
||||
params.load_config(self.model_names[0])
|
||||
zhipuai.api_key = params.api_key
|
||||
|
||||
embeddings = []
|
||||
try:
|
||||
for t in params.texts:
|
||||
response = zhipuai.model_api.invoke(model=params.embed_model or self.DEFAULT_EMBED_MODEL, prompt=t)
|
||||
if response["code"] == 200:
|
||||
embeddings.append(response["data"]["embedding"])
|
||||
else:
|
||||
self.logger.error(f"请求智谱 API 时发生错误:{response}")
|
||||
return response # dict with code & msg
|
||||
except Exception as e:
|
||||
self.logger.error(f"请求智谱 API 时发生错误:{data}")
|
||||
data = {"code": 500, "msg": f"对文本向量化时出错:{e}"}
|
||||
return data
|
||||
|
||||
return {"code": 200, "data": embeddings}
|
||||
token = generate_token(params.api_key, 60)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {token}"
|
||||
}
|
||||
data = {
|
||||
"model": params.version,
|
||||
"messages": params.messages,
|
||||
"max_tokens": params.max_tokens,
|
||||
"temperature": params.temperature,
|
||||
"stream": True
|
||||
}
|
||||
url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
for chunk in response.iter_lines():
|
||||
if chunk:
|
||||
chunk_str = chunk.decode('utf-8')
|
||||
json_start_pos = chunk_str.find('{"id"')
|
||||
if json_start_pos != -1:
|
||||
json_str = chunk_str[json_start_pos:]
|
||||
json_data = json.loads(json_str)
|
||||
for choice in json_data.get('choices', []):
|
||||
delta = choice.get('delta', {})
|
||||
content = delta.get('content', '')
|
||||
yield {"error_code": 0, "text": content}
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
# 临时解决方案,不支持embedding
|
||||
print("embedding")
|
||||
# print(params)
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
# 这里的是chatglm api的模板,其它API的conv_template需要定制
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
|
||||
system_message="你是智谱AI小助手,请根据用户的提示来完成任务",
|
||||
messages=[],
|
||||
roles=["Human", "Assistant", "System"],
|
||||
roles=["user", "assistant", "system"],
|
||||
sep="\n###",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user