liunux4odoo f7c73b842a
优化configs (#1474)
* remove llm_model_dict

* optimize configs

* fix get_model_path

* 更改一些默认参数,添加千帆的默认配置

* Update server_config.py.example
2023-09-15 17:52:22 +08:00

82 lines
2.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
from pydantic import BaseModel, Field
from langchain.prompts.chat import ChatMessagePromptTemplate
from configs import logger, log_verbose
from server.utils import get_model_worker_config, fschat_openai_api_address
from langchain.chat_models import ChatOpenAI
from typing import Awaitable, List, Tuple, Dict, Union, Callable
def get_ChatOpenAI(
model_name: str,
temperature: float,
callbacks: List[Callable] = [],
) -> ChatOpenAI:
config = get_model_worker_config(model_name)
model = ChatOpenAI(
streaming=True,
verbose=True,
callbacks=callbacks,
openai_api_key=config.get("api_key", "EMPTY"),
openai_api_base=fschat_openai_api_address(),
model_name=model_name,
temperature=temperature,
openai_proxy=config.get("openai_proxy")
)
return model
async def wrap_done(fn: Awaitable, event: asyncio.Event):
"""Wrap an awaitable with a event to signal when it's done or an exception is raised."""
try:
await fn
except Exception as e:
# TODO: handle exception
msg = f"Caught exception: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
finally:
# Signal the aiter to stop.
event.set()
class History(BaseModel):
"""
对话历史
可从dict生成
h = History(**{"role":"user","content":"你好"})
也可转换为tuple
h.to_msy_tuple = ("human", "你好")
"""
role: str = Field(...)
content: str = Field(...)
def to_msg_tuple(self):
return "ai" if self.role=="assistant" else "human", self.content
def to_msg_template(self, is_raw=True) -> ChatMessagePromptTemplate:
role_maps = {
"ai": "assistant",
"human": "user",
}
role = role_maps.get(self.role, self.role)
if is_raw: # 当前默认历史消息都是没有input_variable的文本。
content = "{% raw %}" + self.content + "{% endraw %}"
else:
content = self.content
return ChatMessagePromptTemplate.from_template(
content,
"jinja2",
role=role,
)
@classmethod
def from_data(cls, h: Union[List, Tuple, Dict]) -> "History":
if isinstance(h, (list,tuple)) and len(h) >= 2:
h = cls(role=h[0], content=h[1])
elif isinstance(h, dict):
h = cls(**h)
return h