mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 21:37:20 +08:00
删除本地fschat workers
This commit is contained in:
parent
175db6710e
commit
a7b28adc09
@ -1,79 +0,0 @@
|
||||
import base64
|
||||
import datetime
|
||||
import hashlib
|
||||
import hmac
|
||||
from urllib.parse import urlparse
|
||||
from datetime import datetime
|
||||
from time import mktime
|
||||
from urllib.parse import urlencode
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
|
||||
class Ws_Param(object):
|
||||
# 初始化
|
||||
def __init__(self, APPID, APIKey, APISecret, Spark_url):
|
||||
self.APPID = APPID
|
||||
self.APIKey = APIKey
|
||||
self.APISecret = APISecret
|
||||
self.host = urlparse(Spark_url).netloc
|
||||
self.path = urlparse(Spark_url).path
|
||||
self.Spark_url = Spark_url
|
||||
|
||||
# 生成url
|
||||
def create_url(self):
|
||||
# 生成RFC1123格式的时间戳
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
|
||||
# 拼接字符串
|
||||
signature_origin = "host: " + self.host + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + self.path + " HTTP/1.1"
|
||||
|
||||
# 进行hmac-sha256进行加密
|
||||
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
|
||||
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||
|
||||
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
||||
|
||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
|
||||
# 将请求的鉴权参数组合为字典
|
||||
v = {
|
||||
"authorization": authorization,
|
||||
"date": date,
|
||||
"host": self.host
|
||||
}
|
||||
# 拼接鉴权参数,生成url
|
||||
url = self.Spark_url + '?' + urlencode(v)
|
||||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||
return url
|
||||
|
||||
|
||||
def gen_params(appid, domain, question, temperature, max_token):
|
||||
"""
|
||||
通过appid和用户的提问来生成请参数
|
||||
"""
|
||||
data = {
|
||||
"header": {
|
||||
"app_id": appid,
|
||||
"uid": "1234"
|
||||
},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
"domain": domain,
|
||||
"random_threshold": 0.5,
|
||||
"max_tokens": max_token,
|
||||
"auditing": "default",
|
||||
"temperature": temperature,
|
||||
}
|
||||
},
|
||||
"payload": {
|
||||
"message": {
|
||||
"text": question
|
||||
}
|
||||
}
|
||||
}
|
||||
return data
|
||||
@ -1,11 +0,0 @@
|
||||
from .base import *
|
||||
from .zhipu import ChatGLMWorker
|
||||
from .minimax import MiniMaxWorker
|
||||
from .xinghuo import XingHuoWorker
|
||||
from .qianfan import QianFanWorker
|
||||
from .fangzhou import FangZhouWorker
|
||||
from .qwen import QwenWorker
|
||||
from .baichuan import BaiChuanWorker
|
||||
from .azure import AzureWorker
|
||||
from .tiangong import TianGongWorker
|
||||
from .gemini import GeminiWorker
|
||||
@ -1,95 +0,0 @@
|
||||
import sys
|
||||
import os
|
||||
from fastchat.conversation import Conversation
|
||||
from server.model_workers.base import *
|
||||
from server.utils import get_httpx_client
|
||||
from fastchat import conversation as conv
|
||||
import json
|
||||
from typing import List, Dict
|
||||
from configs import logger, log_verbose
|
||||
|
||||
|
||||
class AzureWorker(ApiModelWorker):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
model_names: List[str] = ["azure-api"],
|
||||
version: str = "gpt-35-turbo",
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
params.load_config(self.model_names[0])
|
||||
|
||||
data = dict(
|
||||
messages=params.messages,
|
||||
temperature=params.temperature,
|
||||
max_tokens=params.max_tokens if params.max_tokens else None,
|
||||
stream=True,
|
||||
)
|
||||
url = ("https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}"
|
||||
.format(params.resource_name, params.deployment_name, params.api_version))
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
'api-key': params.api_key,
|
||||
}
|
||||
|
||||
text = ""
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:url: {url}')
|
||||
logger.info(f'{self.__class__.__name__}:headers: {headers}')
|
||||
logger.info(f'{self.__class__.__name__}:data: {data}')
|
||||
|
||||
with get_httpx_client() as client:
|
||||
with client.stream("POST", url, headers=headers, json=data) as response:
|
||||
print(data)
|
||||
for line in response.iter_lines():
|
||||
if not line.strip() or "[DONE]" in line:
|
||||
continue
|
||||
if line.startswith("data: "):
|
||||
line = line[6:]
|
||||
resp = json.loads(line)
|
||||
if choices := resp["choices"]:
|
||||
if chunk := choices[0].get("delta", {}).get("content"):
|
||||
text += chunk
|
||||
yield {
|
||||
"error_code": 0,
|
||||
"text": text
|
||||
}
|
||||
print(text)
|
||||
else:
|
||||
self.logger.error(f"请求 Azure API 时发生错误:{resp}")
|
||||
|
||||
def get_embeddings(self, params):
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="You are a helpful, respectful and honest assistant.",
|
||||
messages=[],
|
||||
roles=["user", "assistant"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
from server.utils import MakeFastAPIOffline
|
||||
from fastchat.serve.base_model_worker import app
|
||||
|
||||
worker = AzureWorker(
|
||||
controller_addr="http://127.0.0.1:20001",
|
||||
worker_addr="http://127.0.0.1:21008",
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
MakeFastAPIOffline(app)
|
||||
uvicorn.run(app, port=21008)
|
||||
@ -1,117 +0,0 @@
|
||||
import json
|
||||
import time
|
||||
import hashlib
|
||||
|
||||
from fastchat.conversation import Conversation
|
||||
from server.model_workers.base import *
|
||||
from server.utils import get_httpx_client
|
||||
from fastchat import conversation as conv
|
||||
import sys
|
||||
import json
|
||||
from typing import List, Literal, Dict
|
||||
from configs import logger, log_verbose
|
||||
|
||||
def calculate_md5(input_string):
|
||||
md5 = hashlib.md5()
|
||||
md5.update(input_string.encode('utf-8'))
|
||||
encrypted = md5.hexdigest()
|
||||
return encrypted
|
||||
|
||||
|
||||
class BaiChuanWorker(ApiModelWorker):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
model_names: List[str] = ["baichuan-api"],
|
||||
version: Literal["Baichuan2-53B"] = "Baichuan2-53B",
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 32768)
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
params.load_config(self.model_names[0])
|
||||
|
||||
url = "https://api.baichuan-ai.com/v1/stream/chat"
|
||||
data = {
|
||||
"model": params.version,
|
||||
"messages": params.messages,
|
||||
"parameters": {"temperature": params.temperature}
|
||||
}
|
||||
|
||||
json_data = json.dumps(data)
|
||||
time_stamp = int(time.time())
|
||||
signature = calculate_md5(params.secret_key + json_data + str(time_stamp))
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + params.api_key,
|
||||
"X-BC-Request-Id": "your requestId",
|
||||
"X-BC-Timestamp": str(time_stamp),
|
||||
"X-BC-Signature": signature,
|
||||
"X-BC-Sign-Algo": "MD5",
|
||||
}
|
||||
|
||||
text = ""
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:json_data: {json_data}')
|
||||
logger.info(f'{self.__class__.__name__}:url: {url}')
|
||||
logger.info(f'{self.__class__.__name__}:headers: {headers}')
|
||||
|
||||
with get_httpx_client() as client:
|
||||
with client.stream("POST", url, headers=headers, json=data) as response:
|
||||
for line in response.iter_lines():
|
||||
if not line.strip():
|
||||
continue
|
||||
resp = json.loads(line)
|
||||
if resp["code"] == 0:
|
||||
text += resp["data"]["messages"][-1]["content"]
|
||||
yield {
|
||||
"error_code": resp["code"],
|
||||
"text": text
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
"error_code": resp["code"],
|
||||
"text": resp["msg"],
|
||||
"error": {
|
||||
"message": resp["msg"],
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
self.logger.error(f"请求百川 API 时发生错误:{data}")
|
||||
yield data
|
||||
|
||||
def get_embeddings(self, params):
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="",
|
||||
messages=[],
|
||||
roles=["user", "assistant"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
from server.utils import MakeFastAPIOffline
|
||||
from fastchat.serve.model_worker import app
|
||||
|
||||
worker = BaiChuanWorker(
|
||||
controller_addr="http://127.0.0.1:20001",
|
||||
worker_addr="http://127.0.0.1:21007",
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
MakeFastAPIOffline(app)
|
||||
uvicorn.run(app, port=21007)
|
||||
# do_request()
|
||||
@ -1,248 +0,0 @@
|
||||
from fastchat.conversation import Conversation
|
||||
from configs import LOG_PATH
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
from fastchat.serve.base_model_worker import BaseModelWorker
|
||||
import uuid
|
||||
import json
|
||||
import sys
|
||||
from pydantic import BaseModel, root_validator
|
||||
import fastchat
|
||||
import asyncio
|
||||
from server.utils import get_model_worker_config
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
__all__ = ["ApiModelWorker", "ApiChatParams", "ApiCompletionParams", "ApiEmbeddingsParams"]
|
||||
|
||||
|
||||
class ApiConfigParams(BaseModel):
|
||||
'''
|
||||
在线API配置参数,未提供的值会自动从model_config.ONLINE_LLM_MODEL中读取
|
||||
'''
|
||||
api_base_url: Optional[str] = None
|
||||
api_proxy: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
secret_key: Optional[str] = None
|
||||
group_id: Optional[str] = None # for minimax
|
||||
is_pro: bool = False # for minimax
|
||||
|
||||
APPID: Optional[str] = None # for xinghuo
|
||||
APISecret: Optional[str] = None # for xinghuo
|
||||
is_v2: bool = False # for xinghuo
|
||||
|
||||
worker_name: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_config(cls, v: Dict) -> Dict:
|
||||
if config := get_model_worker_config(v.get("worker_name")):
|
||||
for n in cls.__fields__:
|
||||
if n in config:
|
||||
v[n] = config[n]
|
||||
return v
|
||||
|
||||
def load_config(self, worker_name: str):
|
||||
self.worker_name = worker_name
|
||||
if config := get_model_worker_config(worker_name):
|
||||
for n in self.__fields__:
|
||||
if n in config:
|
||||
setattr(self, n, config[n])
|
||||
return self
|
||||
|
||||
|
||||
class ApiModelParams(ApiConfigParams):
|
||||
'''
|
||||
模型配置参数
|
||||
'''
|
||||
version: Optional[str] = None
|
||||
version_url: Optional[str] = None
|
||||
api_version: Optional[str] = None # for azure
|
||||
deployment_name: Optional[str] = None # for azure
|
||||
resource_name: Optional[str] = None # for azure
|
||||
|
||||
temperature: float = 0.9
|
||||
max_tokens: Optional[int] = None
|
||||
top_p: Optional[float] = 1.0
|
||||
|
||||
|
||||
class ApiChatParams(ApiModelParams):
|
||||
'''
|
||||
chat请求参数
|
||||
'''
|
||||
messages: List[Dict[str, str]]
|
||||
system_message: Optional[str] = None # for minimax
|
||||
role_meta: Dict = {} # for minimax
|
||||
|
||||
|
||||
class ApiCompletionParams(ApiModelParams):
|
||||
prompt: str
|
||||
|
||||
|
||||
class ApiEmbeddingsParams(ApiConfigParams):
|
||||
texts: List[str]
|
||||
embed_model: Optional[str] = None
|
||||
to_query: bool = False # for minimax
|
||||
|
||||
|
||||
class ApiModelWorker(BaseModelWorker):
|
||||
DEFAULT_EMBED_MODEL: str = None # None means not support embedding
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_names: List[str],
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
context_len: int = 2048,
|
||||
no_register: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.setdefault("worker_id", uuid.uuid4().hex[:8])
|
||||
kwargs.setdefault("model_path", "")
|
||||
kwargs.setdefault("limit_worker_concurrency", 5)
|
||||
super().__init__(model_names=model_names,
|
||||
controller_addr=controller_addr,
|
||||
worker_addr=worker_addr,
|
||||
**kwargs)
|
||||
import fastchat.serve.base_model_worker
|
||||
import sys
|
||||
self.logger = fastchat.serve.base_model_worker.logger
|
||||
# 恢复被fastchat覆盖的标准输出
|
||||
sys.stdout = sys.__stdout__
|
||||
sys.stderr = sys.__stderr__
|
||||
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
|
||||
self.context_len = context_len
|
||||
self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency)
|
||||
self.version = None
|
||||
|
||||
if not no_register and self.controller_addr:
|
||||
self.init_heart_beat()
|
||||
|
||||
|
||||
def count_token(self, params):
|
||||
prompt = params["prompt"]
|
||||
return {"count": len(str(prompt)), "error_code": 0}
|
||||
|
||||
def generate_stream_gate(self, params: Dict):
|
||||
self.call_ct += 1
|
||||
|
||||
try:
|
||||
prompt = params["prompt"]
|
||||
if self._is_chat(prompt):
|
||||
messages = self.prompt_to_messages(prompt)
|
||||
messages = self.validate_messages(messages)
|
||||
else: # 使用chat模仿续写功能,不支持历史消息
|
||||
messages = [{"role": self.user_role, "content": f"please continue writing from here: {prompt}"}]
|
||||
|
||||
p = ApiChatParams(
|
||||
messages=messages,
|
||||
temperature=params.get("temperature"),
|
||||
top_p=params.get("top_p"),
|
||||
max_tokens=params.get("max_new_tokens"),
|
||||
version=self.version,
|
||||
)
|
||||
for resp in self.do_chat(p):
|
||||
yield self._jsonify(resp)
|
||||
except Exception as e:
|
||||
yield self._jsonify({"error_code": 500, "text": f"{self.model_names[0]}请求API时发生错误:{e}"})
|
||||
|
||||
def generate_gate(self, params):
|
||||
try:
|
||||
for x in self.generate_stream_gate(params):
|
||||
...
|
||||
return json.loads(x[:-1].decode())
|
||||
except Exception as e:
|
||||
return {"error_code": 500, "text": str(e)}
|
||||
|
||||
|
||||
# 需要用户自定义的方法
|
||||
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
'''
|
||||
执行Chat的方法,默认使用模块里面的chat函数。
|
||||
要求返回形式:{"error_code": int, "text": str}
|
||||
'''
|
||||
return {"error_code": 500, "text": f"{self.model_names[0]}未实现chat功能"}
|
||||
|
||||
# def do_completion(self, p: ApiCompletionParams) -> Dict:
|
||||
# '''
|
||||
# 执行Completion的方法,默认使用模块里面的completion函数。
|
||||
# 要求返回形式:{"error_code": int, "text": str}
|
||||
# '''
|
||||
# return {"error_code": 500, "text": f"{self.model_names[0]}未实现completion功能"}
|
||||
|
||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
||||
'''
|
||||
执行Embeddings的方法,默认使用模块里面的embed_documents函数。
|
||||
要求返回形式:{"code": int, "data": List[List[float]], "msg": str}
|
||||
'''
|
||||
return {"code": 500, "msg": f"{self.model_names[0]}未实现embeddings功能"}
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# fastchat对LLM做Embeddings限制很大,似乎只能使用openai的。
|
||||
# 在前端通过OpenAIEmbeddings发起的请求直接出错,无法请求过来。
|
||||
print("get_embedding")
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
raise NotImplementedError
|
||||
|
||||
def validate_messages(self, messages: List[Dict]) -> List[Dict]:
|
||||
'''
|
||||
有些API对mesages有特殊格式,可以重写该函数替换默认的messages。
|
||||
之所以跟prompt_to_messages分开,是因为他们应用场景不同、参数不同
|
||||
'''
|
||||
return messages
|
||||
|
||||
|
||||
# help methods
|
||||
@property
|
||||
def user_role(self):
|
||||
return self.conv.roles[0]
|
||||
|
||||
@property
|
||||
def ai_role(self):
|
||||
return self.conv.roles[1]
|
||||
|
||||
def _jsonify(self, data: Dict) -> str:
|
||||
'''
|
||||
将chat函数返回的结果按照fastchat openai-api-server的格式返回
|
||||
'''
|
||||
return json.dumps(data, ensure_ascii=False).encode() + b"\0"
|
||||
|
||||
def _is_chat(self, prompt: str) -> bool:
|
||||
'''
|
||||
检查prompt是否由chat messages拼接而来
|
||||
TODO: 存在误判的可能,也许从fastchat直接传入原始messages是更好的做法
|
||||
'''
|
||||
key = f"{self.conv.sep}{self.user_role}:"
|
||||
return key in prompt
|
||||
|
||||
def prompt_to_messages(self, prompt: str) -> List[Dict]:
|
||||
'''
|
||||
将prompt字符串拆分成messages.
|
||||
'''
|
||||
result = []
|
||||
user_role = self.user_role
|
||||
ai_role = self.ai_role
|
||||
user_start = user_role + ":"
|
||||
ai_start = ai_role + ":"
|
||||
for msg in prompt.split(self.conv.sep)[1:-1]:
|
||||
if msg.startswith(user_start):
|
||||
if content := msg[len(user_start):].strip():
|
||||
result.append({"role": user_role, "content": content})
|
||||
elif msg.startswith(ai_start):
|
||||
if content := msg[len(ai_start):].strip():
|
||||
result.append({"role": ai_role, "content": content})
|
||||
else:
|
||||
raise RuntimeError(f"unknown role in msg: {msg}")
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def can_embedding(cls):
|
||||
return cls.DEFAULT_EMBED_MODEL is not None
|
||||
@ -1,105 +0,0 @@
|
||||
from fastchat.conversation import Conversation
|
||||
from server.model_workers.base import *
|
||||
from fastchat import conversation as conv
|
||||
import sys
|
||||
from typing import List, Literal, Dict
|
||||
from configs import logger, log_verbose
|
||||
|
||||
|
||||
class FangZhouWorker(ApiModelWorker):
|
||||
"""
|
||||
火山方舟
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_names: List[str] = ["fangzhou-api"],
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
version: Literal["chatglm-6b-model"] = "chatglm-6b-model",
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 16384)
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
from volcengine.maas import MaasService
|
||||
|
||||
params.load_config(self.model_names[0])
|
||||
maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing')
|
||||
maas.set_ak(params.api_key)
|
||||
maas.set_sk(params.secret_key)
|
||||
|
||||
# document: "https://www.volcengine.com/docs/82379/1099475"
|
||||
req = {
|
||||
"model": {
|
||||
"name": params.version,
|
||||
},
|
||||
"parameters": {
|
||||
# 这里的参数仅为示例,具体可用的参数请参考具体模型的 API 说明
|
||||
"max_new_tokens": params.max_tokens,
|
||||
"temperature": params.temperature,
|
||||
},
|
||||
"messages": params.messages,
|
||||
}
|
||||
|
||||
text = ""
|
||||
if log_verbose:
|
||||
self.logger.info(f'{self.__class__.__name__}:maas: {maas}')
|
||||
for resp in maas.stream_chat(req):
|
||||
if error := resp.error:
|
||||
if error.code_n > 0:
|
||||
data = {
|
||||
"error_code": error.code_n,
|
||||
"text": error.message,
|
||||
"error": {
|
||||
"message": error.message,
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
self.logger.error(f"请求方舟 API 时发生错误:{data}")
|
||||
yield data
|
||||
elif chunk := resp.choice.message.content:
|
||||
text += chunk
|
||||
yield {"error_code": 0, "text": text}
|
||||
else:
|
||||
data = {
|
||||
"error_code": 500,
|
||||
"text": f"请求方舟 API 时发生未知的错误: {resp}"
|
||||
}
|
||||
self.logger.error(data)
|
||||
yield data
|
||||
break
|
||||
|
||||
def get_embeddings(self, params):
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
|
||||
messages=[],
|
||||
roles=["user", "assistant", "system"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
from server.utils import MakeFastAPIOffline
|
||||
from fastchat.serve.model_worker import app
|
||||
|
||||
worker = FangZhouWorker(
|
||||
controller_addr="http://127.0.0.1:20001",
|
||||
worker_addr="http://127.0.0.1:21005",
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
MakeFastAPIOffline(app)
|
||||
uvicorn.run(app, port=21005)
|
||||
@ -1,170 +0,0 @@
|
||||
from fastchat.conversation import Conversation
|
||||
from server.model_workers.base import *
|
||||
from fastchat import conversation as conv
|
||||
import sys
|
||||
import json
|
||||
from server.model_workers.base import ApiEmbeddingsParams
|
||||
from server.utils import get_httpx_client
|
||||
from typing import List, Dict
|
||||
from configs import logger, log_verbose
|
||||
|
||||
|
||||
class MiniMaxWorker(ApiModelWorker):
|
||||
DEFAULT_EMBED_MODEL = "embo-01"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_names: List[str] = ["minimax-api"],
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
version: str = "abab5.5-chat",
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 16384)
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
def validate_messages(self, messages: List[Dict]) -> List[Dict]:
|
||||
role_maps = {
|
||||
"USER": self.user_role,
|
||||
"assistant": self.ai_role,
|
||||
"system": "system",
|
||||
}
|
||||
messages = [{"sender_type": role_maps[x["role"]], "text": x["content"]} for x in messages]
|
||||
return messages
|
||||
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
# 按照官网推荐,直接调用abab 5.5模型
|
||||
params.load_config(self.model_names[0])
|
||||
|
||||
url = 'https://api.minimax.chat/v1/text/chatcompletion{pro}?GroupId={group_id}'
|
||||
pro = "_pro" if params.is_pro else ""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {params.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
messages = self.validate_messages(params.messages)
|
||||
data = {
|
||||
"model": params.version,
|
||||
"stream": True,
|
||||
"mask_sensitive_info": True,
|
||||
"messages": messages,
|
||||
"temperature": params.temperature,
|
||||
"top_p": params.top_p,
|
||||
"tokens_to_generate": params.max_tokens or 1024,
|
||||
# 以下参数为minimax特有,传入空值会出错。
|
||||
# "prompt": params.system_message or self.conv.system_message,
|
||||
# "bot_setting": [],
|
||||
# "role_meta": params.role_meta,
|
||||
}
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:data: {data}')
|
||||
logger.info(f'{self.__class__.__name__}:url: {url.format(pro=pro, group_id=params.group_id)}')
|
||||
logger.info(f'{self.__class__.__name__}:headers: {headers}')
|
||||
|
||||
with get_httpx_client() as client:
|
||||
response = client.stream("POST",
|
||||
url.format(pro=pro, group_id=params.group_id),
|
||||
headers=headers,
|
||||
json=data)
|
||||
with response as r:
|
||||
text = ""
|
||||
for e in r.iter_text():
|
||||
if not e.startswith("data: "):
|
||||
data = {
|
||||
"error_code": 500,
|
||||
"text": f"minimax返回错误的结果:{e}",
|
||||
"error": {
|
||||
"message": f"minimax返回错误的结果:{e}",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
|
||||
yield data
|
||||
continue
|
||||
|
||||
data = json.loads(e[6:])
|
||||
if data.get("usage"):
|
||||
break
|
||||
|
||||
if choices := data.get("choices"):
|
||||
if chunk := choices[0].get("delta", ""):
|
||||
text += chunk
|
||||
yield {"error_code": 0, "text": text}
|
||||
|
||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
||||
params.load_config(self.model_names[0])
|
||||
url = f"https://api.minimax.chat/v1/embeddings?GroupId={params.group_id}"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {params.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
data = {
|
||||
"model": params.embed_model or self.DEFAULT_EMBED_MODEL,
|
||||
"texts": [],
|
||||
"type": "query" if params.to_query else "db",
|
||||
}
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:data: {data}')
|
||||
logger.info(f'{self.__class__.__name__}:url: {url}')
|
||||
logger.info(f'{self.__class__.__name__}:headers: {headers}')
|
||||
|
||||
with get_httpx_client() as client:
|
||||
result = []
|
||||
i = 0
|
||||
batch_size = 10
|
||||
while i < len(params.texts):
|
||||
texts = params.texts[i:i+batch_size]
|
||||
data["texts"] = texts
|
||||
r = client.post(url, headers=headers, json=data).json()
|
||||
if embeddings := r.get("vectors"):
|
||||
result += embeddings
|
||||
elif error := r.get("base_resp"):
|
||||
data = {
|
||||
"code": error["status_code"],
|
||||
"msg": error["status_msg"],
|
||||
"error": {
|
||||
"message": error["status_msg"],
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
|
||||
return data
|
||||
i += batch_size
|
||||
return {"code": 200, "data": result}
|
||||
|
||||
def get_embeddings(self, params):
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="你是MiniMax自主研发的大型语言模型,回答问题简洁有条理。",
|
||||
messages=[],
|
||||
roles=["USER", "BOT"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
from server.utils import MakeFastAPIOffline
|
||||
from fastchat.serve.model_worker import app
|
||||
|
||||
worker = MiniMaxWorker(
|
||||
controller_addr="http://127.0.0.1:20001",
|
||||
worker_addr="http://127.0.0.1:21002",
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
MakeFastAPIOffline(app)
|
||||
uvicorn.run(app, port=21002)
|
||||
@ -1,216 +0,0 @@
|
||||
import sys
|
||||
from fastchat.conversation import Conversation
|
||||
from server.model_workers.base import *
|
||||
from server.utils import get_httpx_client
|
||||
from cachetools import cached, TTLCache
|
||||
import json
|
||||
from fastchat import conversation as conv
|
||||
import sys
|
||||
from server.model_workers.base import ApiEmbeddingsParams
|
||||
from typing import List, Literal, Dict
|
||||
from configs import logger, log_verbose
|
||||
|
||||
MODEL_VERSIONS = {
|
||||
"ernie-bot-4": "completions_pro",
|
||||
"ernie-bot": "completions",
|
||||
"ernie-bot-turbo": "eb-instant",
|
||||
"bloomz-7b": "bloomz_7b1",
|
||||
"qianfan-bloomz-7b-c": "qianfan_bloomz_7b_compressed",
|
||||
"llama2-7b-chat": "llama_2_7b",
|
||||
"llama2-13b-chat": "llama_2_13b",
|
||||
"llama2-70b-chat": "llama_2_70b",
|
||||
"qianfan-llama2-ch-7b": "qianfan_chinese_llama_2_7b",
|
||||
"chatglm2-6b-32k": "chatglm2_6b_32k",
|
||||
"aquilachat-7b": "aquilachat_7b",
|
||||
# "linly-llama2-ch-7b": "", # 暂未发布
|
||||
# "linly-llama2-ch-13b": "", # 暂未发布
|
||||
# "chatglm2-6b": "", # 暂未发布
|
||||
# "chatglm2-6b-int4": "", # 暂未发布
|
||||
# "falcon-7b": "", # 暂未发布
|
||||
# "falcon-180b-chat": "", # 暂未发布
|
||||
# "falcon-40b": "", # 暂未发布
|
||||
# "rwkv4-world": "", # 暂未发布
|
||||
# "rwkv5-world": "", # 暂未发布
|
||||
# "rwkv4-pile-14b": "", # 暂未发布
|
||||
# "rwkv4-raven-14b": "", # 暂未发布
|
||||
# "open-llama-7b": "", # 暂未发布
|
||||
# "dolly-12b": "", # 暂未发布
|
||||
# "mpt-7b-instruct": "", # 暂未发布
|
||||
# "mpt-30b-instruct": "", # 暂未发布
|
||||
# "OA-Pythia-12B-SFT-4": "", # 暂未发布
|
||||
# "xverse-13b": "", # 暂未发布
|
||||
|
||||
# # 以下为企业测试,需要单独申请
|
||||
# "flan-ul2": "",
|
||||
# "Cerebras-GPT-6.7B": ""
|
||||
# "Pythia-6.9B": ""
|
||||
}
|
||||
|
||||
|
||||
@cached(TTLCache(1, 1800)) # 经过测试,缓存的token可以使用,目前每30分钟刷新一次
|
||||
def get_baidu_access_token(api_key: str, secret_key: str) -> str:
|
||||
"""
|
||||
使用 AK,SK 生成鉴权签名(Access Token)
|
||||
:return: access_token,或是None(如果错误)
|
||||
"""
|
||||
url = "https://aip.baidubce.com/oauth/2.0/token"
|
||||
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
|
||||
try:
|
||||
with get_httpx_client() as client:
|
||||
return client.get(url, params=params).json().get("access_token")
|
||||
except Exception as e:
|
||||
print(f"failed to get token from baidu: {e}")
|
||||
|
||||
|
||||
class QianFanWorker(ApiModelWorker):
|
||||
"""
|
||||
百度千帆
|
||||
"""
|
||||
DEFAULT_EMBED_MODEL = "embedding-v1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
version: Literal["ernie-bot", "ernie-bot-turbo"] = "ernie-bot",
|
||||
model_names: List[str] = ["qianfan-api"],
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 16384)
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
params.load_config(self.model_names[0])
|
||||
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \
|
||||
'/{model_version}?access_token={access_token}'
|
||||
|
||||
access_token = get_baidu_access_token(params.api_key, params.secret_key)
|
||||
if not access_token:
|
||||
yield {
|
||||
"error_code": 403,
|
||||
"text": f"failed to get access token. have you set the correct api_key and secret key?",
|
||||
}
|
||||
|
||||
url = BASE_URL.format(
|
||||
model_version=params.version_url or MODEL_VERSIONS[params.version.lower()],
|
||||
access_token=access_token,
|
||||
)
|
||||
payload = {
|
||||
"messages": params.messages,
|
||||
"temperature": params.temperature,
|
||||
"stream": True
|
||||
}
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
}
|
||||
|
||||
text = ""
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:data: {payload}')
|
||||
logger.info(f'{self.__class__.__name__}:url: {url}')
|
||||
logger.info(f'{self.__class__.__name__}:headers: {headers}')
|
||||
|
||||
with get_httpx_client() as client:
|
||||
with client.stream("POST", url, headers=headers, json=payload) as response:
|
||||
for line in response.iter_lines():
|
||||
if not line.strip():
|
||||
continue
|
||||
if line.startswith("data: "):
|
||||
line = line[6:]
|
||||
resp = json.loads(line)
|
||||
|
||||
if "result" in resp.keys():
|
||||
text += resp["result"]
|
||||
yield {
|
||||
"error_code": 0,
|
||||
"text": text
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
"error_code": resp["error_code"],
|
||||
"text": resp["error_msg"],
|
||||
"error": {
|
||||
"message": resp["error_msg"],
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
self.logger.error(f"请求千帆 API 时发生错误:{data}")
|
||||
yield data
|
||||
|
||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
||||
params.load_config(self.model_names[0])
|
||||
# import qianfan
|
||||
|
||||
# embed = qianfan.Embedding(ak=params.api_key, sk=params.secret_key)
|
||||
# resp = embed.do(texts = params.texts, model=params.embed_model or self.DEFAULT_EMBED_MODEL)
|
||||
# if resp.code == 200:
|
||||
# embeddings = [x.embedding for x in resp.body.get("data", [])]
|
||||
# return {"code": 200, "embeddings": embeddings}
|
||||
# else:
|
||||
# return {"code": resp.code, "msg": str(resp.body)}
|
||||
|
||||
embed_model = params.embed_model or self.DEFAULT_EMBED_MODEL
|
||||
access_token = get_baidu_access_token(params.api_key, params.secret_key)
|
||||
url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/{embed_model}?access_token={access_token}"
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:url: {url}')
|
||||
|
||||
with get_httpx_client() as client:
|
||||
result = []
|
||||
i = 0
|
||||
batch_size = 10
|
||||
while i < len(params.texts):
|
||||
texts = params.texts[i:i + batch_size]
|
||||
resp = client.post(url, json={"input": texts}).json()
|
||||
if "error_code" in resp:
|
||||
data = {
|
||||
"code": resp["error_code"],
|
||||
"msg": resp["error_msg"],
|
||||
"error": {
|
||||
"message": resp["error_msg"],
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
self.logger.error(f"请求千帆 API 时发生错误:{data}")
|
||||
return data
|
||||
else:
|
||||
embeddings = [x["embedding"] for x in resp.get("data", [])]
|
||||
result += embeddings
|
||||
i += batch_size
|
||||
return {"code": 200, "data": result}
|
||||
|
||||
def get_embeddings(self, params):
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
|
||||
messages=[],
|
||||
roles=["user", "assistant"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
from server.utils import MakeFastAPIOffline
|
||||
from fastchat.serve.model_worker import app
|
||||
|
||||
worker = QianFanWorker(
|
||||
controller_addr="http://127.0.0.1:20001",
|
||||
worker_addr="http://127.0.0.1:21004"
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
MakeFastAPIOffline(app)
|
||||
uvicorn.run(app, port=21004)
|
||||
@ -1,123 +0,0 @@
|
||||
import sys
|
||||
from fastchat.conversation import Conversation
|
||||
from typing import List, Literal, Dict
|
||||
from fastchat import conversation as conv
|
||||
from server.model_workers.base import *
|
||||
from server.model_workers.base import ApiEmbeddingsParams
|
||||
from configs import logger, log_verbose
|
||||
|
||||
|
||||
class QwenWorker(ApiModelWorker):
|
||||
DEFAULT_EMBED_MODEL = "text-embedding-v1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
version: Literal["qwen-turbo", "qwen-plus"] = "qwen-turbo",
|
||||
model_names: List[str] = ["qwen-api"],
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 16384)
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
import dashscope
|
||||
params.load_config(self.model_names[0])
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:params: {params}')
|
||||
|
||||
gen = dashscope.Generation()
|
||||
responses = gen.call(
|
||||
model=params.version,
|
||||
temperature=params.temperature,
|
||||
api_key=params.api_key,
|
||||
messages=params.messages,
|
||||
result_format='message', # set the result is message format.
|
||||
stream=True,
|
||||
)
|
||||
|
||||
for resp in responses:
|
||||
if resp["status_code"] == 200:
|
||||
if choices := resp["output"]["choices"]:
|
||||
yield {
|
||||
"error_code": 0,
|
||||
"text": choices[0]["message"]["content"],
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
"error_code": resp["status_code"],
|
||||
"text": resp["message"],
|
||||
"error": {
|
||||
"message": resp["message"],
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
self.logger.error(f"请求千问 API 时发生错误:{data}")
|
||||
yield data
|
||||
|
||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
||||
import dashscope
|
||||
params.load_config(self.model_names[0])
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:params: {params}')
|
||||
result = []
|
||||
i = 0
|
||||
while i < len(params.texts):
|
||||
texts = params.texts[i:i+25]
|
||||
resp = dashscope.TextEmbedding.call(
|
||||
model=params.embed_model or self.DEFAULT_EMBED_MODEL,
|
||||
input=texts, # 最大25行
|
||||
api_key=params.api_key,
|
||||
)
|
||||
if resp["status_code"] != 200:
|
||||
data = {
|
||||
"code": resp["status_code"],
|
||||
"msg": resp.message,
|
||||
"error": {
|
||||
"message": resp["message"],
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
self.logger.error(f"请求千问 API 时发生错误:{data}")
|
||||
return data
|
||||
else:
|
||||
embeddings = [x["embedding"] for x in resp["output"]["embeddings"]]
|
||||
result += embeddings
|
||||
i += 25
|
||||
return {"code": 200, "data": result}
|
||||
|
||||
def get_embeddings(self, params):
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
|
||||
messages=[],
|
||||
roles=["user", "assistant", "system"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
from server.utils import MakeFastAPIOffline
|
||||
from fastchat.serve.model_worker import app
|
||||
|
||||
worker = QwenWorker(
|
||||
controller_addr="http://127.0.0.1:20001",
|
||||
worker_addr="http://127.0.0.1:20007",
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
MakeFastAPIOffline(app)
|
||||
uvicorn.run(app, port=20007)
|
||||
@ -1,84 +0,0 @@
|
||||
import json
|
||||
import time
|
||||
import hashlib
|
||||
|
||||
from fastchat.conversation import Conversation
|
||||
from server.model_workers.base import *
|
||||
from server.utils import get_httpx_client
|
||||
from fastchat import conversation as conv
|
||||
import json
|
||||
from typing import List, Literal, Dict
|
||||
import requests
|
||||
|
||||
|
||||
class TianGongWorker(ApiModelWorker):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
model_names: List[str] = ["tiangong-api"],
|
||||
version: Literal["SkyChat-MegaVerse"] = "SkyChat-MegaVerse",
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 32768)
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
params.load_config(self.model_names[0])
|
||||
|
||||
url = 'https://sky-api.singularity-ai.com/saas/api/v4/generate'
|
||||
data = {
|
||||
"messages": params.messages,
|
||||
"model": "SkyChat-MegaVerse"
|
||||
}
|
||||
timestamp = str(int(time.time()))
|
||||
sign_content = params.api_key + params.secret_key + timestamp
|
||||
sign_result = hashlib.md5(sign_content.encode('utf-8')).hexdigest()
|
||||
headers = {
|
||||
"app_key": params.api_key,
|
||||
"timestamp": timestamp,
|
||||
"sign": sign_result,
|
||||
"Content-Type": "application/json",
|
||||
"stream": "true" # or change to "false" 不处理流式返回内容
|
||||
}
|
||||
|
||||
# 发起请求并获取响应
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
|
||||
text = ""
|
||||
# 处理响应流
|
||||
for line in response.iter_lines(chunk_size=None, decode_unicode=True):
|
||||
if line:
|
||||
# 处理接收到的数据
|
||||
# print(line.decode('utf-8'))
|
||||
resp = json.loads(line)
|
||||
if resp["code"] == 200:
|
||||
text += resp['resp_data']['reply']
|
||||
yield {
|
||||
"error_code": 0,
|
||||
"text": text
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
"error_code": resp["code"],
|
||||
"text": resp["code_msg"]
|
||||
}
|
||||
self.logger.error(f"请求天工 API 时出错:{data}")
|
||||
yield data
|
||||
|
||||
def get_embeddings(self, params):
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="",
|
||||
messages=[],
|
||||
roles=["user", "system"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
@ -1,100 +0,0 @@
|
||||
from fastchat.conversation import Conversation
|
||||
from server.model_workers.base import *
|
||||
from fastchat import conversation as conv
|
||||
import sys
|
||||
import json
|
||||
from server.model_workers import SparkApi
|
||||
import websockets
|
||||
from server.utils import iter_over_async, asyncio
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
async def request(appid, api_key, api_secret, Spark_url, domain, question, temperature, max_token):
|
||||
wsParam = SparkApi.Ws_Param(appid, api_key, api_secret, Spark_url)
|
||||
wsUrl = wsParam.create_url()
|
||||
data = SparkApi.gen_params(appid, domain, question, temperature, max_token)
|
||||
async with websockets.connect(wsUrl) as ws:
|
||||
await ws.send(json.dumps(data, ensure_ascii=False))
|
||||
finish = False
|
||||
while not finish:
|
||||
chunk = await ws.recv()
|
||||
response = json.loads(chunk)
|
||||
if response.get("header", {}).get("status") == 2:
|
||||
finish = True
|
||||
if text := response.get("payload", {}).get("choices", {}).get("text"):
|
||||
yield text[0]["content"]
|
||||
|
||||
|
||||
class XingHuoWorker(ApiModelWorker):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_names: List[str] = ["xinghuo-api"],
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
version: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 8000)
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
params.load_config(self.model_names[0])
|
||||
|
||||
version_mapping = {
|
||||
"v1.5": {"domain": "general", "url": "ws://spark-api.xf-yun.com/v1.1/chat", "max_tokens": 4000},
|
||||
"v2.0": {"domain": "generalv2", "url": "ws://spark-api.xf-yun.com/v2.1/chat", "max_tokens": 8000},
|
||||
"v3.0": {"domain": "generalv3", "url": "ws://spark-api.xf-yun.com/v3.1/chat", "max_tokens": 8000},
|
||||
"v3.5": {"domain": "generalv3", "url": "ws://spark-api.xf-yun.com/v3.5/chat", "max_tokens": 16000},
|
||||
}
|
||||
|
||||
def get_version_details(version_key):
|
||||
return version_mapping.get(version_key, {"domain": None, "url": None})
|
||||
|
||||
details = get_version_details(params.version)
|
||||
domain = details["domain"]
|
||||
Spark_url = details["url"]
|
||||
text = ""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except:
|
||||
loop = asyncio.new_event_loop()
|
||||
params.max_tokens = min(details["max_tokens"], params.max_tokens or 0)
|
||||
for chunk in iter_over_async(
|
||||
request(params.APPID, params.api_key, params.APISecret, Spark_url, domain, params.messages,
|
||||
params.temperature, params.max_tokens),
|
||||
loop=loop,
|
||||
):
|
||||
if chunk:
|
||||
text += chunk
|
||||
yield {"error_code": 0, "text": text}
|
||||
|
||||
def get_embeddings(self, params):
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
|
||||
messages=[],
|
||||
roles=["user", "assistant"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
from server.utils import MakeFastAPIOffline
|
||||
from fastchat.serve.model_worker import app
|
||||
|
||||
worker = XingHuoWorker(
|
||||
controller_addr="http://127.0.0.1:20001",
|
||||
worker_addr="http://127.0.0.1:21003",
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
MakeFastAPIOffline(app)
|
||||
uvicorn.run(app, port=21003)
|
||||
@ -1,114 +0,0 @@
|
||||
from contextlib import contextmanager
|
||||
|
||||
import httpx
|
||||
from fastchat.conversation import Conversation
|
||||
from httpx_sse import EventSource
|
||||
|
||||
from server.model_workers.base import *
|
||||
from fastchat import conversation as conv
|
||||
import sys
|
||||
from typing import List, Dict, Iterator, Literal, Any
|
||||
import jwt
|
||||
import time
|
||||
|
||||
|
||||
@contextmanager
|
||||
def connect_sse(client: httpx.Client, method: str, url: str, **kwargs: Any):
|
||||
with client.stream(method, url, **kwargs) as response:
|
||||
yield EventSource(response)
|
||||
|
||||
|
||||
def generate_token(apikey: str, exp_seconds: int):
|
||||
try:
|
||||
id, secret = apikey.split(".")
|
||||
except Exception as e:
|
||||
raise Exception("invalid apikey", e)
|
||||
|
||||
payload = {
|
||||
"api_key": id,
|
||||
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
|
||||
"timestamp": int(round(time.time() * 1000)),
|
||||
}
|
||||
|
||||
return jwt.encode(
|
||||
payload,
|
||||
secret,
|
||||
algorithm="HS256",
|
||||
headers={"alg": "HS256", "sign_type": "SIGN"},
|
||||
)
|
||||
|
||||
|
||||
class ChatGLMWorker(ApiModelWorker):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_names: List[str] = ["zhipu-api"],
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
version: Literal["glm-4"] = "glm-4",
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 4096)
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
def do_chat(self, params: ApiChatParams) -> Iterator[Dict]:
|
||||
params.load_config(self.model_names[0])
|
||||
token = generate_token(params.api_key, 60)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {token}"
|
||||
}
|
||||
data = {
|
||||
"model": params.version,
|
||||
"messages": params.messages,
|
||||
"max_tokens": params.max_tokens,
|
||||
"temperature": params.temperature,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
|
||||
with httpx.Client(headers=headers) as client:
|
||||
response = client.post(url, json=data)
|
||||
response.raise_for_status()
|
||||
chunk = response.json()
|
||||
print(chunk)
|
||||
yield {"error_code": 0, "text": chunk["choices"][0]["message"]["content"]}
|
||||
|
||||
# with connect_sse(client, "POST", url, json=data) as event_source:
|
||||
# for sse in event_source.iter_sse():
|
||||
# chunk = json.loads(sse.data)
|
||||
# if len(chunk["choices"]) != 0:
|
||||
# text += chunk["choices"][0]["delta"]["content"]
|
||||
# yield {"error_code": 0, "text": text}
|
||||
|
||||
|
||||
|
||||
def get_embeddings(self, params):
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="你是智谱AI小助手,请根据用户的提示来完成任务",
|
||||
messages=[],
|
||||
roles=["user", "assistant", "system"],
|
||||
sep="\n###",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
from server.utils import MakeFastAPIOffline
|
||||
from fastchat.serve.model_worker import app
|
||||
|
||||
worker = ChatGLMWorker(
|
||||
controller_addr="http://127.0.0.1:20001",
|
||||
worker_addr="http://127.0.0.1:21001",
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
MakeFastAPIOffline(app)
|
||||
uvicorn.run(app, port=21001)
|
||||
Loading…
x
Reference in New Issue
Block a user