diff --git a/configs/model_config.py.example b/configs/model_config.py.example index f67fbd4b..d5b624b4 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -88,6 +88,14 @@ llm_model_dict = { "provider": "ChatGLMWorker", "version": "chatglm_pro", # 可选包括 "chatglm_lite", "chatglm_std", "chatglm_pro" }, +"minimax-api": { + "api_base_url": "http://127.0.0.1:8888/v1", + "group_id": "", + "api_key": "", + "is_pro": False, + "provider": "MiniMaxWorker", + }, + } # LLM 名称 diff --git a/configs/server_config.py.example b/configs/server_config.py.example index ad731e37..2157071f 100644 --- a/configs/server_config.py.example +++ b/configs/server_config.py.example @@ -66,6 +66,9 @@ FSCHAT_MODEL_WORKERS = { "chatglm-api": { # 请为每个在线API设置不同的端口 "port": 20003, }, + "minimax-api": { # 请为每个在线API设置不同的端口 + "port": 20004, + }, } # fastchat multi model worker server diff --git a/server/model_workers/__init__.py b/server/model_workers/__init__.py index 932c2f3b..4ec62ddb 100644 --- a/server/model_workers/__init__.py +++ b/server/model_workers/__init__.py @@ -1 +1,2 @@ from .zhipu import ChatGLMWorker +from .minimax import MiniMaxWorker diff --git a/server/model_workers/base.py b/server/model_workers/base.py index b72f6839..090cf722 100644 --- a/server/model_workers/base.py +++ b/server/model_workers/base.py @@ -69,3 +69,8 @@ class ApiModelWorker(BaseModelWorker): target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True, ) self.heart_beat_thread.start() + + # help methods + def get_config(self): + from server.utils import get_model_worker_config + return get_model_worker_config(self.model_names[0]) diff --git a/server/model_workers/minimax.py b/server/model_workers/minimax.py new file mode 100644 index 00000000..a182ae34 --- /dev/null +++ b/server/model_workers/minimax.py @@ -0,0 +1,108 @@ +from server.model_workers.base import ApiModelWorker +from fastchat import conversation as conv +import sys +import json +import httpx +from pprint import pprint +from typing import List, Dict + + +class MiniMaxWorker(ApiModelWorker): + BASE_URL = 'https://api.minimax.chat/v1/text/chatcompletion{pro}?GroupId={group_id}' + + def __init__( + self, + *, + model_names: List[str] = ["minimax-api"], + controller_addr: str, + worker_addr: str, + **kwargs, + ): + kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) + kwargs.setdefault("context_len", 16384) + super().__init__(**kwargs) + + # TODO: 确认模板是否需要修改 + self.conv = conv.Conversation( + name=self.model_names[0], + system_message="", + messages=[], + roles=["USER", "BOT"], + sep="\n### ", + stop_str="###", + ) + + def prompt_to_messages(self, prompt: str) -> List[Dict]: + result = [] + user_start = self.conv.roles[0] + ":" + bot_start = self.conv.roles[1] + ":" + for msg in prompt.split(self.conv.sep)[1:-1]: + if msg.startswith(user_start): + result.append({"sender_type": "USER", "text": msg[len(user_start):].strip()}) + elif msg.startswith(bot_start): + result.append({"sender_type": "BOT", "text": msg[len(bot_start)].strip()}) + else: + raise RuntimeError(f"unknow role in msg: {msg}") + return result + + def generate_stream_gate(self, params): + # 按照官网推荐,直接调用abab 5.5模型 + # TODO: 支持历史消息,支持指定回复要求,支持指定用户名称、AI名称 + + super().generate_stream_gate(params) + config = self.get_config() + group_id = config.get("group_id") + api_key = config.get("api_key") + + pro = "_pro" if config.get("is_pro") else "" + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + data = { + "model": "abab5.5-chat", + "stream": True, + "tokens_to_generate": 1024, # TODO: 1024为官网默认值 + "mask_sensitive_info": True, + "messages": self.prompt_to_messages(params["prompt"]), + "temperature": params.get("temperature"), + "top_p": params.get("top_p"), + "bot_setting": [], + } + print("request data sent to minimax:") + pprint(data) + response = httpx.stream("POST", + self.BASE_URL.format(pro=pro, group_id=group_id), + headers=headers, + json=data) + with response as r: + text = "" + for e in r.iter_text(): + if e.startswith("data: "): # 真是优秀的返回 + data = json.loads(e[6:]) + if not data.get("usage"): + if choices := data.get("choices"): + chunk = choices[0].get("delta", "").strip() + if chunk: + print(chunk) + text += chunk + yield json.dumps({"error_code": 0, "text": text}, ensure_ascii=False).encode() + b"\0" + + def get_embeddings(self, params): + # TODO: 支持embeddings + print("embedding") + print(params) + + +if __name__ == "__main__": + import uvicorn + from server.utils import MakeFastAPIOffline + from fastchat.serve.model_worker import app + + worker = MiniMaxWorker( + controller_addr="http://127.0.0.1:20001", + worker_addr="http://127.0.0.1:20004", + ) + sys.modules["fastchat.serve.model_worker"].worker = worker + MakeFastAPIOffline(app) + uvicorn.run(app, port=20003) diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py index 82eab5da..63fd3b48 100644 --- a/server/model_workers/zhipu.py +++ b/server/model_workers/zhipu.py @@ -25,7 +25,7 @@ class ChatGLMWorker(ApiModelWorker): # 这里的是chatglm api的模板,其它API的conv_template需要定制 self.conv = conv.Conversation( - name="chatglm-api", + name=self.model_names[0], system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。", messages=[], roles=["Human", "Assistant"], @@ -34,12 +34,11 @@ class ChatGLMWorker(ApiModelWorker): ) def generate_stream_gate(self, params): - # TODO: 支持stream参数,维护request_id,传过来的prompt也有问题 - from server.utils import get_model_worker_config + # TODO: 维护request_id import zhipuai super().generate_stream_gate(params) - zhipuai.api_key = get_model_worker_config("chatglm-api").get("api_key") + zhipuai.api_key = self.get_config().get("api_key") response = zhipuai.model_api.sse_invoke( model=self.version,