liunux4odoo deed92169f
支持在线 Embeddings:zhipu-api, qwen-api, minimax-api, qianfan-api (#1907)
* 新功能:
- 支持在线 Embeddings:zhipu-api, qwen-api, minimax-api, qianfan-api
- API 增加 /other/embed_texts 接口
- init_database.py 增加 --embed-model 参数,可以指定使用的嵌入模型(本地或在线均可)

问题修复:
- API 中 list_config_models 会删除 ONLINE_LLM_MODEL 中的敏感信息,导致第二轮API请求错误

开发者:
- 优化 kb_service 中 Embeddings 操作:
  - 统一加载接口: server.utils.load_embeddings,利用全局缓存避免各处 Embeddings 传参
  - 统一文本嵌入接口:server.embedding_api.[embed_texts, embed_documents]
2023-10-28 23:37:30 +08:00

102 lines
3.3 KiB
Python

import json
import sys
from fastchat.conversation import Conversation
from configs import TEMPERATURE
from http import HTTPStatus
from typing import List, Literal, Dict
from fastchat import conversation as conv
from server.model_workers.base import *
from server.model_workers.base import ApiEmbeddingsParams
class QwenWorker(ApiModelWorker):
DEFAULT_EMBED_MODEL = "text-embedding-v1"
def __init__(
self,
*,
version: Literal["qwen-turbo", "qwen-plus"] = "qwen-turbo",
model_names: List[str] = ["qwen-api"],
controller_addr: str = None,
worker_addr: str = None,
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs)
self.version = version
def do_chat(self, params: ApiChatParams) -> Dict:
import dashscope
params.load_config(self.model_names[0])
gen = dashscope.Generation()
responses = gen.call(
model=params.version,
temperature=params.temperature,
api_key=params.api_key,
messages=params.messages,
result_format='message', # set the result is message format.
stream=True,
)
for resp in responses:
if resp["status_code"] == 200:
if choices := resp["output"]["choices"]:
yield {
"error_code": 0,
"text": choices[0]["message"]["content"],
}
else:
yield {
"error_code": resp["status_code"],
"text": resp["message"],
}
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
import dashscope
params.load_config(self.model_names[0])
resp = dashscope.TextEmbedding.call(
model=params.embed_model or self.DEFAULT_EMBED_MODEL,
input=params.texts, # 最大25行
api_key=params.api_key,
)
if resp["status_code"] != 200:
return {"code": resp["status_code"], "msg": resp.message}
else:
embeddings = [x["embedding"] for x in resp["output"]["embeddings"]]
return {"code": 200, "data": embeddings}
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
messages=[],
roles=["user", "assistant", "system"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = QwenWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20007",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=20007)