This commit is contained in:
hzg0601 2023-10-20 19:22:36 +08:00
commit 6e9acfc1af
26 changed files with 321 additions and 192 deletions

View File

@ -1,5 +1,7 @@
import os
# 默认使用的知识库
DEFAULT_KNOWLEDGE_BASE = "samples"
# 默认向量库类型。可选faiss, milvus(离线) & zilliz(在线), pg.
DEFAULT_VS_TYPE = "faiss"
@ -19,6 +21,9 @@ VECTOR_SEARCH_TOP_K = 3
# 知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右
SCORE_THRESHOLD = 1
# 默认搜索引擎。可选bing, duckduckgo, metaphor
DEFAULT_SEARCH_ENGINE = "duckduckgo"
# 搜索引擎匹配结题数量
SEARCH_ENGINE_TOP_K = 3
@ -36,6 +41,10 @@ BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search"
# 是因为服务器加了防火墙需要联系管理员加白名单如果公司的服务器的话就别想了GG
BING_SUBSCRIPTION_KEY = ""
# metaphor搜索需要KEY
METAPHOR_API_KEY = ""
# 是否开启中文标题加强,以及标题增强的相关配置
# 通过增加标题判断判断哪些文本为标题并在metadata中进行标记
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
@ -47,7 +56,6 @@ KB_INFO = {
"知识库名称": "知识库介绍",
"samples": "关于本项目issue的解答",
}
# 通常情况下不需要更改以下内容
# 知识库默认存储路径
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")

View File

@ -44,7 +44,7 @@ MODEL_PATH = {
"chatglm2-6b-int4": "THUDM/chatglm2-6b-int4",
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
"baichuan2-13b": "baichuan-inc/Baichuan-13B-Chat",
"baichuan2-13b": "baichuan-inc/Baichuan2-13B-Chat",
"baichuan2-7b":"baichuan-inc/Baichuan2-7B-Chat",
"baichuan-7b": "baichuan-inc/Baichuan-7B",
@ -112,7 +112,8 @@ TEMPERATURE = 0.7
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
ONLINE_LLM_MODEL = {
LANGCHAIN_LLM_MODEL = {
# 不需要走Fschat封装的Langchain直接支持的模型。
# 调用chatgpt时如果报出 urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443):
# Max retries exceeded with url: /v1/chat/completions
# 则需要将urllib3版本修改为1.25.11
@ -128,11 +129,29 @@ ONLINE_LLM_MODEL = {
# 4.0 seconds as it raised APIConnectionError: Error communicating with OpenAI.
# 需要添加代理访问(正常开的代理软件可能会拦截不上)需要设置配置openai_proxy 或者 使用环境遍历OPENAI_PROXY 进行设置
# 比如: "openai_proxy": 'http://127.0.0.1:4780'
"gpt-3.5-turbo": {
# 这些配置文件的名字不能改动
"Azure-OpenAI": {
"deployment_name": "your Azure deployment name",
"model_version": "0701",
"openai_api_type": "azure",
"api_base_url": "https://your Azure point.azure.com",
"api_version": "2023-07-01-preview",
"api_key": "your Azure api key",
"openai_proxy": "",
},
"OpenAI": {
"model_name": "your openai model name(such as gpt-4)",
"api_base_url": "https://api.openai.com/v1",
"api_key": "your OPENAI_API_KEY",
"openai_proxy": "your OPENAI_PROXY",
"openai_proxy": "",
},
"Anthropic": {
"model_name": "your claude model name(such as claude2-100k)",
"api_key":"your ANTHROPIC_API_KEY",
}
}
ONLINE_LLM_MODEL = {
# 线上模型。请在server_config中为每个在线API设置不同的端口
# 具体注册及api key获取请前往 http://open.bigmodel.cn
"zhipu-api": {

View File

@ -1,3 +1,5 @@
import sys
sys.path.append(".")
from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, prune_db_docs, prune_folder_files
from configs.model_config import NLTK_DATA_PATH
import nltk

View File

@ -1,14 +1,15 @@
langchain>=0.0.314
langchain>=0.0.319
langchain-experimental>=0.0.30
fschat[model_worker]==0.2.30
openai
fschat[model_worker]==0.2.31
xformers>=0.0.22.post4
openai>=0.28.1
sentence_transformers
transformers>=4.34
torch>=2.0.1 # 推荐2.1
torchvision
torchaudio
fastapi>=0.103.2
nltk~=3.8.1
fastapi>=0.104
nltk>=3.8.1
uvicorn~=0.23.1
starlette~=0.27.0
pydantic~=1.10.11
@ -43,7 +44,7 @@ pandas~=2.0.3
streamlit>=1.26.0
streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.1.11
streamlit-chatbox>=1.1.9
streamlit-chatbox==1.1.10
streamlit-aggrid>=0.3.4.post3
httpx~=0.24.1
watchdog

View File

@ -1,13 +1,14 @@
langchain==0.0.313
langchain-experimental==0.0.30
fschat[model_worker]==0.2.30
openai
langchain>=0.0.319
langchain-experimental>=0.0.30
fschat[model_worker]==0.2.31
xformers>=0.0.22.post4
openai>=0.28.1
sentence_transformers>=2.2.2
transformers>=4.34
torch>=2.0.1
torch>=2.1
torchvision
torchaudio
fastapi>=0.103.1
fastapi>=0.104
nltk~=3.8.1
uvicorn~=0.23.1
starlette~=0.27.0

View File

@ -1,11 +1,11 @@
numpy~=1.24.4
pandas~=2.0.3
streamlit>=1.26.0
streamlit>=1.27.2
streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.1.11
streamlit-chatbox>=1.1.9
streamlit-antd-components>=0.2.3
streamlit-chatbox==1.1.10
streamlit-aggrid>=0.3.4.post3
httpx~=0.24.1
nltk
httpx>=0.25.0
nltk>=3.8.1
watchdog
websockets

View File

@ -97,8 +97,16 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
llm_token="",
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_chat_model_start(self,serialized: Dict[str, Any], **kwargs: Any,
async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
self.cur_tool.update(
status=Status.start,

View File

@ -4,7 +4,6 @@ from langchain.prompts import StringPromptTemplate
from typing import List
from langchain.schema import AgentAction, AgentFinish
from server.agent import model_container
begin = False
class CustomPromptTemplate(StringPromptTemplate):
# The template to use
template: str
@ -38,7 +37,7 @@ class CustomOutputParser(AgentOutputParser):
def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction:
# Check if agent should finish
support_agent = ["gpt","Qwen","qwen-api","baichuan-api"]
support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型
if not any(agent in model_container.MODEL for agent in support_agent) and self.begin:
self.begin = False
stop_words = ["Observation:"]

View File

@ -2,7 +2,6 @@
from .search_knowledge_simple import knowledge_search_simple
from .search_all_knowledge_once import knowledge_search_once
from .search_all_knowledge_more import knowledge_search_more
from .travel_assistant import travel_assistant
from .calculate import calculate
from .translator import translate
from .weather import weathercheck

View File

@ -26,8 +26,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量当前默认为1024"),
# TODO: fastchat更新后默认值设为None自动使用LLM支持的最大值。
max_tokens: int = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
):

View File

@ -22,8 +22,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量当前默认为1024"),
# TODO: fastchat更新后默认值设为None自动使用LLM支持的最大值。
max_tokens: int = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):

View File

@ -31,8 +31,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量当前默认为1024"),
# TODO: fastchat更新后默认值设为None自动使用LLM支持的最大值。
max_tokens: int = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)

View File

@ -16,7 +16,7 @@ class OpenAiChatMsgIn(BaseModel):
messages: List[OpenAiMessage]
temperature: float = 0.7
n: int = 1
max_tokens: int = 1024
max_tokens: int = None
stop: List[str] = []
stream: bool = False
presence_penalty: int = 0

View File

@ -1,6 +1,7 @@
from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY,
LLM_MODEL, SEARCH_ENGINE_TOP_K, TEMPERATURE)
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY,
LLM_MODEL, SEARCH_ENGINE_TOP_K, TEMPERATURE,
TEXT_SPLITTER_NAME, OVERLAP_SIZE)
from fastapi import Body
from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool
@ -11,7 +12,7 @@ from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from typing import List, Optional
from typing import List, Optional, Dict
from server.chat.utils import History
from langchain.docstore.document import Document
import json
@ -32,8 +33,49 @@ def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K):
return search.results(text, result_len)
def metaphor_search(
text: str,
result_len: int = SEARCH_ENGINE_TOP_K,
splitter_name: str = "SpacyTextSplitter",
chunk_size: int = 500,
chunk_overlap: int = OVERLAP_SIZE,
) -> List[Dict]:
from metaphor_python import Metaphor
from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
from server.knowledge_base.utils import make_text_splitter
if not METAPHOR_API_KEY:
return []
client = Metaphor(METAPHOR_API_KEY)
search = client.search(text, num_results=result_len, use_autoprompt=True)
contents = search.get_contents().contents
# metaphor 返回的内容都是长文本,需要分词再检索
docs = [Document(page_content=x.extract,
metadata={"link": x.url, "title": x.title})
for x in contents]
text_splitter = make_text_splitter(splitter_name=splitter_name,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap)
splitted_docs = text_splitter.split_documents(docs)
# 将切分好的文档放入临时向量库重新筛选出TOP_K个文档
if len(splitted_docs) > result_len:
vs = memo_faiss_pool.new_vector_store()
vs.add_documents(splitted_docs)
splitted_docs = vs.similarity_search(text, k=result_len, score_threshold=1.0)
docs = [{"snippet": x.page_content,
"link": x.metadata["link"],
"title": x.metadata["title"]}
for x in splitted_docs]
return docs
SEARCH_ENGINES = {"bing": bing_search,
"duckduckgo": duckduckgo_search,
"metaphor": metaphor_search,
}
@ -72,8 +114,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量当前默认为1024"),
# TODO: fastchat更新后默认值设为None自动使用LLM支持的最大值。
max_tokens: int = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
if search_engine_name not in SEARCH_ENGINES.keys():

View File

@ -140,7 +140,7 @@ if __name__ == "__main__":
ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings)
pprint(ids)
elif r == 2: # search docs
docs = vs.similarity_search_with_score(f"{name}", top_k=3, score_threshold=1.0)
docs = vs.similarity_search_with_score(f"{name}", k=3, score_threshold=1.0)
pprint(docs)
if r == 3: # delete docs
logger.warning(f"清除 {vs_name} by {name}")

View File

@ -1,7 +1,5 @@
import os
from transformers import AutoTokenizer
from configs import (
EMBEDDING_MODEL,
KB_ROOT_PATH,

View File

@ -1,5 +1,5 @@
from fastapi import Body
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT,LANGCHAIN_LLM_MODEL
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
get_httpx_client, get_model_worker_config)
@ -16,7 +16,7 @@ def list_running_models(
with get_httpx_client() as client:
r = client.post(controller_address + "/list_models")
models = r.json()["models"]
data = {m: get_model_worker_config(m) for m in models}
data = {m: get_model_config(m).data for m in models}
return BaseResponse(data=data)
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
@ -52,7 +52,6 @@ def get_model_config(
获取LLM模型配置项合并后的
'''
config = get_model_worker_config(model_name=model_name)
# 删除ONLINE_MODEL配置中的敏感信息
del_keys = set(["worker_class"])
for k in config:

View File

@ -65,7 +65,7 @@ def gen_params(appid, domain,question, temperature):
"chat": {
"domain": domain,
"random_threshold": 0.5,
"max_tokens": 2048,
"max_tokens": None,
"auditing": "default",
"temperature": temperature,
}

View File

@ -1,15 +1,15 @@
# import os
# import sys
# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import requests
import json
import time
import hashlib
from server.model_workers.base import ApiModelWorker
from server.utils import get_model_worker_config, get_httpx_client
from fastchat import conversation as conv
import sys
import json
from typing import List, Literal
from typing import List, Literal, Dict
from configs import TEMPERATURE
@ -20,29 +20,29 @@ def calculate_md5(input_string):
return encrypted
def do_request():
url = "https://api.baichuan-ai.com/v1/stream/chat"
api_key = ""
secret_key = ""
def request_baichuan_api(
messages: List[Dict[str, str]],
api_key: str = None,
secret_key: str = None,
version: str = "Baichuan2-53B",
temperature: float = TEMPERATURE,
model_name: str = "baichuan-api",
):
config = get_model_worker_config(model_name)
api_key = api_key or config.get("api_key")
secret_key = secret_key or config.get("secret_key")
version = version or config.get("version")
url = "https://api.baichuan-ai.com/v1/stream/chat"
data = {
"model": "Baichuan2-53B",
"messages": [
{
"role": "user",
"content": "世界第一高峰是"
}
],
"parameters": {
"temperature": 0.1,
"top_k": 10
}
"model": version,
"messages": messages,
"parameters": {"temperature": temperature}
}
json_data = json.dumps(data)
time_stamp = int(time.time())
signature = calculate_md5(secret_key + json_data + str(time_stamp))
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + api_key,
@ -52,18 +52,17 @@ def do_request():
"X-BC-Sign-Algo": "MD5",
}
response = requests.post(url, data=json_data, headers=headers)
if response.status_code == 200:
print("请求成功!")
print("响应header:", response.headers)
print("响应body:", response.text)
else:
print("请求失败,状态码:", response.status_code)
with get_httpx_client() as client:
with client.stream("POST", url, headers=headers, json=data) as response:
for line in response.iter_lines():
if not line.strip():
continue
resp = json.loads(line)
yield resp
class BaiChuanWorker(ApiModelWorker):
BASE_URL = "https://api.baichuan-ai.com/v1/chat"
BASE_URL = "https://api.baichuan-ai.com/v1/stream/chat"
SUPPORT_MODELS = ["Baichuan2-53B"]
def __init__(
@ -95,54 +94,34 @@ class BaiChuanWorker(ApiModelWorker):
self.secret_key = config.get("secret_key")
def generate_stream_gate(self, params):
data = {
"model": self.version,
"messages": [
{
"role": "user",
"content": params["prompt"]
}
],
"parameters": {
"temperature": params.get("temperature",TEMPERATURE),
"top_k": params.get("top_k",1)
}
}
super().generate_stream_gate(params)
json_data = json.dumps(data)
time_stamp = int(time.time())
signature = calculate_md5(self.secret_key + json_data + str(time_stamp))
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + self.api_key,
"X-BC-Request-Id": "your requestId",
"X-BC-Timestamp": str(time_stamp),
"X-BC-Signature": signature,
"X-BC-Sign-Algo": "MD5",
}
messages = self.prompt_to_messages(params["prompt"])
response = requests.post(self.BASE_URL, data=json_data, headers=headers)
text = ""
for resp in request_baichuan_api(messages=messages,
api_key=self.api_key,
secret_key=self.secret_key,
version=self.version,
temperature=params.get("temperature")):
if resp["code"] == 0:
text += resp["data"]["messages"][-1]["content"]
yield json.dumps(
{
"error_code": resp["code"],
"text": text
},
ensure_ascii=False
).encode() + b"\0"
else:
yield json.dumps(
{
"error_code": resp["code"],
"text": resp["msg"]
},
ensure_ascii=False
).encode() + b"\0"
if response.status_code == 200:
resp = eval(response.text)
yield json.dumps(
{
"error_code": resp["code"],
"text": resp["data"]["messages"][-1]["content"]
},
ensure_ascii=False
).encode() + b"\0"
else:
yield json.dumps(
{
"error_code": resp["code"],
"text": resp["msg"]
},
ensure_ascii=False
).encode() + b"\0"
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")

View File

@ -1,13 +1,13 @@
from configs.basic_config import LOG_PATH
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import BaseModelWorker
from fastchat.serve.base_model_worker import BaseModelWorker
import uuid
import json
import sys
from pydantic import BaseModel
import fastchat
import threading
import asyncio
from typing import Dict, List
@ -40,6 +40,7 @@ class ApiModelWorker(BaseModelWorker):
worker_addr=worker_addr,
**kwargs)
self.context_len = context_len
self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency)
self.init_heart_beat()
def count_token(self, params):
@ -62,15 +63,6 @@ class ApiModelWorker(BaseModelWorker):
print("embedding")
print(params)
# workaround to make program exit with Ctrl+c
# it should be deleted after pr is merged by fastchat
def init_heart_beat(self):
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
)
self.heart_beat_thread.start()
# help methods
def get_config(self):
from server.utils import get_model_worker_config

View File

@ -5,12 +5,11 @@ from fastapi import FastAPI
from pathlib import Path
import asyncio
from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL,
logger, log_verbose,
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, LANGCHAIN_LLM_MODEL, logger, log_verbose,
FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT)
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain.chat_models import ChatOpenAI
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI, ChatAnthropic
import httpx
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union
@ -40,19 +39,64 @@ def get_ChatOpenAI(
verbose: bool = True,
**kwargs: Any,
) -> ChatOpenAI:
config = get_model_worker_config(model_name)
model = ChatOpenAI(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
openai_api_key=config.get("api_key", "EMPTY"),
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
openai_proxy=config.get("openai_proxy"),
**kwargs
)
## 以下模型是Langchain原生支持的模型这些模型不会走Fschat封装
config_models = list_config_llm_models()
if model_name in config_models.get("langchain", {}):
config = config_models["langchain"][model_name]
if model_name == "Azure-OpenAI":
model = AzureChatOpenAI(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
deployment_name=config.get("deployment_name"),
model_version=config.get("model_version"),
openai_api_type=config.get("openai_api_type"),
openai_api_base=config.get("api_base_url"),
openai_api_version=config.get("api_version"),
openai_api_key=config.get("api_key"),
openai_proxy=config.get("openai_proxy"),
temperature=temperature,
max_tokens=max_tokens,
)
elif model_name == "OpenAI":
model = ChatOpenAI(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
model_name=config.get("model_name"),
openai_api_base=config.get("api_base_url"),
openai_api_key=config.get("api_key"),
openai_proxy=config.get("openai_proxy"),
temperature=temperature,
max_tokens=max_tokens,
)
elif model_name == "Anthropic":
model = ChatAnthropic(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
model_name=config.get("model_name"),
anthropic_api_key=config.get("api_key"),
)
## TODO 支持其他的Langchain原生支持的模型
else:
## 非Langchain原生支持的模型走Fschat封装
config = get_model_worker_config(model_name)
model = ChatOpenAI(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
openai_api_key=config.get("api_key", "EMPTY"),
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
openai_proxy=config.get("openai_proxy"),
**kwargs
)
return model
@ -249,8 +293,9 @@ def MakeFastAPIOffline(
redoc_favicon_url=favicon,
)
# 从model_config中获取模型信息
# 从model_config中获取模型信息
def list_embed_models() -> List[str]:
'''
get names of configured embedding models
@ -266,9 +311,9 @@ def list_config_llm_models() -> Dict[str, Dict]:
workers = list(FSCHAT_MODEL_WORKERS)
if LLM_MODEL not in workers:
workers.insert(0, LLM_MODEL)
return {
"local": MODEL_PATH["llm_model"],
"langchain": LANGCHAIN_LLM_MODEL,
"online": ONLINE_LLM_MODEL,
"worker": workers,
}
@ -300,8 +345,9 @@ def get_model_path(model_name: str, type: str = None) -> Optional[str]:
return str(path)
return path_str # THUDM/chatglm06b
# 从server_config中获取服务信息
# 从server_config中获取服务信息
def get_model_worker_config(model_name: str = None) -> dict:
'''
加载model worker的配置项
@ -316,6 +362,10 @@ def get_model_worker_config(model_name: str = None) -> dict:
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
# 在线模型API
if model_name in LANGCHAIN_LLM_MODEL:
config["langchain_model"] = True
config["worker_class"] = ""
if model_name in ONLINE_LLM_MODEL:
config["online_api"] = True
if provider := config.get("provider"):
@ -389,7 +439,7 @@ def webui_address() -> str:
return f"http://{host}:{port}"
def get_prompt_template(type:str,name: str) -> Optional[str]:
def get_prompt_template(type: str, name: str) -> Optional[str]:
'''
从prompt_config中加载模板内容
type: "llm_chat","agent_chat","knowledge_base_chat","search_engine_chat"的其中一种如果有新功能应该进行加入
@ -459,8 +509,9 @@ def set_httpx_config(
import urllib.request
urllib.request.getproxies = _get_proxies
# 自动检查torch可用的设备。分布式部署时不运行LLM的机器上可以不装torch
# 自动检查torch可用的设备。分布式部署时不运行LLM的机器上可以不装torch
def detect_device() -> Literal["cuda", "mps", "cpu"]:
try:
import torch
@ -568,6 +619,8 @@ def get_server_configs() -> Dict:
获取configs中的原始配置项供前端使用
'''
from configs.kb_config import (
DEFAULT_KNOWLEDGE_BASE,
DEFAULT_SEARCH_ENGINE,
DEFAULT_VS_TYPE,
CHUNK_SIZE,
OVERLAP_SIZE,

View File

@ -68,7 +68,9 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
controller_address:
worker_address:
对于Langchain支持的模型
langchain_model:True
不会使用fschat
对于online_api:
online_api:True
worker_class: `provider`
@ -78,31 +80,34 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
"""
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import worker_id, logger
import argparse
logger.setLevel(log_level)
parser = argparse.ArgumentParser()
args = parser.parse_args([])
for k, v in kwargs.items():
setattr(args, k, v)
if worker_class := kwargs.get("langchain_model"): #Langchian支持的模型不用做操作
from fastchat.serve.base_model_worker import app
worker = ""
# 在线模型API
if worker_class := kwargs.get("worker_class"):
from fastchat.serve.model_worker import app
elif worker_class := kwargs.get("worker_class"):
from fastchat.serve.base_model_worker import app
worker = worker_class(model_names=args.model_names,
controller_addr=args.controller_address,
worker_addr=args.worker_address)
sys.modules["fastchat.serve.model_worker"].worker = worker
# sys.modules["fastchat.serve.base_model_worker"].worker = worker
sys.modules["fastchat.serve.base_model_worker"].logger.setLevel(log_level)
# 本地模型
else:
from configs.model_config import VLLM_MODEL_DICT
if kwargs["model_names"][0] in VLLM_MODEL_DICT and args.infer_turbo == "vllm":
import fastchat.serve.vllm_worker
from fastchat.serve.vllm_worker import VLLMWorker,app
from fastchat.serve.vllm_worker import VLLMWorker, app
from vllm import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs,EngineArgs
args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加
args.tokenizer_mode = 'auto'
args.trust_remote_code= True
@ -126,8 +131,8 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
args.engine_use_ray = False
args.disable_log_requests = False
# 0.2.0 vllm后要加的参数
args.max_model_len = 8192 # 模型可以处理的最大序列长度。请根据你的大模型设置,
# 0.2.0 vllm后要加的参数, 但是这里不需要
args.max_model_len = None
args.revision = None
args.quantization = None
args.max_log_len = None
@ -155,10 +160,12 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
conv_template = args.conv_template,
)
sys.modules["fastchat.serve.vllm_worker"].engine = engine
sys.modules["fastchat.serve.vllm_worker"].worker = worker
# sys.modules["fastchat.serve.vllm_worker"].worker = worker
sys.modules["fastchat.serve.vllm_worker"].logger.setLevel(log_level)
else:
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
args.gpus = "0" # GPU的编号,如果有多个GPU可以设置为"0,1,2,3"
args.max_gpu_memory = "22GiB"
args.num_gpus = 1 # model worker的切分是model并行这里填写显卡的数量
@ -221,8 +228,8 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
)
sys.modules["fastchat.serve.model_worker"].args = args
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
sys.modules["fastchat.serve.model_worker"].worker = worker
# sys.modules["fastchat.serve.model_worker"].worker = worker
sys.modules["fastchat.serve.model_worker"].logger.setLevel(log_level)
MakeFastAPIOffline(app)
app.title = f"FastChat LLM Server ({args.model_names[0]})"

View File

@ -0,0 +1,16 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.model_workers.baichuan import request_baichuan_api
from pprint import pprint
def test_qwen():
messages = [{"role": "user", "content": "hello"}]
for x in request_baichuan_api(messages):
print(type(x))
pprint(x)
assert x["code"] == 0

View File

@ -21,13 +21,6 @@ if __name__ == "__main__":
}
)
if not chat_box.chat_inited:
running_models = api.list_running_models()
st.toast(
f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n"
f"当前运行中的模型`{running_models}`, 您可以开始提问了."
)
pages = {
"对话": {
"icon": "chat",

View File

@ -3,7 +3,8 @@ from webui_pages.utils import *
from streamlit_chatbox import *
from datetime import datetime
import os
from configs import LLM_MODEL, TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES
from configs import (LLM_MODEL, TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE,LANGCHAIN_LLM_MODEL)
from typing import List, Dict
chat_box = ChatBox(
@ -40,7 +41,6 @@ def get_default_llm_model(api: ApiRequest) -> (str, bool):
返回类型为model_name, is_local_model
'''
running_models = api.list_running_models()
if not running_models:
return "", False
@ -50,12 +50,17 @@ def get_default_llm_model(api: ApiRequest) -> (str, bool):
local_models = [k for k, v in running_models.items() if not v.get("online_api")]
if local_models:
return local_models[0], True
return running_models[0], False
return list(running_models)[0], False
def dialogue_page(api: ApiRequest):
chat_box.init_session()
if not chat_box.chat_inited:
default_model = get_default_llm_model(api)[0]
st.toast(
f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n"
f"当前运行的模型`{default_model}`, 您可以开始提问了."
)
chat_box.init_session()
with st.sidebar:
# TODO: 对话模型与会话绑定
@ -74,16 +79,17 @@ def dialogue_page(api: ApiRequest):
"搜索引擎问答",
"自定义Agent问答",
],
index=3,
index=0,
on_change=on_mode_change,
key="dialogue_mode",
)
def on_llm_change():
config = api.get_model_config(llm_model)
if not config.get("online_api"): # 只有本地model_worker可以切换模型
st.session_state["prev_llm_model"] = llm_model
st.session_state["cur_llm_model"] = st.session_state.llm_model
if llm_model:
config = api.get_model_config(llm_model)
if not config.get("online_api"): # 只有本地model_worker可以切换模型
st.session_state["prev_llm_model"] = llm_model
st.session_state["cur_llm_model"] = st.session_state.llm_model
def llm_model_format_func(x):
if x in running_models:
@ -91,16 +97,18 @@ def dialogue_page(api: ApiRequest):
return x
running_models = list(api.list_running_models())
running_models += LANGCHAIN_LLM_MODEL.keys()
available_models = []
config_models = api.list_config_models()
worker_models = list(config_models.get("worker", {})) # 仅列出在FSCHAT_MODEL_WORKERS中配置的模型
for m in worker_models:
if m not in running_models and m != "default":
available_models.append(m)
for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型如GPT
for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型
if not v.get("provider") and k not in running_models:
print(k, v)
available_models.append(k)
for k, v in config_models.get("langchain", {}).items(): # 列出LANGCHAIN_LLM_MODEL支持的模型
available_models.append(k)
llm_models = running_models + available_models
index = llm_models.index(st.session_state.get("cur_llm_model", get_default_llm_model(api)[0]))
llm_model = st.selectbox("选择LLM模型",
@ -111,7 +119,8 @@ def dialogue_page(api: ApiRequest):
key="llm_model",
)
if (st.session_state.get("prev_llm_model") != llm_model
and not api.get_model_config(llm_model).get("online_api")
and not llm_model in config_models.get("online", {})
and not llm_model in config_models.get("langchain", {})
and llm_model not in running_models):
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
prev_model = st.session_state.get("prev_llm_model")
@ -156,9 +165,13 @@ def dialogue_page(api: ApiRequest):
if dialogue_mode == "知识库问答":
with st.expander("知识库配置", True):
kb_list = api.list_knowledge_bases()
index = 0
if DEFAULT_KNOWLEDGE_BASE in kb_list:
index = kb_list.index(DEFAULT_KNOWLEDGE_BASE)
selected_kb = st.selectbox(
"请选择知识库:",
kb_list,
index=index,
on_change=on_kb_change,
key="selected_kb",
)
@ -167,11 +180,15 @@ def dialogue_page(api: ApiRequest):
elif dialogue_mode == "搜索引擎问答":
search_engine_list = api.list_search_engines()
if DEFAULT_SEARCH_ENGINE in search_engine_list:
index = search_engine_list.index(DEFAULT_SEARCH_ENGINE)
else:
index = search_engine_list.index("duckduckgo") if "duckduckgo" in search_engine_list else 0
with st.expander("搜索引擎配置", True):
search_engine = st.selectbox(
label="请选择搜索引擎",
options=search_engine_list,
index=search_engine_list.index("duckduckgo") if "duckduckgo" in search_engine_list else 0,
index=index,
)
se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, SEARCH_ENGINE_TOP_K)
@ -210,9 +227,9 @@ def dialogue_page(api: ApiRequest):
])
text = ""
ans = ""
support_agent = ["gpt", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型
support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型
if not any(agent in llm_model for agent in support_agent):
ans += "正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐无法正常使用Agent功能</span>\n\n\n<span style='color:red'>请更换 GPT4或Qwen-14B等支持Agent的模型获得更好的体验 </span> \n\n\n"
ans += "正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐请更换支持Agent的模型获得更好的体验</span>\n\n\n"
chat_box.update_msg(ans, element_index=0, streaming=False)
for d in api.agent_chat(prompt,
history=history,

View File

@ -245,7 +245,7 @@ class ApiRequest:
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = TEMPERATURE,
max_tokens: int = 1024,
max_tokens: int = None,
**kwargs: Any,
):
'''
@ -278,7 +278,7 @@ class ApiRequest:
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = TEMPERATURE,
max_tokens: int = 1024,
max_tokens: int = None,
prompt_name: str = "default",
**kwargs,
):
@ -308,7 +308,7 @@ class ApiRequest:
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = TEMPERATURE,
max_tokens: int = 1024,
max_tokens: int = None,
prompt_name: str = "default",
):
'''
@ -340,7 +340,7 @@ class ApiRequest:
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = TEMPERATURE,
max_tokens: int = 1024,
max_tokens: int = None,
prompt_name: str = "default",
):
'''
@ -378,7 +378,7 @@ class ApiRequest:
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = TEMPERATURE,
max_tokens: int = 1024,
max_tokens: int = None,
prompt_name: str = "default",
):
'''