Langchain-Chatchat/server/chat/openai_chat.py
liunux4odoo f7c73b842a
优化configs (#1474)
* remove llm_model_dict

* optimize configs

* fix get_model_path

* 更改一些默认参数,添加千帆的默认配置

* Update server_config.py.example
2023-09-15 17:52:22 +08:00

59 lines
1.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi.responses import StreamingResponse
from typing import List
import openai
from configs import LLM_MODEL, logger, log_verbose
from server.utils import get_model_worker_config, fschat_openai_api_address
from pydantic import BaseModel
class OpenAiMessage(BaseModel):
role: str = "user"
content: str = "hello"
class OpenAiChatMsgIn(BaseModel):
model: str = LLM_MODEL
messages: List[OpenAiMessage]
temperature: float = 0.7
n: int = 1
max_tokens: int = 1024
stop: List[str] = []
stream: bool = False
presence_penalty: int = 0
frequency_penalty: int = 0
async def openai_chat(msg: OpenAiChatMsgIn):
config = get_model_worker_config(msg.model)
openai.api_key = config.get("api_key", "EMPTY")
print(f"{openai.api_key=}")
openai.api_base = fschat_openai_api_address()
print(f"{openai.api_base=}")
print(msg)
async def get_response(msg):
data = msg.dict()
try:
response = await openai.ChatCompletion.acreate(**data)
if msg.stream:
async for data in response:
if choices := data.choices:
if chunk := choices[0].get("delta", {}).get("content"):
print(chunk, end="", flush=True)
yield chunk
else:
if response.choices:
answer = response.choices[0].message.content
print(answer)
yield(answer)
except Exception as e:
msg = f"获取ChatCompletion时出错{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
return StreamingResponse(
get_response(msg),
media_type='text/event-stream',
)