From a7b28adc09ffe8bf86948d0f89104452bd453f41 Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Sun, 21 Jan 2024 18:30:21 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=9C=AC=E5=9C=B0fschat=20wo?= =?UTF-8?q?rkers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/model_workers/SparkApi.py | 79 ---------- server/model_workers/__init__.py | 11 -- server/model_workers/azure.py | 95 ------------ server/model_workers/baichuan.py | 117 --------------- server/model_workers/base.py | 248 ------------------------------- server/model_workers/fangzhou.py | 105 ------------- server/model_workers/minimax.py | 170 --------------------- server/model_workers/qianfan.py | 216 --------------------------- server/model_workers/qwen.py | 123 --------------- server/model_workers/tiangong.py | 84 ----------- server/model_workers/xinghuo.py | 100 ------------- server/model_workers/zhipu.py | 114 -------------- 12 files changed, 1462 deletions(-) delete mode 100644 server/model_workers/SparkApi.py delete mode 100644 server/model_workers/__init__.py delete mode 100644 server/model_workers/azure.py delete mode 100644 server/model_workers/baichuan.py delete mode 100644 server/model_workers/base.py delete mode 100644 server/model_workers/fangzhou.py delete mode 100644 server/model_workers/minimax.py delete mode 100644 server/model_workers/qianfan.py delete mode 100644 server/model_workers/qwen.py delete mode 100644 server/model_workers/tiangong.py delete mode 100644 server/model_workers/xinghuo.py delete mode 100644 server/model_workers/zhipu.py diff --git a/server/model_workers/SparkApi.py b/server/model_workers/SparkApi.py deleted file mode 100644 index 795b1f73..00000000 --- a/server/model_workers/SparkApi.py +++ /dev/null @@ -1,79 +0,0 @@ -import base64 -import datetime -import hashlib -import hmac -from urllib.parse import urlparse -from datetime import datetime -from time import mktime -from urllib.parse import urlencode -from wsgiref.handlers import format_date_time - - -class Ws_Param(object): - # 初始化 - def __init__(self, APPID, APIKey, APISecret, Spark_url): - self.APPID = APPID - self.APIKey = APIKey - self.APISecret = APISecret - self.host = urlparse(Spark_url).netloc - self.path = urlparse(Spark_url).path - self.Spark_url = Spark_url - - # 生成url - def create_url(self): - # 生成RFC1123格式的时间戳 - now = datetime.now() - date = format_date_time(mktime(now.timetuple())) - - # 拼接字符串 - signature_origin = "host: " + self.host + "\n" - signature_origin += "date: " + date + "\n" - signature_origin += "GET " + self.path + " HTTP/1.1" - - # 进行hmac-sha256进行加密 - signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), - digestmod=hashlib.sha256).digest() - - signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8') - - authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' - - authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') - - # 将请求的鉴权参数组合为字典 - v = { - "authorization": authorization, - "date": date, - "host": self.host - } - # 拼接鉴权参数,生成url - url = self.Spark_url + '?' + urlencode(v) - # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 - return url - - -def gen_params(appid, domain, question, temperature, max_token): - """ - 通过appid和用户的提问来生成请参数 - """ - data = { - "header": { - "app_id": appid, - "uid": "1234" - }, - "parameter": { - "chat": { - "domain": domain, - "random_threshold": 0.5, - "max_tokens": max_token, - "auditing": "default", - "temperature": temperature, - } - }, - "payload": { - "message": { - "text": question - } - } - } - return data diff --git a/server/model_workers/__init__.py b/server/model_workers/__init__.py deleted file mode 100644 index d0320f41..00000000 --- a/server/model_workers/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from .base import * -from .zhipu import ChatGLMWorker -from .minimax import MiniMaxWorker -from .xinghuo import XingHuoWorker -from .qianfan import QianFanWorker -from .fangzhou import FangZhouWorker -from .qwen import QwenWorker -from .baichuan import BaiChuanWorker -from .azure import AzureWorker -from .tiangong import TianGongWorker -from .gemini import GeminiWorker \ No newline at end of file diff --git a/server/model_workers/azure.py b/server/model_workers/azure.py deleted file mode 100644 index f0835ae1..00000000 --- a/server/model_workers/azure.py +++ /dev/null @@ -1,95 +0,0 @@ -import sys -import os -from fastchat.conversation import Conversation -from server.model_workers.base import * -from server.utils import get_httpx_client -from fastchat import conversation as conv -import json -from typing import List, Dict -from configs import logger, log_verbose - - -class AzureWorker(ApiModelWorker): - def __init__( - self, - *, - controller_addr: str = None, - worker_addr: str = None, - model_names: List[str] = ["azure-api"], - version: str = "gpt-35-turbo", - **kwargs, - ): - kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) - super().__init__(**kwargs) - self.version = version - - def do_chat(self, params: ApiChatParams) -> Dict: - params.load_config(self.model_names[0]) - - data = dict( - messages=params.messages, - temperature=params.temperature, - max_tokens=params.max_tokens if params.max_tokens else None, - stream=True, - ) - url = ("https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}" - .format(params.resource_name, params.deployment_name, params.api_version)) - headers = { - 'Content-Type': 'application/json', - 'Accept': 'application/json', - 'api-key': params.api_key, - } - - text = "" - if log_verbose: - logger.info(f'{self.__class__.__name__}:url: {url}') - logger.info(f'{self.__class__.__name__}:headers: {headers}') - logger.info(f'{self.__class__.__name__}:data: {data}') - - with get_httpx_client() as client: - with client.stream("POST", url, headers=headers, json=data) as response: - print(data) - for line in response.iter_lines(): - if not line.strip() or "[DONE]" in line: - continue - if line.startswith("data: "): - line = line[6:] - resp = json.loads(line) - if choices := resp["choices"]: - if chunk := choices[0].get("delta", {}).get("content"): - text += chunk - yield { - "error_code": 0, - "text": text - } - print(text) - else: - self.logger.error(f"请求 Azure API 时发生错误:{resp}") - - def get_embeddings(self, params): - print("embedding") - print(params) - - def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: - return conv.Conversation( - name=self.model_names[0], - system_message="You are a helpful, respectful and honest assistant.", - messages=[], - roles=["user", "assistant"], - sep="\n### ", - stop_str="###", - ) - - -if __name__ == "__main__": - import uvicorn - from server.utils import MakeFastAPIOffline - from fastchat.serve.base_model_worker import app - - worker = AzureWorker( - controller_addr="http://127.0.0.1:20001", - worker_addr="http://127.0.0.1:21008", - ) - sys.modules["fastchat.serve.model_worker"].worker = worker - MakeFastAPIOffline(app) - uvicorn.run(app, port=21008) \ No newline at end of file diff --git a/server/model_workers/baichuan.py b/server/model_workers/baichuan.py deleted file mode 100644 index 75cfad4e..00000000 --- a/server/model_workers/baichuan.py +++ /dev/null @@ -1,117 +0,0 @@ -import json -import time -import hashlib - -from fastchat.conversation import Conversation -from server.model_workers.base import * -from server.utils import get_httpx_client -from fastchat import conversation as conv -import sys -import json -from typing import List, Literal, Dict -from configs import logger, log_verbose - -def calculate_md5(input_string): - md5 = hashlib.md5() - md5.update(input_string.encode('utf-8')) - encrypted = md5.hexdigest() - return encrypted - - -class BaiChuanWorker(ApiModelWorker): - def __init__( - self, - *, - controller_addr: str = None, - worker_addr: str = None, - model_names: List[str] = ["baichuan-api"], - version: Literal["Baichuan2-53B"] = "Baichuan2-53B", - **kwargs, - ): - kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) - kwargs.setdefault("context_len", 32768) - super().__init__(**kwargs) - self.version = version - - def do_chat(self, params: ApiChatParams) -> Dict: - params.load_config(self.model_names[0]) - - url = "https://api.baichuan-ai.com/v1/stream/chat" - data = { - "model": params.version, - "messages": params.messages, - "parameters": {"temperature": params.temperature} - } - - json_data = json.dumps(data) - time_stamp = int(time.time()) - signature = calculate_md5(params.secret_key + json_data + str(time_stamp)) - headers = { - "Content-Type": "application/json", - "Authorization": "Bearer " + params.api_key, - "X-BC-Request-Id": "your requestId", - "X-BC-Timestamp": str(time_stamp), - "X-BC-Signature": signature, - "X-BC-Sign-Algo": "MD5", - } - - text = "" - if log_verbose: - logger.info(f'{self.__class__.__name__}:json_data: {json_data}') - logger.info(f'{self.__class__.__name__}:url: {url}') - logger.info(f'{self.__class__.__name__}:headers: {headers}') - - with get_httpx_client() as client: - with client.stream("POST", url, headers=headers, json=data) as response: - for line in response.iter_lines(): - if not line.strip(): - continue - resp = json.loads(line) - if resp["code"] == 0: - text += resp["data"]["messages"][-1]["content"] - yield { - "error_code": resp["code"], - "text": text - } - else: - data = { - "error_code": resp["code"], - "text": resp["msg"], - "error": { - "message": resp["msg"], - "type": "invalid_request_error", - "param": None, - "code": None, - } - } - self.logger.error(f"请求百川 API 时发生错误:{data}") - yield data - - def get_embeddings(self, params): - print("embedding") - print(params) - - def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: - return conv.Conversation( - name=self.model_names[0], - system_message="", - messages=[], - roles=["user", "assistant"], - sep="\n### ", - stop_str="###", - ) - - -if __name__ == "__main__": - import uvicorn - from server.utils import MakeFastAPIOffline - from fastchat.serve.model_worker import app - - worker = BaiChuanWorker( - controller_addr="http://127.0.0.1:20001", - worker_addr="http://127.0.0.1:21007", - ) - sys.modules["fastchat.serve.model_worker"].worker = worker - MakeFastAPIOffline(app) - uvicorn.run(app, port=21007) - # do_request() diff --git a/server/model_workers/base.py b/server/model_workers/base.py deleted file mode 100644 index b6e88d31..00000000 --- a/server/model_workers/base.py +++ /dev/null @@ -1,248 +0,0 @@ -from fastchat.conversation import Conversation -from configs import LOG_PATH -import fastchat.constants -fastchat.constants.LOGDIR = LOG_PATH -from fastchat.serve.base_model_worker import BaseModelWorker -import uuid -import json -import sys -from pydantic import BaseModel, root_validator -import fastchat -import asyncio -from server.utils import get_model_worker_config -from typing import Dict, List, Optional - - -__all__ = ["ApiModelWorker", "ApiChatParams", "ApiCompletionParams", "ApiEmbeddingsParams"] - - -class ApiConfigParams(BaseModel): - ''' - 在线API配置参数,未提供的值会自动从model_config.ONLINE_LLM_MODEL中读取 - ''' - api_base_url: Optional[str] = None - api_proxy: Optional[str] = None - api_key: Optional[str] = None - secret_key: Optional[str] = None - group_id: Optional[str] = None # for minimax - is_pro: bool = False # for minimax - - APPID: Optional[str] = None # for xinghuo - APISecret: Optional[str] = None # for xinghuo - is_v2: bool = False # for xinghuo - - worker_name: Optional[str] = None - - class Config: - extra = "allow" - - @root_validator(pre=True) - def validate_config(cls, v: Dict) -> Dict: - if config := get_model_worker_config(v.get("worker_name")): - for n in cls.__fields__: - if n in config: - v[n] = config[n] - return v - - def load_config(self, worker_name: str): - self.worker_name = worker_name - if config := get_model_worker_config(worker_name): - for n in self.__fields__: - if n in config: - setattr(self, n, config[n]) - return self - - -class ApiModelParams(ApiConfigParams): - ''' - 模型配置参数 - ''' - version: Optional[str] = None - version_url: Optional[str] = None - api_version: Optional[str] = None # for azure - deployment_name: Optional[str] = None # for azure - resource_name: Optional[str] = None # for azure - - temperature: float = 0.9 - max_tokens: Optional[int] = None - top_p: Optional[float] = 1.0 - - -class ApiChatParams(ApiModelParams): - ''' - chat请求参数 - ''' - messages: List[Dict[str, str]] - system_message: Optional[str] = None # for minimax - role_meta: Dict = {} # for minimax - - -class ApiCompletionParams(ApiModelParams): - prompt: str - - -class ApiEmbeddingsParams(ApiConfigParams): - texts: List[str] - embed_model: Optional[str] = None - to_query: bool = False # for minimax - - -class ApiModelWorker(BaseModelWorker): - DEFAULT_EMBED_MODEL: str = None # None means not support embedding - - def __init__( - self, - model_names: List[str], - controller_addr: str = None, - worker_addr: str = None, - context_len: int = 2048, - no_register: bool = False, - **kwargs, - ): - kwargs.setdefault("worker_id", uuid.uuid4().hex[:8]) - kwargs.setdefault("model_path", "") - kwargs.setdefault("limit_worker_concurrency", 5) - super().__init__(model_names=model_names, - controller_addr=controller_addr, - worker_addr=worker_addr, - **kwargs) - import fastchat.serve.base_model_worker - import sys - self.logger = fastchat.serve.base_model_worker.logger - # 恢复被fastchat覆盖的标准输出 - sys.stdout = sys.__stdout__ - sys.stderr = sys.__stderr__ - - new_loop = asyncio.new_event_loop() - asyncio.set_event_loop(new_loop) - - self.context_len = context_len - self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency) - self.version = None - - if not no_register and self.controller_addr: - self.init_heart_beat() - - - def count_token(self, params): - prompt = params["prompt"] - return {"count": len(str(prompt)), "error_code": 0} - - def generate_stream_gate(self, params: Dict): - self.call_ct += 1 - - try: - prompt = params["prompt"] - if self._is_chat(prompt): - messages = self.prompt_to_messages(prompt) - messages = self.validate_messages(messages) - else: # 使用chat模仿续写功能,不支持历史消息 - messages = [{"role": self.user_role, "content": f"please continue writing from here: {prompt}"}] - - p = ApiChatParams( - messages=messages, - temperature=params.get("temperature"), - top_p=params.get("top_p"), - max_tokens=params.get("max_new_tokens"), - version=self.version, - ) - for resp in self.do_chat(p): - yield self._jsonify(resp) - except Exception as e: - yield self._jsonify({"error_code": 500, "text": f"{self.model_names[0]}请求API时发生错误:{e}"}) - - def generate_gate(self, params): - try: - for x in self.generate_stream_gate(params): - ... - return json.loads(x[:-1].decode()) - except Exception as e: - return {"error_code": 500, "text": str(e)} - - - # 需要用户自定义的方法 - - def do_chat(self, params: ApiChatParams) -> Dict: - ''' - 执行Chat的方法,默认使用模块里面的chat函数。 - 要求返回形式:{"error_code": int, "text": str} - ''' - return {"error_code": 500, "text": f"{self.model_names[0]}未实现chat功能"} - - # def do_completion(self, p: ApiCompletionParams) -> Dict: - # ''' - # 执行Completion的方法,默认使用模块里面的completion函数。 - # 要求返回形式:{"error_code": int, "text": str} - # ''' - # return {"error_code": 500, "text": f"{self.model_names[0]}未实现completion功能"} - - def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: - ''' - 执行Embeddings的方法,默认使用模块里面的embed_documents函数。 - 要求返回形式:{"code": int, "data": List[List[float]], "msg": str} - ''' - return {"code": 500, "msg": f"{self.model_names[0]}未实现embeddings功能"} - - def get_embeddings(self, params): - # fastchat对LLM做Embeddings限制很大,似乎只能使用openai的。 - # 在前端通过OpenAIEmbeddings发起的请求直接出错,无法请求过来。 - print("get_embedding") - print(params) - - def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: - raise NotImplementedError - - def validate_messages(self, messages: List[Dict]) -> List[Dict]: - ''' - 有些API对mesages有特殊格式,可以重写该函数替换默认的messages。 - 之所以跟prompt_to_messages分开,是因为他们应用场景不同、参数不同 - ''' - return messages - - - # help methods - @property - def user_role(self): - return self.conv.roles[0] - - @property - def ai_role(self): - return self.conv.roles[1] - - def _jsonify(self, data: Dict) -> str: - ''' - 将chat函数返回的结果按照fastchat openai-api-server的格式返回 - ''' - return json.dumps(data, ensure_ascii=False).encode() + b"\0" - - def _is_chat(self, prompt: str) -> bool: - ''' - 检查prompt是否由chat messages拼接而来 - TODO: 存在误判的可能,也许从fastchat直接传入原始messages是更好的做法 - ''' - key = f"{self.conv.sep}{self.user_role}:" - return key in prompt - - def prompt_to_messages(self, prompt: str) -> List[Dict]: - ''' - 将prompt字符串拆分成messages. - ''' - result = [] - user_role = self.user_role - ai_role = self.ai_role - user_start = user_role + ":" - ai_start = ai_role + ":" - for msg in prompt.split(self.conv.sep)[1:-1]: - if msg.startswith(user_start): - if content := msg[len(user_start):].strip(): - result.append({"role": user_role, "content": content}) - elif msg.startswith(ai_start): - if content := msg[len(ai_start):].strip(): - result.append({"role": ai_role, "content": content}) - else: - raise RuntimeError(f"unknown role in msg: {msg}") - return result - - @classmethod - def can_embedding(cls): - return cls.DEFAULT_EMBED_MODEL is not None diff --git a/server/model_workers/fangzhou.py b/server/model_workers/fangzhou.py deleted file mode 100644 index fdb50a1c..00000000 --- a/server/model_workers/fangzhou.py +++ /dev/null @@ -1,105 +0,0 @@ -from fastchat.conversation import Conversation -from server.model_workers.base import * -from fastchat import conversation as conv -import sys -from typing import List, Literal, Dict -from configs import logger, log_verbose - - -class FangZhouWorker(ApiModelWorker): - """ - 火山方舟 - """ - - def __init__( - self, - *, - model_names: List[str] = ["fangzhou-api"], - controller_addr: str = None, - worker_addr: str = None, - version: Literal["chatglm-6b-model"] = "chatglm-6b-model", - **kwargs, - ): - kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) - kwargs.setdefault("context_len", 16384) - super().__init__(**kwargs) - self.version = version - - def do_chat(self, params: ApiChatParams) -> Dict: - from volcengine.maas import MaasService - - params.load_config(self.model_names[0]) - maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing') - maas.set_ak(params.api_key) - maas.set_sk(params.secret_key) - - # document: "https://www.volcengine.com/docs/82379/1099475" - req = { - "model": { - "name": params.version, - }, - "parameters": { - # 这里的参数仅为示例,具体可用的参数请参考具体模型的 API 说明 - "max_new_tokens": params.max_tokens, - "temperature": params.temperature, - }, - "messages": params.messages, - } - - text = "" - if log_verbose: - self.logger.info(f'{self.__class__.__name__}:maas: {maas}') - for resp in maas.stream_chat(req): - if error := resp.error: - if error.code_n > 0: - data = { - "error_code": error.code_n, - "text": error.message, - "error": { - "message": error.message, - "type": "invalid_request_error", - "param": None, - "code": None, - } - } - self.logger.error(f"请求方舟 API 时发生错误:{data}") - yield data - elif chunk := resp.choice.message.content: - text += chunk - yield {"error_code": 0, "text": text} - else: - data = { - "error_code": 500, - "text": f"请求方舟 API 时发生未知的错误: {resp}" - } - self.logger.error(data) - yield data - break - - def get_embeddings(self, params): - print("embedding") - print(params) - - def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: - return conv.Conversation( - name=self.model_names[0], - system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。", - messages=[], - roles=["user", "assistant", "system"], - sep="\n### ", - stop_str="###", - ) - - -if __name__ == "__main__": - import uvicorn - from server.utils import MakeFastAPIOffline - from fastchat.serve.model_worker import app - - worker = FangZhouWorker( - controller_addr="http://127.0.0.1:20001", - worker_addr="http://127.0.0.1:21005", - ) - sys.modules["fastchat.serve.model_worker"].worker = worker - MakeFastAPIOffline(app) - uvicorn.run(app, port=21005) diff --git a/server/model_workers/minimax.py b/server/model_workers/minimax.py deleted file mode 100644 index 79d24514..00000000 --- a/server/model_workers/minimax.py +++ /dev/null @@ -1,170 +0,0 @@ -from fastchat.conversation import Conversation -from server.model_workers.base import * -from fastchat import conversation as conv -import sys -import json -from server.model_workers.base import ApiEmbeddingsParams -from server.utils import get_httpx_client -from typing import List, Dict -from configs import logger, log_verbose - - -class MiniMaxWorker(ApiModelWorker): - DEFAULT_EMBED_MODEL = "embo-01" - - def __init__( - self, - *, - model_names: List[str] = ["minimax-api"], - controller_addr: str = None, - worker_addr: str = None, - version: str = "abab5.5-chat", - **kwargs, - ): - kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) - kwargs.setdefault("context_len", 16384) - super().__init__(**kwargs) - self.version = version - - def validate_messages(self, messages: List[Dict]) -> List[Dict]: - role_maps = { - "USER": self.user_role, - "assistant": self.ai_role, - "system": "system", - } - messages = [{"sender_type": role_maps[x["role"]], "text": x["content"]} for x in messages] - return messages - - def do_chat(self, params: ApiChatParams) -> Dict: - # 按照官网推荐,直接调用abab 5.5模型 - params.load_config(self.model_names[0]) - - url = 'https://api.minimax.chat/v1/text/chatcompletion{pro}?GroupId={group_id}' - pro = "_pro" if params.is_pro else "" - headers = { - "Authorization": f"Bearer {params.api_key}", - "Content-Type": "application/json", - } - messages = self.validate_messages(params.messages) - data = { - "model": params.version, - "stream": True, - "mask_sensitive_info": True, - "messages": messages, - "temperature": params.temperature, - "top_p": params.top_p, - "tokens_to_generate": params.max_tokens or 1024, - # 以下参数为minimax特有,传入空值会出错。 - # "prompt": params.system_message or self.conv.system_message, - # "bot_setting": [], - # "role_meta": params.role_meta, - } - if log_verbose: - logger.info(f'{self.__class__.__name__}:data: {data}') - logger.info(f'{self.__class__.__name__}:url: {url.format(pro=pro, group_id=params.group_id)}') - logger.info(f'{self.__class__.__name__}:headers: {headers}') - - with get_httpx_client() as client: - response = client.stream("POST", - url.format(pro=pro, group_id=params.group_id), - headers=headers, - json=data) - with response as r: - text = "" - for e in r.iter_text(): - if not e.startswith("data: "): - data = { - "error_code": 500, - "text": f"minimax返回错误的结果:{e}", - "error": { - "message": f"minimax返回错误的结果:{e}", - "type": "invalid_request_error", - "param": None, - "code": None, - } - } - self.logger.error(f"请求 MiniMax API 时发生错误:{data}") - yield data - continue - - data = json.loads(e[6:]) - if data.get("usage"): - break - - if choices := data.get("choices"): - if chunk := choices[0].get("delta", ""): - text += chunk - yield {"error_code": 0, "text": text} - - def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: - params.load_config(self.model_names[0]) - url = f"https://api.minimax.chat/v1/embeddings?GroupId={params.group_id}" - - headers = { - "Authorization": f"Bearer {params.api_key}", - "Content-Type": "application/json", - } - - data = { - "model": params.embed_model or self.DEFAULT_EMBED_MODEL, - "texts": [], - "type": "query" if params.to_query else "db", - } - if log_verbose: - logger.info(f'{self.__class__.__name__}:data: {data}') - logger.info(f'{self.__class__.__name__}:url: {url}') - logger.info(f'{self.__class__.__name__}:headers: {headers}') - - with get_httpx_client() as client: - result = [] - i = 0 - batch_size = 10 - while i < len(params.texts): - texts = params.texts[i:i+batch_size] - data["texts"] = texts - r = client.post(url, headers=headers, json=data).json() - if embeddings := r.get("vectors"): - result += embeddings - elif error := r.get("base_resp"): - data = { - "code": error["status_code"], - "msg": error["status_msg"], - "error": { - "message": error["status_msg"], - "type": "invalid_request_error", - "param": None, - "code": None, - } - } - self.logger.error(f"请求 MiniMax API 时发生错误:{data}") - return data - i += batch_size - return {"code": 200, "data": result} - - def get_embeddings(self, params): - print("embedding") - print(params) - - def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: - return conv.Conversation( - name=self.model_names[0], - system_message="你是MiniMax自主研发的大型语言模型,回答问题简洁有条理。", - messages=[], - roles=["USER", "BOT"], - sep="\n### ", - stop_str="###", - ) - - -if __name__ == "__main__": - import uvicorn - from server.utils import MakeFastAPIOffline - from fastchat.serve.model_worker import app - - worker = MiniMaxWorker( - controller_addr="http://127.0.0.1:20001", - worker_addr="http://127.0.0.1:21002", - ) - sys.modules["fastchat.serve.model_worker"].worker = worker - MakeFastAPIOffline(app) - uvicorn.run(app, port=21002) diff --git a/server/model_workers/qianfan.py b/server/model_workers/qianfan.py deleted file mode 100644 index da362ec6..00000000 --- a/server/model_workers/qianfan.py +++ /dev/null @@ -1,216 +0,0 @@ -import sys -from fastchat.conversation import Conversation -from server.model_workers.base import * -from server.utils import get_httpx_client -from cachetools import cached, TTLCache -import json -from fastchat import conversation as conv -import sys -from server.model_workers.base import ApiEmbeddingsParams -from typing import List, Literal, Dict -from configs import logger, log_verbose - -MODEL_VERSIONS = { - "ernie-bot-4": "completions_pro", - "ernie-bot": "completions", - "ernie-bot-turbo": "eb-instant", - "bloomz-7b": "bloomz_7b1", - "qianfan-bloomz-7b-c": "qianfan_bloomz_7b_compressed", - "llama2-7b-chat": "llama_2_7b", - "llama2-13b-chat": "llama_2_13b", - "llama2-70b-chat": "llama_2_70b", - "qianfan-llama2-ch-7b": "qianfan_chinese_llama_2_7b", - "chatglm2-6b-32k": "chatglm2_6b_32k", - "aquilachat-7b": "aquilachat_7b", - # "linly-llama2-ch-7b": "", # 暂未发布 - # "linly-llama2-ch-13b": "", # 暂未发布 - # "chatglm2-6b": "", # 暂未发布 - # "chatglm2-6b-int4": "", # 暂未发布 - # "falcon-7b": "", # 暂未发布 - # "falcon-180b-chat": "", # 暂未发布 - # "falcon-40b": "", # 暂未发布 - # "rwkv4-world": "", # 暂未发布 - # "rwkv5-world": "", # 暂未发布 - # "rwkv4-pile-14b": "", # 暂未发布 - # "rwkv4-raven-14b": "", # 暂未发布 - # "open-llama-7b": "", # 暂未发布 - # "dolly-12b": "", # 暂未发布 - # "mpt-7b-instruct": "", # 暂未发布 - # "mpt-30b-instruct": "", # 暂未发布 - # "OA-Pythia-12B-SFT-4": "", # 暂未发布 - # "xverse-13b": "", # 暂未发布 - - # # 以下为企业测试,需要单独申请 - # "flan-ul2": "", - # "Cerebras-GPT-6.7B": "" - # "Pythia-6.9B": "" -} - - -@cached(TTLCache(1, 1800)) # 经过测试,缓存的token可以使用,目前每30分钟刷新一次 -def get_baidu_access_token(api_key: str, secret_key: str) -> str: - """ - 使用 AK,SK 生成鉴权签名(Access Token) - :return: access_token,或是None(如果错误) - """ - url = "https://aip.baidubce.com/oauth/2.0/token" - params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key} - try: - with get_httpx_client() as client: - return client.get(url, params=params).json().get("access_token") - except Exception as e: - print(f"failed to get token from baidu: {e}") - - -class QianFanWorker(ApiModelWorker): - """ - 百度千帆 - """ - DEFAULT_EMBED_MODEL = "embedding-v1" - - def __init__( - self, - *, - version: Literal["ernie-bot", "ernie-bot-turbo"] = "ernie-bot", - model_names: List[str] = ["qianfan-api"], - controller_addr: str = None, - worker_addr: str = None, - **kwargs, - ): - kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) - kwargs.setdefault("context_len", 16384) - super().__init__(**kwargs) - self.version = version - - def do_chat(self, params: ApiChatParams) -> Dict: - params.load_config(self.model_names[0]) - BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \ - '/{model_version}?access_token={access_token}' - - access_token = get_baidu_access_token(params.api_key, params.secret_key) - if not access_token: - yield { - "error_code": 403, - "text": f"failed to get access token. have you set the correct api_key and secret key?", - } - - url = BASE_URL.format( - model_version=params.version_url or MODEL_VERSIONS[params.version.lower()], - access_token=access_token, - ) - payload = { - "messages": params.messages, - "temperature": params.temperature, - "stream": True - } - headers = { - 'Content-Type': 'application/json', - 'Accept': 'application/json', - } - - text = "" - if log_verbose: - logger.info(f'{self.__class__.__name__}:data: {payload}') - logger.info(f'{self.__class__.__name__}:url: {url}') - logger.info(f'{self.__class__.__name__}:headers: {headers}') - - with get_httpx_client() as client: - with client.stream("POST", url, headers=headers, json=payload) as response: - for line in response.iter_lines(): - if not line.strip(): - continue - if line.startswith("data: "): - line = line[6:] - resp = json.loads(line) - - if "result" in resp.keys(): - text += resp["result"] - yield { - "error_code": 0, - "text": text - } - else: - data = { - "error_code": resp["error_code"], - "text": resp["error_msg"], - "error": { - "message": resp["error_msg"], - "type": "invalid_request_error", - "param": None, - "code": None, - } - } - self.logger.error(f"请求千帆 API 时发生错误:{data}") - yield data - - def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: - params.load_config(self.model_names[0]) - # import qianfan - - # embed = qianfan.Embedding(ak=params.api_key, sk=params.secret_key) - # resp = embed.do(texts = params.texts, model=params.embed_model or self.DEFAULT_EMBED_MODEL) - # if resp.code == 200: - # embeddings = [x.embedding for x in resp.body.get("data", [])] - # return {"code": 200, "embeddings": embeddings} - # else: - # return {"code": resp.code, "msg": str(resp.body)} - - embed_model = params.embed_model or self.DEFAULT_EMBED_MODEL - access_token = get_baidu_access_token(params.api_key, params.secret_key) - url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/{embed_model}?access_token={access_token}" - if log_verbose: - logger.info(f'{self.__class__.__name__}:url: {url}') - - with get_httpx_client() as client: - result = [] - i = 0 - batch_size = 10 - while i < len(params.texts): - texts = params.texts[i:i + batch_size] - resp = client.post(url, json={"input": texts}).json() - if "error_code" in resp: - data = { - "code": resp["error_code"], - "msg": resp["error_msg"], - "error": { - "message": resp["error_msg"], - "type": "invalid_request_error", - "param": None, - "code": None, - } - } - self.logger.error(f"请求千帆 API 时发生错误:{data}") - return data - else: - embeddings = [x["embedding"] for x in resp.get("data", [])] - result += embeddings - i += batch_size - return {"code": 200, "data": result} - - def get_embeddings(self, params): - print("embedding") - print(params) - - def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: - return conv.Conversation( - name=self.model_names[0], - system_message="你是一个聪明的助手,请根据用户的提示来完成任务", - messages=[], - roles=["user", "assistant"], - sep="\n### ", - stop_str="###", - ) - - -if __name__ == "__main__": - import uvicorn - from server.utils import MakeFastAPIOffline - from fastchat.serve.model_worker import app - - worker = QianFanWorker( - controller_addr="http://127.0.0.1:20001", - worker_addr="http://127.0.0.1:21004" - ) - sys.modules["fastchat.serve.model_worker"].worker = worker - MakeFastAPIOffline(app) - uvicorn.run(app, port=21004) diff --git a/server/model_workers/qwen.py b/server/model_workers/qwen.py deleted file mode 100644 index f9ae6cb2..00000000 --- a/server/model_workers/qwen.py +++ /dev/null @@ -1,123 +0,0 @@ -import sys -from fastchat.conversation import Conversation -from typing import List, Literal, Dict -from fastchat import conversation as conv -from server.model_workers.base import * -from server.model_workers.base import ApiEmbeddingsParams -from configs import logger, log_verbose - - -class QwenWorker(ApiModelWorker): - DEFAULT_EMBED_MODEL = "text-embedding-v1" - - def __init__( - self, - *, - version: Literal["qwen-turbo", "qwen-plus"] = "qwen-turbo", - model_names: List[str] = ["qwen-api"], - controller_addr: str = None, - worker_addr: str = None, - **kwargs, - ): - kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) - kwargs.setdefault("context_len", 16384) - super().__init__(**kwargs) - self.version = version - - def do_chat(self, params: ApiChatParams) -> Dict: - import dashscope - params.load_config(self.model_names[0]) - if log_verbose: - logger.info(f'{self.__class__.__name__}:params: {params}') - - gen = dashscope.Generation() - responses = gen.call( - model=params.version, - temperature=params.temperature, - api_key=params.api_key, - messages=params.messages, - result_format='message', # set the result is message format. - stream=True, - ) - - for resp in responses: - if resp["status_code"] == 200: - if choices := resp["output"]["choices"]: - yield { - "error_code": 0, - "text": choices[0]["message"]["content"], - } - else: - data = { - "error_code": resp["status_code"], - "text": resp["message"], - "error": { - "message": resp["message"], - "type": "invalid_request_error", - "param": None, - "code": None, - } - } - self.logger.error(f"请求千问 API 时发生错误:{data}") - yield data - - def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: - import dashscope - params.load_config(self.model_names[0]) - if log_verbose: - logger.info(f'{self.__class__.__name__}:params: {params}') - result = [] - i = 0 - while i < len(params.texts): - texts = params.texts[i:i+25] - resp = dashscope.TextEmbedding.call( - model=params.embed_model or self.DEFAULT_EMBED_MODEL, - input=texts, # 最大25行 - api_key=params.api_key, - ) - if resp["status_code"] != 200: - data = { - "code": resp["status_code"], - "msg": resp.message, - "error": { - "message": resp["message"], - "type": "invalid_request_error", - "param": None, - "code": None, - } - } - self.logger.error(f"请求千问 API 时发生错误:{data}") - return data - else: - embeddings = [x["embedding"] for x in resp["output"]["embeddings"]] - result += embeddings - i += 25 - return {"code": 200, "data": result} - - def get_embeddings(self, params): - print("embedding") - print(params) - - def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: - return conv.Conversation( - name=self.model_names[0], - system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。", - messages=[], - roles=["user", "assistant", "system"], - sep="\n### ", - stop_str="###", - ) - - -if __name__ == "__main__": - import uvicorn - from server.utils import MakeFastAPIOffline - from fastchat.serve.model_worker import app - - worker = QwenWorker( - controller_addr="http://127.0.0.1:20001", - worker_addr="http://127.0.0.1:20007", - ) - sys.modules["fastchat.serve.model_worker"].worker = worker - MakeFastAPIOffline(app) - uvicorn.run(app, port=20007) diff --git a/server/model_workers/tiangong.py b/server/model_workers/tiangong.py deleted file mode 100644 index 88010a15..00000000 --- a/server/model_workers/tiangong.py +++ /dev/null @@ -1,84 +0,0 @@ -import json -import time -import hashlib - -from fastchat.conversation import Conversation -from server.model_workers.base import * -from server.utils import get_httpx_client -from fastchat import conversation as conv -import json -from typing import List, Literal, Dict -import requests - - -class TianGongWorker(ApiModelWorker): - def __init__( - self, - *, - controller_addr: str = None, - worker_addr: str = None, - model_names: List[str] = ["tiangong-api"], - version: Literal["SkyChat-MegaVerse"] = "SkyChat-MegaVerse", - **kwargs, - ): - kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) - kwargs.setdefault("context_len", 32768) - super().__init__(**kwargs) - self.version = version - - def do_chat(self, params: ApiChatParams) -> Dict: - params.load_config(self.model_names[0]) - - url = 'https://sky-api.singularity-ai.com/saas/api/v4/generate' - data = { - "messages": params.messages, - "model": "SkyChat-MegaVerse" - } - timestamp = str(int(time.time())) - sign_content = params.api_key + params.secret_key + timestamp - sign_result = hashlib.md5(sign_content.encode('utf-8')).hexdigest() - headers = { - "app_key": params.api_key, - "timestamp": timestamp, - "sign": sign_result, - "Content-Type": "application/json", - "stream": "true" # or change to "false" 不处理流式返回内容 - } - - # 发起请求并获取响应 - response = requests.post(url, headers=headers, json=data, stream=True) - - text = "" - # 处理响应流 - for line in response.iter_lines(chunk_size=None, decode_unicode=True): - if line: - # 处理接收到的数据 - # print(line.decode('utf-8')) - resp = json.loads(line) - if resp["code"] == 200: - text += resp['resp_data']['reply'] - yield { - "error_code": 0, - "text": text - } - else: - data = { - "error_code": resp["code"], - "text": resp["code_msg"] - } - self.logger.error(f"请求天工 API 时出错:{data}") - yield data - - def get_embeddings(self, params): - print("embedding") - print(params) - - def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: - return conv.Conversation( - name=self.model_names[0], - system_message="", - messages=[], - roles=["user", "system"], - sep="\n### ", - stop_str="###", - ) diff --git a/server/model_workers/xinghuo.py b/server/model_workers/xinghuo.py deleted file mode 100644 index b2a1cdc1..00000000 --- a/server/model_workers/xinghuo.py +++ /dev/null @@ -1,100 +0,0 @@ -from fastchat.conversation import Conversation -from server.model_workers.base import * -from fastchat import conversation as conv -import sys -import json -from server.model_workers import SparkApi -import websockets -from server.utils import iter_over_async, asyncio -from typing import List, Dict - - -async def request(appid, api_key, api_secret, Spark_url, domain, question, temperature, max_token): - wsParam = SparkApi.Ws_Param(appid, api_key, api_secret, Spark_url) - wsUrl = wsParam.create_url() - data = SparkApi.gen_params(appid, domain, question, temperature, max_token) - async with websockets.connect(wsUrl) as ws: - await ws.send(json.dumps(data, ensure_ascii=False)) - finish = False - while not finish: - chunk = await ws.recv() - response = json.loads(chunk) - if response.get("header", {}).get("status") == 2: - finish = True - if text := response.get("payload", {}).get("choices", {}).get("text"): - yield text[0]["content"] - - -class XingHuoWorker(ApiModelWorker): - def __init__( - self, - *, - model_names: List[str] = ["xinghuo-api"], - controller_addr: str = None, - worker_addr: str = None, - version: str = None, - **kwargs, - ): - kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) - kwargs.setdefault("context_len", 8000) - super().__init__(**kwargs) - self.version = version - - def do_chat(self, params: ApiChatParams) -> Dict: - params.load_config(self.model_names[0]) - - version_mapping = { - "v1.5": {"domain": "general", "url": "ws://spark-api.xf-yun.com/v1.1/chat", "max_tokens": 4000}, - "v2.0": {"domain": "generalv2", "url": "ws://spark-api.xf-yun.com/v2.1/chat", "max_tokens": 8000}, - "v3.0": {"domain": "generalv3", "url": "ws://spark-api.xf-yun.com/v3.1/chat", "max_tokens": 8000}, - "v3.5": {"domain": "generalv3", "url": "ws://spark-api.xf-yun.com/v3.5/chat", "max_tokens": 16000}, - } - - def get_version_details(version_key): - return version_mapping.get(version_key, {"domain": None, "url": None}) - - details = get_version_details(params.version) - domain = details["domain"] - Spark_url = details["url"] - text = "" - try: - loop = asyncio.get_event_loop() - except: - loop = asyncio.new_event_loop() - params.max_tokens = min(details["max_tokens"], params.max_tokens or 0) - for chunk in iter_over_async( - request(params.APPID, params.api_key, params.APISecret, Spark_url, domain, params.messages, - params.temperature, params.max_tokens), - loop=loop, - ): - if chunk: - text += chunk - yield {"error_code": 0, "text": text} - - def get_embeddings(self, params): - print("embedding") - print(params) - - def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: - return conv.Conversation( - name=self.model_names[0], - system_message="你是一个聪明的助手,请根据用户的提示来完成任务", - messages=[], - roles=["user", "assistant"], - sep="\n### ", - stop_str="###", - ) - - -if __name__ == "__main__": - import uvicorn - from server.utils import MakeFastAPIOffline - from fastchat.serve.model_worker import app - - worker = XingHuoWorker( - controller_addr="http://127.0.0.1:20001", - worker_addr="http://127.0.0.1:21003", - ) - sys.modules["fastchat.serve.model_worker"].worker = worker - MakeFastAPIOffline(app) - uvicorn.run(app, port=21003) diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py deleted file mode 100644 index 898427c8..00000000 --- a/server/model_workers/zhipu.py +++ /dev/null @@ -1,114 +0,0 @@ -from contextlib import contextmanager - -import httpx -from fastchat.conversation import Conversation -from httpx_sse import EventSource - -from server.model_workers.base import * -from fastchat import conversation as conv -import sys -from typing import List, Dict, Iterator, Literal, Any -import jwt -import time - - -@contextmanager -def connect_sse(client: httpx.Client, method: str, url: str, **kwargs: Any): - with client.stream(method, url, **kwargs) as response: - yield EventSource(response) - - -def generate_token(apikey: str, exp_seconds: int): - try: - id, secret = apikey.split(".") - except Exception as e: - raise Exception("invalid apikey", e) - - payload = { - "api_key": id, - "exp": int(round(time.time() * 1000)) + exp_seconds * 1000, - "timestamp": int(round(time.time() * 1000)), - } - - return jwt.encode( - payload, - secret, - algorithm="HS256", - headers={"alg": "HS256", "sign_type": "SIGN"}, - ) - - -class ChatGLMWorker(ApiModelWorker): - def __init__( - self, - *, - model_names: List[str] = ["zhipu-api"], - controller_addr: str = None, - worker_addr: str = None, - version: Literal["glm-4"] = "glm-4", - **kwargs, - ): - kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) - kwargs.setdefault("context_len", 4096) - super().__init__(**kwargs) - self.version = version - - def do_chat(self, params: ApiChatParams) -> Iterator[Dict]: - params.load_config(self.model_names[0]) - token = generate_token(params.api_key, 60) - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {token}" - } - data = { - "model": params.version, - "messages": params.messages, - "max_tokens": params.max_tokens, - "temperature": params.temperature, - "stream": False - } - - url = "https://open.bigmodel.cn/api/paas/v4/chat/completions" - with httpx.Client(headers=headers) as client: - response = client.post(url, json=data) - response.raise_for_status() - chunk = response.json() - print(chunk) - yield {"error_code": 0, "text": chunk["choices"][0]["message"]["content"]} - - # with connect_sse(client, "POST", url, json=data) as event_source: - # for sse in event_source.iter_sse(): - # chunk = json.loads(sse.data) - # if len(chunk["choices"]) != 0: - # text += chunk["choices"][0]["delta"]["content"] - # yield {"error_code": 0, "text": text} - - - - def get_embeddings(self, params): - print("embedding") - print(params) - - def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: - return conv.Conversation( - name=self.model_names[0], - system_message="你是智谱AI小助手,请根据用户的提示来完成任务", - messages=[], - roles=["user", "assistant", "system"], - sep="\n###", - stop_str="###", - ) - - -if __name__ == "__main__": - import uvicorn - from server.utils import MakeFastAPIOffline - from fastchat.serve.model_worker import app - - worker = ChatGLMWorker( - controller_addr="http://127.0.0.1:20001", - worker_addr="http://127.0.0.1:21001", - ) - sys.modules["fastchat.serve.model_worker"].worker = worker - MakeFastAPIOffline(app) - uvicorn.run(app, port=21001)