diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 74e90233..46481da5 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -108,6 +108,10 @@ LLM_DEVICE = "auto" # 历史对话轮数 HISTORY_LEN = 3 +# 大模型最长支持的长度,如果不填写,则使用模型默认的最大长度,如果填写,则为用户设定的最大长度 + +MAX_TOKENS = None + # LLM通用对话参数 TEMPERATURE = 0.7 # TOP_P = 0.95 # ChatOpenAI暂不支持该参数 @@ -132,7 +136,7 @@ ONLINE_LLM_MODEL = { "APPID": "", "APISecret": "", "api_key": "", - "is_v2": False, + "version": "v1.5", # 你使用的讯飞星火大模型版本,可选包括 "v3.0", "v1.5", "v2.0" "provider": "XingHuoWorker", }, # 百度千帆 API,申请方式请参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/4lilb2lpf diff --git a/server/agent/tools/search_all_knowledge_more.py b/server/agent/tools/search_all_knowledge_more.py index 4108887d..a2d76b00 100644 --- a/server/agent/tools/search_all_knowledge_more.py +++ b/server/agent/tools/search_all_knowledge_more.py @@ -11,7 +11,7 @@ from langchain.schema.language_model import BaseLanguageModel from typing import List, Any, Optional from langchain.prompts import PromptTemplate from server.chat.knowledge_base_chat import knowledge_base_chat -from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD +from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS import asyncio from server.agent import model_container @@ -23,7 +23,7 @@ async def search_knowledge_base_iter(database: str, query: str) -> str: temperature=0.01, history=[], top_k=VECTOR_SEARCH_TOP_K, - max_tokens=None, + max_tokens=MAX_TOKENS, prompt_name="default", score_threshold=SCORE_THRESHOLD, stream=False) diff --git a/server/agent/tools/search_all_knowledge_once.py b/server/agent/tools/search_all_knowledge_once.py index 98ad2bbc..6e9d9659 100644 --- a/server/agent/tools/search_all_knowledge_once.py +++ b/server/agent/tools/search_all_knowledge_once.py @@ -19,7 +19,7 @@ import json sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) from server.chat.knowledge_base_chat import knowledge_base_chat -from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD +from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS import asyncio from server.agent import model_container @@ -32,7 +32,7 @@ async def search_knowledge_base_iter(database: str, query: str): temperature=0.01, history=[], top_k=VECTOR_SEARCH_TOP_K, - max_tokens=None, + max_tokens=MAX_TOKENS, prompt_name="knowledge_base_chat", score_threshold=SCORE_THRESHOLD, stream=False) diff --git a/server/agent/tools/search_internet.py b/server/agent/tools/search_internet.py index 5266efc9..0d52789c 100644 --- a/server/agent/tools/search_internet.py +++ b/server/agent/tools/search_internet.py @@ -1,6 +1,6 @@ import json from server.chat.search_engine_chat import search_engine_chat -from configs import VECTOR_SEARCH_TOP_K +from configs import VECTOR_SEARCH_TOP_K, MAX_TOKENS import asyncio from server.agent import model_container @@ -11,7 +11,7 @@ async def search_engine_iter(query: str): temperature=0.01, # Agent 搜索互联网的时候,温度设置为0.01 history=[], top_k = VECTOR_SEARCH_TOP_K, - max_tokens= None, # Agent 搜索互联网的时候,max_tokens设置为None + max_tokens= MAX_TOKENS, prompt_name = "default", stream=False) diff --git a/server/agent/tools/search_knowledge_simple.py b/server/agent/tools/search_knowledge_simple.py index bad5ed5e..fbffbd7e 100644 --- a/server/agent/tools/search_knowledge_simple.py +++ b/server/agent/tools/search_knowledge_simple.py @@ -1,5 +1,5 @@ from server.chat.knowledge_base_chat import knowledge_base_chat -from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD +from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS import json import asyncio from server.agent import model_container @@ -11,6 +11,7 @@ async def search_knowledge_base_iter(database: str, query: str) -> str: temperature=0.01, history=[], top_k=VECTOR_SEARCH_TOP_K, + max_tokens=MAX_TOKENS, prompt_name="knowledge_base_chat", score_threshold=SCORE_THRESHOLD, stream=False) diff --git a/server/model_workers/SparkApi.py b/server/model_workers/SparkApi.py index c4e090e8..795b1f73 100644 --- a/server/model_workers/SparkApi.py +++ b/server/model_workers/SparkApi.py @@ -52,7 +52,7 @@ class Ws_Param(object): return url -def gen_params(appid, domain,question, temperature): +def gen_params(appid, domain, question, temperature, max_token): """ 通过appid和用户的提问来生成请参数 """ @@ -65,7 +65,7 @@ def gen_params(appid, domain,question, temperature): "chat": { "domain": domain, "random_threshold": 0.5, - "max_tokens": None, + "max_tokens": max_token, "auditing": "default", "temperature": temperature, } diff --git a/server/model_workers/xinghuo.py b/server/model_workers/xinghuo.py index 0a3a9c8b..6f2fbfb8 100644 --- a/server/model_workers/xinghuo.py +++ b/server/model_workers/xinghuo.py @@ -9,32 +9,31 @@ from server.utils import iter_over_async, asyncio from typing import List, Dict -async def request(appid, api_key, api_secret, Spark_url,domain, question, temperature): - # print("星火:") +async def request(appid, api_key, api_secret, Spark_url, domain, question, temperature, max_token): wsParam = SparkApi.Ws_Param(appid, api_key, api_secret, Spark_url) wsUrl = wsParam.create_url() - data = SparkApi.gen_params(appid, domain, question, temperature) + data = SparkApi.gen_params(appid, domain, question, temperature, max_token) async with websockets.connect(wsUrl) as ws: await ws.send(json.dumps(data, ensure_ascii=False)) finish = False while not finish: - chunk = await ws.recv() - response = json.loads(chunk) - if response.get("header", {}).get("status") == 2: - finish = True - if text := response.get("payload", {}).get("choices", {}).get("text"): + chunk = await ws.recv() + response = json.loads(chunk) + if response.get("header", {}).get("status") == 2: + finish = True + if text := response.get("payload", {}).get("choices", {}).get("text"): yield text[0]["content"] class XingHuoWorker(ApiModelWorker): def __init__( - self, - *, - model_names: List[str] = ["xinghuo-api"], - controller_addr: str = None, - worker_addr: str = None, - version: str = None, - **kwargs, + self, + *, + model_names: List[str] = ["xinghuo-api"], + controller_addr: str = None, + worker_addr: str = None, + version: str = None, + **kwargs, ): kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.setdefault("context_len", 8192) @@ -45,27 +44,34 @@ class XingHuoWorker(ApiModelWorker): # TODO: 当前每次对话都要重新连接websocket,确认是否可以保持连接 params.load_config(self.model_names[0]) - if params.is_v2: - domain = "generalv2" # v2.0版本 - Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" # v2.0环境的地址 - else: - domain = "general" # v1.5版本 - Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat" # v1.5环境的地址 + version_mapping = { + "v1.5": {"domain": "general", "url": "ws://spark-api.xf-yun.com/v1.1/chat","max_tokens": 2048}, + "v2.0": {"domain": "generalv2", "url": "ws://spark-api.xf-yun.com/v2.1/chat","max_tokens": 4096}, + "v3.0": {"domain": "generalv3", "url": "ws://spark-api.xf-yun.com/v3.1/chat","max_tokens": 8192}, + } + def get_version_details(version_key): + return version_mapping.get(version_key, {"domain": None, "url": None}) + + # 使用方法: + details = get_version_details(params.version) + domain = details["domain"] + Spark_url = details["url"] text = "" try: loop = asyncio.get_event_loop() except: loop = asyncio.new_event_loop() - + params.max_tokens = min(details["max_tokens"], params.max_tokens) for chunk in iter_over_async( - request(params.APPID, params.api_key, params.APISecret, Spark_url, domain, params.messages, params.temperature), - loop=loop, + request(params.APPID, params.api_key, params.APISecret, Spark_url, domain, params.messages, + params.temperature, params.max_tokens), + loop=loop, ): if chunk: text += chunk yield {"error_code": 0, "text": text} - + def get_embeddings(self, params): # TODO: 支持embeddings print("embedding")