liunux4odoo 1de4258aa0
给 ApiModelWorker 添加 logger 成员变量,API请求出错时输出有意义的错误信息。 (#2169)
* 给 ApiModelWorker 添加 logger 成员变量,API请求出错时输出有意义的错误信息。
/chat/chat 接口 conversation_id参数改为默认 "",避免 swagger 页面默认值错误导致历史消息失效

* 修复在线模型一些bug
2023-11-25 13:51:07 +08:00

131 lines
4.5 KiB
Python

import json
import sys
from fastchat.conversation import Conversation
from configs import TEMPERATURE
from http import HTTPStatus
from typing import List, Literal, Dict
from fastchat import conversation as conv
from server.model_workers.base import *
from server.model_workers.base import ApiEmbeddingsParams
from configs import logger, log_verbose
class QwenWorker(ApiModelWorker):
DEFAULT_EMBED_MODEL = "text-embedding-v1"
def __init__(
self,
*,
version: Literal["qwen-turbo", "qwen-plus"] = "qwen-turbo",
model_names: List[str] = ["qwen-api"],
controller_addr: str = None,
worker_addr: str = None,
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs)
self.version = version
def do_chat(self, params: ApiChatParams) -> Dict:
import dashscope
params.load_config(self.model_names[0])
if log_verbose:
logger.info(f'{self.__class__.__name__}:params: {params}')
gen = dashscope.Generation()
responses = gen.call(
model=params.version,
temperature=params.temperature,
api_key=params.api_key,
messages=params.messages,
result_format='message', # set the result is message format.
stream=True,
)
for resp in responses:
if resp["status_code"] == 200:
if choices := resp["output"]["choices"]:
yield {
"error_code": 0,
"text": choices[0]["message"]["content"],
}
else:
data = {
"error_code": resp["status_code"],
"text": resp["message"],
"error": {
"message": resp["message"],
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
self.logger.error(f"请求千问 API 时发生错误:{data}")
yield data
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
import dashscope
params.load_config(self.model_names[0])
if log_verbose:
logger.info(f'{self.__class__.__name__}:params: {params}')
result = []
i = 0
while i < len(params.texts):
texts = params.texts[i:i+25]
resp = dashscope.TextEmbedding.call(
model=params.embed_model or self.DEFAULT_EMBED_MODEL,
input=texts, # 最大25行
api_key=params.api_key,
)
if resp["status_code"] != 200:
data = {
"code": resp["status_code"],
"msg": resp.message,
"error": {
"message": resp["message"],
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
self.logger.error(f"请求千问 API 时发生错误:{data}")
return data
else:
embeddings = [x["embedding"] for x in resp["output"]["embeddings"]]
result += embeddings
i += 25
return {"code": 200, "data": result}
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
messages=[],
roles=["user", "assistant", "system"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = QwenWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20007",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=20007)