mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-07 15:38:27 +08:00
删除本地fschat workers
This commit is contained in:
parent
175db6710e
commit
a7b28adc09
@ -1,79 +0,0 @@
|
|||||||
import base64
|
|
||||||
import datetime
|
|
||||||
import hashlib
|
|
||||||
import hmac
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
from datetime import datetime
|
|
||||||
from time import mktime
|
|
||||||
from urllib.parse import urlencode
|
|
||||||
from wsgiref.handlers import format_date_time
|
|
||||||
|
|
||||||
|
|
||||||
class Ws_Param(object):
|
|
||||||
# 初始化
|
|
||||||
def __init__(self, APPID, APIKey, APISecret, Spark_url):
|
|
||||||
self.APPID = APPID
|
|
||||||
self.APIKey = APIKey
|
|
||||||
self.APISecret = APISecret
|
|
||||||
self.host = urlparse(Spark_url).netloc
|
|
||||||
self.path = urlparse(Spark_url).path
|
|
||||||
self.Spark_url = Spark_url
|
|
||||||
|
|
||||||
# 生成url
|
|
||||||
def create_url(self):
|
|
||||||
# 生成RFC1123格式的时间戳
|
|
||||||
now = datetime.now()
|
|
||||||
date = format_date_time(mktime(now.timetuple()))
|
|
||||||
|
|
||||||
# 拼接字符串
|
|
||||||
signature_origin = "host: " + self.host + "\n"
|
|
||||||
signature_origin += "date: " + date + "\n"
|
|
||||||
signature_origin += "GET " + self.path + " HTTP/1.1"
|
|
||||||
|
|
||||||
# 进行hmac-sha256进行加密
|
|
||||||
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
|
|
||||||
digestmod=hashlib.sha256).digest()
|
|
||||||
|
|
||||||
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
|
||||||
|
|
||||||
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
|
||||||
|
|
||||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
|
||||||
|
|
||||||
# 将请求的鉴权参数组合为字典
|
|
||||||
v = {
|
|
||||||
"authorization": authorization,
|
|
||||||
"date": date,
|
|
||||||
"host": self.host
|
|
||||||
}
|
|
||||||
# 拼接鉴权参数,生成url
|
|
||||||
url = self.Spark_url + '?' + urlencode(v)
|
|
||||||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
|
||||||
return url
|
|
||||||
|
|
||||||
|
|
||||||
def gen_params(appid, domain, question, temperature, max_token):
|
|
||||||
"""
|
|
||||||
通过appid和用户的提问来生成请参数
|
|
||||||
"""
|
|
||||||
data = {
|
|
||||||
"header": {
|
|
||||||
"app_id": appid,
|
|
||||||
"uid": "1234"
|
|
||||||
},
|
|
||||||
"parameter": {
|
|
||||||
"chat": {
|
|
||||||
"domain": domain,
|
|
||||||
"random_threshold": 0.5,
|
|
||||||
"max_tokens": max_token,
|
|
||||||
"auditing": "default",
|
|
||||||
"temperature": temperature,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"payload": {
|
|
||||||
"message": {
|
|
||||||
"text": question
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return data
|
|
||||||
@ -1,11 +0,0 @@
|
|||||||
from .base import *
|
|
||||||
from .zhipu import ChatGLMWorker
|
|
||||||
from .minimax import MiniMaxWorker
|
|
||||||
from .xinghuo import XingHuoWorker
|
|
||||||
from .qianfan import QianFanWorker
|
|
||||||
from .fangzhou import FangZhouWorker
|
|
||||||
from .qwen import QwenWorker
|
|
||||||
from .baichuan import BaiChuanWorker
|
|
||||||
from .azure import AzureWorker
|
|
||||||
from .tiangong import TianGongWorker
|
|
||||||
from .gemini import GeminiWorker
|
|
||||||
@ -1,95 +0,0 @@
|
|||||||
import sys
|
|
||||||
import os
|
|
||||||
from fastchat.conversation import Conversation
|
|
||||||
from server.model_workers.base import *
|
|
||||||
from server.utils import get_httpx_client
|
|
||||||
from fastchat import conversation as conv
|
|
||||||
import json
|
|
||||||
from typing import List, Dict
|
|
||||||
from configs import logger, log_verbose
|
|
||||||
|
|
||||||
|
|
||||||
class AzureWorker(ApiModelWorker):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
controller_addr: str = None,
|
|
||||||
worker_addr: str = None,
|
|
||||||
model_names: List[str] = ["azure-api"],
|
|
||||||
version: str = "gpt-35-turbo",
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.version = version
|
|
||||||
|
|
||||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
|
||||||
params.load_config(self.model_names[0])
|
|
||||||
|
|
||||||
data = dict(
|
|
||||||
messages=params.messages,
|
|
||||||
temperature=params.temperature,
|
|
||||||
max_tokens=params.max_tokens if params.max_tokens else None,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
url = ("https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}"
|
|
||||||
.format(params.resource_name, params.deployment_name, params.api_version))
|
|
||||||
headers = {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Accept': 'application/json',
|
|
||||||
'api-key': params.api_key,
|
|
||||||
}
|
|
||||||
|
|
||||||
text = ""
|
|
||||||
if log_verbose:
|
|
||||||
logger.info(f'{self.__class__.__name__}:url: {url}')
|
|
||||||
logger.info(f'{self.__class__.__name__}:headers: {headers}')
|
|
||||||
logger.info(f'{self.__class__.__name__}:data: {data}')
|
|
||||||
|
|
||||||
with get_httpx_client() as client:
|
|
||||||
with client.stream("POST", url, headers=headers, json=data) as response:
|
|
||||||
print(data)
|
|
||||||
for line in response.iter_lines():
|
|
||||||
if not line.strip() or "[DONE]" in line:
|
|
||||||
continue
|
|
||||||
if line.startswith("data: "):
|
|
||||||
line = line[6:]
|
|
||||||
resp = json.loads(line)
|
|
||||||
if choices := resp["choices"]:
|
|
||||||
if chunk := choices[0].get("delta", {}).get("content"):
|
|
||||||
text += chunk
|
|
||||||
yield {
|
|
||||||
"error_code": 0,
|
|
||||||
"text": text
|
|
||||||
}
|
|
||||||
print(text)
|
|
||||||
else:
|
|
||||||
self.logger.error(f"请求 Azure API 时发生错误:{resp}")
|
|
||||||
|
|
||||||
def get_embeddings(self, params):
|
|
||||||
print("embedding")
|
|
||||||
print(params)
|
|
||||||
|
|
||||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
|
||||||
return conv.Conversation(
|
|
||||||
name=self.model_names[0],
|
|
||||||
system_message="You are a helpful, respectful and honest assistant.",
|
|
||||||
messages=[],
|
|
||||||
roles=["user", "assistant"],
|
|
||||||
sep="\n### ",
|
|
||||||
stop_str="###",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import uvicorn
|
|
||||||
from server.utils import MakeFastAPIOffline
|
|
||||||
from fastchat.serve.base_model_worker import app
|
|
||||||
|
|
||||||
worker = AzureWorker(
|
|
||||||
controller_addr="http://127.0.0.1:20001",
|
|
||||||
worker_addr="http://127.0.0.1:21008",
|
|
||||||
)
|
|
||||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
|
||||||
MakeFastAPIOffline(app)
|
|
||||||
uvicorn.run(app, port=21008)
|
|
||||||
@ -1,117 +0,0 @@
|
|||||||
import json
|
|
||||||
import time
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
from fastchat.conversation import Conversation
|
|
||||||
from server.model_workers.base import *
|
|
||||||
from server.utils import get_httpx_client
|
|
||||||
from fastchat import conversation as conv
|
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
from typing import List, Literal, Dict
|
|
||||||
from configs import logger, log_verbose
|
|
||||||
|
|
||||||
def calculate_md5(input_string):
|
|
||||||
md5 = hashlib.md5()
|
|
||||||
md5.update(input_string.encode('utf-8'))
|
|
||||||
encrypted = md5.hexdigest()
|
|
||||||
return encrypted
|
|
||||||
|
|
||||||
|
|
||||||
class BaiChuanWorker(ApiModelWorker):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
controller_addr: str = None,
|
|
||||||
worker_addr: str = None,
|
|
||||||
model_names: List[str] = ["baichuan-api"],
|
|
||||||
version: Literal["Baichuan2-53B"] = "Baichuan2-53B",
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
|
||||||
kwargs.setdefault("context_len", 32768)
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.version = version
|
|
||||||
|
|
||||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
|
||||||
params.load_config(self.model_names[0])
|
|
||||||
|
|
||||||
url = "https://api.baichuan-ai.com/v1/stream/chat"
|
|
||||||
data = {
|
|
||||||
"model": params.version,
|
|
||||||
"messages": params.messages,
|
|
||||||
"parameters": {"temperature": params.temperature}
|
|
||||||
}
|
|
||||||
|
|
||||||
json_data = json.dumps(data)
|
|
||||||
time_stamp = int(time.time())
|
|
||||||
signature = calculate_md5(params.secret_key + json_data + str(time_stamp))
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": "Bearer " + params.api_key,
|
|
||||||
"X-BC-Request-Id": "your requestId",
|
|
||||||
"X-BC-Timestamp": str(time_stamp),
|
|
||||||
"X-BC-Signature": signature,
|
|
||||||
"X-BC-Sign-Algo": "MD5",
|
|
||||||
}
|
|
||||||
|
|
||||||
text = ""
|
|
||||||
if log_verbose:
|
|
||||||
logger.info(f'{self.__class__.__name__}:json_data: {json_data}')
|
|
||||||
logger.info(f'{self.__class__.__name__}:url: {url}')
|
|
||||||
logger.info(f'{self.__class__.__name__}:headers: {headers}')
|
|
||||||
|
|
||||||
with get_httpx_client() as client:
|
|
||||||
with client.stream("POST", url, headers=headers, json=data) as response:
|
|
||||||
for line in response.iter_lines():
|
|
||||||
if not line.strip():
|
|
||||||
continue
|
|
||||||
resp = json.loads(line)
|
|
||||||
if resp["code"] == 0:
|
|
||||||
text += resp["data"]["messages"][-1]["content"]
|
|
||||||
yield {
|
|
||||||
"error_code": resp["code"],
|
|
||||||
"text": text
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
data = {
|
|
||||||
"error_code": resp["code"],
|
|
||||||
"text": resp["msg"],
|
|
||||||
"error": {
|
|
||||||
"message": resp["msg"],
|
|
||||||
"type": "invalid_request_error",
|
|
||||||
"param": None,
|
|
||||||
"code": None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
self.logger.error(f"请求百川 API 时发生错误:{data}")
|
|
||||||
yield data
|
|
||||||
|
|
||||||
def get_embeddings(self, params):
|
|
||||||
print("embedding")
|
|
||||||
print(params)
|
|
||||||
|
|
||||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
|
||||||
return conv.Conversation(
|
|
||||||
name=self.model_names[0],
|
|
||||||
system_message="",
|
|
||||||
messages=[],
|
|
||||||
roles=["user", "assistant"],
|
|
||||||
sep="\n### ",
|
|
||||||
stop_str="###",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import uvicorn
|
|
||||||
from server.utils import MakeFastAPIOffline
|
|
||||||
from fastchat.serve.model_worker import app
|
|
||||||
|
|
||||||
worker = BaiChuanWorker(
|
|
||||||
controller_addr="http://127.0.0.1:20001",
|
|
||||||
worker_addr="http://127.0.0.1:21007",
|
|
||||||
)
|
|
||||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
|
||||||
MakeFastAPIOffline(app)
|
|
||||||
uvicorn.run(app, port=21007)
|
|
||||||
# do_request()
|
|
||||||
@ -1,248 +0,0 @@
|
|||||||
from fastchat.conversation import Conversation
|
|
||||||
from configs import LOG_PATH
|
|
||||||
import fastchat.constants
|
|
||||||
fastchat.constants.LOGDIR = LOG_PATH
|
|
||||||
from fastchat.serve.base_model_worker import BaseModelWorker
|
|
||||||
import uuid
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
from pydantic import BaseModel, root_validator
|
|
||||||
import fastchat
|
|
||||||
import asyncio
|
|
||||||
from server.utils import get_model_worker_config
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["ApiModelWorker", "ApiChatParams", "ApiCompletionParams", "ApiEmbeddingsParams"]
|
|
||||||
|
|
||||||
|
|
||||||
class ApiConfigParams(BaseModel):
|
|
||||||
'''
|
|
||||||
在线API配置参数,未提供的值会自动从model_config.ONLINE_LLM_MODEL中读取
|
|
||||||
'''
|
|
||||||
api_base_url: Optional[str] = None
|
|
||||||
api_proxy: Optional[str] = None
|
|
||||||
api_key: Optional[str] = None
|
|
||||||
secret_key: Optional[str] = None
|
|
||||||
group_id: Optional[str] = None # for minimax
|
|
||||||
is_pro: bool = False # for minimax
|
|
||||||
|
|
||||||
APPID: Optional[str] = None # for xinghuo
|
|
||||||
APISecret: Optional[str] = None # for xinghuo
|
|
||||||
is_v2: bool = False # for xinghuo
|
|
||||||
|
|
||||||
worker_name: Optional[str] = None
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
extra = "allow"
|
|
||||||
|
|
||||||
@root_validator(pre=True)
|
|
||||||
def validate_config(cls, v: Dict) -> Dict:
|
|
||||||
if config := get_model_worker_config(v.get("worker_name")):
|
|
||||||
for n in cls.__fields__:
|
|
||||||
if n in config:
|
|
||||||
v[n] = config[n]
|
|
||||||
return v
|
|
||||||
|
|
||||||
def load_config(self, worker_name: str):
|
|
||||||
self.worker_name = worker_name
|
|
||||||
if config := get_model_worker_config(worker_name):
|
|
||||||
for n in self.__fields__:
|
|
||||||
if n in config:
|
|
||||||
setattr(self, n, config[n])
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class ApiModelParams(ApiConfigParams):
|
|
||||||
'''
|
|
||||||
模型配置参数
|
|
||||||
'''
|
|
||||||
version: Optional[str] = None
|
|
||||||
version_url: Optional[str] = None
|
|
||||||
api_version: Optional[str] = None # for azure
|
|
||||||
deployment_name: Optional[str] = None # for azure
|
|
||||||
resource_name: Optional[str] = None # for azure
|
|
||||||
|
|
||||||
temperature: float = 0.9
|
|
||||||
max_tokens: Optional[int] = None
|
|
||||||
top_p: Optional[float] = 1.0
|
|
||||||
|
|
||||||
|
|
||||||
class ApiChatParams(ApiModelParams):
|
|
||||||
'''
|
|
||||||
chat请求参数
|
|
||||||
'''
|
|
||||||
messages: List[Dict[str, str]]
|
|
||||||
system_message: Optional[str] = None # for minimax
|
|
||||||
role_meta: Dict = {} # for minimax
|
|
||||||
|
|
||||||
|
|
||||||
class ApiCompletionParams(ApiModelParams):
|
|
||||||
prompt: str
|
|
||||||
|
|
||||||
|
|
||||||
class ApiEmbeddingsParams(ApiConfigParams):
|
|
||||||
texts: List[str]
|
|
||||||
embed_model: Optional[str] = None
|
|
||||||
to_query: bool = False # for minimax
|
|
||||||
|
|
||||||
|
|
||||||
class ApiModelWorker(BaseModelWorker):
|
|
||||||
DEFAULT_EMBED_MODEL: str = None # None means not support embedding
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_names: List[str],
|
|
||||||
controller_addr: str = None,
|
|
||||||
worker_addr: str = None,
|
|
||||||
context_len: int = 2048,
|
|
||||||
no_register: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
kwargs.setdefault("worker_id", uuid.uuid4().hex[:8])
|
|
||||||
kwargs.setdefault("model_path", "")
|
|
||||||
kwargs.setdefault("limit_worker_concurrency", 5)
|
|
||||||
super().__init__(model_names=model_names,
|
|
||||||
controller_addr=controller_addr,
|
|
||||||
worker_addr=worker_addr,
|
|
||||||
**kwargs)
|
|
||||||
import fastchat.serve.base_model_worker
|
|
||||||
import sys
|
|
||||||
self.logger = fastchat.serve.base_model_worker.logger
|
|
||||||
# 恢复被fastchat覆盖的标准输出
|
|
||||||
sys.stdout = sys.__stdout__
|
|
||||||
sys.stderr = sys.__stderr__
|
|
||||||
|
|
||||||
new_loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(new_loop)
|
|
||||||
|
|
||||||
self.context_len = context_len
|
|
||||||
self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency)
|
|
||||||
self.version = None
|
|
||||||
|
|
||||||
if not no_register and self.controller_addr:
|
|
||||||
self.init_heart_beat()
|
|
||||||
|
|
||||||
|
|
||||||
def count_token(self, params):
|
|
||||||
prompt = params["prompt"]
|
|
||||||
return {"count": len(str(prompt)), "error_code": 0}
|
|
||||||
|
|
||||||
def generate_stream_gate(self, params: Dict):
|
|
||||||
self.call_ct += 1
|
|
||||||
|
|
||||||
try:
|
|
||||||
prompt = params["prompt"]
|
|
||||||
if self._is_chat(prompt):
|
|
||||||
messages = self.prompt_to_messages(prompt)
|
|
||||||
messages = self.validate_messages(messages)
|
|
||||||
else: # 使用chat模仿续写功能,不支持历史消息
|
|
||||||
messages = [{"role": self.user_role, "content": f"please continue writing from here: {prompt}"}]
|
|
||||||
|
|
||||||
p = ApiChatParams(
|
|
||||||
messages=messages,
|
|
||||||
temperature=params.get("temperature"),
|
|
||||||
top_p=params.get("top_p"),
|
|
||||||
max_tokens=params.get("max_new_tokens"),
|
|
||||||
version=self.version,
|
|
||||||
)
|
|
||||||
for resp in self.do_chat(p):
|
|
||||||
yield self._jsonify(resp)
|
|
||||||
except Exception as e:
|
|
||||||
yield self._jsonify({"error_code": 500, "text": f"{self.model_names[0]}请求API时发生错误:{e}"})
|
|
||||||
|
|
||||||
def generate_gate(self, params):
|
|
||||||
try:
|
|
||||||
for x in self.generate_stream_gate(params):
|
|
||||||
...
|
|
||||||
return json.loads(x[:-1].decode())
|
|
||||||
except Exception as e:
|
|
||||||
return {"error_code": 500, "text": str(e)}
|
|
||||||
|
|
||||||
|
|
||||||
# 需要用户自定义的方法
|
|
||||||
|
|
||||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
|
||||||
'''
|
|
||||||
执行Chat的方法,默认使用模块里面的chat函数。
|
|
||||||
要求返回形式:{"error_code": int, "text": str}
|
|
||||||
'''
|
|
||||||
return {"error_code": 500, "text": f"{self.model_names[0]}未实现chat功能"}
|
|
||||||
|
|
||||||
# def do_completion(self, p: ApiCompletionParams) -> Dict:
|
|
||||||
# '''
|
|
||||||
# 执行Completion的方法,默认使用模块里面的completion函数。
|
|
||||||
# 要求返回形式:{"error_code": int, "text": str}
|
|
||||||
# '''
|
|
||||||
# return {"error_code": 500, "text": f"{self.model_names[0]}未实现completion功能"}
|
|
||||||
|
|
||||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
|
||||||
'''
|
|
||||||
执行Embeddings的方法,默认使用模块里面的embed_documents函数。
|
|
||||||
要求返回形式:{"code": int, "data": List[List[float]], "msg": str}
|
|
||||||
'''
|
|
||||||
return {"code": 500, "msg": f"{self.model_names[0]}未实现embeddings功能"}
|
|
||||||
|
|
||||||
def get_embeddings(self, params):
|
|
||||||
# fastchat对LLM做Embeddings限制很大,似乎只能使用openai的。
|
|
||||||
# 在前端通过OpenAIEmbeddings发起的请求直接出错,无法请求过来。
|
|
||||||
print("get_embedding")
|
|
||||||
print(params)
|
|
||||||
|
|
||||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def validate_messages(self, messages: List[Dict]) -> List[Dict]:
|
|
||||||
'''
|
|
||||||
有些API对mesages有特殊格式,可以重写该函数替换默认的messages。
|
|
||||||
之所以跟prompt_to_messages分开,是因为他们应用场景不同、参数不同
|
|
||||||
'''
|
|
||||||
return messages
|
|
||||||
|
|
||||||
|
|
||||||
# help methods
|
|
||||||
@property
|
|
||||||
def user_role(self):
|
|
||||||
return self.conv.roles[0]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def ai_role(self):
|
|
||||||
return self.conv.roles[1]
|
|
||||||
|
|
||||||
def _jsonify(self, data: Dict) -> str:
|
|
||||||
'''
|
|
||||||
将chat函数返回的结果按照fastchat openai-api-server的格式返回
|
|
||||||
'''
|
|
||||||
return json.dumps(data, ensure_ascii=False).encode() + b"\0"
|
|
||||||
|
|
||||||
def _is_chat(self, prompt: str) -> bool:
|
|
||||||
'''
|
|
||||||
检查prompt是否由chat messages拼接而来
|
|
||||||
TODO: 存在误判的可能,也许从fastchat直接传入原始messages是更好的做法
|
|
||||||
'''
|
|
||||||
key = f"{self.conv.sep}{self.user_role}:"
|
|
||||||
return key in prompt
|
|
||||||
|
|
||||||
def prompt_to_messages(self, prompt: str) -> List[Dict]:
|
|
||||||
'''
|
|
||||||
将prompt字符串拆分成messages.
|
|
||||||
'''
|
|
||||||
result = []
|
|
||||||
user_role = self.user_role
|
|
||||||
ai_role = self.ai_role
|
|
||||||
user_start = user_role + ":"
|
|
||||||
ai_start = ai_role + ":"
|
|
||||||
for msg in prompt.split(self.conv.sep)[1:-1]:
|
|
||||||
if msg.startswith(user_start):
|
|
||||||
if content := msg[len(user_start):].strip():
|
|
||||||
result.append({"role": user_role, "content": content})
|
|
||||||
elif msg.startswith(ai_start):
|
|
||||||
if content := msg[len(ai_start):].strip():
|
|
||||||
result.append({"role": ai_role, "content": content})
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"unknown role in msg: {msg}")
|
|
||||||
return result
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def can_embedding(cls):
|
|
||||||
return cls.DEFAULT_EMBED_MODEL is not None
|
|
||||||
@ -1,105 +0,0 @@
|
|||||||
from fastchat.conversation import Conversation
|
|
||||||
from server.model_workers.base import *
|
|
||||||
from fastchat import conversation as conv
|
|
||||||
import sys
|
|
||||||
from typing import List, Literal, Dict
|
|
||||||
from configs import logger, log_verbose
|
|
||||||
|
|
||||||
|
|
||||||
class FangZhouWorker(ApiModelWorker):
|
|
||||||
"""
|
|
||||||
火山方舟
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
model_names: List[str] = ["fangzhou-api"],
|
|
||||||
controller_addr: str = None,
|
|
||||||
worker_addr: str = None,
|
|
||||||
version: Literal["chatglm-6b-model"] = "chatglm-6b-model",
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
|
||||||
kwargs.setdefault("context_len", 16384)
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.version = version
|
|
||||||
|
|
||||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
|
||||||
from volcengine.maas import MaasService
|
|
||||||
|
|
||||||
params.load_config(self.model_names[0])
|
|
||||||
maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing')
|
|
||||||
maas.set_ak(params.api_key)
|
|
||||||
maas.set_sk(params.secret_key)
|
|
||||||
|
|
||||||
# document: "https://www.volcengine.com/docs/82379/1099475"
|
|
||||||
req = {
|
|
||||||
"model": {
|
|
||||||
"name": params.version,
|
|
||||||
},
|
|
||||||
"parameters": {
|
|
||||||
# 这里的参数仅为示例,具体可用的参数请参考具体模型的 API 说明
|
|
||||||
"max_new_tokens": params.max_tokens,
|
|
||||||
"temperature": params.temperature,
|
|
||||||
},
|
|
||||||
"messages": params.messages,
|
|
||||||
}
|
|
||||||
|
|
||||||
text = ""
|
|
||||||
if log_verbose:
|
|
||||||
self.logger.info(f'{self.__class__.__name__}:maas: {maas}')
|
|
||||||
for resp in maas.stream_chat(req):
|
|
||||||
if error := resp.error:
|
|
||||||
if error.code_n > 0:
|
|
||||||
data = {
|
|
||||||
"error_code": error.code_n,
|
|
||||||
"text": error.message,
|
|
||||||
"error": {
|
|
||||||
"message": error.message,
|
|
||||||
"type": "invalid_request_error",
|
|
||||||
"param": None,
|
|
||||||
"code": None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
self.logger.error(f"请求方舟 API 时发生错误:{data}")
|
|
||||||
yield data
|
|
||||||
elif chunk := resp.choice.message.content:
|
|
||||||
text += chunk
|
|
||||||
yield {"error_code": 0, "text": text}
|
|
||||||
else:
|
|
||||||
data = {
|
|
||||||
"error_code": 500,
|
|
||||||
"text": f"请求方舟 API 时发生未知的错误: {resp}"
|
|
||||||
}
|
|
||||||
self.logger.error(data)
|
|
||||||
yield data
|
|
||||||
break
|
|
||||||
|
|
||||||
def get_embeddings(self, params):
|
|
||||||
print("embedding")
|
|
||||||
print(params)
|
|
||||||
|
|
||||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
|
||||||
return conv.Conversation(
|
|
||||||
name=self.model_names[0],
|
|
||||||
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
|
|
||||||
messages=[],
|
|
||||||
roles=["user", "assistant", "system"],
|
|
||||||
sep="\n### ",
|
|
||||||
stop_str="###",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import uvicorn
|
|
||||||
from server.utils import MakeFastAPIOffline
|
|
||||||
from fastchat.serve.model_worker import app
|
|
||||||
|
|
||||||
worker = FangZhouWorker(
|
|
||||||
controller_addr="http://127.0.0.1:20001",
|
|
||||||
worker_addr="http://127.0.0.1:21005",
|
|
||||||
)
|
|
||||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
|
||||||
MakeFastAPIOffline(app)
|
|
||||||
uvicorn.run(app, port=21005)
|
|
||||||
@ -1,170 +0,0 @@
|
|||||||
from fastchat.conversation import Conversation
|
|
||||||
from server.model_workers.base import *
|
|
||||||
from fastchat import conversation as conv
|
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
from server.model_workers.base import ApiEmbeddingsParams
|
|
||||||
from server.utils import get_httpx_client
|
|
||||||
from typing import List, Dict
|
|
||||||
from configs import logger, log_verbose
|
|
||||||
|
|
||||||
|
|
||||||
class MiniMaxWorker(ApiModelWorker):
|
|
||||||
DEFAULT_EMBED_MODEL = "embo-01"
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
model_names: List[str] = ["minimax-api"],
|
|
||||||
controller_addr: str = None,
|
|
||||||
worker_addr: str = None,
|
|
||||||
version: str = "abab5.5-chat",
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
|
||||||
kwargs.setdefault("context_len", 16384)
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.version = version
|
|
||||||
|
|
||||||
def validate_messages(self, messages: List[Dict]) -> List[Dict]:
|
|
||||||
role_maps = {
|
|
||||||
"USER": self.user_role,
|
|
||||||
"assistant": self.ai_role,
|
|
||||||
"system": "system",
|
|
||||||
}
|
|
||||||
messages = [{"sender_type": role_maps[x["role"]], "text": x["content"]} for x in messages]
|
|
||||||
return messages
|
|
||||||
|
|
||||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
|
||||||
# 按照官网推荐,直接调用abab 5.5模型
|
|
||||||
params.load_config(self.model_names[0])
|
|
||||||
|
|
||||||
url = 'https://api.minimax.chat/v1/text/chatcompletion{pro}?GroupId={group_id}'
|
|
||||||
pro = "_pro" if params.is_pro else ""
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {params.api_key}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
messages = self.validate_messages(params.messages)
|
|
||||||
data = {
|
|
||||||
"model": params.version,
|
|
||||||
"stream": True,
|
|
||||||
"mask_sensitive_info": True,
|
|
||||||
"messages": messages,
|
|
||||||
"temperature": params.temperature,
|
|
||||||
"top_p": params.top_p,
|
|
||||||
"tokens_to_generate": params.max_tokens or 1024,
|
|
||||||
# 以下参数为minimax特有,传入空值会出错。
|
|
||||||
# "prompt": params.system_message or self.conv.system_message,
|
|
||||||
# "bot_setting": [],
|
|
||||||
# "role_meta": params.role_meta,
|
|
||||||
}
|
|
||||||
if log_verbose:
|
|
||||||
logger.info(f'{self.__class__.__name__}:data: {data}')
|
|
||||||
logger.info(f'{self.__class__.__name__}:url: {url.format(pro=pro, group_id=params.group_id)}')
|
|
||||||
logger.info(f'{self.__class__.__name__}:headers: {headers}')
|
|
||||||
|
|
||||||
with get_httpx_client() as client:
|
|
||||||
response = client.stream("POST",
|
|
||||||
url.format(pro=pro, group_id=params.group_id),
|
|
||||||
headers=headers,
|
|
||||||
json=data)
|
|
||||||
with response as r:
|
|
||||||
text = ""
|
|
||||||
for e in r.iter_text():
|
|
||||||
if not e.startswith("data: "):
|
|
||||||
data = {
|
|
||||||
"error_code": 500,
|
|
||||||
"text": f"minimax返回错误的结果:{e}",
|
|
||||||
"error": {
|
|
||||||
"message": f"minimax返回错误的结果:{e}",
|
|
||||||
"type": "invalid_request_error",
|
|
||||||
"param": None,
|
|
||||||
"code": None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
|
|
||||||
yield data
|
|
||||||
continue
|
|
||||||
|
|
||||||
data = json.loads(e[6:])
|
|
||||||
if data.get("usage"):
|
|
||||||
break
|
|
||||||
|
|
||||||
if choices := data.get("choices"):
|
|
||||||
if chunk := choices[0].get("delta", ""):
|
|
||||||
text += chunk
|
|
||||||
yield {"error_code": 0, "text": text}
|
|
||||||
|
|
||||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
|
||||||
params.load_config(self.model_names[0])
|
|
||||||
url = f"https://api.minimax.chat/v1/embeddings?GroupId={params.group_id}"
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {params.api_key}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"model": params.embed_model or self.DEFAULT_EMBED_MODEL,
|
|
||||||
"texts": [],
|
|
||||||
"type": "query" if params.to_query else "db",
|
|
||||||
}
|
|
||||||
if log_verbose:
|
|
||||||
logger.info(f'{self.__class__.__name__}:data: {data}')
|
|
||||||
logger.info(f'{self.__class__.__name__}:url: {url}')
|
|
||||||
logger.info(f'{self.__class__.__name__}:headers: {headers}')
|
|
||||||
|
|
||||||
with get_httpx_client() as client:
|
|
||||||
result = []
|
|
||||||
i = 0
|
|
||||||
batch_size = 10
|
|
||||||
while i < len(params.texts):
|
|
||||||
texts = params.texts[i:i+batch_size]
|
|
||||||
data["texts"] = texts
|
|
||||||
r = client.post(url, headers=headers, json=data).json()
|
|
||||||
if embeddings := r.get("vectors"):
|
|
||||||
result += embeddings
|
|
||||||
elif error := r.get("base_resp"):
|
|
||||||
data = {
|
|
||||||
"code": error["status_code"],
|
|
||||||
"msg": error["status_msg"],
|
|
||||||
"error": {
|
|
||||||
"message": error["status_msg"],
|
|
||||||
"type": "invalid_request_error",
|
|
||||||
"param": None,
|
|
||||||
"code": None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
|
|
||||||
return data
|
|
||||||
i += batch_size
|
|
||||||
return {"code": 200, "data": result}
|
|
||||||
|
|
||||||
def get_embeddings(self, params):
|
|
||||||
print("embedding")
|
|
||||||
print(params)
|
|
||||||
|
|
||||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
|
||||||
return conv.Conversation(
|
|
||||||
name=self.model_names[0],
|
|
||||||
system_message="你是MiniMax自主研发的大型语言模型,回答问题简洁有条理。",
|
|
||||||
messages=[],
|
|
||||||
roles=["USER", "BOT"],
|
|
||||||
sep="\n### ",
|
|
||||||
stop_str="###",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import uvicorn
|
|
||||||
from server.utils import MakeFastAPIOffline
|
|
||||||
from fastchat.serve.model_worker import app
|
|
||||||
|
|
||||||
worker = MiniMaxWorker(
|
|
||||||
controller_addr="http://127.0.0.1:20001",
|
|
||||||
worker_addr="http://127.0.0.1:21002",
|
|
||||||
)
|
|
||||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
|
||||||
MakeFastAPIOffline(app)
|
|
||||||
uvicorn.run(app, port=21002)
|
|
||||||
@ -1,216 +0,0 @@
|
|||||||
import sys
|
|
||||||
from fastchat.conversation import Conversation
|
|
||||||
from server.model_workers.base import *
|
|
||||||
from server.utils import get_httpx_client
|
|
||||||
from cachetools import cached, TTLCache
|
|
||||||
import json
|
|
||||||
from fastchat import conversation as conv
|
|
||||||
import sys
|
|
||||||
from server.model_workers.base import ApiEmbeddingsParams
|
|
||||||
from typing import List, Literal, Dict
|
|
||||||
from configs import logger, log_verbose
|
|
||||||
|
|
||||||
MODEL_VERSIONS = {
|
|
||||||
"ernie-bot-4": "completions_pro",
|
|
||||||
"ernie-bot": "completions",
|
|
||||||
"ernie-bot-turbo": "eb-instant",
|
|
||||||
"bloomz-7b": "bloomz_7b1",
|
|
||||||
"qianfan-bloomz-7b-c": "qianfan_bloomz_7b_compressed",
|
|
||||||
"llama2-7b-chat": "llama_2_7b",
|
|
||||||
"llama2-13b-chat": "llama_2_13b",
|
|
||||||
"llama2-70b-chat": "llama_2_70b",
|
|
||||||
"qianfan-llama2-ch-7b": "qianfan_chinese_llama_2_7b",
|
|
||||||
"chatglm2-6b-32k": "chatglm2_6b_32k",
|
|
||||||
"aquilachat-7b": "aquilachat_7b",
|
|
||||||
# "linly-llama2-ch-7b": "", # 暂未发布
|
|
||||||
# "linly-llama2-ch-13b": "", # 暂未发布
|
|
||||||
# "chatglm2-6b": "", # 暂未发布
|
|
||||||
# "chatglm2-6b-int4": "", # 暂未发布
|
|
||||||
# "falcon-7b": "", # 暂未发布
|
|
||||||
# "falcon-180b-chat": "", # 暂未发布
|
|
||||||
# "falcon-40b": "", # 暂未发布
|
|
||||||
# "rwkv4-world": "", # 暂未发布
|
|
||||||
# "rwkv5-world": "", # 暂未发布
|
|
||||||
# "rwkv4-pile-14b": "", # 暂未发布
|
|
||||||
# "rwkv4-raven-14b": "", # 暂未发布
|
|
||||||
# "open-llama-7b": "", # 暂未发布
|
|
||||||
# "dolly-12b": "", # 暂未发布
|
|
||||||
# "mpt-7b-instruct": "", # 暂未发布
|
|
||||||
# "mpt-30b-instruct": "", # 暂未发布
|
|
||||||
# "OA-Pythia-12B-SFT-4": "", # 暂未发布
|
|
||||||
# "xverse-13b": "", # 暂未发布
|
|
||||||
|
|
||||||
# # 以下为企业测试,需要单独申请
|
|
||||||
# "flan-ul2": "",
|
|
||||||
# "Cerebras-GPT-6.7B": ""
|
|
||||||
# "Pythia-6.9B": ""
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@cached(TTLCache(1, 1800)) # 经过测试,缓存的token可以使用,目前每30分钟刷新一次
|
|
||||||
def get_baidu_access_token(api_key: str, secret_key: str) -> str:
|
|
||||||
"""
|
|
||||||
使用 AK,SK 生成鉴权签名(Access Token)
|
|
||||||
:return: access_token,或是None(如果错误)
|
|
||||||
"""
|
|
||||||
url = "https://aip.baidubce.com/oauth/2.0/token"
|
|
||||||
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
|
|
||||||
try:
|
|
||||||
with get_httpx_client() as client:
|
|
||||||
return client.get(url, params=params).json().get("access_token")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"failed to get token from baidu: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
class QianFanWorker(ApiModelWorker):
|
|
||||||
"""
|
|
||||||
百度千帆
|
|
||||||
"""
|
|
||||||
DEFAULT_EMBED_MODEL = "embedding-v1"
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
version: Literal["ernie-bot", "ernie-bot-turbo"] = "ernie-bot",
|
|
||||||
model_names: List[str] = ["qianfan-api"],
|
|
||||||
controller_addr: str = None,
|
|
||||||
worker_addr: str = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
|
||||||
kwargs.setdefault("context_len", 16384)
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.version = version
|
|
||||||
|
|
||||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
|
||||||
params.load_config(self.model_names[0])
|
|
||||||
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \
|
|
||||||
'/{model_version}?access_token={access_token}'
|
|
||||||
|
|
||||||
access_token = get_baidu_access_token(params.api_key, params.secret_key)
|
|
||||||
if not access_token:
|
|
||||||
yield {
|
|
||||||
"error_code": 403,
|
|
||||||
"text": f"failed to get access token. have you set the correct api_key and secret key?",
|
|
||||||
}
|
|
||||||
|
|
||||||
url = BASE_URL.format(
|
|
||||||
model_version=params.version_url or MODEL_VERSIONS[params.version.lower()],
|
|
||||||
access_token=access_token,
|
|
||||||
)
|
|
||||||
payload = {
|
|
||||||
"messages": params.messages,
|
|
||||||
"temperature": params.temperature,
|
|
||||||
"stream": True
|
|
||||||
}
|
|
||||||
headers = {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Accept': 'application/json',
|
|
||||||
}
|
|
||||||
|
|
||||||
text = ""
|
|
||||||
if log_verbose:
|
|
||||||
logger.info(f'{self.__class__.__name__}:data: {payload}')
|
|
||||||
logger.info(f'{self.__class__.__name__}:url: {url}')
|
|
||||||
logger.info(f'{self.__class__.__name__}:headers: {headers}')
|
|
||||||
|
|
||||||
with get_httpx_client() as client:
|
|
||||||
with client.stream("POST", url, headers=headers, json=payload) as response:
|
|
||||||
for line in response.iter_lines():
|
|
||||||
if not line.strip():
|
|
||||||
continue
|
|
||||||
if line.startswith("data: "):
|
|
||||||
line = line[6:]
|
|
||||||
resp = json.loads(line)
|
|
||||||
|
|
||||||
if "result" in resp.keys():
|
|
||||||
text += resp["result"]
|
|
||||||
yield {
|
|
||||||
"error_code": 0,
|
|
||||||
"text": text
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
data = {
|
|
||||||
"error_code": resp["error_code"],
|
|
||||||
"text": resp["error_msg"],
|
|
||||||
"error": {
|
|
||||||
"message": resp["error_msg"],
|
|
||||||
"type": "invalid_request_error",
|
|
||||||
"param": None,
|
|
||||||
"code": None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
self.logger.error(f"请求千帆 API 时发生错误:{data}")
|
|
||||||
yield data
|
|
||||||
|
|
||||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
|
||||||
params.load_config(self.model_names[0])
|
|
||||||
# import qianfan
|
|
||||||
|
|
||||||
# embed = qianfan.Embedding(ak=params.api_key, sk=params.secret_key)
|
|
||||||
# resp = embed.do(texts = params.texts, model=params.embed_model or self.DEFAULT_EMBED_MODEL)
|
|
||||||
# if resp.code == 200:
|
|
||||||
# embeddings = [x.embedding for x in resp.body.get("data", [])]
|
|
||||||
# return {"code": 200, "embeddings": embeddings}
|
|
||||||
# else:
|
|
||||||
# return {"code": resp.code, "msg": str(resp.body)}
|
|
||||||
|
|
||||||
embed_model = params.embed_model or self.DEFAULT_EMBED_MODEL
|
|
||||||
access_token = get_baidu_access_token(params.api_key, params.secret_key)
|
|
||||||
url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/{embed_model}?access_token={access_token}"
|
|
||||||
if log_verbose:
|
|
||||||
logger.info(f'{self.__class__.__name__}:url: {url}')
|
|
||||||
|
|
||||||
with get_httpx_client() as client:
|
|
||||||
result = []
|
|
||||||
i = 0
|
|
||||||
batch_size = 10
|
|
||||||
while i < len(params.texts):
|
|
||||||
texts = params.texts[i:i + batch_size]
|
|
||||||
resp = client.post(url, json={"input": texts}).json()
|
|
||||||
if "error_code" in resp:
|
|
||||||
data = {
|
|
||||||
"code": resp["error_code"],
|
|
||||||
"msg": resp["error_msg"],
|
|
||||||
"error": {
|
|
||||||
"message": resp["error_msg"],
|
|
||||||
"type": "invalid_request_error",
|
|
||||||
"param": None,
|
|
||||||
"code": None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
self.logger.error(f"请求千帆 API 时发生错误:{data}")
|
|
||||||
return data
|
|
||||||
else:
|
|
||||||
embeddings = [x["embedding"] for x in resp.get("data", [])]
|
|
||||||
result += embeddings
|
|
||||||
i += batch_size
|
|
||||||
return {"code": 200, "data": result}
|
|
||||||
|
|
||||||
def get_embeddings(self, params):
|
|
||||||
print("embedding")
|
|
||||||
print(params)
|
|
||||||
|
|
||||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
|
||||||
return conv.Conversation(
|
|
||||||
name=self.model_names[0],
|
|
||||||
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
|
|
||||||
messages=[],
|
|
||||||
roles=["user", "assistant"],
|
|
||||||
sep="\n### ",
|
|
||||||
stop_str="###",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import uvicorn
|
|
||||||
from server.utils import MakeFastAPIOffline
|
|
||||||
from fastchat.serve.model_worker import app
|
|
||||||
|
|
||||||
worker = QianFanWorker(
|
|
||||||
controller_addr="http://127.0.0.1:20001",
|
|
||||||
worker_addr="http://127.0.0.1:21004"
|
|
||||||
)
|
|
||||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
|
||||||
MakeFastAPIOffline(app)
|
|
||||||
uvicorn.run(app, port=21004)
|
|
||||||
@ -1,123 +0,0 @@
|
|||||||
import sys
|
|
||||||
from fastchat.conversation import Conversation
|
|
||||||
from typing import List, Literal, Dict
|
|
||||||
from fastchat import conversation as conv
|
|
||||||
from server.model_workers.base import *
|
|
||||||
from server.model_workers.base import ApiEmbeddingsParams
|
|
||||||
from configs import logger, log_verbose
|
|
||||||
|
|
||||||
|
|
||||||
class QwenWorker(ApiModelWorker):
|
|
||||||
DEFAULT_EMBED_MODEL = "text-embedding-v1"
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
version: Literal["qwen-turbo", "qwen-plus"] = "qwen-turbo",
|
|
||||||
model_names: List[str] = ["qwen-api"],
|
|
||||||
controller_addr: str = None,
|
|
||||||
worker_addr: str = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
|
||||||
kwargs.setdefault("context_len", 16384)
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.version = version
|
|
||||||
|
|
||||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
|
||||||
import dashscope
|
|
||||||
params.load_config(self.model_names[0])
|
|
||||||
if log_verbose:
|
|
||||||
logger.info(f'{self.__class__.__name__}:params: {params}')
|
|
||||||
|
|
||||||
gen = dashscope.Generation()
|
|
||||||
responses = gen.call(
|
|
||||||
model=params.version,
|
|
||||||
temperature=params.temperature,
|
|
||||||
api_key=params.api_key,
|
|
||||||
messages=params.messages,
|
|
||||||
result_format='message', # set the result is message format.
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
for resp in responses:
|
|
||||||
if resp["status_code"] == 200:
|
|
||||||
if choices := resp["output"]["choices"]:
|
|
||||||
yield {
|
|
||||||
"error_code": 0,
|
|
||||||
"text": choices[0]["message"]["content"],
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
data = {
|
|
||||||
"error_code": resp["status_code"],
|
|
||||||
"text": resp["message"],
|
|
||||||
"error": {
|
|
||||||
"message": resp["message"],
|
|
||||||
"type": "invalid_request_error",
|
|
||||||
"param": None,
|
|
||||||
"code": None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
self.logger.error(f"请求千问 API 时发生错误:{data}")
|
|
||||||
yield data
|
|
||||||
|
|
||||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
|
||||||
import dashscope
|
|
||||||
params.load_config(self.model_names[0])
|
|
||||||
if log_verbose:
|
|
||||||
logger.info(f'{self.__class__.__name__}:params: {params}')
|
|
||||||
result = []
|
|
||||||
i = 0
|
|
||||||
while i < len(params.texts):
|
|
||||||
texts = params.texts[i:i+25]
|
|
||||||
resp = dashscope.TextEmbedding.call(
|
|
||||||
model=params.embed_model or self.DEFAULT_EMBED_MODEL,
|
|
||||||
input=texts, # 最大25行
|
|
||||||
api_key=params.api_key,
|
|
||||||
)
|
|
||||||
if resp["status_code"] != 200:
|
|
||||||
data = {
|
|
||||||
"code": resp["status_code"],
|
|
||||||
"msg": resp.message,
|
|
||||||
"error": {
|
|
||||||
"message": resp["message"],
|
|
||||||
"type": "invalid_request_error",
|
|
||||||
"param": None,
|
|
||||||
"code": None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
self.logger.error(f"请求千问 API 时发生错误:{data}")
|
|
||||||
return data
|
|
||||||
else:
|
|
||||||
embeddings = [x["embedding"] for x in resp["output"]["embeddings"]]
|
|
||||||
result += embeddings
|
|
||||||
i += 25
|
|
||||||
return {"code": 200, "data": result}
|
|
||||||
|
|
||||||
def get_embeddings(self, params):
|
|
||||||
print("embedding")
|
|
||||||
print(params)
|
|
||||||
|
|
||||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
|
||||||
return conv.Conversation(
|
|
||||||
name=self.model_names[0],
|
|
||||||
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
|
|
||||||
messages=[],
|
|
||||||
roles=["user", "assistant", "system"],
|
|
||||||
sep="\n### ",
|
|
||||||
stop_str="###",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import uvicorn
|
|
||||||
from server.utils import MakeFastAPIOffline
|
|
||||||
from fastchat.serve.model_worker import app
|
|
||||||
|
|
||||||
worker = QwenWorker(
|
|
||||||
controller_addr="http://127.0.0.1:20001",
|
|
||||||
worker_addr="http://127.0.0.1:20007",
|
|
||||||
)
|
|
||||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
|
||||||
MakeFastAPIOffline(app)
|
|
||||||
uvicorn.run(app, port=20007)
|
|
||||||
@ -1,84 +0,0 @@
|
|||||||
import json
|
|
||||||
import time
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
from fastchat.conversation import Conversation
|
|
||||||
from server.model_workers.base import *
|
|
||||||
from server.utils import get_httpx_client
|
|
||||||
from fastchat import conversation as conv
|
|
||||||
import json
|
|
||||||
from typing import List, Literal, Dict
|
|
||||||
import requests
|
|
||||||
|
|
||||||
|
|
||||||
class TianGongWorker(ApiModelWorker):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
controller_addr: str = None,
|
|
||||||
worker_addr: str = None,
|
|
||||||
model_names: List[str] = ["tiangong-api"],
|
|
||||||
version: Literal["SkyChat-MegaVerse"] = "SkyChat-MegaVerse",
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
|
||||||
kwargs.setdefault("context_len", 32768)
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.version = version
|
|
||||||
|
|
||||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
|
||||||
params.load_config(self.model_names[0])
|
|
||||||
|
|
||||||
url = 'https://sky-api.singularity-ai.com/saas/api/v4/generate'
|
|
||||||
data = {
|
|
||||||
"messages": params.messages,
|
|
||||||
"model": "SkyChat-MegaVerse"
|
|
||||||
}
|
|
||||||
timestamp = str(int(time.time()))
|
|
||||||
sign_content = params.api_key + params.secret_key + timestamp
|
|
||||||
sign_result = hashlib.md5(sign_content.encode('utf-8')).hexdigest()
|
|
||||||
headers = {
|
|
||||||
"app_key": params.api_key,
|
|
||||||
"timestamp": timestamp,
|
|
||||||
"sign": sign_result,
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"stream": "true" # or change to "false" 不处理流式返回内容
|
|
||||||
}
|
|
||||||
|
|
||||||
# 发起请求并获取响应
|
|
||||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
|
||||||
|
|
||||||
text = ""
|
|
||||||
# 处理响应流
|
|
||||||
for line in response.iter_lines(chunk_size=None, decode_unicode=True):
|
|
||||||
if line:
|
|
||||||
# 处理接收到的数据
|
|
||||||
# print(line.decode('utf-8'))
|
|
||||||
resp = json.loads(line)
|
|
||||||
if resp["code"] == 200:
|
|
||||||
text += resp['resp_data']['reply']
|
|
||||||
yield {
|
|
||||||
"error_code": 0,
|
|
||||||
"text": text
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
data = {
|
|
||||||
"error_code": resp["code"],
|
|
||||||
"text": resp["code_msg"]
|
|
||||||
}
|
|
||||||
self.logger.error(f"请求天工 API 时出错:{data}")
|
|
||||||
yield data
|
|
||||||
|
|
||||||
def get_embeddings(self, params):
|
|
||||||
print("embedding")
|
|
||||||
print(params)
|
|
||||||
|
|
||||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
|
||||||
return conv.Conversation(
|
|
||||||
name=self.model_names[0],
|
|
||||||
system_message="",
|
|
||||||
messages=[],
|
|
||||||
roles=["user", "system"],
|
|
||||||
sep="\n### ",
|
|
||||||
stop_str="###",
|
|
||||||
)
|
|
||||||
@ -1,100 +0,0 @@
|
|||||||
from fastchat.conversation import Conversation
|
|
||||||
from server.model_workers.base import *
|
|
||||||
from fastchat import conversation as conv
|
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
from server.model_workers import SparkApi
|
|
||||||
import websockets
|
|
||||||
from server.utils import iter_over_async, asyncio
|
|
||||||
from typing import List, Dict
|
|
||||||
|
|
||||||
|
|
||||||
async def request(appid, api_key, api_secret, Spark_url, domain, question, temperature, max_token):
|
|
||||||
wsParam = SparkApi.Ws_Param(appid, api_key, api_secret, Spark_url)
|
|
||||||
wsUrl = wsParam.create_url()
|
|
||||||
data = SparkApi.gen_params(appid, domain, question, temperature, max_token)
|
|
||||||
async with websockets.connect(wsUrl) as ws:
|
|
||||||
await ws.send(json.dumps(data, ensure_ascii=False))
|
|
||||||
finish = False
|
|
||||||
while not finish:
|
|
||||||
chunk = await ws.recv()
|
|
||||||
response = json.loads(chunk)
|
|
||||||
if response.get("header", {}).get("status") == 2:
|
|
||||||
finish = True
|
|
||||||
if text := response.get("payload", {}).get("choices", {}).get("text"):
|
|
||||||
yield text[0]["content"]
|
|
||||||
|
|
||||||
|
|
||||||
class XingHuoWorker(ApiModelWorker):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
model_names: List[str] = ["xinghuo-api"],
|
|
||||||
controller_addr: str = None,
|
|
||||||
worker_addr: str = None,
|
|
||||||
version: str = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
|
||||||
kwargs.setdefault("context_len", 8000)
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.version = version
|
|
||||||
|
|
||||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
|
||||||
params.load_config(self.model_names[0])
|
|
||||||
|
|
||||||
version_mapping = {
|
|
||||||
"v1.5": {"domain": "general", "url": "ws://spark-api.xf-yun.com/v1.1/chat", "max_tokens": 4000},
|
|
||||||
"v2.0": {"domain": "generalv2", "url": "ws://spark-api.xf-yun.com/v2.1/chat", "max_tokens": 8000},
|
|
||||||
"v3.0": {"domain": "generalv3", "url": "ws://spark-api.xf-yun.com/v3.1/chat", "max_tokens": 8000},
|
|
||||||
"v3.5": {"domain": "generalv3", "url": "ws://spark-api.xf-yun.com/v3.5/chat", "max_tokens": 16000},
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_version_details(version_key):
|
|
||||||
return version_mapping.get(version_key, {"domain": None, "url": None})
|
|
||||||
|
|
||||||
details = get_version_details(params.version)
|
|
||||||
domain = details["domain"]
|
|
||||||
Spark_url = details["url"]
|
|
||||||
text = ""
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
except:
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
params.max_tokens = min(details["max_tokens"], params.max_tokens or 0)
|
|
||||||
for chunk in iter_over_async(
|
|
||||||
request(params.APPID, params.api_key, params.APISecret, Spark_url, domain, params.messages,
|
|
||||||
params.temperature, params.max_tokens),
|
|
||||||
loop=loop,
|
|
||||||
):
|
|
||||||
if chunk:
|
|
||||||
text += chunk
|
|
||||||
yield {"error_code": 0, "text": text}
|
|
||||||
|
|
||||||
def get_embeddings(self, params):
|
|
||||||
print("embedding")
|
|
||||||
print(params)
|
|
||||||
|
|
||||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
|
||||||
return conv.Conversation(
|
|
||||||
name=self.model_names[0],
|
|
||||||
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
|
|
||||||
messages=[],
|
|
||||||
roles=["user", "assistant"],
|
|
||||||
sep="\n### ",
|
|
||||||
stop_str="###",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import uvicorn
|
|
||||||
from server.utils import MakeFastAPIOffline
|
|
||||||
from fastchat.serve.model_worker import app
|
|
||||||
|
|
||||||
worker = XingHuoWorker(
|
|
||||||
controller_addr="http://127.0.0.1:20001",
|
|
||||||
worker_addr="http://127.0.0.1:21003",
|
|
||||||
)
|
|
||||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
|
||||||
MakeFastAPIOffline(app)
|
|
||||||
uvicorn.run(app, port=21003)
|
|
||||||
@ -1,114 +0,0 @@
|
|||||||
from contextlib import contextmanager
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
from fastchat.conversation import Conversation
|
|
||||||
from httpx_sse import EventSource
|
|
||||||
|
|
||||||
from server.model_workers.base import *
|
|
||||||
from fastchat import conversation as conv
|
|
||||||
import sys
|
|
||||||
from typing import List, Dict, Iterator, Literal, Any
|
|
||||||
import jwt
|
|
||||||
import time
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def connect_sse(client: httpx.Client, method: str, url: str, **kwargs: Any):
|
|
||||||
with client.stream(method, url, **kwargs) as response:
|
|
||||||
yield EventSource(response)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_token(apikey: str, exp_seconds: int):
|
|
||||||
try:
|
|
||||||
id, secret = apikey.split(".")
|
|
||||||
except Exception as e:
|
|
||||||
raise Exception("invalid apikey", e)
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"api_key": id,
|
|
||||||
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
|
|
||||||
"timestamp": int(round(time.time() * 1000)),
|
|
||||||
}
|
|
||||||
|
|
||||||
return jwt.encode(
|
|
||||||
payload,
|
|
||||||
secret,
|
|
||||||
algorithm="HS256",
|
|
||||||
headers={"alg": "HS256", "sign_type": "SIGN"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ChatGLMWorker(ApiModelWorker):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
model_names: List[str] = ["zhipu-api"],
|
|
||||||
controller_addr: str = None,
|
|
||||||
worker_addr: str = None,
|
|
||||||
version: Literal["glm-4"] = "glm-4",
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
|
||||||
kwargs.setdefault("context_len", 4096)
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.version = version
|
|
||||||
|
|
||||||
def do_chat(self, params: ApiChatParams) -> Iterator[Dict]:
|
|
||||||
params.load_config(self.model_names[0])
|
|
||||||
token = generate_token(params.api_key, 60)
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {token}"
|
|
||||||
}
|
|
||||||
data = {
|
|
||||||
"model": params.version,
|
|
||||||
"messages": params.messages,
|
|
||||||
"max_tokens": params.max_tokens,
|
|
||||||
"temperature": params.temperature,
|
|
||||||
"stream": False
|
|
||||||
}
|
|
||||||
|
|
||||||
url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
|
|
||||||
with httpx.Client(headers=headers) as client:
|
|
||||||
response = client.post(url, json=data)
|
|
||||||
response.raise_for_status()
|
|
||||||
chunk = response.json()
|
|
||||||
print(chunk)
|
|
||||||
yield {"error_code": 0, "text": chunk["choices"][0]["message"]["content"]}
|
|
||||||
|
|
||||||
# with connect_sse(client, "POST", url, json=data) as event_source:
|
|
||||||
# for sse in event_source.iter_sse():
|
|
||||||
# chunk = json.loads(sse.data)
|
|
||||||
# if len(chunk["choices"]) != 0:
|
|
||||||
# text += chunk["choices"][0]["delta"]["content"]
|
|
||||||
# yield {"error_code": 0, "text": text}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_embeddings(self, params):
|
|
||||||
print("embedding")
|
|
||||||
print(params)
|
|
||||||
|
|
||||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
|
||||||
return conv.Conversation(
|
|
||||||
name=self.model_names[0],
|
|
||||||
system_message="你是智谱AI小助手,请根据用户的提示来完成任务",
|
|
||||||
messages=[],
|
|
||||||
roles=["user", "assistant", "system"],
|
|
||||||
sep="\n###",
|
|
||||||
stop_str="###",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import uvicorn
|
|
||||||
from server.utils import MakeFastAPIOffline
|
|
||||||
from fastchat.serve.model_worker import app
|
|
||||||
|
|
||||||
worker = ChatGLMWorker(
|
|
||||||
controller_addr="http://127.0.0.1:20001",
|
|
||||||
worker_addr="http://127.0.0.1:21001",
|
|
||||||
)
|
|
||||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
|
||||||
MakeFastAPIOffline(app)
|
|
||||||
uvicorn.run(app, port=21001)
|
|
||||||
Loading…
x
Reference in New Issue
Block a user