liunux4odoo b4c68ddd05
优化在线 API ,支持 completion 和 embedding,简化在线 API 开发方式 (#1886)
* 优化在线 API ,支持 completion 和 embedding,简化在线 API 开发方式

新功能
- 智谱AI、Minimax、千帆、千问 4 个在线模型支持 embeddings(不通过Fastchat,后续会单独提供相关api接口)
- 在线模型自动检测传入参数,在传入非 messages 格式的 prompt 时,自动转换为 completion 形式,以支持 completion 接口

开发者:
- 重构ApiModelWorker:
  - 所有在线 API 请求封装到 do_chat 方法:自动传入参数 ApiChatParams,简化参数与配置项的获取;自动处理与fastchat的接口
  - 加强 API 请求错误处理,返回更有意义的信息
  - 改用 qianfan sdk 重写 qianfan-api
  - 将所有在线模型的测试用例统一在一起,简化测试用例编写

* Delete requirements_langflow.txt
2023-10-26 22:44:48 +08:00

250 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastchat.conversation import Conversation
from configs import LOG_PATH, TEMPERATURE
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.base_model_worker import BaseModelWorker
import uuid
import json
import sys
from pydantic import BaseModel, root_validator
import fastchat
import asyncio
from server.utils import get_model_worker_config
from typing import Dict, List, Optional
__all__ = ["ApiModelWorker", "ApiChatParams", "ApiCompletionParams", "ApiEmbeddingsParams"]
# 恢复被fastchat覆盖的标准输出
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
class ApiConfigParams(BaseModel):
'''
在线API配置参数未提供的值会自动从model_config.ONLINE_LLM_MODEL中读取
'''
api_base_url: Optional[str] = None
api_proxy: Optional[str] = None
api_key: Optional[str] = None
secret_key: Optional[str] = None
group_id: Optional[str] = None # for minimax
is_pro: bool = False # for minimax
APPID: Optional[str] = None # for xinghuo
APISecret: Optional[str] = None # for xinghuo
is_v2: bool = False # for xinghuo
worker_name: Optional[str] = None
class Config:
extra = "allow"
@root_validator(pre=True)
def validate_config(cls, v: Dict) -> Dict:
if config := get_model_worker_config(v.get("worker_name")):
for n in cls.__fields__:
if n in config:
v[n] = config[n]
return v
def load_config(self, worker_name: str):
self.worker_name = worker_name
if config := get_model_worker_config(worker_name):
for n in self.__fields__:
if n in config:
setattr(self, n, config[n])
return self
class ApiModelParams(ApiConfigParams):
'''
模型配置参数
'''
version: Optional[str] = None
version_url: Optional[str] = None
api_version: Optional[str] = None # for azure
deployment_name: Optional[str] = None # for azure
resource_name: Optional[str] = None # for azure
temperature: float = TEMPERATURE
max_tokens: Optional[int] = None
top_p: Optional[float] = 1.0
class ApiChatParams(ApiModelParams):
'''
chat请求参数
'''
messages: List[Dict[str, str]]
system_message: Optional[str] = None # for minimax
role_meta: Dict = {} # for minimax
class ApiCompletionParams(ApiModelParams):
prompt: str
class ApiEmbeddingsParams(ApiConfigParams):
texts: List[str]
embed_model: Optional[str] = None
to_query: bool = False # for minimax
class ApiModelWorker(BaseModelWorker):
DEFAULT_EMBED_MODEL: str = None # None means not support embedding
def __init__(
self,
model_names: List[str],
controller_addr: str = None,
worker_addr: str = None,
context_len: int = 2048,
no_register: bool = False,
**kwargs,
):
kwargs.setdefault("worker_id", uuid.uuid4().hex[:8])
kwargs.setdefault("model_path", "")
kwargs.setdefault("limit_worker_concurrency", 5)
super().__init__(model_names=model_names,
controller_addr=controller_addr,
worker_addr=worker_addr,
**kwargs)
import sys
# 恢复被fastchat覆盖的标准输出
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
self.context_len = context_len
self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency)
self.version = None
if not no_register and self.controller_addr:
self.init_heart_beat()
def count_token(self, params):
# TODO需要完善
# print("count token")
prompt = params["prompt"]
return {"count": len(str(prompt)), "error_code": 0}
def generate_stream_gate(self, params: Dict):
self.call_ct += 1
try:
prompt = params["prompt"]
if self._is_chat(prompt):
messages = self.prompt_to_messages(prompt)
messages = self.validate_messages(messages)
else: # 使用chat模仿续写功能不支持历史消息
messages = [{"role": self.user_role, "content": f"please continue writing from here: {prompt}"}]
p = ApiChatParams(
messages=messages,
temperature=params.get("temperature"),
top_p=params.get("top_p"),
max_tokens=params.get("max_new_tokens"),
version=self.version,
)
for resp in self.do_chat(p):
yield self._jsonify(resp)
except Exception as e:
yield self._jsonify({"error_code": 500, "text": f"{self.model_names[0]}请求API时发生错误{e}"})
def generate_gate(self, params):
try:
for x in self.generate_stream_gate(params):
...
return json.loads(x[:-1].decode())
except Exception as e:
return {"error_code": 500, "text": str(e)}
# 需要用户自定义的方法
def do_chat(self, params: ApiChatParams) -> Dict:
'''
执行Chat的方法默认使用模块里面的chat函数。
要求返回形式:{"error_code": int, "text": str}
'''
return {"error_code": 500, "text": f"{self.model_names[0]}未实现chat功能"}
# def do_completion(self, p: ApiCompletionParams) -> Dict:
# '''
# 执行Completion的方法默认使用模块里面的completion函数。
# 要求返回形式:{"error_code": int, "text": str}
# '''
# return {"error_code": 500, "text": f"{self.model_names[0]}未实现completion功能"}
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
'''
执行Embeddings的方法默认使用模块里面的embed_documents函数。
要求返回形式:{"code": int, "embeddings": List[List[float]], "msg": str}
'''
return {"code": 500, "msg": f"{self.model_names[0]}未实现embeddings功能"}
def get_embeddings(self, params):
# fastchat对LLM做Embeddings限制很大似乎只能使用openai的。
# 在前端通过OpenAIEmbeddings发起的请求直接出错无法请求过来。
print("get_embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
raise NotImplementedError
def validate_messages(self, messages: List[Dict]) -> List[Dict]:
'''
有些API对mesages有特殊格式可以重写该函数替换默认的messages。
之所以跟prompt_to_messages分开是因为他们应用场景不同、参数不同
'''
return messages
# help methods
@property
def user_role(self):
return self.conv.roles[0]
@property
def ai_role(self):
return self.conv.roles[1]
def _jsonify(self, data: Dict) -> str:
'''
将chat函数返回的结果按照fastchat openai-api-server的格式返回
'''
return json.dumps(data, ensure_ascii=False).encode() + b"\0"
def _is_chat(self, prompt: str) -> bool:
'''
检查prompt是否由chat messages拼接而来
TODO: 存在误判的可能也许从fastchat直接传入原始messages是更好的做法
'''
key = f"{self.conv.sep}{self.user_role}:"
return key in prompt
def prompt_to_messages(self, prompt: str) -> List[Dict]:
'''
将prompt字符串拆分成messages.
'''
result = []
user_role = self.user_role
ai_role = self.ai_role
user_start = user_role + ":"
ai_start = ai_role + ":"
for msg in prompt.split(self.conv.sep)[1:-1]:
if msg.startswith(user_start):
if content := msg[len(user_start):].strip():
result.append({"role": user_role, "content": content})
elif msg.startswith(ai_start):
if content := msg[len(ai_start):].strip():
result.append({"role": ai_role, "content": content})
else:
raise RuntimeError(f"unknown role in msg: {msg}")
return result
@classmethod
def can_embedding(cls):
return cls.DEFAULT_EMBED_MODEL is not None