删除本地fschat workers

This commit is contained in:
glide-the 2024-01-21 18:30:21 +08:00 committed by liunux4odoo
parent 175db6710e
commit a7b28adc09
12 changed files with 0 additions and 1462 deletions

View File

@ -1,79 +0,0 @@
import base64
import datetime
import hashlib
import hmac
from urllib.parse import urlparse
from datetime import datetime
from time import mktime
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
class Ws_Param(object):
# 初始化
def __init__(self, APPID, APIKey, APISecret, Spark_url):
self.APPID = APPID
self.APIKey = APIKey
self.APISecret = APISecret
self.host = urlparse(Spark_url).netloc
self.path = urlparse(Spark_url).path
self.Spark_url = Spark_url
# 生成url
def create_url(self):
# 生成RFC1123格式的时间戳
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
# 拼接字符串
signature_origin = "host: " + self.host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + self.path + " HTTP/1.1"
# 进行hmac-sha256进行加密
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
# 将请求的鉴权参数组合为字典
v = {
"authorization": authorization,
"date": date,
"host": self.host
}
# 拼接鉴权参数生成url
url = self.Spark_url + '?' + urlencode(v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释比对相同参数时生成的url与自己代码生成的url是否一致
return url
def gen_params(appid, domain, question, temperature, max_token):
"""
通过appid和用户的提问来生成请参数
"""
data = {
"header": {
"app_id": appid,
"uid": "1234"
},
"parameter": {
"chat": {
"domain": domain,
"random_threshold": 0.5,
"max_tokens": max_token,
"auditing": "default",
"temperature": temperature,
}
},
"payload": {
"message": {
"text": question
}
}
}
return data

View File

@ -1,11 +0,0 @@
from .base import *
from .zhipu import ChatGLMWorker
from .minimax import MiniMaxWorker
from .xinghuo import XingHuoWorker
from .qianfan import QianFanWorker
from .fangzhou import FangZhouWorker
from .qwen import QwenWorker
from .baichuan import BaiChuanWorker
from .azure import AzureWorker
from .tiangong import TianGongWorker
from .gemini import GeminiWorker

View File

@ -1,95 +0,0 @@
import sys
import os
from fastchat.conversation import Conversation
from server.model_workers.base import *
from server.utils import get_httpx_client
from fastchat import conversation as conv
import json
from typing import List, Dict
from configs import logger, log_verbose
class AzureWorker(ApiModelWorker):
def __init__(
self,
*,
controller_addr: str = None,
worker_addr: str = None,
model_names: List[str] = ["azure-api"],
version: str = "gpt-35-turbo",
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
super().__init__(**kwargs)
self.version = version
def do_chat(self, params: ApiChatParams) -> Dict:
params.load_config(self.model_names[0])
data = dict(
messages=params.messages,
temperature=params.temperature,
max_tokens=params.max_tokens if params.max_tokens else None,
stream=True,
)
url = ("https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}"
.format(params.resource_name, params.deployment_name, params.api_version))
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json',
'api-key': params.api_key,
}
text = ""
if log_verbose:
logger.info(f'{self.__class__.__name__}:url: {url}')
logger.info(f'{self.__class__.__name__}:headers: {headers}')
logger.info(f'{self.__class__.__name__}:data: {data}')
with get_httpx_client() as client:
with client.stream("POST", url, headers=headers, json=data) as response:
print(data)
for line in response.iter_lines():
if not line.strip() or "[DONE]" in line:
continue
if line.startswith("data: "):
line = line[6:]
resp = json.loads(line)
if choices := resp["choices"]:
if chunk := choices[0].get("delta", {}).get("content"):
text += chunk
yield {
"error_code": 0,
"text": text
}
print(text)
else:
self.logger.error(f"请求 Azure API 时发生错误:{resp}")
def get_embeddings(self, params):
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
return conv.Conversation(
name=self.model_names[0],
system_message="You are a helpful, respectful and honest assistant.",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.base_model_worker import app
worker = AzureWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21008",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21008)

View File

@ -1,117 +0,0 @@
import json
import time
import hashlib
from fastchat.conversation import Conversation
from server.model_workers.base import *
from server.utils import get_httpx_client
from fastchat import conversation as conv
import sys
import json
from typing import List, Literal, Dict
from configs import logger, log_verbose
def calculate_md5(input_string):
md5 = hashlib.md5()
md5.update(input_string.encode('utf-8'))
encrypted = md5.hexdigest()
return encrypted
class BaiChuanWorker(ApiModelWorker):
def __init__(
self,
*,
controller_addr: str = None,
worker_addr: str = None,
model_names: List[str] = ["baichuan-api"],
version: Literal["Baichuan2-53B"] = "Baichuan2-53B",
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 32768)
super().__init__(**kwargs)
self.version = version
def do_chat(self, params: ApiChatParams) -> Dict:
params.load_config(self.model_names[0])
url = "https://api.baichuan-ai.com/v1/stream/chat"
data = {
"model": params.version,
"messages": params.messages,
"parameters": {"temperature": params.temperature}
}
json_data = json.dumps(data)
time_stamp = int(time.time())
signature = calculate_md5(params.secret_key + json_data + str(time_stamp))
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + params.api_key,
"X-BC-Request-Id": "your requestId",
"X-BC-Timestamp": str(time_stamp),
"X-BC-Signature": signature,
"X-BC-Sign-Algo": "MD5",
}
text = ""
if log_verbose:
logger.info(f'{self.__class__.__name__}:json_data: {json_data}')
logger.info(f'{self.__class__.__name__}:url: {url}')
logger.info(f'{self.__class__.__name__}:headers: {headers}')
with get_httpx_client() as client:
with client.stream("POST", url, headers=headers, json=data) as response:
for line in response.iter_lines():
if not line.strip():
continue
resp = json.loads(line)
if resp["code"] == 0:
text += resp["data"]["messages"][-1]["content"]
yield {
"error_code": resp["code"],
"text": text
}
else:
data = {
"error_code": resp["code"],
"text": resp["msg"],
"error": {
"message": resp["msg"],
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
self.logger.error(f"请求百川 API 时发生错误:{data}")
yield data
def get_embeddings(self, params):
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
return conv.Conversation(
name=self.model_names[0],
system_message="",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = BaiChuanWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21007",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21007)
# do_request()

View File

@ -1,248 +0,0 @@
from fastchat.conversation import Conversation
from configs import LOG_PATH
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.base_model_worker import BaseModelWorker
import uuid
import json
import sys
from pydantic import BaseModel, root_validator
import fastchat
import asyncio
from server.utils import get_model_worker_config
from typing import Dict, List, Optional
__all__ = ["ApiModelWorker", "ApiChatParams", "ApiCompletionParams", "ApiEmbeddingsParams"]
class ApiConfigParams(BaseModel):
'''
在线API配置参数未提供的值会自动从model_config.ONLINE_LLM_MODEL中读取
'''
api_base_url: Optional[str] = None
api_proxy: Optional[str] = None
api_key: Optional[str] = None
secret_key: Optional[str] = None
group_id: Optional[str] = None # for minimax
is_pro: bool = False # for minimax
APPID: Optional[str] = None # for xinghuo
APISecret: Optional[str] = None # for xinghuo
is_v2: bool = False # for xinghuo
worker_name: Optional[str] = None
class Config:
extra = "allow"
@root_validator(pre=True)
def validate_config(cls, v: Dict) -> Dict:
if config := get_model_worker_config(v.get("worker_name")):
for n in cls.__fields__:
if n in config:
v[n] = config[n]
return v
def load_config(self, worker_name: str):
self.worker_name = worker_name
if config := get_model_worker_config(worker_name):
for n in self.__fields__:
if n in config:
setattr(self, n, config[n])
return self
class ApiModelParams(ApiConfigParams):
'''
模型配置参数
'''
version: Optional[str] = None
version_url: Optional[str] = None
api_version: Optional[str] = None # for azure
deployment_name: Optional[str] = None # for azure
resource_name: Optional[str] = None # for azure
temperature: float = 0.9
max_tokens: Optional[int] = None
top_p: Optional[float] = 1.0
class ApiChatParams(ApiModelParams):
'''
chat请求参数
'''
messages: List[Dict[str, str]]
system_message: Optional[str] = None # for minimax
role_meta: Dict = {} # for minimax
class ApiCompletionParams(ApiModelParams):
prompt: str
class ApiEmbeddingsParams(ApiConfigParams):
texts: List[str]
embed_model: Optional[str] = None
to_query: bool = False # for minimax
class ApiModelWorker(BaseModelWorker):
DEFAULT_EMBED_MODEL: str = None # None means not support embedding
def __init__(
self,
model_names: List[str],
controller_addr: str = None,
worker_addr: str = None,
context_len: int = 2048,
no_register: bool = False,
**kwargs,
):
kwargs.setdefault("worker_id", uuid.uuid4().hex[:8])
kwargs.setdefault("model_path", "")
kwargs.setdefault("limit_worker_concurrency", 5)
super().__init__(model_names=model_names,
controller_addr=controller_addr,
worker_addr=worker_addr,
**kwargs)
import fastchat.serve.base_model_worker
import sys
self.logger = fastchat.serve.base_model_worker.logger
# 恢复被fastchat覆盖的标准输出
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
self.context_len = context_len
self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency)
self.version = None
if not no_register and self.controller_addr:
self.init_heart_beat()
def count_token(self, params):
prompt = params["prompt"]
return {"count": len(str(prompt)), "error_code": 0}
def generate_stream_gate(self, params: Dict):
self.call_ct += 1
try:
prompt = params["prompt"]
if self._is_chat(prompt):
messages = self.prompt_to_messages(prompt)
messages = self.validate_messages(messages)
else: # 使用chat模仿续写功能不支持历史消息
messages = [{"role": self.user_role, "content": f"please continue writing from here: {prompt}"}]
p = ApiChatParams(
messages=messages,
temperature=params.get("temperature"),
top_p=params.get("top_p"),
max_tokens=params.get("max_new_tokens"),
version=self.version,
)
for resp in self.do_chat(p):
yield self._jsonify(resp)
except Exception as e:
yield self._jsonify({"error_code": 500, "text": f"{self.model_names[0]}请求API时发生错误{e}"})
def generate_gate(self, params):
try:
for x in self.generate_stream_gate(params):
...
return json.loads(x[:-1].decode())
except Exception as e:
return {"error_code": 500, "text": str(e)}
# 需要用户自定义的方法
def do_chat(self, params: ApiChatParams) -> Dict:
'''
执行Chat的方法默认使用模块里面的chat函数
要求返回形式{"error_code": int, "text": str}
'''
return {"error_code": 500, "text": f"{self.model_names[0]}未实现chat功能"}
# def do_completion(self, p: ApiCompletionParams) -> Dict:
# '''
# 执行Completion的方法默认使用模块里面的completion函数。
# 要求返回形式:{"error_code": int, "text": str}
# '''
# return {"error_code": 500, "text": f"{self.model_names[0]}未实现completion功能"}
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
'''
执行Embeddings的方法默认使用模块里面的embed_documents函数
要求返回形式{"code": int, "data": List[List[float]], "msg": str}
'''
return {"code": 500, "msg": f"{self.model_names[0]}未实现embeddings功能"}
def get_embeddings(self, params):
# fastchat对LLM做Embeddings限制很大似乎只能使用openai的。
# 在前端通过OpenAIEmbeddings发起的请求直接出错无法请求过来。
print("get_embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
raise NotImplementedError
def validate_messages(self, messages: List[Dict]) -> List[Dict]:
'''
有些API对mesages有特殊格式可以重写该函数替换默认的messages
之所以跟prompt_to_messages分开是因为他们应用场景不同参数不同
'''
return messages
# help methods
@property
def user_role(self):
return self.conv.roles[0]
@property
def ai_role(self):
return self.conv.roles[1]
def _jsonify(self, data: Dict) -> str:
'''
将chat函数返回的结果按照fastchat openai-api-server的格式返回
'''
return json.dumps(data, ensure_ascii=False).encode() + b"\0"
def _is_chat(self, prompt: str) -> bool:
'''
检查prompt是否由chat messages拼接而来
TODO: 存在误判的可能也许从fastchat直接传入原始messages是更好的做法
'''
key = f"{self.conv.sep}{self.user_role}:"
return key in prompt
def prompt_to_messages(self, prompt: str) -> List[Dict]:
'''
将prompt字符串拆分成messages.
'''
result = []
user_role = self.user_role
ai_role = self.ai_role
user_start = user_role + ":"
ai_start = ai_role + ":"
for msg in prompt.split(self.conv.sep)[1:-1]:
if msg.startswith(user_start):
if content := msg[len(user_start):].strip():
result.append({"role": user_role, "content": content})
elif msg.startswith(ai_start):
if content := msg[len(ai_start):].strip():
result.append({"role": ai_role, "content": content})
else:
raise RuntimeError(f"unknown role in msg: {msg}")
return result
@classmethod
def can_embedding(cls):
return cls.DEFAULT_EMBED_MODEL is not None

View File

@ -1,105 +0,0 @@
from fastchat.conversation import Conversation
from server.model_workers.base import *
from fastchat import conversation as conv
import sys
from typing import List, Literal, Dict
from configs import logger, log_verbose
class FangZhouWorker(ApiModelWorker):
"""
火山方舟
"""
def __init__(
self,
*,
model_names: List[str] = ["fangzhou-api"],
controller_addr: str = None,
worker_addr: str = None,
version: Literal["chatglm-6b-model"] = "chatglm-6b-model",
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs)
self.version = version
def do_chat(self, params: ApiChatParams) -> Dict:
from volcengine.maas import MaasService
params.load_config(self.model_names[0])
maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing')
maas.set_ak(params.api_key)
maas.set_sk(params.secret_key)
# document: "https://www.volcengine.com/docs/82379/1099475"
req = {
"model": {
"name": params.version,
},
"parameters": {
# 这里的参数仅为示例,具体可用的参数请参考具体模型的 API 说明
"max_new_tokens": params.max_tokens,
"temperature": params.temperature,
},
"messages": params.messages,
}
text = ""
if log_verbose:
self.logger.info(f'{self.__class__.__name__}:maas: {maas}')
for resp in maas.stream_chat(req):
if error := resp.error:
if error.code_n > 0:
data = {
"error_code": error.code_n,
"text": error.message,
"error": {
"message": error.message,
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
self.logger.error(f"请求方舟 API 时发生错误:{data}")
yield data
elif chunk := resp.choice.message.content:
text += chunk
yield {"error_code": 0, "text": text}
else:
data = {
"error_code": 500,
"text": f"请求方舟 API 时发生未知的错误: {resp}"
}
self.logger.error(data)
yield data
break
def get_embeddings(self, params):
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
return conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
messages=[],
roles=["user", "assistant", "system"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = FangZhouWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21005",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21005)

View File

@ -1,170 +0,0 @@
from fastchat.conversation import Conversation
from server.model_workers.base import *
from fastchat import conversation as conv
import sys
import json
from server.model_workers.base import ApiEmbeddingsParams
from server.utils import get_httpx_client
from typing import List, Dict
from configs import logger, log_verbose
class MiniMaxWorker(ApiModelWorker):
DEFAULT_EMBED_MODEL = "embo-01"
def __init__(
self,
*,
model_names: List[str] = ["minimax-api"],
controller_addr: str = None,
worker_addr: str = None,
version: str = "abab5.5-chat",
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs)
self.version = version
def validate_messages(self, messages: List[Dict]) -> List[Dict]:
role_maps = {
"USER": self.user_role,
"assistant": self.ai_role,
"system": "system",
}
messages = [{"sender_type": role_maps[x["role"]], "text": x["content"]} for x in messages]
return messages
def do_chat(self, params: ApiChatParams) -> Dict:
# 按照官网推荐直接调用abab 5.5模型
params.load_config(self.model_names[0])
url = 'https://api.minimax.chat/v1/text/chatcompletion{pro}?GroupId={group_id}'
pro = "_pro" if params.is_pro else ""
headers = {
"Authorization": f"Bearer {params.api_key}",
"Content-Type": "application/json",
}
messages = self.validate_messages(params.messages)
data = {
"model": params.version,
"stream": True,
"mask_sensitive_info": True,
"messages": messages,
"temperature": params.temperature,
"top_p": params.top_p,
"tokens_to_generate": params.max_tokens or 1024,
# 以下参数为minimax特有传入空值会出错。
# "prompt": params.system_message or self.conv.system_message,
# "bot_setting": [],
# "role_meta": params.role_meta,
}
if log_verbose:
logger.info(f'{self.__class__.__name__}:data: {data}')
logger.info(f'{self.__class__.__name__}:url: {url.format(pro=pro, group_id=params.group_id)}')
logger.info(f'{self.__class__.__name__}:headers: {headers}')
with get_httpx_client() as client:
response = client.stream("POST",
url.format(pro=pro, group_id=params.group_id),
headers=headers,
json=data)
with response as r:
text = ""
for e in r.iter_text():
if not e.startswith("data: "):
data = {
"error_code": 500,
"text": f"minimax返回错误的结果{e}",
"error": {
"message": f"minimax返回错误的结果{e}",
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
yield data
continue
data = json.loads(e[6:])
if data.get("usage"):
break
if choices := data.get("choices"):
if chunk := choices[0].get("delta", ""):
text += chunk
yield {"error_code": 0, "text": text}
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
params.load_config(self.model_names[0])
url = f"https://api.minimax.chat/v1/embeddings?GroupId={params.group_id}"
headers = {
"Authorization": f"Bearer {params.api_key}",
"Content-Type": "application/json",
}
data = {
"model": params.embed_model or self.DEFAULT_EMBED_MODEL,
"texts": [],
"type": "query" if params.to_query else "db",
}
if log_verbose:
logger.info(f'{self.__class__.__name__}:data: {data}')
logger.info(f'{self.__class__.__name__}:url: {url}')
logger.info(f'{self.__class__.__name__}:headers: {headers}')
with get_httpx_client() as client:
result = []
i = 0
batch_size = 10
while i < len(params.texts):
texts = params.texts[i:i+batch_size]
data["texts"] = texts
r = client.post(url, headers=headers, json=data).json()
if embeddings := r.get("vectors"):
result += embeddings
elif error := r.get("base_resp"):
data = {
"code": error["status_code"],
"msg": error["status_msg"],
"error": {
"message": error["status_msg"],
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
return data
i += batch_size
return {"code": 200, "data": result}
def get_embeddings(self, params):
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
return conv.Conversation(
name=self.model_names[0],
system_message="你是MiniMax自主研发的大型语言模型回答问题简洁有条理。",
messages=[],
roles=["USER", "BOT"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = MiniMaxWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21002",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21002)

View File

@ -1,216 +0,0 @@
import sys
from fastchat.conversation import Conversation
from server.model_workers.base import *
from server.utils import get_httpx_client
from cachetools import cached, TTLCache
import json
from fastchat import conversation as conv
import sys
from server.model_workers.base import ApiEmbeddingsParams
from typing import List, Literal, Dict
from configs import logger, log_verbose
MODEL_VERSIONS = {
"ernie-bot-4": "completions_pro",
"ernie-bot": "completions",
"ernie-bot-turbo": "eb-instant",
"bloomz-7b": "bloomz_7b1",
"qianfan-bloomz-7b-c": "qianfan_bloomz_7b_compressed",
"llama2-7b-chat": "llama_2_7b",
"llama2-13b-chat": "llama_2_13b",
"llama2-70b-chat": "llama_2_70b",
"qianfan-llama2-ch-7b": "qianfan_chinese_llama_2_7b",
"chatglm2-6b-32k": "chatglm2_6b_32k",
"aquilachat-7b": "aquilachat_7b",
# "linly-llama2-ch-7b": "", # 暂未发布
# "linly-llama2-ch-13b": "", # 暂未发布
# "chatglm2-6b": "", # 暂未发布
# "chatglm2-6b-int4": "", # 暂未发布
# "falcon-7b": "", # 暂未发布
# "falcon-180b-chat": "", # 暂未发布
# "falcon-40b": "", # 暂未发布
# "rwkv4-world": "", # 暂未发布
# "rwkv5-world": "", # 暂未发布
# "rwkv4-pile-14b": "", # 暂未发布
# "rwkv4-raven-14b": "", # 暂未发布
# "open-llama-7b": "", # 暂未发布
# "dolly-12b": "", # 暂未发布
# "mpt-7b-instruct": "", # 暂未发布
# "mpt-30b-instruct": "", # 暂未发布
# "OA-Pythia-12B-SFT-4": "", # 暂未发布
# "xverse-13b": "", # 暂未发布
# # 以下为企业测试,需要单独申请
# "flan-ul2": "",
# "Cerebras-GPT-6.7B": ""
# "Pythia-6.9B": ""
}
@cached(TTLCache(1, 1800)) # 经过测试缓存的token可以使用目前每30分钟刷新一次
def get_baidu_access_token(api_key: str, secret_key: str) -> str:
"""
使用 AKSK 生成鉴权签名Access Token
:return: access_token或是None(如果错误)
"""
url = "https://aip.baidubce.com/oauth/2.0/token"
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
try:
with get_httpx_client() as client:
return client.get(url, params=params).json().get("access_token")
except Exception as e:
print(f"failed to get token from baidu: {e}")
class QianFanWorker(ApiModelWorker):
"""
百度千帆
"""
DEFAULT_EMBED_MODEL = "embedding-v1"
def __init__(
self,
*,
version: Literal["ernie-bot", "ernie-bot-turbo"] = "ernie-bot",
model_names: List[str] = ["qianfan-api"],
controller_addr: str = None,
worker_addr: str = None,
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs)
self.version = version
def do_chat(self, params: ApiChatParams) -> Dict:
params.load_config(self.model_names[0])
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \
'/{model_version}?access_token={access_token}'
access_token = get_baidu_access_token(params.api_key, params.secret_key)
if not access_token:
yield {
"error_code": 403,
"text": f"failed to get access token. have you set the correct api_key and secret key?",
}
url = BASE_URL.format(
model_version=params.version_url or MODEL_VERSIONS[params.version.lower()],
access_token=access_token,
)
payload = {
"messages": params.messages,
"temperature": params.temperature,
"stream": True
}
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json',
}
text = ""
if log_verbose:
logger.info(f'{self.__class__.__name__}:data: {payload}')
logger.info(f'{self.__class__.__name__}:url: {url}')
logger.info(f'{self.__class__.__name__}:headers: {headers}')
with get_httpx_client() as client:
with client.stream("POST", url, headers=headers, json=payload) as response:
for line in response.iter_lines():
if not line.strip():
continue
if line.startswith("data: "):
line = line[6:]
resp = json.loads(line)
if "result" in resp.keys():
text += resp["result"]
yield {
"error_code": 0,
"text": text
}
else:
data = {
"error_code": resp["error_code"],
"text": resp["error_msg"],
"error": {
"message": resp["error_msg"],
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
self.logger.error(f"请求千帆 API 时发生错误:{data}")
yield data
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
params.load_config(self.model_names[0])
# import qianfan
# embed = qianfan.Embedding(ak=params.api_key, sk=params.secret_key)
# resp = embed.do(texts = params.texts, model=params.embed_model or self.DEFAULT_EMBED_MODEL)
# if resp.code == 200:
# embeddings = [x.embedding for x in resp.body.get("data", [])]
# return {"code": 200, "embeddings": embeddings}
# else:
# return {"code": resp.code, "msg": str(resp.body)}
embed_model = params.embed_model or self.DEFAULT_EMBED_MODEL
access_token = get_baidu_access_token(params.api_key, params.secret_key)
url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/{embed_model}?access_token={access_token}"
if log_verbose:
logger.info(f'{self.__class__.__name__}:url: {url}')
with get_httpx_client() as client:
result = []
i = 0
batch_size = 10
while i < len(params.texts):
texts = params.texts[i:i + batch_size]
resp = client.post(url, json={"input": texts}).json()
if "error_code" in resp:
data = {
"code": resp["error_code"],
"msg": resp["error_msg"],
"error": {
"message": resp["error_msg"],
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
self.logger.error(f"请求千帆 API 时发生错误:{data}")
return data
else:
embeddings = [x["embedding"] for x in resp.get("data", [])]
result += embeddings
i += batch_size
return {"code": 200, "data": result}
def get_embeddings(self, params):
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
return conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = QianFanWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21004"
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21004)

View File

@ -1,123 +0,0 @@
import sys
from fastchat.conversation import Conversation
from typing import List, Literal, Dict
from fastchat import conversation as conv
from server.model_workers.base import *
from server.model_workers.base import ApiEmbeddingsParams
from configs import logger, log_verbose
class QwenWorker(ApiModelWorker):
DEFAULT_EMBED_MODEL = "text-embedding-v1"
def __init__(
self,
*,
version: Literal["qwen-turbo", "qwen-plus"] = "qwen-turbo",
model_names: List[str] = ["qwen-api"],
controller_addr: str = None,
worker_addr: str = None,
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs)
self.version = version
def do_chat(self, params: ApiChatParams) -> Dict:
import dashscope
params.load_config(self.model_names[0])
if log_verbose:
logger.info(f'{self.__class__.__name__}:params: {params}')
gen = dashscope.Generation()
responses = gen.call(
model=params.version,
temperature=params.temperature,
api_key=params.api_key,
messages=params.messages,
result_format='message', # set the result is message format.
stream=True,
)
for resp in responses:
if resp["status_code"] == 200:
if choices := resp["output"]["choices"]:
yield {
"error_code": 0,
"text": choices[0]["message"]["content"],
}
else:
data = {
"error_code": resp["status_code"],
"text": resp["message"],
"error": {
"message": resp["message"],
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
self.logger.error(f"请求千问 API 时发生错误:{data}")
yield data
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
import dashscope
params.load_config(self.model_names[0])
if log_verbose:
logger.info(f'{self.__class__.__name__}:params: {params}')
result = []
i = 0
while i < len(params.texts):
texts = params.texts[i:i+25]
resp = dashscope.TextEmbedding.call(
model=params.embed_model or self.DEFAULT_EMBED_MODEL,
input=texts, # 最大25行
api_key=params.api_key,
)
if resp["status_code"] != 200:
data = {
"code": resp["status_code"],
"msg": resp.message,
"error": {
"message": resp["message"],
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
self.logger.error(f"请求千问 API 时发生错误:{data}")
return data
else:
embeddings = [x["embedding"] for x in resp["output"]["embeddings"]]
result += embeddings
i += 25
return {"code": 200, "data": result}
def get_embeddings(self, params):
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
return conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
messages=[],
roles=["user", "assistant", "system"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = QwenWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20007",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=20007)

View File

@ -1,84 +0,0 @@
import json
import time
import hashlib
from fastchat.conversation import Conversation
from server.model_workers.base import *
from server.utils import get_httpx_client
from fastchat import conversation as conv
import json
from typing import List, Literal, Dict
import requests
class TianGongWorker(ApiModelWorker):
def __init__(
self,
*,
controller_addr: str = None,
worker_addr: str = None,
model_names: List[str] = ["tiangong-api"],
version: Literal["SkyChat-MegaVerse"] = "SkyChat-MegaVerse",
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 32768)
super().__init__(**kwargs)
self.version = version
def do_chat(self, params: ApiChatParams) -> Dict:
params.load_config(self.model_names[0])
url = 'https://sky-api.singularity-ai.com/saas/api/v4/generate'
data = {
"messages": params.messages,
"model": "SkyChat-MegaVerse"
}
timestamp = str(int(time.time()))
sign_content = params.api_key + params.secret_key + timestamp
sign_result = hashlib.md5(sign_content.encode('utf-8')).hexdigest()
headers = {
"app_key": params.api_key,
"timestamp": timestamp,
"sign": sign_result,
"Content-Type": "application/json",
"stream": "true" # or change to "false" 不处理流式返回内容
}
# 发起请求并获取响应
response = requests.post(url, headers=headers, json=data, stream=True)
text = ""
# 处理响应流
for line in response.iter_lines(chunk_size=None, decode_unicode=True):
if line:
# 处理接收到的数据
# print(line.decode('utf-8'))
resp = json.loads(line)
if resp["code"] == 200:
text += resp['resp_data']['reply']
yield {
"error_code": 0,
"text": text
}
else:
data = {
"error_code": resp["code"],
"text": resp["code_msg"]
}
self.logger.error(f"请求天工 API 时出错:{data}")
yield data
def get_embeddings(self, params):
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
return conv.Conversation(
name=self.model_names[0],
system_message="",
messages=[],
roles=["user", "system"],
sep="\n### ",
stop_str="###",
)

View File

@ -1,100 +0,0 @@
from fastchat.conversation import Conversation
from server.model_workers.base import *
from fastchat import conversation as conv
import sys
import json
from server.model_workers import SparkApi
import websockets
from server.utils import iter_over_async, asyncio
from typing import List, Dict
async def request(appid, api_key, api_secret, Spark_url, domain, question, temperature, max_token):
wsParam = SparkApi.Ws_Param(appid, api_key, api_secret, Spark_url)
wsUrl = wsParam.create_url()
data = SparkApi.gen_params(appid, domain, question, temperature, max_token)
async with websockets.connect(wsUrl) as ws:
await ws.send(json.dumps(data, ensure_ascii=False))
finish = False
while not finish:
chunk = await ws.recv()
response = json.loads(chunk)
if response.get("header", {}).get("status") == 2:
finish = True
if text := response.get("payload", {}).get("choices", {}).get("text"):
yield text[0]["content"]
class XingHuoWorker(ApiModelWorker):
def __init__(
self,
*,
model_names: List[str] = ["xinghuo-api"],
controller_addr: str = None,
worker_addr: str = None,
version: str = None,
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 8000)
super().__init__(**kwargs)
self.version = version
def do_chat(self, params: ApiChatParams) -> Dict:
params.load_config(self.model_names[0])
version_mapping = {
"v1.5": {"domain": "general", "url": "ws://spark-api.xf-yun.com/v1.1/chat", "max_tokens": 4000},
"v2.0": {"domain": "generalv2", "url": "ws://spark-api.xf-yun.com/v2.1/chat", "max_tokens": 8000},
"v3.0": {"domain": "generalv3", "url": "ws://spark-api.xf-yun.com/v3.1/chat", "max_tokens": 8000},
"v3.5": {"domain": "generalv3", "url": "ws://spark-api.xf-yun.com/v3.5/chat", "max_tokens": 16000},
}
def get_version_details(version_key):
return version_mapping.get(version_key, {"domain": None, "url": None})
details = get_version_details(params.version)
domain = details["domain"]
Spark_url = details["url"]
text = ""
try:
loop = asyncio.get_event_loop()
except:
loop = asyncio.new_event_loop()
params.max_tokens = min(details["max_tokens"], params.max_tokens or 0)
for chunk in iter_over_async(
request(params.APPID, params.api_key, params.APISecret, Spark_url, domain, params.messages,
params.temperature, params.max_tokens),
loop=loop,
):
if chunk:
text += chunk
yield {"error_code": 0, "text": text}
def get_embeddings(self, params):
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
return conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = XingHuoWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21003",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21003)

View File

@ -1,114 +0,0 @@
from contextlib import contextmanager
import httpx
from fastchat.conversation import Conversation
from httpx_sse import EventSource
from server.model_workers.base import *
from fastchat import conversation as conv
import sys
from typing import List, Dict, Iterator, Literal, Any
import jwt
import time
@contextmanager
def connect_sse(client: httpx.Client, method: str, url: str, **kwargs: Any):
with client.stream(method, url, **kwargs) as response:
yield EventSource(response)
def generate_token(apikey: str, exp_seconds: int):
try:
id, secret = apikey.split(".")
except Exception as e:
raise Exception("invalid apikey", e)
payload = {
"api_key": id,
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
"timestamp": int(round(time.time() * 1000)),
}
return jwt.encode(
payload,
secret,
algorithm="HS256",
headers={"alg": "HS256", "sign_type": "SIGN"},
)
class ChatGLMWorker(ApiModelWorker):
def __init__(
self,
*,
model_names: List[str] = ["zhipu-api"],
controller_addr: str = None,
worker_addr: str = None,
version: Literal["glm-4"] = "glm-4",
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 4096)
super().__init__(**kwargs)
self.version = version
def do_chat(self, params: ApiChatParams) -> Iterator[Dict]:
params.load_config(self.model_names[0])
token = generate_token(params.api_key, 60)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {token}"
}
data = {
"model": params.version,
"messages": params.messages,
"max_tokens": params.max_tokens,
"temperature": params.temperature,
"stream": False
}
url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
with httpx.Client(headers=headers) as client:
response = client.post(url, json=data)
response.raise_for_status()
chunk = response.json()
print(chunk)
yield {"error_code": 0, "text": chunk["choices"][0]["message"]["content"]}
# with connect_sse(client, "POST", url, json=data) as event_source:
# for sse in event_source.iter_sse():
# chunk = json.loads(sse.data)
# if len(chunk["choices"]) != 0:
# text += chunk["choices"][0]["delta"]["content"]
# yield {"error_code": 0, "text": text}
def get_embeddings(self, params):
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
return conv.Conversation(
name=self.model_names[0],
system_message="你是智谱AI小助手请根据用户的提示来完成任务",
messages=[],
roles=["user", "assistant", "system"],
sep="\n###",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = ChatGLMWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21001",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21001)