Langchain-Chatchat/server/chat/openai_chat.py
liunux4odoo d053950aee
新功能: (#1801)
- 更新langchain/fastchat依赖,添加xformers依赖
- 默认max_tokens=None, 生成tokens自动为模型支持的最大值

修复:
- history_len=0 时会带入1条不完整的历史消息,导致LLM错误
- 当对话轮数 达到history_len时,传入的历史消息为空
2023-10-19 22:09:15 +08:00

59 lines
1.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi.responses import StreamingResponse
from typing import List
import openai
from configs import LLM_MODEL, logger, log_verbose
from server.utils import get_model_worker_config, fschat_openai_api_address
from pydantic import BaseModel
class OpenAiMessage(BaseModel):
role: str = "user"
content: str = "hello"
class OpenAiChatMsgIn(BaseModel):
model: str = LLM_MODEL
messages: List[OpenAiMessage]
temperature: float = 0.7
n: int = 1
max_tokens: int = None
stop: List[str] = []
stream: bool = False
presence_penalty: int = 0
frequency_penalty: int = 0
async def openai_chat(msg: OpenAiChatMsgIn):
config = get_model_worker_config(msg.model)
openai.api_key = config.get("api_key", "EMPTY")
print(f"{openai.api_key=}")
openai.api_base = config.get("api_base_url", fschat_openai_api_address())
print(f"{openai.api_base=}")
print(msg)
async def get_response(msg):
data = msg.dict()
try:
response = await openai.ChatCompletion.acreate(**data)
if msg.stream:
async for data in response:
if choices := data.choices:
if chunk := choices[0].get("delta", {}).get("content"):
print(chunk, end="", flush=True)
yield chunk
else:
if response.choices:
answer = response.choices[0].message.content
print(answer)
yield(answer)
except Exception as e:
msg = f"获取ChatCompletion时出错{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
return StreamingResponse(
get_response(msg),
media_type='text/event-stream',
)