更新0.2.x Agent,之后的Agent在0.3.x更新

This commit is contained in:
zR 2024-01-12 12:01:22 +08:00
parent 75ff268e88
commit 269090ea66
22 changed files with 359 additions and 816 deletions

View File

@ -7,31 +7,35 @@
保存的模型的位置位于原本嵌入模型的目录下模型的名称为原模型名称+Merge_Keywords_时间戳
'''
import sys
sys.path.append("..")
import os
import torch
from datetime import datetime
from configs import (
MODEL_PATH,
EMBEDDING_MODEL,
EMBEDDING_KEYWORD_FILE,
)
import os
import torch
from safetensors.torch import save_model
from sentence_transformers import SentenceTransformer
from langchain_core._api import deprecated
@deprecated(
since="0.3.0",
message="自定义关键词 Langchain-Chatchat 0.3.x 重写, 0.2.x中相关功能将废弃",
removal="0.3.0"
)
def get_keyword_embedding(bert_model, tokenizer, key_words):
tokenizer_output = tokenizer(key_words, return_tensors="pt", padding=True, truncation=True)
# No need to manually convert to tensor as we've set return_tensors="pt"
input_ids = tokenizer_output['input_ids']
# Remove the first and last token for each sequence in the batch
input_ids = input_ids[:, 1:-1]
keyword_embedding = bert_model.embeddings.word_embeddings(input_ids)
keyword_embedding = torch.mean(keyword_embedding, 1)
return keyword_embedding
@ -47,14 +51,11 @@ def add_keyword_to_model(model_name=EMBEDDING_MODEL, keyword_file: str = "", out
bert_model = word_embedding_model.auto_model
tokenizer = word_embedding_model.tokenizer
key_words_embedding = get_keyword_embedding(bert_model, tokenizer, key_words)
# key_words_embedding = st_model.encode(key_words)
embedding_weight = bert_model.embeddings.word_embeddings.weight
embedding_weight_len = len(embedding_weight)
tokenizer.add_tokens(key_words)
bert_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32)
# key_words_embedding_tensor = torch.from_numpy(key_words_embedding)
embedding_weight = bert_model.embeddings.word_embeddings.weight
with torch.no_grad():
embedding_weight[embedding_weight_len:embedding_weight_len + key_words_len, :] = key_words_embedding
@ -76,46 +77,3 @@ def add_keyword_to_embedding_model(path: str = EMBEDDING_KEYWORD_FILE):
output_model_name = "{}_Merge_Keywords_{}".format(EMBEDDING_MODEL, current_time)
output_model_path = os.path.join(model_parent_directory, output_model_name)
add_keyword_to_model(model_name, keyword_file, output_model_path)
if __name__ == '__main__':
add_keyword_to_embedding_model(EMBEDDING_KEYWORD_FILE)
# input_model_name = ""
# output_model_path = ""
# # 以下为加入关键字前后tokenizer的测试用例对比
# def print_token_ids(output, tokenizer, sentences):
# for idx, ids in enumerate(output['input_ids']):
# print(f'sentence={sentences[idx]}')
# print(f'ids={ids}')
# for id in ids:
# decoded_id = tokenizer.decode(id)
# print(f' {decoded_id}->{id}')
#
# sentences = [
# '数据科学与大数据技术',
# 'Langchain-Chatchat'
# ]
#
# st_no_keywords = SentenceTransformer(input_model_name)
# tokenizer_without_keywords = st_no_keywords.tokenizer
# print("===== tokenizer with no keywords added =====")
# output = tokenizer_without_keywords(sentences)
# print_token_ids(output, tokenizer_without_keywords, sentences)
# print(f'-------- embedding with no keywords added -----')
# embeddings = st_no_keywords.encode(sentences)
# print(embeddings)
#
# print("--------------------------------------------")
# print("--------------------------------------------")
# print("--------------------------------------------")
#
# st_with_keywords = SentenceTransformer(output_model_path)
# tokenizer_with_keywords = st_with_keywords.tokenizer
# print("===== tokenizer with keyword added =====")
# output = tokenizer_with_keywords(sentences)
# print_token_ids(output, tokenizer_with_keywords, sentences)
#
# print(f'-------- embedding with keywords added -----')
# embeddings = st_with_keywords.encode(sentences)
# print(embeddings)

View File

@ -38,13 +38,11 @@ numpy~=1.24.4
pandas~=2.0.3
einops>=0.7.0
transformers_stream_generator==0.0.4
vllm==0.2.6; sys_platform == "linux"
httpx[brotli,http2,socks]==0.25.2
llama-index
vllm==0.2.7; sys_platform == "linux"
# optional document loaders
# rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files
#rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files
jq==1.6.0 # for .json and .jsonl files. suggest `conda install jq` on windows
beautifulsoup4~=4.12.2 # for .mhtml files
pysrt~=1.1.2
@ -69,9 +67,11 @@ metaphor-python~=0.1.23
# WebUI requirements
streamlit~=1.29.0
streamlit-option-menu>=0.3.6
streamlit==1.30.0
streamlit-option-menu==0.3.6
streamlit-antd-components==0.3.1
streamlit-chatbox==1.1.11
streamlit-modal>=0.1.0
streamlit-aggrid>=0.3.4.post3
watchdog>=3.0.0
streamlit-modal==0.1.0
streamlit-aggrid==0.3.4.post3
httpx==0.26.0
watchdog==3.0.0

View File

@ -36,8 +36,8 @@ numpy~=1.24.4
pandas~=2.0.3
einops>=0.7.0
transformers_stream_generator==0.0.4
vllm==0.2.6; sys_platform == "linux"
httpx[brotli,http2,socks]==0.25.2
vllm==0.2.7; sys_platform == "linux"
httpx==0.26.0
llama-index
# optional document loaders

View File

@ -27,7 +27,7 @@ numpy~=1.24.4
pandas~=2.0.3
einops>=0.7.0
transformers_stream_generator==0.0.4
vllm==0.2.6; sys_platform == "linux"
vllm==0.2.7; sys_platform == "linux"
httpx[brotli,http2,socks]==0.25.2
requests
pathlib
@ -54,11 +54,11 @@ metaphor-python~=0.1.23
# WebUI requirements
streamlit>=1.29.0
streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.3.0
streamlit-chatbox>=1.1.11
streamlit-modal>=0.1.0
streamlit-aggrid>=0.3.4.post3
httpx[brotli,http2,socks]>=0.25.2
watchdog>=3.0.0
streamlit==1.30.0
streamlit-option-menu==0.3.6
streamlit-antd-components==0.3.1
streamlit-chatbox==1.1.11
streamlit-modal==0.1.0
streamlit-aggrid==0.3.4.post3
httpx==0.26.0
watchdog==3.0.0

View File

@ -1,10 +1,10 @@
# WebUI requirements
streamlit>=1.29.0
streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.3.0
streamlit-chatbox>=1.1.11
streamlit-modal>=0.1.0
streamlit-aggrid>=0.3.4.post3
httpx[brotli,http2,socks]>=0.25.2
watchdog>=3.0.0
streamlit==1.30.0
streamlit-option-menu==0.3.6
streamlit-antd-components==0.3.1
streamlit-chatbox==1.1.11
streamlit-modal==0.1.0
streamlit-aggrid==0.3.4.post3
httpx==0.26.0
watchdog==3.0.0

View File

@ -1,22 +1,19 @@
"""
This file is a modified version for ChatGLM3-6B the original ChatGLM3Agent.py file from the langchain repo.
This file is a modified version for ChatGLM3-6B the original glm3_agent.py file from the langchain repo.
"""
from __future__ import annotations
import yaml
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
from langchain.memory import ConversationBufferWindowMemory
from typing import Any, List, Sequence, Tuple, Optional, Union
import os
from langchain.agents.agent import Agent
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate, MessagesPlaceholder,
)
import json
import logging
from typing import Any, List, Sequence, Tuple, Optional, Union
from pydantic.schema import model_schema
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
from langchain.memory import ConversationBufferWindowMemory
from langchain.agents.agent import Agent
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate
from langchain.agents.agent import AgentOutputParser
from langchain.output_parsers import OutputFixingParser
from langchain.pydantic_v1 import Field
@ -43,12 +40,18 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser):
first_index = min([text.find(token) if token in text else len(text) for token in special_tokens])
text = text[:first_index]
if "tool_call" in text:
tool_name_end = text.find("```")
tool_name = text[:tool_name_end].strip()
input_para = text.split("='")[-1].split("'")[0]
action_end = text.find("```")
action = text[:action_end].strip()
params_str_start = text.find("(") + 1
params_str_end = text.rfind(")")
params_str = text[params_str_start:params_str_end]
params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param]
params = {pair[0].strip(): pair[1].strip().strip("'\"") for pair in params_pairs}
action_json = {
"action": tool_name,
"action_input": input_para
"action": action,
"action_input": params
}
else:
action_json = {
@ -109,10 +112,6 @@ class StructuredGLM3ChatAgent(Agent):
else:
return agent_scratchpad
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
pass
@classmethod
def _get_default_output_parser(
cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any
@ -121,7 +120,7 @@ class StructuredGLM3ChatAgent(Agent):
@property
def _stop(self) -> List[str]:
return ["```<observation>"]
return ["<|observation|>"]
@classmethod
def create_prompt(
@ -131,44 +130,25 @@ class StructuredGLM3ChatAgent(Agent):
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
) -> BasePromptTemplate:
def tool_config_from_file(tool_name, directory="server/agent/tools/"):
"""search tool yaml and return simplified json format"""
file_path = os.path.join(directory, f"{tool_name.lower()}.yaml")
try:
with open(file_path, 'r', encoding='utf-8') as file:
tool_config = yaml.safe_load(file)
# Simplify the structure if needed
simplified_config = {
"name": tool_config.get("name", ""),
"description": tool_config.get("description", ""),
"parameters": tool_config.get("parameters", {})
}
return simplified_config
except FileNotFoundError:
logger.error(f"File not found: {file_path}")
return None
except Exception as e:
logger.error(f"An error occurred while reading {file_path}: {e}")
return None
tools_json = []
tool_names = []
for tool in tools:
tool_config = tool_config_from_file(tool.name)
if tool_config:
tools_json.append(tool_config)
tool_names.append(tool.name)
# Format the tools for output
tool_schema = model_schema(tool.args_schema) if tool.args_schema else {}
simplified_config_langchain = {
"name": tool.name,
"description": tool.description,
"parameters": tool_schema.get("properties", {})
}
tools_json.append(simplified_config_langchain)
tool_names.append(tool.name)
formatted_tools = "\n".join([
f"{tool['name']}: {tool['description']}, args: {tool['parameters']}"
for tool in tools_json
])
formatted_tools = formatted_tools.replace("'", "\\'").replace("{", "{{").replace("}", "}}")
template = prompt.format(tool_names=tool_names,
tools=formatted_tools,
history="{history}",
history="None",
input="{input}",
agent_scratchpad="{agent_scratchpad}")
@ -225,7 +205,6 @@ def initialize_glm3_agent(
tools: Sequence[BaseTool],
llm: BaseLanguageModel,
prompt: str = None,
callback_manager: Optional[BaseCallbackManager] = None,
memory: Optional[ConversationBufferWindowMemory] = None,
agent_kwargs: Optional[dict] = None,
*,
@ -238,14 +217,12 @@ def initialize_glm3_agent(
llm=llm,
tools=tools,
prompt=prompt,
callback_manager=callback_manager, **agent_kwargs
**agent_kwargs
)
return AgentExecutor.from_agent_and_tools(
agent=agent_obj,
tools=tools,
callback_manager=callback_manager,
memory=memory,
tags=tags_,
**kwargs,
)
)

View File

@ -1,5 +1,3 @@
## 由于工具类无法传参,所以使用全局变量来传递模型和对应的知识库介绍
class ModelContainer:
def __init__(self):
self.MODEL = None

View File

@ -3,7 +3,7 @@ from .search_knowledgebase_simple import search_knowledgebase_simple
from .search_knowledgebase_once import search_knowledgebase_once, KnowledgeSearchInput
from .search_knowledgebase_complex import search_knowledgebase_complex, KnowledgeSearchInput
from .calculate import calculate, CalculatorInput
from .weather_check import weathercheck, WhetherSchema
from .weather_check import weathercheck, WeatherInput
from .shell import shell, ShellInput
from .search_internet import search_internet, SearchInternetInput
from .wolfram import wolfram, WolframInput

View File

@ -1,10 +0,0 @@
name: arxiv
description: A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.
parameters:
type: object
properties:
query:
type: string
description: The search query title
required:
- query

View File

@ -1,10 +0,0 @@
name: calculate
description: Useful for when you need to answer questions about simple calculations
parameters:
type: object
properties:
query:
type: string
description: The formula to be calculated
required:
- query

View File

@ -1,10 +0,0 @@
name: search_internet
description: Use this tool to surf internet and get information
parameters:
type: object
properties:
query:
type: string
description: Query for Internet search
required:
- query

View File

@ -1,10 +0,0 @@
name: search_knowledgebase_complex
description: Use this tool to search local knowledgebase and get information
parameters:
type: object
properties:
query:
type: string
description: The query to be searched
required:
- query

View File

@ -1,10 +0,0 @@
name: search_youtube
description: Use this tools to search youtube videos
parameters:
type: object
properties:
query:
type: string
description: Query for Videos search
required:
- query

View File

@ -1,10 +0,0 @@
name: shell
description: Use Linux Shell to execute Linux commands
parameters:
type: object
properties:
query:
type: string
description: The command to execute
required:
- query

View File

@ -1,338 +1,25 @@
from __future__ import annotations
## 单独运行的时候需要添加
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
import re
import warnings
from typing import Dict
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.pydantic_v1 import Extra, root_validator
from langchain.schema import BasePromptTemplate
from langchain.schema.language_model import BaseLanguageModel
import requests
from typing import List, Any, Optional
from datetime import datetime
from langchain.prompts import PromptTemplate
from server.agent import model_container
from pydantic import BaseModel, Field
## 使用和风天气API查询天气
KEY = "ac880e5a877042809ac7ffdd19d95b0d"
# key长这样这里提供了示例的key这个key没法使用你需要自己去注册和风天气的账号然后在这里填入你的key
_PROMPT_TEMPLATE = """
用户会提出一个关于天气的问题你的目标是拆分出用户问题中的区 并按照我提供的工具回答
例如 用户提出的问题是: 上海浦东未来1小时天气情况
提取的市和区是: 上海 浦东
如果用户提出的问题是: 上海未来1小时天气情况
提取的市和区是: 上海 None
请注意以下内容:
1. 如果你没有找到区的内容,则一定要使用 None 替代否则程序无法运行
2. 如果用户没有指定市 则直接返回缺少信息
问题: ${{用户的问题}}
你的回答格式应该按照下面的内容请注意格式内的```text 等标记都必须输出这是我用来提取答案的标记
```text
${{拆分的市和区中间用空格隔开}}
```
... weathercheck( )...
```output
${{提取后的答案}}
```
答案: ${{答案}}
这是一个例子
问题: 上海浦东未来1小时天气情况
```text
上海 浦东
```
...weathercheck(上海 浦东)...
```output
预报时间: 1小时后
具体时间: 今天 18:00
温度: 24°C
天气: 多云
风向: 西南风
风速: 7
湿度: 88%
降水概率: 16%
Answer: 上海浦东一小时后的天气是多云
现在这是我的问题
问题: {question}
"""
PROMPT = PromptTemplate(
input_variables=["question"],
template=_PROMPT_TEMPLATE,
)
更简单的单参数输入工具实现用于查询现在天气的情况
"""
from pydantic import BaseModel, Field
import requests
def weather(location: str, api_key: str):
url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={location}&language=zh-Hans&unit=c"
response = requests.get(url)
if response.status_code == 200:
data = response.json()
weather = {
"temperature": data["results"][0]["now"]["temperature"],
"description": data["results"][0]["now"]["text"],
}
return weather
else:
raise Exception(
f"Failed to retrieve weather: {response.status_code}")
def get_city_info(location, adm, key):
base_url = 'https://geoapi.qweather.com/v2/city/lookup?'
params = {'location': location, 'adm': adm, 'key': key}
response = requests.get(base_url, params=params)
data = response.json()
return data
def format_weather_data(data, place):
hourly_forecast = data['hourly']
formatted_data = f"\n 这是查询到的关于{place}未来24小时的天气信息: \n"
for forecast in hourly_forecast:
# 将预报时间转换为datetime对象
forecast_time = datetime.strptime(forecast['fxTime'], '%Y-%m-%dT%H:%M%z')
# 获取预报时间的时区
forecast_tz = forecast_time.tzinfo
# 获取当前时间(使用预报时间的时区)
now = datetime.now(forecast_tz)
# 计算预报日期与当前日期的差值
days_diff = (forecast_time.date() - now.date()).days
if days_diff == 0:
forecast_date_str = '今天'
elif days_diff == 1:
forecast_date_str = '明天'
elif days_diff == 2:
forecast_date_str = '后天'
else:
forecast_date_str = str(days_diff) + '天后'
forecast_time_str = forecast_date_str + ' ' + forecast_time.strftime('%H:%M')
# 计算预报时间与当前时间的差值
time_diff = forecast_time - now
# 将差值转换为小时
hours_diff = time_diff.total_seconds() // 3600
if hours_diff < 1:
hours_diff_str = '1小时后'
elif hours_diff >= 24:
# 如果超过24小时转换为天数
days_diff = hours_diff // 24
hours_diff_str = str(int(days_diff)) + ''
else:
hours_diff_str = str(int(hours_diff)) + '小时'
# 将预报时间和当前时间的差值添加到输出中
formatted_data += '预报时间: ' + forecast_time_str + ' 距离现在有: ' + hours_diff_str + '\n'
formatted_data += '温度: ' + forecast['temp'] + '°C\n'
formatted_data += '天气: ' + forecast['text'] + '\n'
formatted_data += '风向: ' + forecast['windDir'] + '\n'
formatted_data += '风速: ' + forecast['windSpeed'] + '\n'
formatted_data += '湿度: ' + forecast['humidity'] + '%\n'
formatted_data += '降水概率: ' + forecast['pop'] + '%\n'
# formatted_data += '降水量: ' + forecast['precip'] + 'mm\n'
formatted_data += '\n'
return formatted_data
def get_weather(key, location_id, place):
url = "https://devapi.qweather.com/v7/weather/24h?"
params = {
'location': location_id,
'key': key,
}
response = requests.get(url, params=params)
data = response.json()
return format_weather_data(data, place)
def split_query(query):
parts = query.split()
adm = parts[0]
if len(parts) == 1:
return adm, adm
location = parts[1] if parts[1] != 'None' else adm
return location, adm
def weather(query):
location, adm = split_query(query)
key = KEY
if key == "":
return "请先在代码中填入和风天气API Key"
try:
city_info = get_city_info(location=location, adm=adm, key=key)
location_id = city_info['location'][0]['id']
place = adm + "" + location + ""
weather_data = get_weather(key=key, location_id=location_id, place=place)
return weather_data + "以上是查询到的天气信息,请你查收\n"
except KeyError:
try:
city_info = get_city_info(location=adm, adm=adm, key=key)
location_id = city_info['location'][0]['id']
place = adm + ""
weather_data = get_weather(key=key, location_id=location_id, place=place)
return weather_data + "重要提醒:用户提供的市和区中,区的信息不存在,或者出现错别字,因此该信息是关于市的天气,请你查收\n"
except KeyError:
return "输入的地区不存在,无法提供天气预报"
class LLMWeatherChain(Chain):
llm_chain: LLMChain
llm: Optional[BaseLanguageModel] = None
"""[Deprecated] LLM wrapper to use."""
prompt: BasePromptTemplate = PROMPT
"""[Deprecated] Prompt to use to translate to python if necessary."""
input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
if "llm" in values:
warnings.warn(
"Directly instantiating an LLMWeatherChain with an llm is deprecated. "
"Please instantiate with llm_chain argument or using the from_llm "
"class method."
)
if "llm_chain" not in values and values["llm"] is not None:
prompt = values.get("prompt", PROMPT)
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
return values
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Expect output key.
:meta private:
"""
return [self.output_key]
def _evaluate_expression(self, expression: str) -> str:
try:
output = weather(expression)
except Exception as e:
output = "输入的信息有误,请再次尝试"
return output
def _process_llm_result(
self, llm_output: str, run_manager: CallbackManagerForChainRun
) -> Dict[str, str]:
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
if text_match:
expression = text_match.group(1)
output = self._evaluate_expression(expression)
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
run_manager.on_text(output, color="yellow", verbose=self.verbose)
answer = "Answer: " + output
elif llm_output.startswith("Answer:"):
answer = llm_output
elif "Answer:" in llm_output:
answer = "Answer: " + llm_output.split("Answer:")[-1]
else:
return {self.output_key: f"输入的格式不对: {llm_output},应该输入 (市 区)的组合"}
return {self.output_key: answer}
async def _aprocess_llm_result(
self,
llm_output: str,
run_manager: AsyncCallbackManagerForChainRun,
) -> Dict[str, str]:
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
if text_match:
expression = text_match.group(1)
output = self._evaluate_expression(expression)
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
answer = "Answer: " + output
elif llm_output.startswith("Answer:"):
answer = llm_output
elif "Answer:" in llm_output:
answer = "Answer: " + llm_output.split("Answer:")[-1]
else:
raise ValueError(f"unknown format from LLM: {llm_output}")
return {self.output_key: answer}
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
_run_manager.on_text(inputs[self.input_key])
llm_output = self.llm_chain.predict(
question=inputs[self.input_key],
stop=["```output"],
callbacks=_run_manager.get_child(),
)
return self._process_llm_result(llm_output, _run_manager)
async def _acall(
self,
inputs: Dict[str, str],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
await _run_manager.on_text(inputs[self.input_key])
llm_output = await self.llm_chain.apredict(
question=inputs[self.input_key],
stop=["```output"],
callbacks=_run_manager.get_child(),
)
return await self._aprocess_llm_result(llm_output, _run_manager)
@property
def _chain_type(self) -> str:
return "llm_weather_chain"
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: BasePromptTemplate = PROMPT,
**kwargs: Any,
) -> LLMWeatherChain:
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(llm_chain=llm_chain, **kwargs)
def weathercheck(query: str):
model = model_container.MODEL
llm_weather = LLMWeatherChain.from_llm(model, verbose=True, prompt=PROMPT)
ans = llm_weather.run(query)
return ans
class WhetherSchema(BaseModel):
location: str = Field(description="应该是一个地区的名称,用空格隔开,例如:上海 浦东,如果没有区的信息,可以只输入上海")
if __name__ == '__main__':
result = weathercheck("苏州姑苏区今晚热不热?")
def weathercheck(location: str):
return weather(location, "S8vrB4U_-c5mvAMiK")
class WeatherInput(BaseModel):
location: str = Field(description="City name,include city and county,like '厦门'")

View File

@ -1,10 +0,0 @@
name: weather_check
description: Use Weather API to get weather information
parameters:
type: object
properties:
query:
type: string
description: City name,include city and county,like "厦门市思明区"
required:
- query

View File

@ -1,10 +0,0 @@
name: wolfram
description: Useful for when you need to calculate difficult math formulas
parameters:
type: object
properties:
query:
type: string
description: The formula to be calculated
required:
- query

View File

@ -1,8 +1,6 @@
from langchain.tools import Tool
from server.agent.tools import *
## 请注意如果你是为了使用AgentLM在这里你应该使用英文版本。
tools = [
Tool.from_function(
func=calculate,
@ -20,7 +18,7 @@ tools = [
func=weathercheck,
name="weather_check",
description="",
args_schema=WhetherSchema,
args_schema=WeatherInput,
),
Tool.from_function(
func=shell,

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Union, Optional
from typing import Any, Dict, List
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult

View File

@ -1,23 +1,23 @@
from langchain.memory import ConversationBufferWindowMemory
import json
import asyncio
from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent
from server.agent.tools_select import tools, tool_names
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
from langchain.agents import LLMSingleActionAgent, AgentExecutor
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
from fastapi import Body
from sse_starlette.sse import EventSourceResponse
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
from langchain.chains import LLMChain
from typing import AsyncIterable, Optional
import asyncio
from typing import List
from server.chat.utils import History
import json
from server.agent import model_container
from server.knowledge_base.kb_service.base import get_kb_details
from langchain.chains import LLMChain
from langchain.memory import ConversationBufferWindowMemory
from langchain.agents import LLMSingleActionAgent, AgentExecutor
from typing import AsyncIterable, Optional, List
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
from server.knowledge_base.kb_service.base import get_kb_details
from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent
from server.agent.tools_select import tools, tool_names
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
from server.chat.utils import History
from server.agent import model_container
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
history: List[History] = Body([],
@ -33,7 +33,6 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
):
history = [History.from_data(h) for h in history]
@ -55,12 +54,10 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
callbacks=[callback],
)
## 传入全局变量来实现agent调用
kb_list = {x["kb_name"]: x for x in get_kb_details()}
model_container.DATABASE = {name: details['kb_info'] for name, details in kb_list.items()}
if Agent_MODEL:
## 如果有指定使用Agent模型来完成任务
model_agent = get_ChatOpenAI(
model_name=Agent_MODEL,
temperature=temperature,
@ -79,15 +76,11 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
)
output_parser = CustomOutputParser()
llm_chain = LLMChain(llm=model, prompt=prompt_template_agent)
# 把history转成agent的memory
memory = ConversationBufferWindowMemory(k=HISTORY_LEN * 2)
for message in history:
# 检查消息的角色
if message.role == 'user':
# 添加用户消息
memory.chat_memory.add_user_message(message.content)
else:
# 添加AI消息
memory.chat_memory.add_ai_message(message.content)
if "chatglm3" in model_container.MODEL.model_name:
@ -95,7 +88,6 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
llm=model,
tools=tools,
callback_manager=None,
# Langchain Prompt is not constructed directly here, it is constructed inside the GLM3 agent.
prompt=prompt_template,
input_variables=["input", "intermediate_steps", "history"],
memory=memory,
@ -155,7 +147,6 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
answer = ""
final_answer = ""
async for chunk in callback.aiter():
# Use server-sent-events to stream the response
data = json.loads(chunk)
if data["status"] == Status.start or data["status"] == Status.complete:
continue
@ -181,7 +172,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
await task
return EventSourceResponse(agent_chat_iterator(query=query,
history=history,
model_name=model_name,
prompt_name=prompt_name),
)
history=history,
model_name=model_name,
prompt_name=prompt_name),
)

View File

@ -1,23 +1,23 @@
from langchain.utilities.bing_search import BingSearchAPIWrapper
from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY,
LLM_MODELS, SEARCH_ENGINE_TOP_K, TEMPERATURE,
TEXT_SPLITTER_NAME, OVERLAP_SIZE)
from fastapi import Body
from sse_starlette import EventSourceResponse
from fastapi.concurrency import run_in_threadpool
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template
LLM_MODELS, SEARCH_ENGINE_TOP_K, TEMPERATURE, OVERLAP_SIZE)
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from typing import List, Optional, Dict
from server.chat.utils import History
from langchain.docstore.document import Document
from fastapi import Body
from fastapi.concurrency import run_in_threadpool
from sse_starlette import EventSourceResponse
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template
from server.chat.utils import History
from typing import AsyncIterable
import asyncio
import json
from typing import List, Optional, Dict
from strsimpy.normalized_levenshtein import NormalizedLevenshtein
from markdownify import markdownify
@ -38,11 +38,11 @@ def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
def metaphor_search(
text: str,
result_len: int = SEARCH_ENGINE_TOP_K,
split_result: bool = False,
chunk_size: int = 500,
chunk_overlap: int = OVERLAP_SIZE,
text: str,
result_len: int = SEARCH_ENGINE_TOP_K,
split_result: bool = False,
chunk_size: int = 500,
chunk_overlap: int = OVERLAP_SIZE,
) -> List[Dict]:
from metaphor_python import Metaphor
@ -58,13 +58,13 @@ def metaphor_search(
# metaphor 返回的内容都是长文本,需要分词再检索
if split_result:
docs = [Document(page_content=x.extract,
metadata={"link": x.url, "title": x.title})
metadata={"link": x.url, "title": x.title})
for x in contents]
text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "],
chunk_size=chunk_size,
chunk_overlap=chunk_overlap)
splitted_docs = text_splitter.split_documents(docs)
# 将切分好的文档放入临时向量库重新筛选出TOP_K个文档
if len(splitted_docs) > result_len:
normal = NormalizedLevenshtein()
@ -74,13 +74,13 @@ def metaphor_search(
splitted_docs = splitted_docs[:result_len]
docs = [{"snippet": x.page_content,
"link": x.metadata["link"],
"title": x.metadata["title"]}
"link": x.metadata["link"],
"title": x.metadata["title"]}
for x in splitted_docs]
else:
docs = [{"snippet": x.extract,
"link": x.url,
"title": x.title}
"link": x.url,
"title": x.title}
for x in contents]
return docs
@ -113,25 +113,27 @@ async def lookup_search_engine(
docs = search_result2docs(results)
return docs
async def search_engine_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user",
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
{"role": "assistant",
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
split_result: bool = Body(False, description="是否对搜索结果进行拆分主要用于metaphor搜索引擎")
):
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None,
description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
split_result: bool = Body(False,
description="是否对搜索结果进行拆分主要用于metaphor搜索引擎")
):
if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
@ -198,9 +200,9 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
await task
return EventSourceResponse(search_engine_chat_iterator(query=query,
search_engine_name=search_engine_name,
top_k=top_k,
history=history,
model_name=model_name,
prompt_name=prompt_name),
)
search_engine_name=search_engine_name,
top_k=top_k,
history=history,
model_name=model_name,
prompt_name=prompt_name),
)

View File

@ -1,7 +1,6 @@
# 该文件封装了对api.py的请求可以被不同的webui使用
# 通过ApiRequest和AsyncApiRequest支持同步/异步调用
from typing import *
from pathlib import Path
# 此处导入的配置为发起请求如WEBUI机器上的配置主要用于为前端设置默认值。分布式部署时可以与服务器上的不同
@ -27,7 +26,7 @@ from io import BytesIO
from server.utils import set_httpx_config, api_address, get_httpx_client
from pprint import pprint
from langchain_core._api import deprecated
set_httpx_config()
@ -36,10 +35,11 @@ class ApiRequest:
'''
api.py调用的封装同步模式,简化api调用方式
'''
def __init__(
self,
base_url: str = api_address(),
timeout: float = HTTPX_DEFAULT_TIMEOUT,
self,
base_url: str = api_address(),
timeout: float = HTTPX_DEFAULT_TIMEOUT,
):
self.base_url = base_url
self.timeout = timeout
@ -55,12 +55,12 @@ class ApiRequest:
return self._client
def get(
self,
url: str,
params: Union[Dict, List[Tuple], bytes] = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any,
self,
url: str,
params: Union[Dict, List[Tuple], bytes] = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any,
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
while retry > 0:
try:
@ -75,13 +75,13 @@ class ApiRequest:
retry -= 1
def post(
self,
url: str,
data: Dict = None,
json: Dict = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any
self,
url: str,
data: Dict = None,
json: Dict = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
while retry > 0:
try:
@ -97,13 +97,13 @@ class ApiRequest:
retry -= 1
def delete(
self,
url: str,
data: Dict = None,
json: Dict = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any
self,
url: str,
data: Dict = None,
json: Dict = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
while retry > 0:
try:
@ -118,24 +118,25 @@ class ApiRequest:
retry -= 1
def _httpx_stream2generator(
self,
response: contextlib._GeneratorContextManager,
as_json: bool = False,
self,
response: contextlib._GeneratorContextManager,
as_json: bool = False,
):
'''
将httpx.stream返回的GeneratorContextManager转化为普通生成器
'''
async def ret_async(response, as_json):
try:
async with response as r:
async for chunk in r.aiter_text(None):
if not chunk: # fastchat api yield empty bytes on start and end
if not chunk: # fastchat api yield empty bytes on start and end
continue
if as_json:
try:
if chunk.startswith("data: "):
data = json.loads(chunk[6:-2])
elif chunk.startswith(":"): # skip sse comment line
elif chunk.startswith(":"): # skip sse comment line
continue
else:
data = json.loads(chunk)
@ -143,7 +144,7 @@ class ApiRequest:
except Exception as e:
msg = f"接口返回json错误 {chunk}’。错误信息是:{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
exc_info=e if log_verbose else None)
else:
# print(chunk, end="", flush=True)
yield chunk
@ -158,20 +159,20 @@ class ApiRequest:
except Exception as e:
msg = f"API通信遇到错误{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
exc_info=e if log_verbose else None)
yield {"code": 500, "msg": msg}
def ret_sync(response, as_json):
try:
with response as r:
for chunk in r.iter_text(None):
if not chunk: # fastchat api yield empty bytes on start and end
if not chunk: # fastchat api yield empty bytes on start and end
continue
if as_json:
try:
if chunk.startswith("data: "):
data = json.loads(chunk[6:-2])
elif chunk.startswith(":"): # skip sse comment line
elif chunk.startswith(":"): # skip sse comment line
continue
else:
data = json.loads(chunk)
@ -179,7 +180,7 @@ class ApiRequest:
except Exception as e:
msg = f"接口返回json错误 {chunk}’。错误信息是:{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
exc_info=e if log_verbose else None)
else:
# print(chunk, end="", flush=True)
yield chunk
@ -194,7 +195,7 @@ class ApiRequest:
except Exception as e:
msg = f"API通信遇到错误{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
exc_info=e if log_verbose else None)
yield {"code": 500, "msg": msg}
if self._use_async:
@ -203,16 +204,17 @@ class ApiRequest:
return ret_sync(response, as_json)
def _get_response_value(
self,
response: httpx.Response,
as_json: bool = False,
value_func: Callable = None,
self,
response: httpx.Response,
as_json: bool = False,
value_func: Callable = None,
):
'''
转换同步或异步请求返回的响应
`as_json`: 返回json
`value_func`: 用户可以自定义返回值该函数接受response或json
'''
def to_json(r):
try:
return r.json()
@ -220,7 +222,7 @@ class ApiRequest:
msg = "API未能返回正确的JSON。" + str(e)
if log_verbose:
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
exc_info=e if log_verbose else None)
return {"code": 500, "msg": msg, "data": None}
if value_func is None:
@ -250,10 +252,10 @@ class ApiRequest:
return self._get_response_value(response, as_json=True, value_func=lambda r: r["data"])
def get_prompt_template(
self,
type: str = "llm_chat",
name: str = "default",
**kwargs,
self,
type: str = "llm_chat",
name: str = "default",
**kwargs,
) -> str:
data = {
"type": type,
@ -297,15 +299,19 @@ class ApiRequest:
response = self.post("/chat/chat", json=data, stream=True, **kwargs)
return self._httpx_stream2generator(response, as_json=True)
@deprecated(
since="0.3.0",
message="自定义Agent问答将于 Langchain-Chatchat 0.3.x重写, 0.2.x中相关功能将废弃",
removal="0.3.0")
def agent_chat(
self,
query: str,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
self,
query: str,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
):
'''
对应api.py/chat/agent_chat 接口
@ -327,17 +333,17 @@ class ApiRequest:
return self._httpx_stream2generator(response, as_json=True)
def knowledge_base_chat(
self,
query: str,
knowledge_base_name: str,
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
self,
query: str,
knowledge_base_name: str,
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
):
'''
对应api.py/chat/knowledge_base_chat接口
@ -366,28 +372,29 @@ class ApiRequest:
return self._httpx_stream2generator(response, as_json=True)
def upload_temp_docs(
self,
files: List[Union[str, Path, bytes]],
knowledge_id: str = None,
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE,
self,
files: List[Union[str, Path, bytes]],
knowledge_id: str = None,
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE,
):
'''
对应api.py/knowledge_base/upload_tmep_docs接口
'''
def convert_file(file, filename=None):
if isinstance(file, bytes): # raw bytes
if isinstance(file, bytes): # raw bytes
file = BytesIO(file)
elif hasattr(file, "read"): # a file io like object
elif hasattr(file, "read"): # a file io like object
filename = filename or file.name
else: # a local path
else: # a local path
file = Path(file).absolute().open("rb")
filename = filename or os.path.split(file.name)[-1]
return filename, file
files = [convert_file(file) for file in files]
data={
data = {
"knowledge_id": knowledge_id,
"chunk_size": chunk_size,
"chunk_overlap": chunk_overlap,
@ -402,17 +409,17 @@ class ApiRequest:
return self._get_response_value(response, as_json=True)
def file_chat(
self,
query: str,
knowledge_id: str,
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
self,
query: str,
knowledge_id: str,
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
):
'''
对应api.py/chat/file_chat接口
@ -440,18 +447,23 @@ class ApiRequest:
)
return self._httpx_stream2generator(response, as_json=True)
@deprecated(
since="0.3.0",
message="搜索引擎问答将于 Langchain-Chatchat 0.3.x重写, 0.2.x中相关功能将废弃",
removal="0.3.0"
)
def search_engine_chat(
self,
query: str,
search_engine_name: str,
top_k: int = SEARCH_ENGINE_TOP_K,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
split_result: bool = False,
self,
query: str,
search_engine_name: str,
top_k: int = SEARCH_ENGINE_TOP_K,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
split_result: bool = False,
):
'''
对应api.py/chat/search_engine_chat接口
@ -482,7 +494,7 @@ class ApiRequest:
# 知识库相关操作
def list_knowledge_bases(
self,
self,
):
'''
对应api.py/knowledge_base/list_knowledge_bases接口
@ -493,10 +505,10 @@ class ApiRequest:
value_func=lambda r: r.get("data", []))
def create_knowledge_base(
self,
knowledge_base_name: str,
vector_store_type: str = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL,
self,
knowledge_base_name: str,
vector_store_type: str = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL,
):
'''
对应api.py/knowledge_base/create_knowledge_base接口
@ -514,8 +526,8 @@ class ApiRequest:
return self._get_response_value(response, as_json=True)
def delete_knowledge_base(
self,
knowledge_base_name: str,
self,
knowledge_base_name: str,
):
'''
对应api.py/knowledge_base/delete_knowledge_base接口
@ -527,8 +539,8 @@ class ApiRequest:
return self._get_response_value(response, as_json=True)
def list_kb_docs(
self,
knowledge_base_name: str,
self,
knowledge_base_name: str,
):
'''
对应api.py/knowledge_base/list_files接口
@ -542,13 +554,13 @@ class ApiRequest:
value_func=lambda r: r.get("data", []))
def search_kb_docs(
self,
knowledge_base_name: str,
query: str = "",
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: int = SCORE_THRESHOLD,
file_name: str = "",
metadata: dict = {},
self,
knowledge_base_name: str,
query: str = "",
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: int = SCORE_THRESHOLD,
file_name: str = "",
metadata: dict = {},
) -> List:
'''
对应api.py/knowledge_base/search_docs接口
@ -569,9 +581,9 @@ class ApiRequest:
return self._get_response_value(response, as_json=True)
def update_docs_by_id(
self,
knowledge_base_name: str,
docs: Dict[str, Dict],
self,
knowledge_base_name: str,
docs: Dict[str, Dict],
) -> bool:
'''
对应api.py/knowledge_base/update_docs_by_id接口
@ -587,32 +599,33 @@ class ApiRequest:
return self._get_response_value(response)
def upload_kb_docs(
self,
files: List[Union[str, Path, bytes]],
knowledge_base_name: str,
override: bool = False,
to_vector_store: bool = True,
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE,
docs: Dict = {},
not_refresh_vs_cache: bool = False,
self,
files: List[Union[str, Path, bytes]],
knowledge_base_name: str,
override: bool = False,
to_vector_store: bool = True,
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE,
docs: Dict = {},
not_refresh_vs_cache: bool = False,
):
'''
对应api.py/knowledge_base/upload_docs接口
'''
def convert_file(file, filename=None):
if isinstance(file, bytes): # raw bytes
if isinstance(file, bytes): # raw bytes
file = BytesIO(file)
elif hasattr(file, "read"): # a file io like object
elif hasattr(file, "read"): # a file io like object
filename = filename or file.name
else: # a local path
else: # a local path
file = Path(file).absolute().open("rb")
filename = filename or os.path.split(file.name)[-1]
return filename, file
files = [convert_file(file) for file in files]
data={
data = {
"knowledge_base_name": knowledge_base_name,
"override": override,
"to_vector_store": to_vector_store,
@ -633,11 +646,11 @@ class ApiRequest:
return self._get_response_value(response, as_json=True)
def delete_kb_docs(
self,
knowledge_base_name: str,
file_names: List[str],
delete_content: bool = False,
not_refresh_vs_cache: bool = False,
self,
knowledge_base_name: str,
file_names: List[str],
delete_content: bool = False,
not_refresh_vs_cache: bool = False,
):
'''
对应api.py/knowledge_base/delete_docs接口
@ -655,8 +668,7 @@ class ApiRequest:
)
return self._get_response_value(response, as_json=True)
def update_kb_info(self,knowledge_base_name,kb_info):
def update_kb_info(self, knowledge_base_name, kb_info):
'''
对应api.py/knowledge_base/update_info接口
'''
@ -672,15 +684,15 @@ class ApiRequest:
return self._get_response_value(response, as_json=True)
def update_kb_docs(
self,
knowledge_base_name: str,
file_names: List[str],
override_custom_docs: bool = False,
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE,
docs: Dict = {},
not_refresh_vs_cache: bool = False,
self,
knowledge_base_name: str,
file_names: List[str],
override_custom_docs: bool = False,
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE,
docs: Dict = {},
not_refresh_vs_cache: bool = False,
):
'''
对应api.py/knowledge_base/update_docs接口
@ -706,14 +718,14 @@ class ApiRequest:
return self._get_response_value(response, as_json=True)
def recreate_vector_store(
self,
knowledge_base_name: str,
allow_empty_kb: bool = True,
vs_type: str = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL,
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE,
self,
knowledge_base_name: str,
allow_empty_kb: bool = True,
vs_type: str = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL,
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE,
):
'''
对应api.py/knowledge_base/recreate_vector_store接口
@ -738,8 +750,8 @@ class ApiRequest:
# LLM模型相关操作
def list_running_models(
self,
controller_address: str = None,
self,
controller_address: str = None,
):
'''
获取Fastchat中正运行的模型列表
@ -755,8 +767,7 @@ class ApiRequest:
"/llm_model/list_running_models",
json=data,
)
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", []))
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", []))
def get_default_llm_model(self, local_first: bool = True) -> Tuple[str, bool]:
'''
@ -764,6 +775,7 @@ class ApiRequest:
local_first=True 优先返回运行中的本地模型否则优先按LLM_MODELS配置顺序返回
返回类型为model_name, is_local_model
'''
def ret_sync():
running_models = self.list_running_models()
if not running_models:
@ -780,7 +792,7 @@ class ApiRequest:
model = m
break
if not model: # LLM_MODELS中配置的模型都不在running_models里
if not model: # LLM_MODELS中配置的模型都不在running_models里
model = list(running_models)[0]
is_local = not running_models[model].get("online_api")
return model, is_local
@ -801,7 +813,7 @@ class ApiRequest:
model = m
break
if not model: # LLM_MODELS中配置的模型都不在running_models里
if not model: # LLM_MODELS中配置的模型都不在running_models里
model = list(running_models)[0]
is_local = not running_models[model].get("online_api")
return model, is_local
@ -812,8 +824,8 @@ class ApiRequest:
return ret_sync()
def list_config_models(
self,
types: List[str] = ["local", "online"],
self,
types: List[str] = ["local", "online"],
) -> Dict[str, Dict]:
'''
获取服务器configs中配置的模型列表返回形式为{"type": {model_name: config}, ...}
@ -825,23 +837,23 @@ class ApiRequest:
"/llm_model/list_config_models",
json=data,
)
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {}))
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
def get_model_config(
self,
model_name: str = None,
self,
model_name: str = None,
) -> Dict:
'''
获取服务器上模型配置
'''
data={
data = {
"model_name": model_name,
}
response = self.post(
"/llm_model/get_model_config",
json=data,
)
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {}))
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
def list_search_engines(self) -> List[str]:
'''
@ -850,12 +862,12 @@ class ApiRequest:
response = self.post(
"/server/list_search_engines",
)
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {}))
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
def stop_llm_model(
self,
model_name: str,
controller_address: str = None,
self,
model_name: str,
controller_address: str = None,
):
'''
停止某个LLM模型
@ -873,10 +885,10 @@ class ApiRequest:
return self._get_response_value(response, as_json=True)
def change_llm_model(
self,
model_name: str,
new_model_name: str,
controller_address: str = None,
self,
model_name: str,
new_model_name: str,
controller_address: str = None,
):
'''
向fastchat controller请求切换LLM模型
@ -959,10 +971,10 @@ class ApiRequest:
return ret_sync()
def embed_texts(
self,
texts: List[str],
embed_model: str = EMBEDDING_MODEL,
to_query: bool = False,
self,
texts: List[str],
embed_model: str = EMBEDDING_MODEL,
to_query: bool = False,
) -> List[List[float]]:
'''
对文本进行向量化可选模型包括本地 embed_models 和支持 embeddings 的在线模型
@ -979,10 +991,10 @@ class ApiRequest:
return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data"))
def chat_feedback(
self,
message_id: str,
score: int,
reason: str = "",
self,
message_id: str,
score: int,
reason: str = "",
) -> int:
'''
反馈对话评价
@ -1019,9 +1031,9 @@ def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str:
return error message if error occured when requests API
'''
if (isinstance(data, dict)
and key in data
and "code" in data
and data["code"] == 200):
and key in data
and "code" in data
and data["code"] == 200):
return data[key]
return ""