zR 2c8fc95f7a
Agent大更新合并 (#1666)
* 更新上agent提示词代码

* 更新部分文档,修复了issue中提到的bge匹配超过1 的bug

* 按需修改

* 解决了部分最新用户用依赖的bug,加了两个工具,移除google工具

* Agent大幅度优化

1. 修改了UI界面
(1)高亮所有没有进行agent对齐的模型,
(2)优化输出体验和逻辑,使用markdown

2. 降低天气工具使用门槛
3. 依赖更新
(1) vllm 更新到0.2.0,增加了一些参数
(2) torch 建议更新到2.1
(3)pydantic不要更新到1.10.12
2023-10-07 11:26:11 +08:00

98 lines
3.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from configs.basic_config import LOG_PATH
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import BaseModelWorker
import uuid
import json
import sys
from pydantic import BaseModel
import fastchat
import threading
from typing import Dict, List
# 恢复被fastchat覆盖的标准输出
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
class ApiModelOutMsg(BaseModel):
error_code: int = 0
text: str
class ApiModelWorker(BaseModelWorker):
BASE_URL: str
SUPPORT_MODELS: List
def __init__(
self,
model_names: List[str],
controller_addr: str,
worker_addr: str,
context_len: int = 2048,
**kwargs,
):
kwargs.setdefault("worker_id", uuid.uuid4().hex[:8])
kwargs.setdefault("model_path", "")
kwargs.setdefault("limit_worker_concurrency", 5)
super().__init__(model_names=model_names,
controller_addr=controller_addr,
worker_addr=worker_addr,
**kwargs)
self.context_len = context_len
self.init_heart_beat()
def count_token(self, params):
# TODO需要完善
# print("count token")
print("\n\n\n")
print(params)
prompt = params["prompt"]
return {"count": len(str(prompt)), "error_code": 0}
def generate_stream_gate(self, params):
self.call_ct += 1
def generate_gate(self, params):
for x in self.generate_stream_gate(params):
pass
return json.loads(x[:-1].decode())
def get_embeddings(self, params):
print("embedding")
print(params)
# workaround to make program exit with Ctrl+c
# it should be deleted after pr is merged by fastchat
def init_heart_beat(self):
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
)
self.heart_beat_thread.start()
# help methods
def get_config(self):
from server.utils import get_model_worker_config
return get_model_worker_config(self.model_names[0])
def prompt_to_messages(self, prompt: str) -> List[Dict]:
'''
将prompt字符串拆分成messages.
'''
result = []
user_role = self.conv.roles[0]
ai_role = self.conv.roles[1]
user_start = user_role + ":"
ai_start = ai_role + ":"
for msg in prompt.split(self.conv.sep)[1:-1]:
if msg.startswith(user_start):
if content := msg[len(user_start):].strip():
result.append({"role": user_role, "content": content})
elif msg.startswith(ai_start):
if content := msg[len(ai_start):].strip():
result.append({"role": ai_role, "content": content})
else:
raise RuntimeError(f"unknown role in msg: {msg}")
return result