diff --git a/embeddings/add_embedding_keywords.py b/embeddings/add_embedding_keywords.py index 622a4cac..f46dee29 100644 --- a/embeddings/add_embedding_keywords.py +++ b/embeddings/add_embedding_keywords.py @@ -7,31 +7,35 @@ 保存的模型的位置位于原本嵌入模型的目录下,模型的名称为原模型名称+Merge_Keywords_时间戳 ''' import sys + sys.path.append("..") +import os +import torch + from datetime import datetime from configs import ( MODEL_PATH, EMBEDDING_MODEL, EMBEDDING_KEYWORD_FILE, ) -import os -import torch + from safetensors.torch import save_model from sentence_transformers import SentenceTransformer +from langchain_core._api import deprecated +@deprecated( + since="0.3.0", + message="自定义关键词 Langchain-Chatchat 0.3.x 重写, 0.2.x中相关功能将废弃", + removal="0.3.0" + ) def get_keyword_embedding(bert_model, tokenizer, key_words): tokenizer_output = tokenizer(key_words, return_tensors="pt", padding=True, truncation=True) - - # No need to manually convert to tensor as we've set return_tensors="pt" input_ids = tokenizer_output['input_ids'] - - # Remove the first and last token for each sequence in the batch input_ids = input_ids[:, 1:-1] keyword_embedding = bert_model.embeddings.word_embeddings(input_ids) keyword_embedding = torch.mean(keyword_embedding, 1) - return keyword_embedding @@ -47,14 +51,11 @@ def add_keyword_to_model(model_name=EMBEDDING_MODEL, keyword_file: str = "", out bert_model = word_embedding_model.auto_model tokenizer = word_embedding_model.tokenizer key_words_embedding = get_keyword_embedding(bert_model, tokenizer, key_words) - # key_words_embedding = st_model.encode(key_words) embedding_weight = bert_model.embeddings.word_embeddings.weight embedding_weight_len = len(embedding_weight) tokenizer.add_tokens(key_words) bert_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32) - - # key_words_embedding_tensor = torch.from_numpy(key_words_embedding) embedding_weight = bert_model.embeddings.word_embeddings.weight with torch.no_grad(): embedding_weight[embedding_weight_len:embedding_weight_len + key_words_len, :] = key_words_embedding @@ -76,46 +77,3 @@ def add_keyword_to_embedding_model(path: str = EMBEDDING_KEYWORD_FILE): output_model_name = "{}_Merge_Keywords_{}".format(EMBEDDING_MODEL, current_time) output_model_path = os.path.join(model_parent_directory, output_model_name) add_keyword_to_model(model_name, keyword_file, output_model_path) - - -if __name__ == '__main__': - add_keyword_to_embedding_model(EMBEDDING_KEYWORD_FILE) - - # input_model_name = "" - # output_model_path = "" - # # 以下为加入关键字前后tokenizer的测试用例对比 - # def print_token_ids(output, tokenizer, sentences): - # for idx, ids in enumerate(output['input_ids']): - # print(f'sentence={sentences[idx]}') - # print(f'ids={ids}') - # for id in ids: - # decoded_id = tokenizer.decode(id) - # print(f' {decoded_id}->{id}') - # - # sentences = [ - # '数据科学与大数据技术', - # 'Langchain-Chatchat' - # ] - # - # st_no_keywords = SentenceTransformer(input_model_name) - # tokenizer_without_keywords = st_no_keywords.tokenizer - # print("===== tokenizer with no keywords added =====") - # output = tokenizer_without_keywords(sentences) - # print_token_ids(output, tokenizer_without_keywords, sentences) - # print(f'-------- embedding with no keywords added -----') - # embeddings = st_no_keywords.encode(sentences) - # print(embeddings) - # - # print("--------------------------------------------") - # print("--------------------------------------------") - # print("--------------------------------------------") - # - # st_with_keywords = SentenceTransformer(output_model_path) - # tokenizer_with_keywords = st_with_keywords.tokenizer - # print("===== tokenizer with keyword added =====") - # output = tokenizer_with_keywords(sentences) - # print_token_ids(output, tokenizer_with_keywords, sentences) - # - # print(f'-------- embedding with keywords added -----') - # embeddings = st_with_keywords.encode(sentences) - # print(embeddings) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index b041742e..5ab39657 100644 --- a/requirements.txt +++ b/requirements.txt @@ -38,13 +38,11 @@ numpy~=1.24.4 pandas~=2.0.3 einops>=0.7.0 transformers_stream_generator==0.0.4 -vllm==0.2.6; sys_platform == "linux" -httpx[brotli,http2,socks]==0.25.2 -llama-index +vllm==0.2.7; sys_platform == "linux" # optional document loaders -# rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files +#rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files jq==1.6.0 # for .json and .jsonl files. suggest `conda install jq` on windows beautifulsoup4~=4.12.2 # for .mhtml files pysrt~=1.1.2 @@ -69,9 +67,11 @@ metaphor-python~=0.1.23 # WebUI requirements -streamlit~=1.29.0 -streamlit-option-menu>=0.3.6 +streamlit==1.30.0 +streamlit-option-menu==0.3.6 +streamlit-antd-components==0.3.1 streamlit-chatbox==1.1.11 -streamlit-modal>=0.1.0 -streamlit-aggrid>=0.3.4.post3 -watchdog>=3.0.0 +streamlit-modal==0.1.0 +streamlit-aggrid==0.3.4.post3 +httpx==0.26.0 +watchdog==3.0.0 \ No newline at end of file diff --git a/requirements_api.txt b/requirements_api.txt index 5cd1048f..0e2a9eee 100644 --- a/requirements_api.txt +++ b/requirements_api.txt @@ -36,8 +36,8 @@ numpy~=1.24.4 pandas~=2.0.3 einops>=0.7.0 transformers_stream_generator==0.0.4 -vllm==0.2.6; sys_platform == "linux" -httpx[brotli,http2,socks]==0.25.2 +vllm==0.2.7; sys_platform == "linux" +httpx==0.26.0 llama-index # optional document loaders diff --git a/requirements_lite.txt b/requirements_lite.txt index 584d0c75..db57273a 100644 --- a/requirements_lite.txt +++ b/requirements_lite.txt @@ -27,7 +27,7 @@ numpy~=1.24.4 pandas~=2.0.3 einops>=0.7.0 transformers_stream_generator==0.0.4 -vllm==0.2.6; sys_platform == "linux" +vllm==0.2.7; sys_platform == "linux" httpx[brotli,http2,socks]==0.25.2 requests pathlib @@ -54,11 +54,11 @@ metaphor-python~=0.1.23 # WebUI requirements -streamlit>=1.29.0 -streamlit-option-menu>=0.3.6 -streamlit-antd-components>=0.3.0 -streamlit-chatbox>=1.1.11 -streamlit-modal>=0.1.0 -streamlit-aggrid>=0.3.4.post3 -httpx[brotli,http2,socks]>=0.25.2 -watchdog>=3.0.0 \ No newline at end of file +streamlit==1.30.0 +streamlit-option-menu==0.3.6 +streamlit-antd-components==0.3.1 +streamlit-chatbox==1.1.11 +streamlit-modal==0.1.0 +streamlit-aggrid==0.3.4.post3 +httpx==0.26.0 +watchdog==3.0.0 \ No newline at end of file diff --git a/requirements_webui.txt b/requirements_webui.txt index 28810274..111dedaa 100644 --- a/requirements_webui.txt +++ b/requirements_webui.txt @@ -1,10 +1,10 @@ # WebUI requirements -streamlit>=1.29.0 -streamlit-option-menu>=0.3.6 -streamlit-antd-components>=0.3.0 -streamlit-chatbox>=1.1.11 -streamlit-modal>=0.1.0 -streamlit-aggrid>=0.3.4.post3 -httpx[brotli,http2,socks]>=0.25.2 -watchdog>=3.0.0 \ No newline at end of file +streamlit==1.30.0 +streamlit-option-menu==0.3.6 +streamlit-antd-components==0.3.1 +streamlit-chatbox==1.1.11 +streamlit-modal==0.1.0 +streamlit-aggrid==0.3.4.post3 +httpx==0.26.0 +watchdog==3.0.0 \ No newline at end of file diff --git a/server/agent/custom_agent/ChatGLM3Agent.py b/server/agent/custom_agent/ChatGLM3Agent.py index 92ddc28b..65f57567 100644 --- a/server/agent/custom_agent/ChatGLM3Agent.py +++ b/server/agent/custom_agent/ChatGLM3Agent.py @@ -1,22 +1,19 @@ """ -This file is a modified version for ChatGLM3-6B the original ChatGLM3Agent.py file from the langchain repo. +This file is a modified version for ChatGLM3-6B the original glm3_agent.py file from the langchain repo. """ from __future__ import annotations -import yaml -from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser -from langchain.memory import ConversationBufferWindowMemory -from typing import Any, List, Sequence, Tuple, Optional, Union -import os -from langchain.agents.agent import Agent -from langchain.chains.llm import LLMChain -from langchain.prompts.chat import ( - ChatPromptTemplate, - HumanMessagePromptTemplate, - SystemMessagePromptTemplate, MessagesPlaceholder, -) import json import logging +from typing import Any, List, Sequence, Tuple, Optional, Union +from pydantic.schema import model_schema + + +from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser +from langchain.memory import ConversationBufferWindowMemory +from langchain.agents.agent import Agent +from langchain.chains.llm import LLMChain +from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate from langchain.agents.agent import AgentOutputParser from langchain.output_parsers import OutputFixingParser from langchain.pydantic_v1 import Field @@ -43,12 +40,18 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser): first_index = min([text.find(token) if token in text else len(text) for token in special_tokens]) text = text[:first_index] if "tool_call" in text: - tool_name_end = text.find("```") - tool_name = text[:tool_name_end].strip() - input_para = text.split("='")[-1].split("'")[0] + action_end = text.find("```") + action = text[:action_end].strip() + params_str_start = text.find("(") + 1 + params_str_end = text.rfind(")") + params_str = text[params_str_start:params_str_end] + + params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param] + params = {pair[0].strip(): pair[1].strip().strip("'\"") for pair in params_pairs} + action_json = { - "action": tool_name, - "action_input": input_para + "action": action, + "action_input": params } else: action_json = { @@ -109,10 +112,6 @@ class StructuredGLM3ChatAgent(Agent): else: return agent_scratchpad - @classmethod - def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: - pass - @classmethod def _get_default_output_parser( cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any @@ -121,7 +120,7 @@ class StructuredGLM3ChatAgent(Agent): @property def _stop(self) -> List[str]: - return ["```"] + return ["<|observation|>"] @classmethod def create_prompt( @@ -131,44 +130,25 @@ class StructuredGLM3ChatAgent(Agent): input_variables: Optional[List[str]] = None, memory_prompts: Optional[List[BasePromptTemplate]] = None, ) -> BasePromptTemplate: - def tool_config_from_file(tool_name, directory="server/agent/tools/"): - """search tool yaml and return simplified json format""" - file_path = os.path.join(directory, f"{tool_name.lower()}.yaml") - try: - with open(file_path, 'r', encoding='utf-8') as file: - tool_config = yaml.safe_load(file) - # Simplify the structure if needed - simplified_config = { - "name": tool_config.get("name", ""), - "description": tool_config.get("description", ""), - "parameters": tool_config.get("parameters", {}) - } - return simplified_config - except FileNotFoundError: - logger.error(f"File not found: {file_path}") - return None - except Exception as e: - logger.error(f"An error occurred while reading {file_path}: {e}") - return None - tools_json = [] tool_names = [] for tool in tools: - tool_config = tool_config_from_file(tool.name) - if tool_config: - tools_json.append(tool_config) - tool_names.append(tool.name) - - # Format the tools for output + tool_schema = model_schema(tool.args_schema) if tool.args_schema else {} + simplified_config_langchain = { + "name": tool.name, + "description": tool.description, + "parameters": tool_schema.get("properties", {}) + } + tools_json.append(simplified_config_langchain) + tool_names.append(tool.name) formatted_tools = "\n".join([ f"{tool['name']}: {tool['description']}, args: {tool['parameters']}" for tool in tools_json ]) formatted_tools = formatted_tools.replace("'", "\\'").replace("{", "{{").replace("}", "}}") - template = prompt.format(tool_names=tool_names, tools=formatted_tools, - history="{history}", + history="None", input="{input}", agent_scratchpad="{agent_scratchpad}") @@ -225,7 +205,6 @@ def initialize_glm3_agent( tools: Sequence[BaseTool], llm: BaseLanguageModel, prompt: str = None, - callback_manager: Optional[BaseCallbackManager] = None, memory: Optional[ConversationBufferWindowMemory] = None, agent_kwargs: Optional[dict] = None, *, @@ -238,14 +217,12 @@ def initialize_glm3_agent( llm=llm, tools=tools, prompt=prompt, - callback_manager=callback_manager, **agent_kwargs + **agent_kwargs ) return AgentExecutor.from_agent_and_tools( agent=agent_obj, tools=tools, - callback_manager=callback_manager, memory=memory, tags=tags_, **kwargs, - ) - + ) \ No newline at end of file diff --git a/server/agent/model_contain.py b/server/agent/model_contain.py index 1927c88f..0141ad03 100644 --- a/server/agent/model_contain.py +++ b/server/agent/model_contain.py @@ -1,5 +1,3 @@ - -## 由于工具类无法传参,所以使用全局变量来传递模型和对应的知识库介绍 class ModelContainer: def __init__(self): self.MODEL = None diff --git a/server/agent/tools/__init__.py b/server/agent/tools/__init__.py index 43d57343..21dfd973 100644 --- a/server/agent/tools/__init__.py +++ b/server/agent/tools/__init__.py @@ -3,7 +3,7 @@ from .search_knowledgebase_simple import search_knowledgebase_simple from .search_knowledgebase_once import search_knowledgebase_once, KnowledgeSearchInput from .search_knowledgebase_complex import search_knowledgebase_complex, KnowledgeSearchInput from .calculate import calculate, CalculatorInput -from .weather_check import weathercheck, WhetherSchema +from .weather_check import weathercheck, WeatherInput from .shell import shell, ShellInput from .search_internet import search_internet, SearchInternetInput from .wolfram import wolfram, WolframInput diff --git a/server/agent/tools/arxiv.yaml b/server/agent/tools/arxiv.yaml deleted file mode 100644 index 5da8a9f9..00000000 --- a/server/agent/tools/arxiv.yaml +++ /dev/null @@ -1,10 +0,0 @@ -name: arxiv -description: A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields. -parameters: - type: object - properties: - query: - type: string - description: The search query title -required: - - query \ No newline at end of file diff --git a/server/agent/tools/calculate.yaml b/server/agent/tools/calculate.yaml deleted file mode 100644 index 2976c011..00000000 --- a/server/agent/tools/calculate.yaml +++ /dev/null @@ -1,10 +0,0 @@ -name: calculate -description: Useful for when you need to answer questions about simple calculations -parameters: - type: object - properties: - query: - type: string - description: The formula to be calculated -required: - - query \ No newline at end of file diff --git a/server/agent/tools/search_internet.yaml b/server/agent/tools/search_internet.yaml deleted file mode 100644 index 57608e7e..00000000 --- a/server/agent/tools/search_internet.yaml +++ /dev/null @@ -1,10 +0,0 @@ -name: search_internet -description: Use this tool to surf internet and get information -parameters: - type: object - properties: - query: - type: string - description: Query for Internet search -required: - - query \ No newline at end of file diff --git a/server/agent/tools/search_knowledgebase_complex.yaml b/server/agent/tools/search_knowledgebase_complex.yaml deleted file mode 100644 index bccbfd8c..00000000 --- a/server/agent/tools/search_knowledgebase_complex.yaml +++ /dev/null @@ -1,10 +0,0 @@ -name: search_knowledgebase_complex -description: Use this tool to search local knowledgebase and get information -parameters: - type: object - properties: - query: - type: string - description: The query to be searched -required: - - query \ No newline at end of file diff --git a/server/agent/tools/search_youtube.yaml b/server/agent/tools/search_youtube.yaml deleted file mode 100644 index ce446068..00000000 --- a/server/agent/tools/search_youtube.yaml +++ /dev/null @@ -1,10 +0,0 @@ -name: search_youtube -description: Use this tools to search youtube videos -parameters: - type: object - properties: - query: - type: string - description: Query for Videos search -required: - - query \ No newline at end of file diff --git a/server/agent/tools/shell.yaml b/server/agent/tools/shell.yaml deleted file mode 100644 index 22b94180..00000000 --- a/server/agent/tools/shell.yaml +++ /dev/null @@ -1,10 +0,0 @@ -name: shell -description: Use Linux Shell to execute Linux commands -parameters: - type: object - properties: - query: - type: string - description: The command to execute -required: - - query \ No newline at end of file diff --git a/server/agent/tools/weather_check.py b/server/agent/tools/weather_check.py index db88b874..7e55c7cb 100644 --- a/server/agent/tools/weather_check.py +++ b/server/agent/tools/weather_check.py @@ -1,338 +1,25 @@ -from __future__ import annotations - -## 单独运行的时候需要添加 -import sys -import os - -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) - -import re -import warnings -from typing import Dict - -from langchain.callbacks.manager import ( - AsyncCallbackManagerForChainRun, - CallbackManagerForChainRun, -) -from langchain.chains.base import Chain -from langchain.chains.llm import LLMChain -from langchain.pydantic_v1 import Extra, root_validator -from langchain.schema import BasePromptTemplate -from langchain.schema.language_model import BaseLanguageModel -import requests -from typing import List, Any, Optional -from datetime import datetime -from langchain.prompts import PromptTemplate -from server.agent import model_container -from pydantic import BaseModel, Field - -## 使用和风天气API查询天气 -KEY = "ac880e5a877042809ac7ffdd19d95b0d" -# key长这样,这里提供了示例的key,这个key没法使用,你需要自己去注册和风天气的账号,然后在这里填入你的key - - -_PROMPT_TEMPLATE = """ -用户会提出一个关于天气的问题,你的目标是拆分出用户问题中的区,市 并按照我提供的工具回答。 -例如 用户提出的问题是: 上海浦东未来1小时天气情况? -则 提取的市和区是: 上海 浦东 -如果用户提出的问题是: 上海未来1小时天气情况? -则 提取的市和区是: 上海 None -请注意以下内容: -1. 如果你没有找到区的内容,则一定要使用 None 替代,否则程序无法运行 -2. 如果用户没有指定市 则直接返回缺少信息 - -问题: ${{用户的问题}} - -你的回答格式应该按照下面的内容,请注意,格式内的```text 等标记都必须输出,这是我用来提取答案的标记。 -```text - -${{拆分的市和区,中间用空格隔开}} -``` -... weathercheck(市 区)... -```output - -${{提取后的答案}} -``` -答案: ${{答案}} - - - -这是一个例子: -问题: 上海浦东未来1小时天气情况? - - -```text -上海 浦东 -``` -...weathercheck(上海 浦东)... - -```output -预报时间: 1小时后 -具体时间: 今天 18:00 -温度: 24°C -天气: 多云 -风向: 西南风 -风速: 7级 -湿度: 88% -降水概率: 16% - -Answer: 上海浦东一小时后的天气是多云。 - -现在,这是我的问题: - -问题: {question} """ -PROMPT = PromptTemplate( - input_variables=["question"], - template=_PROMPT_TEMPLATE, -) +更简单的单参数输入工具实现,用于查询现在天气的情况 +""" +from pydantic import BaseModel, Field +import requests + +def weather(location: str, api_key: str): + url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={location}&language=zh-Hans&unit=c" + response = requests.get(url) + if response.status_code == 200: + data = response.json() + weather = { + "temperature": data["results"][0]["now"]["temperature"], + "description": data["results"][0]["now"]["text"], + } + return weather + else: + raise Exception( + f"Failed to retrieve weather: {response.status_code}") -def get_city_info(location, adm, key): - base_url = 'https://geoapi.qweather.com/v2/city/lookup?' - params = {'location': location, 'adm': adm, 'key': key} - response = requests.get(base_url, params=params) - data = response.json() - return data - - -def format_weather_data(data, place): - hourly_forecast = data['hourly'] - formatted_data = f"\n 这是查询到的关于{place}未来24小时的天气信息: \n" - for forecast in hourly_forecast: - # 将预报时间转换为datetime对象 - forecast_time = datetime.strptime(forecast['fxTime'], '%Y-%m-%dT%H:%M%z') - # 获取预报时间的时区 - forecast_tz = forecast_time.tzinfo - # 获取当前时间(使用预报时间的时区) - now = datetime.now(forecast_tz) - # 计算预报日期与当前日期的差值 - days_diff = (forecast_time.date() - now.date()).days - if days_diff == 0: - forecast_date_str = '今天' - elif days_diff == 1: - forecast_date_str = '明天' - elif days_diff == 2: - forecast_date_str = '后天' - else: - forecast_date_str = str(days_diff) + '天后' - forecast_time_str = forecast_date_str + ' ' + forecast_time.strftime('%H:%M') - # 计算预报时间与当前时间的差值 - time_diff = forecast_time - now - # 将差值转换为小时 - hours_diff = time_diff.total_seconds() // 3600 - if hours_diff < 1: - hours_diff_str = '1小时后' - elif hours_diff >= 24: - # 如果超过24小时,转换为天数 - days_diff = hours_diff // 24 - hours_diff_str = str(int(days_diff)) + '天' - else: - hours_diff_str = str(int(hours_diff)) + '小时' - # 将预报时间和当前时间的差值添加到输出中 - formatted_data += '预报时间: ' + forecast_time_str + ' 距离现在有: ' + hours_diff_str + '\n' - formatted_data += '温度: ' + forecast['temp'] + '°C\n' - formatted_data += '天气: ' + forecast['text'] + '\n' - formatted_data += '风向: ' + forecast['windDir'] + '\n' - formatted_data += '风速: ' + forecast['windSpeed'] + '级\n' - formatted_data += '湿度: ' + forecast['humidity'] + '%\n' - formatted_data += '降水概率: ' + forecast['pop'] + '%\n' - # formatted_data += '降水量: ' + forecast['precip'] + 'mm\n' - formatted_data += '\n' - return formatted_data - - -def get_weather(key, location_id, place): - url = "https://devapi.qweather.com/v7/weather/24h?" - params = { - 'location': location_id, - 'key': key, - } - response = requests.get(url, params=params) - data = response.json() - return format_weather_data(data, place) - - -def split_query(query): - parts = query.split() - adm = parts[0] - if len(parts) == 1: - return adm, adm - location = parts[1] if parts[1] != 'None' else adm - return location, adm - - -def weather(query): - location, adm = split_query(query) - key = KEY - if key == "": - return "请先在代码中填入和风天气API Key" - try: - city_info = get_city_info(location=location, adm=adm, key=key) - location_id = city_info['location'][0]['id'] - place = adm + "市" + location + "区" - - weather_data = get_weather(key=key, location_id=location_id, place=place) - return weather_data + "以上是查询到的天气信息,请你查收\n" - except KeyError: - try: - city_info = get_city_info(location=adm, adm=adm, key=key) - location_id = city_info['location'][0]['id'] - place = adm + "市" - weather_data = get_weather(key=key, location_id=location_id, place=place) - return weather_data + "重要提醒:用户提供的市和区中,区的信息不存在,或者出现错别字,因此该信息是关于市的天气,请你查收\n" - except KeyError: - return "输入的地区不存在,无法提供天气预报" - - -class LLMWeatherChain(Chain): - llm_chain: LLMChain - llm: Optional[BaseLanguageModel] = None - """[Deprecated] LLM wrapper to use.""" - prompt: BasePromptTemplate = PROMPT - """[Deprecated] Prompt to use to translate to python if necessary.""" - input_key: str = "question" #: :meta private: - output_key: str = "answer" #: :meta private: - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - @root_validator(pre=True) - def raise_deprecation(cls, values: Dict) -> Dict: - if "llm" in values: - warnings.warn( - "Directly instantiating an LLMWeatherChain with an llm is deprecated. " - "Please instantiate with llm_chain argument or using the from_llm " - "class method." - ) - if "llm_chain" not in values and values["llm"] is not None: - prompt = values.get("prompt", PROMPT) - values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt) - return values - - @property - def input_keys(self) -> List[str]: - """Expect input key. - - :meta private: - """ - return [self.input_key] - - @property - def output_keys(self) -> List[str]: - """Expect output key. - - :meta private: - """ - return [self.output_key] - - def _evaluate_expression(self, expression: str) -> str: - try: - output = weather(expression) - except Exception as e: - output = "输入的信息有误,请再次尝试" - return output - - def _process_llm_result( - self, llm_output: str, run_manager: CallbackManagerForChainRun - ) -> Dict[str, str]: - - run_manager.on_text(llm_output, color="green", verbose=self.verbose) - - llm_output = llm_output.strip() - text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL) - if text_match: - expression = text_match.group(1) - output = self._evaluate_expression(expression) - run_manager.on_text("\nAnswer: ", verbose=self.verbose) - run_manager.on_text(output, color="yellow", verbose=self.verbose) - answer = "Answer: " + output - elif llm_output.startswith("Answer:"): - answer = llm_output - elif "Answer:" in llm_output: - answer = "Answer: " + llm_output.split("Answer:")[-1] - else: - return {self.output_key: f"输入的格式不对: {llm_output},应该输入 (市 区)的组合"} - return {self.output_key: answer} - - async def _aprocess_llm_result( - self, - llm_output: str, - run_manager: AsyncCallbackManagerForChainRun, - ) -> Dict[str, str]: - await run_manager.on_text(llm_output, color="green", verbose=self.verbose) - llm_output = llm_output.strip() - text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL) - - if text_match: - expression = text_match.group(1) - output = self._evaluate_expression(expression) - await run_manager.on_text("\nAnswer: ", verbose=self.verbose) - await run_manager.on_text(output, color="yellow", verbose=self.verbose) - answer = "Answer: " + output - elif llm_output.startswith("Answer:"): - answer = llm_output - elif "Answer:" in llm_output: - answer = "Answer: " + llm_output.split("Answer:")[-1] - else: - raise ValueError(f"unknown format from LLM: {llm_output}") - return {self.output_key: answer} - - def _call( - self, - inputs: Dict[str, str], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, str]: - _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() - _run_manager.on_text(inputs[self.input_key]) - llm_output = self.llm_chain.predict( - question=inputs[self.input_key], - stop=["```output"], - callbacks=_run_manager.get_child(), - ) - return self._process_llm_result(llm_output, _run_manager) - - async def _acall( - self, - inputs: Dict[str, str], - run_manager: Optional[AsyncCallbackManagerForChainRun] = None, - ) -> Dict[str, str]: - _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() - await _run_manager.on_text(inputs[self.input_key]) - llm_output = await self.llm_chain.apredict( - question=inputs[self.input_key], - stop=["```output"], - callbacks=_run_manager.get_child(), - ) - return await self._aprocess_llm_result(llm_output, _run_manager) - - @property - def _chain_type(self) -> str: - return "llm_weather_chain" - - @classmethod - def from_llm( - cls, - llm: BaseLanguageModel, - prompt: BasePromptTemplate = PROMPT, - **kwargs: Any, - ) -> LLMWeatherChain: - llm_chain = LLMChain(llm=llm, prompt=prompt) - return cls(llm_chain=llm_chain, **kwargs) - - -def weathercheck(query: str): - model = model_container.MODEL - llm_weather = LLMWeatherChain.from_llm(model, verbose=True, prompt=PROMPT) - ans = llm_weather.run(query) - return ans - - -class WhetherSchema(BaseModel): - location: str = Field(description="应该是一个地区的名称,用空格隔开,例如:上海 浦东,如果没有区的信息,可以只输入上海") - -if __name__ == '__main__': - result = weathercheck("苏州姑苏区今晚热不热?") +def weathercheck(location: str): + return weather(location, "S8vrB4U_-c5mvAMiK") +class WeatherInput(BaseModel): + location: str = Field(description="City name,include city and county,like '厦门'") diff --git a/server/agent/tools/weather_check.yaml b/server/agent/tools/weather_check.yaml deleted file mode 100644 index e72676ac..00000000 --- a/server/agent/tools/weather_check.yaml +++ /dev/null @@ -1,10 +0,0 @@ -name: weather_check -description: Use Weather API to get weather information -parameters: - type: object - properties: - query: - type: string - description: City name,include city and county,like "厦门市思明区" -required: - - query \ No newline at end of file diff --git a/server/agent/tools/wolfram.yaml b/server/agent/tools/wolfram.yaml deleted file mode 100644 index 532f8248..00000000 --- a/server/agent/tools/wolfram.yaml +++ /dev/null @@ -1,10 +0,0 @@ -name: wolfram -description: Useful for when you need to calculate difficult math formulas -parameters: - type: object - properties: - query: - type: string - description: The formula to be calculated -required: - - query \ No newline at end of file diff --git a/server/agent/tools_select.py b/server/agent/tools_select.py index 821d324e..237c20b6 100644 --- a/server/agent/tools_select.py +++ b/server/agent/tools_select.py @@ -1,8 +1,6 @@ from langchain.tools import Tool from server.agent.tools import * -## 请注意,如果你是为了使用AgentLM,在这里,你应该使用英文版本。 - tools = [ Tool.from_function( func=calculate, @@ -20,7 +18,7 @@ tools = [ func=weathercheck, name="weather_check", description="", - args_schema=WhetherSchema, + args_schema=WeatherInput, ), Tool.from_function( func=shell, diff --git a/server/callback_handler/conversation_callback_handler.py b/server/callback_handler/conversation_callback_handler.py index 6642c1e2..8f09b40d 100644 --- a/server/callback_handler/conversation_callback_handler.py +++ b/server/callback_handler/conversation_callback_handler.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Union, Optional +from typing import Any, Dict, List from langchain.callbacks.base import BaseCallbackHandler from langchain.schema import LLMResult diff --git a/server/chat/agent_chat.py b/server/chat/agent_chat.py index 03c0c2b9..f47958a0 100644 --- a/server/chat/agent_chat.py +++ b/server/chat/agent_chat.py @@ -1,23 +1,23 @@ -from langchain.memory import ConversationBufferWindowMemory +import json +import asyncio -from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent -from server.agent.tools_select import tools, tool_names -from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status -from langchain.agents import LLMSingleActionAgent, AgentExecutor -from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate from fastapi import Body from sse_starlette.sse import EventSourceResponse from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL -from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template -from langchain.chains import LLMChain -from typing import AsyncIterable, Optional -import asyncio -from typing import List -from server.chat.utils import History -import json -from server.agent import model_container -from server.knowledge_base.kb_service.base import get_kb_details +from langchain.chains import LLMChain +from langchain.memory import ConversationBufferWindowMemory +from langchain.agents import LLMSingleActionAgent, AgentExecutor +from typing import AsyncIterable, Optional, List + +from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template +from server.knowledge_base.kb_service.base import get_kb_details +from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent +from server.agent.tools_select import tools, tool_names +from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status +from server.chat.utils import History +from server.agent import model_container +from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), history: List[History] = Body([], @@ -33,7 +33,6 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), - # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), ): history = [History.from_data(h) for h in history] @@ -55,12 +54,10 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples callbacks=[callback], ) - ## 传入全局变量来实现agent调用 kb_list = {x["kb_name"]: x for x in get_kb_details()} model_container.DATABASE = {name: details['kb_info'] for name, details in kb_list.items()} if Agent_MODEL: - ## 如果有指定使用Agent模型来完成任务 model_agent = get_ChatOpenAI( model_name=Agent_MODEL, temperature=temperature, @@ -79,15 +76,11 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples ) output_parser = CustomOutputParser() llm_chain = LLMChain(llm=model, prompt=prompt_template_agent) - # 把history转成agent的memory memory = ConversationBufferWindowMemory(k=HISTORY_LEN * 2) for message in history: - # 检查消息的角色 if message.role == 'user': - # 添加用户消息 memory.chat_memory.add_user_message(message.content) else: - # 添加AI消息 memory.chat_memory.add_ai_message(message.content) if "chatglm3" in model_container.MODEL.model_name: @@ -95,7 +88,6 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples llm=model, tools=tools, callback_manager=None, - # Langchain Prompt is not constructed directly here, it is constructed inside the GLM3 agent. prompt=prompt_template, input_variables=["input", "intermediate_steps", "history"], memory=memory, @@ -155,7 +147,6 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples answer = "" final_answer = "" async for chunk in callback.aiter(): - # Use server-sent-events to stream the response data = json.loads(chunk) if data["status"] == Status.start or data["status"] == Status.complete: continue @@ -181,7 +172,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples await task return EventSourceResponse(agent_chat_iterator(query=query, - history=history, - model_name=model_name, - prompt_name=prompt_name), - ) + history=history, + model_name=model_name, + prompt_name=prompt_name), + ) diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index db00dfa4..42bef3c2 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -1,23 +1,23 @@ from langchain.utilities.bing_search import BingSearchAPIWrapper from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY, - LLM_MODELS, SEARCH_ENGINE_TOP_K, TEMPERATURE, - TEXT_SPLITTER_NAME, OVERLAP_SIZE) -from fastapi import Body -from sse_starlette import EventSourceResponse -from fastapi.concurrency import run_in_threadpool -from server.utils import wrap_done, get_ChatOpenAI -from server.utils import BaseResponse, get_prompt_template + LLM_MODELS, SEARCH_ENGINE_TOP_K, TEMPERATURE, OVERLAP_SIZE) from langchain.chains import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler -from typing import AsyncIterable -import asyncio + from langchain.prompts.chat import ChatPromptTemplate from langchain.text_splitter import RecursiveCharacterTextSplitter -from typing import List, Optional, Dict -from server.chat.utils import History from langchain.docstore.document import Document +from fastapi import Body +from fastapi.concurrency import run_in_threadpool +from sse_starlette import EventSourceResponse +from server.utils import wrap_done, get_ChatOpenAI +from server.utils import BaseResponse, get_prompt_template +from server.chat.utils import History +from typing import AsyncIterable +import asyncio import json +from typing import List, Optional, Dict from strsimpy.normalized_levenshtein import NormalizedLevenshtein from markdownify import markdownify @@ -38,11 +38,11 @@ def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs): def metaphor_search( - text: str, - result_len: int = SEARCH_ENGINE_TOP_K, - split_result: bool = False, - chunk_size: int = 500, - chunk_overlap: int = OVERLAP_SIZE, + text: str, + result_len: int = SEARCH_ENGINE_TOP_K, + split_result: bool = False, + chunk_size: int = 500, + chunk_overlap: int = OVERLAP_SIZE, ) -> List[Dict]: from metaphor_python import Metaphor @@ -58,13 +58,13 @@ def metaphor_search( # metaphor 返回的内容都是长文本,需要分词再检索 if split_result: docs = [Document(page_content=x.extract, - metadata={"link": x.url, "title": x.title}) + metadata={"link": x.url, "title": x.title}) for x in contents] text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "], chunk_size=chunk_size, chunk_overlap=chunk_overlap) splitted_docs = text_splitter.split_documents(docs) - + # 将切分好的文档放入临时向量库,重新筛选出TOP_K个文档 if len(splitted_docs) > result_len: normal = NormalizedLevenshtein() @@ -74,13 +74,13 @@ def metaphor_search( splitted_docs = splitted_docs[:result_len] docs = [{"snippet": x.page_content, - "link": x.metadata["link"], - "title": x.metadata["title"]} + "link": x.metadata["link"], + "title": x.metadata["title"]} for x in splitted_docs] else: docs = [{"snippet": x.extract, - "link": x.url, - "title": x.title} + "link": x.url, + "title": x.title} for x in contents] return docs @@ -113,25 +113,27 @@ async def lookup_search_engine( docs = search_result2docs(results) return docs - async def search_engine_chat(query: str = Body(..., description="用户输入", examples=["你好"]), - search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]), - top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"), - history: List[History] = Body([], - description="历史对话", - examples=[[ - {"role": "user", + search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]), + top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"), + history: List[History] = Body([], + description="历史对话", + examples=[[ + {"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}, - {"role": "assistant", + {"role": "assistant", "content": "虎头虎脑"}]] - ), - stream: bool = Body(False, description="流式输出"), - model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), - temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), - max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), - prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), - split_result: bool = Body(False, description="是否对搜索结果进行拆分(主要用于metaphor搜索引擎)") - ): + ), + stream: bool = Body(False, description="流式输出"), + model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), + temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), + max_tokens: Optional[int] = Body(None, + description="限制LLM生成Token数量,默认None代表模型最大值"), + prompt_name: str = Body("default", + description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), + split_result: bool = Body(False, + description="是否对搜索结果进行拆分(主要用于metaphor搜索引擎)") + ): if search_engine_name not in SEARCH_ENGINES.keys(): return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}") @@ -198,9 +200,9 @@ async def search_engine_chat(query: str = Body(..., description="用户输入", await task return EventSourceResponse(search_engine_chat_iterator(query=query, - search_engine_name=search_engine_name, - top_k=top_k, - history=history, - model_name=model_name, - prompt_name=prompt_name), - ) + search_engine_name=search_engine_name, + top_k=top_k, + history=history, + model_name=model_name, + prompt_name=prompt_name), + ) diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 4426d7d3..5a73e628 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -1,7 +1,6 @@ # 该文件封装了对api.py的请求,可以被不同的webui使用 # 通过ApiRequest和AsyncApiRequest支持同步/异步调用 - from typing import * from pathlib import Path # 此处导入的配置为发起请求(如WEBUI)机器上的配置,主要用于为前端设置默认值。分布式部署时可以与服务器上的不同 @@ -27,7 +26,7 @@ from io import BytesIO from server.utils import set_httpx_config, api_address, get_httpx_client from pprint import pprint - +from langchain_core._api import deprecated set_httpx_config() @@ -36,10 +35,11 @@ class ApiRequest: ''' api.py调用的封装(同步模式),简化api调用方式 ''' + def __init__( - self, - base_url: str = api_address(), - timeout: float = HTTPX_DEFAULT_TIMEOUT, + self, + base_url: str = api_address(), + timeout: float = HTTPX_DEFAULT_TIMEOUT, ): self.base_url = base_url self.timeout = timeout @@ -55,12 +55,12 @@ class ApiRequest: return self._client def get( - self, - url: str, - params: Union[Dict, List[Tuple], bytes] = None, - retry: int = 3, - stream: bool = False, - **kwargs: Any, + self, + url: str, + params: Union[Dict, List[Tuple], bytes] = None, + retry: int = 3, + stream: bool = False, + **kwargs: Any, ) -> Union[httpx.Response, Iterator[httpx.Response], None]: while retry > 0: try: @@ -75,13 +75,13 @@ class ApiRequest: retry -= 1 def post( - self, - url: str, - data: Dict = None, - json: Dict = None, - retry: int = 3, - stream: bool = False, - **kwargs: Any + self, + url: str, + data: Dict = None, + json: Dict = None, + retry: int = 3, + stream: bool = False, + **kwargs: Any ) -> Union[httpx.Response, Iterator[httpx.Response], None]: while retry > 0: try: @@ -97,13 +97,13 @@ class ApiRequest: retry -= 1 def delete( - self, - url: str, - data: Dict = None, - json: Dict = None, - retry: int = 3, - stream: bool = False, - **kwargs: Any + self, + url: str, + data: Dict = None, + json: Dict = None, + retry: int = 3, + stream: bool = False, + **kwargs: Any ) -> Union[httpx.Response, Iterator[httpx.Response], None]: while retry > 0: try: @@ -118,24 +118,25 @@ class ApiRequest: retry -= 1 def _httpx_stream2generator( - self, - response: contextlib._GeneratorContextManager, - as_json: bool = False, + self, + response: contextlib._GeneratorContextManager, + as_json: bool = False, ): ''' 将httpx.stream返回的GeneratorContextManager转化为普通生成器 ''' + async def ret_async(response, as_json): try: async with response as r: async for chunk in r.aiter_text(None): - if not chunk: # fastchat api yield empty bytes on start and end + if not chunk: # fastchat api yield empty bytes on start and end continue if as_json: try: if chunk.startswith("data: "): data = json.loads(chunk[6:-2]) - elif chunk.startswith(":"): # skip sse comment line + elif chunk.startswith(":"): # skip sse comment line continue else: data = json.loads(chunk) @@ -143,7 +144,7 @@ class ApiRequest: except Exception as e: msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。" logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) + exc_info=e if log_verbose else None) else: # print(chunk, end="", flush=True) yield chunk @@ -158,20 +159,20 @@ class ApiRequest: except Exception as e: msg = f"API通信遇到错误:{e}" logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) + exc_info=e if log_verbose else None) yield {"code": 500, "msg": msg} def ret_sync(response, as_json): try: with response as r: for chunk in r.iter_text(None): - if not chunk: # fastchat api yield empty bytes on start and end + if not chunk: # fastchat api yield empty bytes on start and end continue if as_json: try: if chunk.startswith("data: "): data = json.loads(chunk[6:-2]) - elif chunk.startswith(":"): # skip sse comment line + elif chunk.startswith(":"): # skip sse comment line continue else: data = json.loads(chunk) @@ -179,7 +180,7 @@ class ApiRequest: except Exception as e: msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。" logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) + exc_info=e if log_verbose else None) else: # print(chunk, end="", flush=True) yield chunk @@ -194,7 +195,7 @@ class ApiRequest: except Exception as e: msg = f"API通信遇到错误:{e}" logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) + exc_info=e if log_verbose else None) yield {"code": 500, "msg": msg} if self._use_async: @@ -203,16 +204,17 @@ class ApiRequest: return ret_sync(response, as_json) def _get_response_value( - self, - response: httpx.Response, - as_json: bool = False, - value_func: Callable = None, + self, + response: httpx.Response, + as_json: bool = False, + value_func: Callable = None, ): ''' 转换同步或异步请求返回的响应 `as_json`: 返回json `value_func`: 用户可以自定义返回值,该函数接受response或json ''' + def to_json(r): try: return r.json() @@ -220,7 +222,7 @@ class ApiRequest: msg = "API未能返回正确的JSON。" + str(e) if log_verbose: logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) + exc_info=e if log_verbose else None) return {"code": 500, "msg": msg, "data": None} if value_func is None: @@ -250,10 +252,10 @@ class ApiRequest: return self._get_response_value(response, as_json=True, value_func=lambda r: r["data"]) def get_prompt_template( - self, - type: str = "llm_chat", - name: str = "default", - **kwargs, + self, + type: str = "llm_chat", + name: str = "default", + **kwargs, ) -> str: data = { "type": type, @@ -297,15 +299,19 @@ class ApiRequest: response = self.post("/chat/chat", json=data, stream=True, **kwargs) return self._httpx_stream2generator(response, as_json=True) + @deprecated( + since="0.3.0", + message="自定义Agent问答将于 Langchain-Chatchat 0.3.x重写, 0.2.x中相关功能将废弃", + removal="0.3.0") def agent_chat( - self, - query: str, - history: List[Dict] = [], - stream: bool = True, - model: str = LLM_MODELS[0], - temperature: float = TEMPERATURE, - max_tokens: int = None, - prompt_name: str = "default", + self, + query: str, + history: List[Dict] = [], + stream: bool = True, + model: str = LLM_MODELS[0], + temperature: float = TEMPERATURE, + max_tokens: int = None, + prompt_name: str = "default", ): ''' 对应api.py/chat/agent_chat 接口 @@ -327,17 +333,17 @@ class ApiRequest: return self._httpx_stream2generator(response, as_json=True) def knowledge_base_chat( - self, - query: str, - knowledge_base_name: str, - top_k: int = VECTOR_SEARCH_TOP_K, - score_threshold: float = SCORE_THRESHOLD, - history: List[Dict] = [], - stream: bool = True, - model: str = LLM_MODELS[0], - temperature: float = TEMPERATURE, - max_tokens: int = None, - prompt_name: str = "default", + self, + query: str, + knowledge_base_name: str, + top_k: int = VECTOR_SEARCH_TOP_K, + score_threshold: float = SCORE_THRESHOLD, + history: List[Dict] = [], + stream: bool = True, + model: str = LLM_MODELS[0], + temperature: float = TEMPERATURE, + max_tokens: int = None, + prompt_name: str = "default", ): ''' 对应api.py/chat/knowledge_base_chat接口 @@ -366,28 +372,29 @@ class ApiRequest: return self._httpx_stream2generator(response, as_json=True) def upload_temp_docs( - self, - files: List[Union[str, Path, bytes]], - knowledge_id: str = None, - chunk_size=CHUNK_SIZE, - chunk_overlap=OVERLAP_SIZE, - zh_title_enhance=ZH_TITLE_ENHANCE, + self, + files: List[Union[str, Path, bytes]], + knowledge_id: str = None, + chunk_size=CHUNK_SIZE, + chunk_overlap=OVERLAP_SIZE, + zh_title_enhance=ZH_TITLE_ENHANCE, ): ''' 对应api.py/knowledge_base/upload_tmep_docs接口 ''' + def convert_file(file, filename=None): - if isinstance(file, bytes): # raw bytes + if isinstance(file, bytes): # raw bytes file = BytesIO(file) - elif hasattr(file, "read"): # a file io like object + elif hasattr(file, "read"): # a file io like object filename = filename or file.name - else: # a local path + else: # a local path file = Path(file).absolute().open("rb") filename = filename or os.path.split(file.name)[-1] return filename, file files = [convert_file(file) for file in files] - data={ + data = { "knowledge_id": knowledge_id, "chunk_size": chunk_size, "chunk_overlap": chunk_overlap, @@ -402,17 +409,17 @@ class ApiRequest: return self._get_response_value(response, as_json=True) def file_chat( - self, - query: str, - knowledge_id: str, - top_k: int = VECTOR_SEARCH_TOP_K, - score_threshold: float = SCORE_THRESHOLD, - history: List[Dict] = [], - stream: bool = True, - model: str = LLM_MODELS[0], - temperature: float = TEMPERATURE, - max_tokens: int = None, - prompt_name: str = "default", + self, + query: str, + knowledge_id: str, + top_k: int = VECTOR_SEARCH_TOP_K, + score_threshold: float = SCORE_THRESHOLD, + history: List[Dict] = [], + stream: bool = True, + model: str = LLM_MODELS[0], + temperature: float = TEMPERATURE, + max_tokens: int = None, + prompt_name: str = "default", ): ''' 对应api.py/chat/file_chat接口 @@ -440,18 +447,23 @@ class ApiRequest: ) return self._httpx_stream2generator(response, as_json=True) + @deprecated( + since="0.3.0", + message="搜索引擎问答将于 Langchain-Chatchat 0.3.x重写, 0.2.x中相关功能将废弃", + removal="0.3.0" + ) def search_engine_chat( - self, - query: str, - search_engine_name: str, - top_k: int = SEARCH_ENGINE_TOP_K, - history: List[Dict] = [], - stream: bool = True, - model: str = LLM_MODELS[0], - temperature: float = TEMPERATURE, - max_tokens: int = None, - prompt_name: str = "default", - split_result: bool = False, + self, + query: str, + search_engine_name: str, + top_k: int = SEARCH_ENGINE_TOP_K, + history: List[Dict] = [], + stream: bool = True, + model: str = LLM_MODELS[0], + temperature: float = TEMPERATURE, + max_tokens: int = None, + prompt_name: str = "default", + split_result: bool = False, ): ''' 对应api.py/chat/search_engine_chat接口 @@ -482,7 +494,7 @@ class ApiRequest: # 知识库相关操作 def list_knowledge_bases( - self, + self, ): ''' 对应api.py/knowledge_base/list_knowledge_bases接口 @@ -493,10 +505,10 @@ class ApiRequest: value_func=lambda r: r.get("data", [])) def create_knowledge_base( - self, - knowledge_base_name: str, - vector_store_type: str = DEFAULT_VS_TYPE, - embed_model: str = EMBEDDING_MODEL, + self, + knowledge_base_name: str, + vector_store_type: str = DEFAULT_VS_TYPE, + embed_model: str = EMBEDDING_MODEL, ): ''' 对应api.py/knowledge_base/create_knowledge_base接口 @@ -514,8 +526,8 @@ class ApiRequest: return self._get_response_value(response, as_json=True) def delete_knowledge_base( - self, - knowledge_base_name: str, + self, + knowledge_base_name: str, ): ''' 对应api.py/knowledge_base/delete_knowledge_base接口 @@ -527,8 +539,8 @@ class ApiRequest: return self._get_response_value(response, as_json=True) def list_kb_docs( - self, - knowledge_base_name: str, + self, + knowledge_base_name: str, ): ''' 对应api.py/knowledge_base/list_files接口 @@ -542,13 +554,13 @@ class ApiRequest: value_func=lambda r: r.get("data", [])) def search_kb_docs( - self, - knowledge_base_name: str, - query: str = "", - top_k: int = VECTOR_SEARCH_TOP_K, - score_threshold: int = SCORE_THRESHOLD, - file_name: str = "", - metadata: dict = {}, + self, + knowledge_base_name: str, + query: str = "", + top_k: int = VECTOR_SEARCH_TOP_K, + score_threshold: int = SCORE_THRESHOLD, + file_name: str = "", + metadata: dict = {}, ) -> List: ''' 对应api.py/knowledge_base/search_docs接口 @@ -569,9 +581,9 @@ class ApiRequest: return self._get_response_value(response, as_json=True) def update_docs_by_id( - self, - knowledge_base_name: str, - docs: Dict[str, Dict], + self, + knowledge_base_name: str, + docs: Dict[str, Dict], ) -> bool: ''' 对应api.py/knowledge_base/update_docs_by_id接口 @@ -587,32 +599,33 @@ class ApiRequest: return self._get_response_value(response) def upload_kb_docs( - self, - files: List[Union[str, Path, bytes]], - knowledge_base_name: str, - override: bool = False, - to_vector_store: bool = True, - chunk_size=CHUNK_SIZE, - chunk_overlap=OVERLAP_SIZE, - zh_title_enhance=ZH_TITLE_ENHANCE, - docs: Dict = {}, - not_refresh_vs_cache: bool = False, + self, + files: List[Union[str, Path, bytes]], + knowledge_base_name: str, + override: bool = False, + to_vector_store: bool = True, + chunk_size=CHUNK_SIZE, + chunk_overlap=OVERLAP_SIZE, + zh_title_enhance=ZH_TITLE_ENHANCE, + docs: Dict = {}, + not_refresh_vs_cache: bool = False, ): ''' 对应api.py/knowledge_base/upload_docs接口 ''' + def convert_file(file, filename=None): - if isinstance(file, bytes): # raw bytes + if isinstance(file, bytes): # raw bytes file = BytesIO(file) - elif hasattr(file, "read"): # a file io like object + elif hasattr(file, "read"): # a file io like object filename = filename or file.name - else: # a local path + else: # a local path file = Path(file).absolute().open("rb") filename = filename or os.path.split(file.name)[-1] return filename, file files = [convert_file(file) for file in files] - data={ + data = { "knowledge_base_name": knowledge_base_name, "override": override, "to_vector_store": to_vector_store, @@ -633,11 +646,11 @@ class ApiRequest: return self._get_response_value(response, as_json=True) def delete_kb_docs( - self, - knowledge_base_name: str, - file_names: List[str], - delete_content: bool = False, - not_refresh_vs_cache: bool = False, + self, + knowledge_base_name: str, + file_names: List[str], + delete_content: bool = False, + not_refresh_vs_cache: bool = False, ): ''' 对应api.py/knowledge_base/delete_docs接口 @@ -655,8 +668,7 @@ class ApiRequest: ) return self._get_response_value(response, as_json=True) - - def update_kb_info(self,knowledge_base_name,kb_info): + def update_kb_info(self, knowledge_base_name, kb_info): ''' 对应api.py/knowledge_base/update_info接口 ''' @@ -672,15 +684,15 @@ class ApiRequest: return self._get_response_value(response, as_json=True) def update_kb_docs( - self, - knowledge_base_name: str, - file_names: List[str], - override_custom_docs: bool = False, - chunk_size=CHUNK_SIZE, - chunk_overlap=OVERLAP_SIZE, - zh_title_enhance=ZH_TITLE_ENHANCE, - docs: Dict = {}, - not_refresh_vs_cache: bool = False, + self, + knowledge_base_name: str, + file_names: List[str], + override_custom_docs: bool = False, + chunk_size=CHUNK_SIZE, + chunk_overlap=OVERLAP_SIZE, + zh_title_enhance=ZH_TITLE_ENHANCE, + docs: Dict = {}, + not_refresh_vs_cache: bool = False, ): ''' 对应api.py/knowledge_base/update_docs接口 @@ -706,14 +718,14 @@ class ApiRequest: return self._get_response_value(response, as_json=True) def recreate_vector_store( - self, - knowledge_base_name: str, - allow_empty_kb: bool = True, - vs_type: str = DEFAULT_VS_TYPE, - embed_model: str = EMBEDDING_MODEL, - chunk_size=CHUNK_SIZE, - chunk_overlap=OVERLAP_SIZE, - zh_title_enhance=ZH_TITLE_ENHANCE, + self, + knowledge_base_name: str, + allow_empty_kb: bool = True, + vs_type: str = DEFAULT_VS_TYPE, + embed_model: str = EMBEDDING_MODEL, + chunk_size=CHUNK_SIZE, + chunk_overlap=OVERLAP_SIZE, + zh_title_enhance=ZH_TITLE_ENHANCE, ): ''' 对应api.py/knowledge_base/recreate_vector_store接口 @@ -738,8 +750,8 @@ class ApiRequest: # LLM模型相关操作 def list_running_models( - self, - controller_address: str = None, + self, + controller_address: str = None, ): ''' 获取Fastchat中正运行的模型列表 @@ -755,8 +767,7 @@ class ApiRequest: "/llm_model/list_running_models", json=data, ) - return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", [])) - + return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", [])) def get_default_llm_model(self, local_first: bool = True) -> Tuple[str, bool]: ''' @@ -764,6 +775,7 @@ class ApiRequest: 当 local_first=True 时,优先返回运行中的本地模型,否则优先按LLM_MODELS配置顺序返回。 返回类型为(model_name, is_local_model) ''' + def ret_sync(): running_models = self.list_running_models() if not running_models: @@ -780,7 +792,7 @@ class ApiRequest: model = m break - if not model: # LLM_MODELS中配置的模型都不在running_models里 + if not model: # LLM_MODELS中配置的模型都不在running_models里 model = list(running_models)[0] is_local = not running_models[model].get("online_api") return model, is_local @@ -801,7 +813,7 @@ class ApiRequest: model = m break - if not model: # LLM_MODELS中配置的模型都不在running_models里 + if not model: # LLM_MODELS中配置的模型都不在running_models里 model = list(running_models)[0] is_local = not running_models[model].get("online_api") return model, is_local @@ -812,8 +824,8 @@ class ApiRequest: return ret_sync() def list_config_models( - self, - types: List[str] = ["local", "online"], + self, + types: List[str] = ["local", "online"], ) -> Dict[str, Dict]: ''' 获取服务器configs中配置的模型列表,返回形式为{"type": {model_name: config}, ...}。 @@ -825,23 +837,23 @@ class ApiRequest: "/llm_model/list_config_models", json=data, ) - return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {})) + return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {})) def get_model_config( - self, - model_name: str = None, + self, + model_name: str = None, ) -> Dict: ''' 获取服务器上模型配置 ''' - data={ + data = { "model_name": model_name, } response = self.post( "/llm_model/get_model_config", json=data, ) - return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {})) + return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {})) def list_search_engines(self) -> List[str]: ''' @@ -850,12 +862,12 @@ class ApiRequest: response = self.post( "/server/list_search_engines", ) - return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {})) + return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {})) def stop_llm_model( - self, - model_name: str, - controller_address: str = None, + self, + model_name: str, + controller_address: str = None, ): ''' 停止某个LLM模型。 @@ -873,10 +885,10 @@ class ApiRequest: return self._get_response_value(response, as_json=True) def change_llm_model( - self, - model_name: str, - new_model_name: str, - controller_address: str = None, + self, + model_name: str, + new_model_name: str, + controller_address: str = None, ): ''' 向fastchat controller请求切换LLM模型。 @@ -959,10 +971,10 @@ class ApiRequest: return ret_sync() def embed_texts( - self, - texts: List[str], - embed_model: str = EMBEDDING_MODEL, - to_query: bool = False, + self, + texts: List[str], + embed_model: str = EMBEDDING_MODEL, + to_query: bool = False, ) -> List[List[float]]: ''' 对文本进行向量化,可选模型包括本地 embed_models 和支持 embeddings 的在线模型 @@ -979,10 +991,10 @@ class ApiRequest: return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data")) def chat_feedback( - self, - message_id: str, - score: int, - reason: str = "", + self, + message_id: str, + score: int, + reason: str = "", ) -> int: ''' 反馈对话评价 @@ -1019,9 +1031,9 @@ def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str: return error message if error occured when requests API ''' if (isinstance(data, dict) - and key in data - and "code" in data - and data["code"] == 200): + and key in data + and "code" in data + and data["code"] == 200): return data[key] return ""