liunux4odoo 6cb1bdf623
添加切换模型功能,支持智谱AI在线模型 (#1342)
* 添加LLM模型切换功能,需要在server_config中设置可切换的模型
* add tests for api.py/llm_model/*
* - 支持模型切换
- 支持智普AI线上模型
- startup.py增加参数`--api-worker`,自动运行所有的线上API模型。使用`-a
  (--all-webui), --all-api`时默认开启该选项
* 修复被fastchat覆盖的标准输出
* 对fastchat日志进行更细致的控制,startup.py中增加-q(--quiet)开关,可以减少无用的fastchat日志输出
* 修正chatglm api的对话模板


Co-authored-by: liunux4odoo <liunu@qq.com>
2023-09-01 23:58:09 +08:00

76 lines
2.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import zhipuai
from server.model_workers.base import ApiModelWorker
from fastchat import conversation as conv
import sys
import json
from typing import List, Literal
class ChatGLMWorker(ApiModelWorker):
BASE_URL = "https://open.bigmodel.cn/api/paas/v3/model-api"
SUPPORT_MODELS = ["chatglm_pro", "chatglm_std", "chatglm_lite"]
def __init__(
self,
*,
model_names: List[str] = ["chatglm-api"],
version: Literal["chatglm_pro", "chatglm_std", "chatglm_lite"] = "chatglm_std",
controller_addr: str,
worker_addr: str,
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 32768)
super().__init__(**kwargs)
self.version = version
# 这里的是chatglm api的模板其它API的conv_template需要定制
self.conv = conv.Conversation(
name="chatglm-api",
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
messages=[],
roles=["Human", "Assistant"],
sep="\n### ",
stop_str="###",
)
def generate_stream_gate(self, params):
# TODO: 支持stream参数维护request_id传过来的prompt也有问题
from server.utils import get_model_worker_config
super().generate_stream_gate(params)
zhipuai.api_key = get_model_worker_config("chatglm-api").get("api_key")
response = zhipuai.model_api.sse_invoke(
model=self.version,
prompt=[{"role": "user", "content": params["prompt"]}],
temperature=params.get("temperature"),
top_p=params.get("top_p"),
incremental=False,
)
for e in response.events():
if e.event == "add":
yield json.dumps({"error_code": 0, "text": e.data}, ensure_ascii=False).encode() + b"\0"
# TODO: 更健壮的消息处理
# elif e.event == "finish":
# ...
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = ChatGLMWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20003",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=20003)