zR 46225ad784
Dev (#1811)
* 北京黑客松更新

知识库支持:
支持zilliz数据库
Agent支持:
支持以下工具调用
1. 支持互联网Agent调用
2. 支持知识库Agent调用
3. 支持旅游助手工具(未上传)

知识库更新
1. 支持知识库简介,用于Agent选择
2. UI对应知识库简介

提示词选择
1. UI 和模板支持提示词模板更换选择

* 数据库更新介绍问题解决

* 关于Langchain自己支持的模型

1. 修复了Openai无法调用的bug
2. 支持了Azure Openai Claude模型
(在模型切换界面由于优先级问题,显示的会是其他联网模型)
3. 422问题被修复,用了另一种替代方案。
4. 更新了部分依赖

* 换一些图
2023-10-20 20:07:59 +08:00

88 lines
2.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from configs.basic_config import LOG_PATH
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.base_model_worker import BaseModelWorker
import uuid
import json
import sys
from pydantic import BaseModel
import fastchat
import asyncio
from typing import Dict, List
# 恢复被fastchat覆盖的标准输出
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
class ApiModelOutMsg(BaseModel):
error_code: int = 0
text: str
class ApiModelWorker(BaseModelWorker):
BASE_URL: str
SUPPORT_MODELS: List
def __init__(
self,
model_names: List[str],
controller_addr: str,
worker_addr: str,
context_len: int = 2048,
**kwargs,
):
kwargs.setdefault("worker_id", uuid.uuid4().hex[:8])
kwargs.setdefault("model_path", "")
kwargs.setdefault("limit_worker_concurrency", 5)
super().__init__(model_names=model_names,
controller_addr=controller_addr,
worker_addr=worker_addr,
**kwargs)
self.context_len = context_len
self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency)
self.init_heart_beat()
def count_token(self, params):
# TODO需要完善
# print("count token")
prompt = params["prompt"]
return {"count": len(str(prompt)), "error_code": 0}
def generate_stream_gate(self, params):
self.call_ct += 1
def generate_gate(self, params):
for x in self.generate_stream_gate(params):
pass
return json.loads(x[:-1].decode())
def get_embeddings(self, params):
print("embedding")
# print(params)
# help methods
def get_config(self):
from server.utils import get_model_worker_config
return get_model_worker_config(self.model_names[0])
def prompt_to_messages(self, prompt: str) -> List[Dict]:
'''
将prompt字符串拆分成messages.
'''
result = []
user_role = self.conv.roles[0]
ai_role = self.conv.roles[1]
user_start = user_role + ":"
ai_start = ai_role + ":"
for msg in prompt.split(self.conv.sep)[1:-1]:
if msg.startswith(user_start):
if content := msg[len(user_start):].strip():
result.append({"role": user_role, "content": content})
elif msg.startswith(ai_start):
if content := msg[len(ai_start):].strip():
result.append({"role": ai_role, "content": content})
else:
raise RuntimeError(f"unknown role in msg: {msg}")
return result