liunux4odoo 5d422ca9a1 修改模型配置方式,所有模型以 openai 兼容框架的形式接入,chatchat 自身不再加载模型。
改变 Embeddings 模型改为使用框架 API,不再手动加载,删除自定义 Embeddings Keyword 代码
修改依赖文件,移除 torch transformers 等重依赖
暂时移出对 loom 的集成

后续:
1、优化目录结构
2、检查合并中有无被覆盖的 0.2.10 内容
2024-03-06 13:49:38 +08:00

69 lines
3.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import Body
from sse_starlette.sse import EventSourceResponse
from server.utils import wrap_done, get_OpenAI
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable, Optional
import asyncio
from langchain.prompts import PromptTemplate
from server.utils import get_prompt_template
async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
stream: bool = Body(False, description="流式输出"),
echo: bool = Body(False, description="除了输出之外,还回显输入"),
model_name: str = Body(None, description="LLM 模型名称。"),
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量默认None代表模型最大值"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
prompt_name: str = Body("default",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
#TODO: 因ApiModelWorker 默认是按chat处理的会对params["prompt"] 解析为messages因此ApiModelWorker 使用时需要有相应处理
async def completion_iterator(query: str,
model_name: str = None,
prompt_name: str = prompt_name,
echo: bool = echo,
) -> AsyncIterable[str]:
nonlocal max_tokens
callback = AsyncIteratorCallbackHandler()
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
model = get_OpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
echo=echo
)
prompt_template = get_prompt_template("completion", prompt_name)
prompt = PromptTemplate.from_template(prompt_template)
chain = LLMChain(prompt=prompt, llm=model)
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
chain.acall({"input": query}),
callback.done),
)
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield token
else:
answer = ""
async for token in callback.aiter():
answer += token
yield answer
await task
return EventSourceResponse(completion_iterator(query=query,
model_name=model_name,
prompt_name=prompt_name),
)