liunux4odoo d053950aee
新功能: (#1801)
- 更新langchain/fastchat依赖,添加xformers依赖
- 默认max_tokens=None, 生成tokens自动为模型支持的最大值

修复:
- history_len=0 时会带入1条不完整的历史消息,导致LLM错误
- 当对话轮数 达到history_len时,传入的历史消息为空
2023-10-19 22:09:15 +08:00

89 lines
2.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from configs.basic_config import LOG_PATH
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.base_model_worker import BaseModelWorker
import uuid
import json
import sys
from pydantic import BaseModel
import fastchat
import threading
from typing import Dict, List
# 恢复被fastchat覆盖的标准输出
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
class ApiModelOutMsg(BaseModel):
error_code: int = 0
text: str
class ApiModelWorker(BaseModelWorker):
BASE_URL: str
SUPPORT_MODELS: List
def __init__(
self,
model_names: List[str],
controller_addr: str,
worker_addr: str,
context_len: int = 2048,
**kwargs,
):
kwargs.setdefault("worker_id", uuid.uuid4().hex[:8])
kwargs.setdefault("model_path", "")
kwargs.setdefault("limit_worker_concurrency", 5)
super().__init__(model_names=model_names,
controller_addr=controller_addr,
worker_addr=worker_addr,
**kwargs)
self.context_len = context_len
self.init_heart_beat()
def count_token(self, params):
# TODO需要完善
# print("count token")
print("\n\n\n")
print(params)
prompt = params["prompt"]
return {"count": len(str(prompt)), "error_code": 0}
def generate_stream_gate(self, params):
self.call_ct += 1
def generate_gate(self, params):
for x in self.generate_stream_gate(params):
pass
return json.loads(x[:-1].decode())
def get_embeddings(self, params):
print("embedding")
print(params)
# help methods
def get_config(self):
from server.utils import get_model_worker_config
return get_model_worker_config(self.model_names[0])
def prompt_to_messages(self, prompt: str) -> List[Dict]:
'''
将prompt字符串拆分成messages.
'''
result = []
user_role = self.conv.roles[0]
ai_role = self.conv.roles[1]
user_start = user_role + ":"
ai_start = ai_role + ":"
for msg in prompt.split(self.conv.sep)[1:-1]:
if msg.startswith(user_start):
if content := msg[len(user_start):].strip():
result.append({"role": user_role, "content": content})
elif msg.startswith(ai_start):
if content := msg[len(ai_start):].strip():
result.append({"role": ai_role, "content": content})
else:
raise RuntimeError(f"unknown role in msg: {msg}")
return result