Langchain-Chatchat/server/chat/openai_chat.py
liunux4odoo 775870a516
改变api视图函数的sync/async,提高api并发能力: (#1414)
1. 4个chat类接口改为async
2. 知识库操作,涉及向量库修改的使用async,避免FAISS写入错误;涉及向量库读取的改为sync,提高并发
2023-09-08 12:25:02 +08:00

55 lines
1.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi.responses import StreamingResponse
from typing import List
import openai
from configs.model_config import llm_model_dict, LLM_MODEL, logger
from pydantic import BaseModel
class OpenAiMessage(BaseModel):
role: str = "user"
content: str = "hello"
class OpenAiChatMsgIn(BaseModel):
model: str = LLM_MODEL
messages: List[OpenAiMessage]
temperature: float = 0.7
n: int = 1
max_tokens: int = 1024
stop: List[str] = []
stream: bool = False
presence_penalty: int = 0
frequency_penalty: int = 0
async def openai_chat(msg: OpenAiChatMsgIn):
openai.api_key = llm_model_dict[LLM_MODEL]["api_key"]
print(f"{openai.api_key=}")
openai.api_base = llm_model_dict[LLM_MODEL]["api_base_url"]
print(f"{openai.api_base=}")
print(msg)
async def get_response(msg):
data = msg.dict()
try:
response = await openai.ChatCompletion.acreate(**data)
if msg.stream:
async for data in response:
if choices := data.choices:
if chunk := choices[0].get("delta", {}).get("content"):
print(chunk, end="", flush=True)
yield chunk
else:
if response.choices:
answer = response.choices[0].message.content
print(answer)
yield(answer)
except Exception as e:
logger.error(f"获取ChatCompletion时出错{e}")
return StreamingResponse(
get_response(msg),
media_type='text/event-stream',
)