liunux4odoo 03e55e11c4
支持lite模式:无需安装torch等重依赖,通过在线API实现LLM对话和搜索引擎对话 (#1860)
* move get_default_llm_model from webui to ApiRequest

增加API接口及其测试用例:
- /server/get_prompt_template: 获取服务器配置的 prompt 模板
- 增加知识库多线程访问测试用例

支持lite模式:无需安装torch等重依赖,通过在线API实现LLM对话和搜索引擎对话

* fix bug in server.api

---------

Co-authored-by: imClumsyPanda <littlepanda0716@gmail.com>
2023-10-25 08:30:23 +08:00

104 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastchat.conversation import Conversation
from server.model_workers.base import ApiModelWorker
from fastchat import conversation as conv
import sys
import json
from server.utils import get_httpx_client
from pprint import pprint
from typing import List, Dict
class MiniMaxWorker(ApiModelWorker):
BASE_URL = 'https://api.minimax.chat/v1/text/chatcompletion{pro}?GroupId={group_id}'
def __init__(
self,
*,
model_names: List[str] = ["minimax-api"],
controller_addr: str,
worker_addr: str,
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs)
def prompt_to_messages(self, prompt: str) -> List[Dict]:
result = super().prompt_to_messages(prompt)
messages = [{"sender_type": x["role"], "text": x["content"]} for x in result]
return messages
def generate_stream_gate(self, params):
# 按照官网推荐直接调用abab 5.5模型
# TODO: 支持指定回复要求支持指定用户名称、AI名称
super().generate_stream_gate(params)
config = self.get_config()
group_id = config.get("group_id")
api_key = config.get("api_key")
pro = "_pro" if config.get("is_pro") else ""
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
data = {
"model": "abab5.5-chat",
"stream": True,
"tokens_to_generate": 1024, # TODO: 1024为官网默认值
"mask_sensitive_info": True,
"messages": self.prompt_to_messages(params["prompt"]),
"temperature": params.get("temperature"),
"top_p": params.get("top_p"),
"bot_setting": [],
}
print("request data sent to minimax:")
pprint(data)
with get_httpx_client() as client:
response = client.stream("POST",
self.BASE_URL.format(pro=pro, group_id=group_id),
headers=headers,
json=data)
with response as r:
text = ""
for e in r.iter_text():
if e.startswith("data: "): # 真是优秀的返回
data = json.loads(e[6:])
if not data.get("usage"):
if choices := data.get("choices"):
chunk = choices[0].get("delta", "").strip()
if chunk:
print(chunk)
text += chunk
yield json.dumps({"error_code": 0, "text": text}, ensure_ascii=False).encode() + b"\0"
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation(
name=self.model_names[0],
system_message="",
messages=[],
roles=["USER", "BOT"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = MiniMaxWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21002",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21002)