mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 13:23:16 +08:00
更新0.2.x Agent,之后的Agent在0.3.x更新
This commit is contained in:
parent
75ff268e88
commit
269090ea66
@ -7,31 +7,35 @@
|
|||||||
保存的模型的位置位于原本嵌入模型的目录下,模型的名称为原模型名称+Merge_Keywords_时间戳
|
保存的模型的位置位于原本嵌入模型的目录下,模型的名称为原模型名称+Merge_Keywords_时间戳
|
||||||
'''
|
'''
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.path.append("..")
|
sys.path.append("..")
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from configs import (
|
from configs import (
|
||||||
MODEL_PATH,
|
MODEL_PATH,
|
||||||
EMBEDDING_MODEL,
|
EMBEDDING_MODEL,
|
||||||
EMBEDDING_KEYWORD_FILE,
|
EMBEDDING_KEYWORD_FILE,
|
||||||
)
|
)
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
from safetensors.torch import save_model
|
from safetensors.torch import save_model
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from langchain_core._api import deprecated
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated(
|
||||||
|
since="0.3.0",
|
||||||
|
message="自定义关键词 Langchain-Chatchat 0.3.x 重写, 0.2.x中相关功能将废弃",
|
||||||
|
removal="0.3.0"
|
||||||
|
)
|
||||||
def get_keyword_embedding(bert_model, tokenizer, key_words):
|
def get_keyword_embedding(bert_model, tokenizer, key_words):
|
||||||
tokenizer_output = tokenizer(key_words, return_tensors="pt", padding=True, truncation=True)
|
tokenizer_output = tokenizer(key_words, return_tensors="pt", padding=True, truncation=True)
|
||||||
|
|
||||||
# No need to manually convert to tensor as we've set return_tensors="pt"
|
|
||||||
input_ids = tokenizer_output['input_ids']
|
input_ids = tokenizer_output['input_ids']
|
||||||
|
|
||||||
# Remove the first and last token for each sequence in the batch
|
|
||||||
input_ids = input_ids[:, 1:-1]
|
input_ids = input_ids[:, 1:-1]
|
||||||
|
|
||||||
keyword_embedding = bert_model.embeddings.word_embeddings(input_ids)
|
keyword_embedding = bert_model.embeddings.word_embeddings(input_ids)
|
||||||
keyword_embedding = torch.mean(keyword_embedding, 1)
|
keyword_embedding = torch.mean(keyword_embedding, 1)
|
||||||
|
|
||||||
return keyword_embedding
|
return keyword_embedding
|
||||||
|
|
||||||
|
|
||||||
@ -47,14 +51,11 @@ def add_keyword_to_model(model_name=EMBEDDING_MODEL, keyword_file: str = "", out
|
|||||||
bert_model = word_embedding_model.auto_model
|
bert_model = word_embedding_model.auto_model
|
||||||
tokenizer = word_embedding_model.tokenizer
|
tokenizer = word_embedding_model.tokenizer
|
||||||
key_words_embedding = get_keyword_embedding(bert_model, tokenizer, key_words)
|
key_words_embedding = get_keyword_embedding(bert_model, tokenizer, key_words)
|
||||||
# key_words_embedding = st_model.encode(key_words)
|
|
||||||
|
|
||||||
embedding_weight = bert_model.embeddings.word_embeddings.weight
|
embedding_weight = bert_model.embeddings.word_embeddings.weight
|
||||||
embedding_weight_len = len(embedding_weight)
|
embedding_weight_len = len(embedding_weight)
|
||||||
tokenizer.add_tokens(key_words)
|
tokenizer.add_tokens(key_words)
|
||||||
bert_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32)
|
bert_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32)
|
||||||
|
|
||||||
# key_words_embedding_tensor = torch.from_numpy(key_words_embedding)
|
|
||||||
embedding_weight = bert_model.embeddings.word_embeddings.weight
|
embedding_weight = bert_model.embeddings.word_embeddings.weight
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
embedding_weight[embedding_weight_len:embedding_weight_len + key_words_len, :] = key_words_embedding
|
embedding_weight[embedding_weight_len:embedding_weight_len + key_words_len, :] = key_words_embedding
|
||||||
@ -76,46 +77,3 @@ def add_keyword_to_embedding_model(path: str = EMBEDDING_KEYWORD_FILE):
|
|||||||
output_model_name = "{}_Merge_Keywords_{}".format(EMBEDDING_MODEL, current_time)
|
output_model_name = "{}_Merge_Keywords_{}".format(EMBEDDING_MODEL, current_time)
|
||||||
output_model_path = os.path.join(model_parent_directory, output_model_name)
|
output_model_path = os.path.join(model_parent_directory, output_model_name)
|
||||||
add_keyword_to_model(model_name, keyword_file, output_model_path)
|
add_keyword_to_model(model_name, keyword_file, output_model_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
add_keyword_to_embedding_model(EMBEDDING_KEYWORD_FILE)
|
|
||||||
|
|
||||||
# input_model_name = ""
|
|
||||||
# output_model_path = ""
|
|
||||||
# # 以下为加入关键字前后tokenizer的测试用例对比
|
|
||||||
# def print_token_ids(output, tokenizer, sentences):
|
|
||||||
# for idx, ids in enumerate(output['input_ids']):
|
|
||||||
# print(f'sentence={sentences[idx]}')
|
|
||||||
# print(f'ids={ids}')
|
|
||||||
# for id in ids:
|
|
||||||
# decoded_id = tokenizer.decode(id)
|
|
||||||
# print(f' {decoded_id}->{id}')
|
|
||||||
#
|
|
||||||
# sentences = [
|
|
||||||
# '数据科学与大数据技术',
|
|
||||||
# 'Langchain-Chatchat'
|
|
||||||
# ]
|
|
||||||
#
|
|
||||||
# st_no_keywords = SentenceTransformer(input_model_name)
|
|
||||||
# tokenizer_without_keywords = st_no_keywords.tokenizer
|
|
||||||
# print("===== tokenizer with no keywords added =====")
|
|
||||||
# output = tokenizer_without_keywords(sentences)
|
|
||||||
# print_token_ids(output, tokenizer_without_keywords, sentences)
|
|
||||||
# print(f'-------- embedding with no keywords added -----')
|
|
||||||
# embeddings = st_no_keywords.encode(sentences)
|
|
||||||
# print(embeddings)
|
|
||||||
#
|
|
||||||
# print("--------------------------------------------")
|
|
||||||
# print("--------------------------------------------")
|
|
||||||
# print("--------------------------------------------")
|
|
||||||
#
|
|
||||||
# st_with_keywords = SentenceTransformer(output_model_path)
|
|
||||||
# tokenizer_with_keywords = st_with_keywords.tokenizer
|
|
||||||
# print("===== tokenizer with keyword added =====")
|
|
||||||
# output = tokenizer_with_keywords(sentences)
|
|
||||||
# print_token_ids(output, tokenizer_with_keywords, sentences)
|
|
||||||
#
|
|
||||||
# print(f'-------- embedding with keywords added -----')
|
|
||||||
# embeddings = st_with_keywords.encode(sentences)
|
|
||||||
# print(embeddings)
|
|
||||||
@ -38,13 +38,11 @@ numpy~=1.24.4
|
|||||||
pandas~=2.0.3
|
pandas~=2.0.3
|
||||||
einops>=0.7.0
|
einops>=0.7.0
|
||||||
transformers_stream_generator==0.0.4
|
transformers_stream_generator==0.0.4
|
||||||
vllm==0.2.6; sys_platform == "linux"
|
vllm==0.2.7; sys_platform == "linux"
|
||||||
httpx[brotli,http2,socks]==0.25.2
|
|
||||||
llama-index
|
|
||||||
|
|
||||||
# optional document loaders
|
# optional document loaders
|
||||||
|
|
||||||
# rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files
|
#rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files
|
||||||
jq==1.6.0 # for .json and .jsonl files. suggest `conda install jq` on windows
|
jq==1.6.0 # for .json and .jsonl files. suggest `conda install jq` on windows
|
||||||
beautifulsoup4~=4.12.2 # for .mhtml files
|
beautifulsoup4~=4.12.2 # for .mhtml files
|
||||||
pysrt~=1.1.2
|
pysrt~=1.1.2
|
||||||
@ -69,9 +67,11 @@ metaphor-python~=0.1.23
|
|||||||
|
|
||||||
# WebUI requirements
|
# WebUI requirements
|
||||||
|
|
||||||
streamlit~=1.29.0
|
streamlit==1.30.0
|
||||||
streamlit-option-menu>=0.3.6
|
streamlit-option-menu==0.3.6
|
||||||
|
streamlit-antd-components==0.3.1
|
||||||
streamlit-chatbox==1.1.11
|
streamlit-chatbox==1.1.11
|
||||||
streamlit-modal>=0.1.0
|
streamlit-modal==0.1.0
|
||||||
streamlit-aggrid>=0.3.4.post3
|
streamlit-aggrid==0.3.4.post3
|
||||||
watchdog>=3.0.0
|
httpx==0.26.0
|
||||||
|
watchdog==3.0.0
|
||||||
@ -36,8 +36,8 @@ numpy~=1.24.4
|
|||||||
pandas~=2.0.3
|
pandas~=2.0.3
|
||||||
einops>=0.7.0
|
einops>=0.7.0
|
||||||
transformers_stream_generator==0.0.4
|
transformers_stream_generator==0.0.4
|
||||||
vllm==0.2.6; sys_platform == "linux"
|
vllm==0.2.7; sys_platform == "linux"
|
||||||
httpx[brotli,http2,socks]==0.25.2
|
httpx==0.26.0
|
||||||
llama-index
|
llama-index
|
||||||
|
|
||||||
# optional document loaders
|
# optional document loaders
|
||||||
|
|||||||
@ -27,7 +27,7 @@ numpy~=1.24.4
|
|||||||
pandas~=2.0.3
|
pandas~=2.0.3
|
||||||
einops>=0.7.0
|
einops>=0.7.0
|
||||||
transformers_stream_generator==0.0.4
|
transformers_stream_generator==0.0.4
|
||||||
vllm==0.2.6; sys_platform == "linux"
|
vllm==0.2.7; sys_platform == "linux"
|
||||||
httpx[brotli,http2,socks]==0.25.2
|
httpx[brotli,http2,socks]==0.25.2
|
||||||
requests
|
requests
|
||||||
pathlib
|
pathlib
|
||||||
@ -54,11 +54,11 @@ metaphor-python~=0.1.23
|
|||||||
|
|
||||||
# WebUI requirements
|
# WebUI requirements
|
||||||
|
|
||||||
streamlit>=1.29.0
|
streamlit==1.30.0
|
||||||
streamlit-option-menu>=0.3.6
|
streamlit-option-menu==0.3.6
|
||||||
streamlit-antd-components>=0.3.0
|
streamlit-antd-components==0.3.1
|
||||||
streamlit-chatbox>=1.1.11
|
streamlit-chatbox==1.1.11
|
||||||
streamlit-modal>=0.1.0
|
streamlit-modal==0.1.0
|
||||||
streamlit-aggrid>=0.3.4.post3
|
streamlit-aggrid==0.3.4.post3
|
||||||
httpx[brotli,http2,socks]>=0.25.2
|
httpx==0.26.0
|
||||||
watchdog>=3.0.0
|
watchdog==3.0.0
|
||||||
@ -1,10 +1,10 @@
|
|||||||
# WebUI requirements
|
# WebUI requirements
|
||||||
|
|
||||||
streamlit>=1.29.0
|
streamlit==1.30.0
|
||||||
streamlit-option-menu>=0.3.6
|
streamlit-option-menu==0.3.6
|
||||||
streamlit-antd-components>=0.3.0
|
streamlit-antd-components==0.3.1
|
||||||
streamlit-chatbox>=1.1.11
|
streamlit-chatbox==1.1.11
|
||||||
streamlit-modal>=0.1.0
|
streamlit-modal==0.1.0
|
||||||
streamlit-aggrid>=0.3.4.post3
|
streamlit-aggrid==0.3.4.post3
|
||||||
httpx[brotli,http2,socks]>=0.25.2
|
httpx==0.26.0
|
||||||
watchdog>=3.0.0
|
watchdog==3.0.0
|
||||||
@ -1,22 +1,19 @@
|
|||||||
"""
|
"""
|
||||||
This file is a modified version for ChatGLM3-6B the original ChatGLM3Agent.py file from the langchain repo.
|
This file is a modified version for ChatGLM3-6B the original glm3_agent.py file from the langchain repo.
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import yaml
|
|
||||||
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
|
|
||||||
from langchain.memory import ConversationBufferWindowMemory
|
|
||||||
from typing import Any, List, Sequence, Tuple, Optional, Union
|
|
||||||
import os
|
|
||||||
from langchain.agents.agent import Agent
|
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
from langchain.prompts.chat import (
|
|
||||||
ChatPromptTemplate,
|
|
||||||
HumanMessagePromptTemplate,
|
|
||||||
SystemMessagePromptTemplate, MessagesPlaceholder,
|
|
||||||
)
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Any, List, Sequence, Tuple, Optional, Union
|
||||||
|
from pydantic.schema import model_schema
|
||||||
|
|
||||||
|
|
||||||
|
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
|
||||||
|
from langchain.memory import ConversationBufferWindowMemory
|
||||||
|
from langchain.agents.agent import Agent
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate
|
||||||
from langchain.agents.agent import AgentOutputParser
|
from langchain.agents.agent import AgentOutputParser
|
||||||
from langchain.output_parsers import OutputFixingParser
|
from langchain.output_parsers import OutputFixingParser
|
||||||
from langchain.pydantic_v1 import Field
|
from langchain.pydantic_v1 import Field
|
||||||
@ -43,12 +40,18 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser):
|
|||||||
first_index = min([text.find(token) if token in text else len(text) for token in special_tokens])
|
first_index = min([text.find(token) if token in text else len(text) for token in special_tokens])
|
||||||
text = text[:first_index]
|
text = text[:first_index]
|
||||||
if "tool_call" in text:
|
if "tool_call" in text:
|
||||||
tool_name_end = text.find("```")
|
action_end = text.find("```")
|
||||||
tool_name = text[:tool_name_end].strip()
|
action = text[:action_end].strip()
|
||||||
input_para = text.split("='")[-1].split("'")[0]
|
params_str_start = text.find("(") + 1
|
||||||
|
params_str_end = text.rfind(")")
|
||||||
|
params_str = text[params_str_start:params_str_end]
|
||||||
|
|
||||||
|
params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param]
|
||||||
|
params = {pair[0].strip(): pair[1].strip().strip("'\"") for pair in params_pairs}
|
||||||
|
|
||||||
action_json = {
|
action_json = {
|
||||||
"action": tool_name,
|
"action": action,
|
||||||
"action_input": input_para
|
"action_input": params
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
action_json = {
|
action_json = {
|
||||||
@ -109,10 +112,6 @@ class StructuredGLM3ChatAgent(Agent):
|
|||||||
else:
|
else:
|
||||||
return agent_scratchpad
|
return agent_scratchpad
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_default_output_parser(
|
def _get_default_output_parser(
|
||||||
cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any
|
cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any
|
||||||
@ -121,7 +120,7 @@ class StructuredGLM3ChatAgent(Agent):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _stop(self) -> List[str]:
|
def _stop(self) -> List[str]:
|
||||||
return ["```<observation>"]
|
return ["<|observation|>"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_prompt(
|
def create_prompt(
|
||||||
@ -131,44 +130,25 @@ class StructuredGLM3ChatAgent(Agent):
|
|||||||
input_variables: Optional[List[str]] = None,
|
input_variables: Optional[List[str]] = None,
|
||||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||||
) -> BasePromptTemplate:
|
) -> BasePromptTemplate:
|
||||||
def tool_config_from_file(tool_name, directory="server/agent/tools/"):
|
|
||||||
"""search tool yaml and return simplified json format"""
|
|
||||||
file_path = os.path.join(directory, f"{tool_name.lower()}.yaml")
|
|
||||||
try:
|
|
||||||
with open(file_path, 'r', encoding='utf-8') as file:
|
|
||||||
tool_config = yaml.safe_load(file)
|
|
||||||
# Simplify the structure if needed
|
|
||||||
simplified_config = {
|
|
||||||
"name": tool_config.get("name", ""),
|
|
||||||
"description": tool_config.get("description", ""),
|
|
||||||
"parameters": tool_config.get("parameters", {})
|
|
||||||
}
|
|
||||||
return simplified_config
|
|
||||||
except FileNotFoundError:
|
|
||||||
logger.error(f"File not found: {file_path}")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"An error occurred while reading {file_path}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
tools_json = []
|
tools_json = []
|
||||||
tool_names = []
|
tool_names = []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
tool_config = tool_config_from_file(tool.name)
|
tool_schema = model_schema(tool.args_schema) if tool.args_schema else {}
|
||||||
if tool_config:
|
simplified_config_langchain = {
|
||||||
tools_json.append(tool_config)
|
"name": tool.name,
|
||||||
tool_names.append(tool.name)
|
"description": tool.description,
|
||||||
|
"parameters": tool_schema.get("properties", {})
|
||||||
# Format the tools for output
|
}
|
||||||
|
tools_json.append(simplified_config_langchain)
|
||||||
|
tool_names.append(tool.name)
|
||||||
formatted_tools = "\n".join([
|
formatted_tools = "\n".join([
|
||||||
f"{tool['name']}: {tool['description']}, args: {tool['parameters']}"
|
f"{tool['name']}: {tool['description']}, args: {tool['parameters']}"
|
||||||
for tool in tools_json
|
for tool in tools_json
|
||||||
])
|
])
|
||||||
formatted_tools = formatted_tools.replace("'", "\\'").replace("{", "{{").replace("}", "}}")
|
formatted_tools = formatted_tools.replace("'", "\\'").replace("{", "{{").replace("}", "}}")
|
||||||
|
|
||||||
template = prompt.format(tool_names=tool_names,
|
template = prompt.format(tool_names=tool_names,
|
||||||
tools=formatted_tools,
|
tools=formatted_tools,
|
||||||
history="{history}",
|
history="None",
|
||||||
input="{input}",
|
input="{input}",
|
||||||
agent_scratchpad="{agent_scratchpad}")
|
agent_scratchpad="{agent_scratchpad}")
|
||||||
|
|
||||||
@ -225,7 +205,6 @@ def initialize_glm3_agent(
|
|||||||
tools: Sequence[BaseTool],
|
tools: Sequence[BaseTool],
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
prompt: str = None,
|
prompt: str = None,
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
|
||||||
memory: Optional[ConversationBufferWindowMemory] = None,
|
memory: Optional[ConversationBufferWindowMemory] = None,
|
||||||
agent_kwargs: Optional[dict] = None,
|
agent_kwargs: Optional[dict] = None,
|
||||||
*,
|
*,
|
||||||
@ -238,14 +217,12 @@ def initialize_glm3_agent(
|
|||||||
llm=llm,
|
llm=llm,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
callback_manager=callback_manager, **agent_kwargs
|
**agent_kwargs
|
||||||
)
|
)
|
||||||
return AgentExecutor.from_agent_and_tools(
|
return AgentExecutor.from_agent_and_tools(
|
||||||
agent=agent_obj,
|
agent=agent_obj,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
callback_manager=callback_manager,
|
|
||||||
memory=memory,
|
memory=memory,
|
||||||
tags=tags_,
|
tags=tags_,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1,5 +1,3 @@
|
|||||||
|
|
||||||
## 由于工具类无法传参,所以使用全局变量来传递模型和对应的知识库介绍
|
|
||||||
class ModelContainer:
|
class ModelContainer:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.MODEL = None
|
self.MODEL = None
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from .search_knowledgebase_simple import search_knowledgebase_simple
|
|||||||
from .search_knowledgebase_once import search_knowledgebase_once, KnowledgeSearchInput
|
from .search_knowledgebase_once import search_knowledgebase_once, KnowledgeSearchInput
|
||||||
from .search_knowledgebase_complex import search_knowledgebase_complex, KnowledgeSearchInput
|
from .search_knowledgebase_complex import search_knowledgebase_complex, KnowledgeSearchInput
|
||||||
from .calculate import calculate, CalculatorInput
|
from .calculate import calculate, CalculatorInput
|
||||||
from .weather_check import weathercheck, WhetherSchema
|
from .weather_check import weathercheck, WeatherInput
|
||||||
from .shell import shell, ShellInput
|
from .shell import shell, ShellInput
|
||||||
from .search_internet import search_internet, SearchInternetInput
|
from .search_internet import search_internet, SearchInternetInput
|
||||||
from .wolfram import wolfram, WolframInput
|
from .wolfram import wolfram, WolframInput
|
||||||
|
|||||||
@ -1,10 +0,0 @@
|
|||||||
name: arxiv
|
|
||||||
description: A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.
|
|
||||||
parameters:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
query:
|
|
||||||
type: string
|
|
||||||
description: The search query title
|
|
||||||
required:
|
|
||||||
- query
|
|
||||||
@ -1,10 +0,0 @@
|
|||||||
name: calculate
|
|
||||||
description: Useful for when you need to answer questions about simple calculations
|
|
||||||
parameters:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
query:
|
|
||||||
type: string
|
|
||||||
description: The formula to be calculated
|
|
||||||
required:
|
|
||||||
- query
|
|
||||||
@ -1,10 +0,0 @@
|
|||||||
name: search_internet
|
|
||||||
description: Use this tool to surf internet and get information
|
|
||||||
parameters:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
query:
|
|
||||||
type: string
|
|
||||||
description: Query for Internet search
|
|
||||||
required:
|
|
||||||
- query
|
|
||||||
@ -1,10 +0,0 @@
|
|||||||
name: search_knowledgebase_complex
|
|
||||||
description: Use this tool to search local knowledgebase and get information
|
|
||||||
parameters:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
query:
|
|
||||||
type: string
|
|
||||||
description: The query to be searched
|
|
||||||
required:
|
|
||||||
- query
|
|
||||||
@ -1,10 +0,0 @@
|
|||||||
name: search_youtube
|
|
||||||
description: Use this tools to search youtube videos
|
|
||||||
parameters:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
query:
|
|
||||||
type: string
|
|
||||||
description: Query for Videos search
|
|
||||||
required:
|
|
||||||
- query
|
|
||||||
@ -1,10 +0,0 @@
|
|||||||
name: shell
|
|
||||||
description: Use Linux Shell to execute Linux commands
|
|
||||||
parameters:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
query:
|
|
||||||
type: string
|
|
||||||
description: The command to execute
|
|
||||||
required:
|
|
||||||
- query
|
|
||||||
@ -1,338 +1,25 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
## 单独运行的时候需要添加
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
|
||||||
|
|
||||||
import re
|
|
||||||
import warnings
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
|
||||||
AsyncCallbackManagerForChainRun,
|
|
||||||
CallbackManagerForChainRun,
|
|
||||||
)
|
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
from langchain.pydantic_v1 import Extra, root_validator
|
|
||||||
from langchain.schema import BasePromptTemplate
|
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
|
||||||
import requests
|
|
||||||
from typing import List, Any, Optional
|
|
||||||
from datetime import datetime
|
|
||||||
from langchain.prompts import PromptTemplate
|
|
||||||
from server.agent import model_container
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
## 使用和风天气API查询天气
|
|
||||||
KEY = "ac880e5a877042809ac7ffdd19d95b0d"
|
|
||||||
# key长这样,这里提供了示例的key,这个key没法使用,你需要自己去注册和风天气的账号,然后在这里填入你的key
|
|
||||||
|
|
||||||
|
|
||||||
_PROMPT_TEMPLATE = """
|
|
||||||
用户会提出一个关于天气的问题,你的目标是拆分出用户问题中的区,市 并按照我提供的工具回答。
|
|
||||||
例如 用户提出的问题是: 上海浦东未来1小时天气情况?
|
|
||||||
则 提取的市和区是: 上海 浦东
|
|
||||||
如果用户提出的问题是: 上海未来1小时天气情况?
|
|
||||||
则 提取的市和区是: 上海 None
|
|
||||||
请注意以下内容:
|
|
||||||
1. 如果你没有找到区的内容,则一定要使用 None 替代,否则程序无法运行
|
|
||||||
2. 如果用户没有指定市 则直接返回缺少信息
|
|
||||||
|
|
||||||
问题: ${{用户的问题}}
|
|
||||||
|
|
||||||
你的回答格式应该按照下面的内容,请注意,格式内的```text 等标记都必须输出,这是我用来提取答案的标记。
|
|
||||||
```text
|
|
||||||
|
|
||||||
${{拆分的市和区,中间用空格隔开}}
|
|
||||||
```
|
|
||||||
... weathercheck(市 区)...
|
|
||||||
```output
|
|
||||||
|
|
||||||
${{提取后的答案}}
|
|
||||||
```
|
|
||||||
答案: ${{答案}}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
这是一个例子:
|
|
||||||
问题: 上海浦东未来1小时天气情况?
|
|
||||||
|
|
||||||
|
|
||||||
```text
|
|
||||||
上海 浦东
|
|
||||||
```
|
|
||||||
...weathercheck(上海 浦东)...
|
|
||||||
|
|
||||||
```output
|
|
||||||
预报时间: 1小时后
|
|
||||||
具体时间: 今天 18:00
|
|
||||||
温度: 24°C
|
|
||||||
天气: 多云
|
|
||||||
风向: 西南风
|
|
||||||
风速: 7级
|
|
||||||
湿度: 88%
|
|
||||||
降水概率: 16%
|
|
||||||
|
|
||||||
Answer: 上海浦东一小时后的天气是多云。
|
|
||||||
|
|
||||||
现在,这是我的问题:
|
|
||||||
|
|
||||||
问题: {question}
|
|
||||||
"""
|
"""
|
||||||
PROMPT = PromptTemplate(
|
更简单的单参数输入工具实现,用于查询现在天气的情况
|
||||||
input_variables=["question"],
|
"""
|
||||||
template=_PROMPT_TEMPLATE,
|
from pydantic import BaseModel, Field
|
||||||
)
|
import requests
|
||||||
|
|
||||||
|
def weather(location: str, api_key: str):
|
||||||
|
url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={location}&language=zh-Hans&unit=c"
|
||||||
|
response = requests.get(url)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
weather = {
|
||||||
|
"temperature": data["results"][0]["now"]["temperature"],
|
||||||
|
"description": data["results"][0]["now"]["text"],
|
||||||
|
}
|
||||||
|
return weather
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to retrieve weather: {response.status_code}")
|
||||||
|
|
||||||
|
|
||||||
def get_city_info(location, adm, key):
|
def weathercheck(location: str):
|
||||||
base_url = 'https://geoapi.qweather.com/v2/city/lookup?'
|
return weather(location, "S8vrB4U_-c5mvAMiK")
|
||||||
params = {'location': location, 'adm': adm, 'key': key}
|
class WeatherInput(BaseModel):
|
||||||
response = requests.get(base_url, params=params)
|
location: str = Field(description="City name,include city and county,like '厦门'")
|
||||||
data = response.json()
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def format_weather_data(data, place):
|
|
||||||
hourly_forecast = data['hourly']
|
|
||||||
formatted_data = f"\n 这是查询到的关于{place}未来24小时的天气信息: \n"
|
|
||||||
for forecast in hourly_forecast:
|
|
||||||
# 将预报时间转换为datetime对象
|
|
||||||
forecast_time = datetime.strptime(forecast['fxTime'], '%Y-%m-%dT%H:%M%z')
|
|
||||||
# 获取预报时间的时区
|
|
||||||
forecast_tz = forecast_time.tzinfo
|
|
||||||
# 获取当前时间(使用预报时间的时区)
|
|
||||||
now = datetime.now(forecast_tz)
|
|
||||||
# 计算预报日期与当前日期的差值
|
|
||||||
days_diff = (forecast_time.date() - now.date()).days
|
|
||||||
if days_diff == 0:
|
|
||||||
forecast_date_str = '今天'
|
|
||||||
elif days_diff == 1:
|
|
||||||
forecast_date_str = '明天'
|
|
||||||
elif days_diff == 2:
|
|
||||||
forecast_date_str = '后天'
|
|
||||||
else:
|
|
||||||
forecast_date_str = str(days_diff) + '天后'
|
|
||||||
forecast_time_str = forecast_date_str + ' ' + forecast_time.strftime('%H:%M')
|
|
||||||
# 计算预报时间与当前时间的差值
|
|
||||||
time_diff = forecast_time - now
|
|
||||||
# 将差值转换为小时
|
|
||||||
hours_diff = time_diff.total_seconds() // 3600
|
|
||||||
if hours_diff < 1:
|
|
||||||
hours_diff_str = '1小时后'
|
|
||||||
elif hours_diff >= 24:
|
|
||||||
# 如果超过24小时,转换为天数
|
|
||||||
days_diff = hours_diff // 24
|
|
||||||
hours_diff_str = str(int(days_diff)) + '天'
|
|
||||||
else:
|
|
||||||
hours_diff_str = str(int(hours_diff)) + '小时'
|
|
||||||
# 将预报时间和当前时间的差值添加到输出中
|
|
||||||
formatted_data += '预报时间: ' + forecast_time_str + ' 距离现在有: ' + hours_diff_str + '\n'
|
|
||||||
formatted_data += '温度: ' + forecast['temp'] + '°C\n'
|
|
||||||
formatted_data += '天气: ' + forecast['text'] + '\n'
|
|
||||||
formatted_data += '风向: ' + forecast['windDir'] + '\n'
|
|
||||||
formatted_data += '风速: ' + forecast['windSpeed'] + '级\n'
|
|
||||||
formatted_data += '湿度: ' + forecast['humidity'] + '%\n'
|
|
||||||
formatted_data += '降水概率: ' + forecast['pop'] + '%\n'
|
|
||||||
# formatted_data += '降水量: ' + forecast['precip'] + 'mm\n'
|
|
||||||
formatted_data += '\n'
|
|
||||||
return formatted_data
|
|
||||||
|
|
||||||
|
|
||||||
def get_weather(key, location_id, place):
|
|
||||||
url = "https://devapi.qweather.com/v7/weather/24h?"
|
|
||||||
params = {
|
|
||||||
'location': location_id,
|
|
||||||
'key': key,
|
|
||||||
}
|
|
||||||
response = requests.get(url, params=params)
|
|
||||||
data = response.json()
|
|
||||||
return format_weather_data(data, place)
|
|
||||||
|
|
||||||
|
|
||||||
def split_query(query):
|
|
||||||
parts = query.split()
|
|
||||||
adm = parts[0]
|
|
||||||
if len(parts) == 1:
|
|
||||||
return adm, adm
|
|
||||||
location = parts[1] if parts[1] != 'None' else adm
|
|
||||||
return location, adm
|
|
||||||
|
|
||||||
|
|
||||||
def weather(query):
|
|
||||||
location, adm = split_query(query)
|
|
||||||
key = KEY
|
|
||||||
if key == "":
|
|
||||||
return "请先在代码中填入和风天气API Key"
|
|
||||||
try:
|
|
||||||
city_info = get_city_info(location=location, adm=adm, key=key)
|
|
||||||
location_id = city_info['location'][0]['id']
|
|
||||||
place = adm + "市" + location + "区"
|
|
||||||
|
|
||||||
weather_data = get_weather(key=key, location_id=location_id, place=place)
|
|
||||||
return weather_data + "以上是查询到的天气信息,请你查收\n"
|
|
||||||
except KeyError:
|
|
||||||
try:
|
|
||||||
city_info = get_city_info(location=adm, adm=adm, key=key)
|
|
||||||
location_id = city_info['location'][0]['id']
|
|
||||||
place = adm + "市"
|
|
||||||
weather_data = get_weather(key=key, location_id=location_id, place=place)
|
|
||||||
return weather_data + "重要提醒:用户提供的市和区中,区的信息不存在,或者出现错别字,因此该信息是关于市的天气,请你查收\n"
|
|
||||||
except KeyError:
|
|
||||||
return "输入的地区不存在,无法提供天气预报"
|
|
||||||
|
|
||||||
|
|
||||||
class LLMWeatherChain(Chain):
|
|
||||||
llm_chain: LLMChain
|
|
||||||
llm: Optional[BaseLanguageModel] = None
|
|
||||||
"""[Deprecated] LLM wrapper to use."""
|
|
||||||
prompt: BasePromptTemplate = PROMPT
|
|
||||||
"""[Deprecated] Prompt to use to translate to python if necessary."""
|
|
||||||
input_key: str = "question" #: :meta private:
|
|
||||||
output_key: str = "answer" #: :meta private:
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
extra = Extra.forbid
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
@root_validator(pre=True)
|
|
||||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
|
||||||
if "llm" in values:
|
|
||||||
warnings.warn(
|
|
||||||
"Directly instantiating an LLMWeatherChain with an llm is deprecated. "
|
|
||||||
"Please instantiate with llm_chain argument or using the from_llm "
|
|
||||||
"class method."
|
|
||||||
)
|
|
||||||
if "llm_chain" not in values and values["llm"] is not None:
|
|
||||||
prompt = values.get("prompt", PROMPT)
|
|
||||||
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
|
|
||||||
return values
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_keys(self) -> List[str]:
|
|
||||||
"""Expect input key.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.input_key]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_keys(self) -> List[str]:
|
|
||||||
"""Expect output key.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.output_key]
|
|
||||||
|
|
||||||
def _evaluate_expression(self, expression: str) -> str:
|
|
||||||
try:
|
|
||||||
output = weather(expression)
|
|
||||||
except Exception as e:
|
|
||||||
output = "输入的信息有误,请再次尝试"
|
|
||||||
return output
|
|
||||||
|
|
||||||
def _process_llm_result(
|
|
||||||
self, llm_output: str, run_manager: CallbackManagerForChainRun
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
|
|
||||||
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
|
||||||
|
|
||||||
llm_output = llm_output.strip()
|
|
||||||
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
|
||||||
if text_match:
|
|
||||||
expression = text_match.group(1)
|
|
||||||
output = self._evaluate_expression(expression)
|
|
||||||
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
|
||||||
run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
|
||||||
answer = "Answer: " + output
|
|
||||||
elif llm_output.startswith("Answer:"):
|
|
||||||
answer = llm_output
|
|
||||||
elif "Answer:" in llm_output:
|
|
||||||
answer = "Answer: " + llm_output.split("Answer:")[-1]
|
|
||||||
else:
|
|
||||||
return {self.output_key: f"输入的格式不对: {llm_output},应该输入 (市 区)的组合"}
|
|
||||||
return {self.output_key: answer}
|
|
||||||
|
|
||||||
async def _aprocess_llm_result(
|
|
||||||
self,
|
|
||||||
llm_output: str,
|
|
||||||
run_manager: AsyncCallbackManagerForChainRun,
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
|
||||||
llm_output = llm_output.strip()
|
|
||||||
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
|
||||||
|
|
||||||
if text_match:
|
|
||||||
expression = text_match.group(1)
|
|
||||||
output = self._evaluate_expression(expression)
|
|
||||||
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
|
||||||
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
|
||||||
answer = "Answer: " + output
|
|
||||||
elif llm_output.startswith("Answer:"):
|
|
||||||
answer = llm_output
|
|
||||||
elif "Answer:" in llm_output:
|
|
||||||
answer = "Answer: " + llm_output.split("Answer:")[-1]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"unknown format from LLM: {llm_output}")
|
|
||||||
return {self.output_key: answer}
|
|
||||||
|
|
||||||
def _call(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, str],
|
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
||||||
_run_manager.on_text(inputs[self.input_key])
|
|
||||||
llm_output = self.llm_chain.predict(
|
|
||||||
question=inputs[self.input_key],
|
|
||||||
stop=["```output"],
|
|
||||||
callbacks=_run_manager.get_child(),
|
|
||||||
)
|
|
||||||
return self._process_llm_result(llm_output, _run_manager)
|
|
||||||
|
|
||||||
async def _acall(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, str],
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
|
||||||
await _run_manager.on_text(inputs[self.input_key])
|
|
||||||
llm_output = await self.llm_chain.apredict(
|
|
||||||
question=inputs[self.input_key],
|
|
||||||
stop=["```output"],
|
|
||||||
callbacks=_run_manager.get_child(),
|
|
||||||
)
|
|
||||||
return await self._aprocess_llm_result(llm_output, _run_manager)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _chain_type(self) -> str:
|
|
||||||
return "llm_weather_chain"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_llm(
|
|
||||||
cls,
|
|
||||||
llm: BaseLanguageModel,
|
|
||||||
prompt: BasePromptTemplate = PROMPT,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> LLMWeatherChain:
|
|
||||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
|
||||||
return cls(llm_chain=llm_chain, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def weathercheck(query: str):
|
|
||||||
model = model_container.MODEL
|
|
||||||
llm_weather = LLMWeatherChain.from_llm(model, verbose=True, prompt=PROMPT)
|
|
||||||
ans = llm_weather.run(query)
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
class WhetherSchema(BaseModel):
|
|
||||||
location: str = Field(description="应该是一个地区的名称,用空格隔开,例如:上海 浦东,如果没有区的信息,可以只输入上海")
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
result = weathercheck("苏州姑苏区今晚热不热?")
|
|
||||||
|
|||||||
@ -1,10 +0,0 @@
|
|||||||
name: weather_check
|
|
||||||
description: Use Weather API to get weather information
|
|
||||||
parameters:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
query:
|
|
||||||
type: string
|
|
||||||
description: City name,include city and county,like "厦门市思明区"
|
|
||||||
required:
|
|
||||||
- query
|
|
||||||
@ -1,10 +0,0 @@
|
|||||||
name: wolfram
|
|
||||||
description: Useful for when you need to calculate difficult math formulas
|
|
||||||
parameters:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
query:
|
|
||||||
type: string
|
|
||||||
description: The formula to be calculated
|
|
||||||
required:
|
|
||||||
- query
|
|
||||||
@ -1,8 +1,6 @@
|
|||||||
from langchain.tools import Tool
|
from langchain.tools import Tool
|
||||||
from server.agent.tools import *
|
from server.agent.tools import *
|
||||||
|
|
||||||
## 请注意,如果你是为了使用AgentLM,在这里,你应该使用英文版本。
|
|
||||||
|
|
||||||
tools = [
|
tools = [
|
||||||
Tool.from_function(
|
Tool.from_function(
|
||||||
func=calculate,
|
func=calculate,
|
||||||
@ -20,7 +18,7 @@ tools = [
|
|||||||
func=weathercheck,
|
func=weathercheck,
|
||||||
name="weather_check",
|
name="weather_check",
|
||||||
description="",
|
description="",
|
||||||
args_schema=WhetherSchema,
|
args_schema=WeatherInput,
|
||||||
),
|
),
|
||||||
Tool.from_function(
|
Tool.from_function(
|
||||||
func=shell,
|
func=shell,
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List, Union, Optional
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.schema import LLMResult
|
from langchain.schema import LLMResult
|
||||||
|
|||||||
@ -1,23 +1,23 @@
|
|||||||
from langchain.memory import ConversationBufferWindowMemory
|
import json
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent
|
|
||||||
from server.agent.tools_select import tools, tool_names
|
|
||||||
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
|
|
||||||
from langchain.agents import LLMSingleActionAgent, AgentExecutor
|
|
||||||
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
|
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from sse_starlette.sse import EventSourceResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL
|
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL
|
||||||
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
|
|
||||||
from langchain.chains import LLMChain
|
|
||||||
from typing import AsyncIterable, Optional
|
|
||||||
import asyncio
|
|
||||||
from typing import List
|
|
||||||
from server.chat.utils import History
|
|
||||||
import json
|
|
||||||
from server.agent import model_container
|
|
||||||
from server.knowledge_base.kb_service.base import get_kb_details
|
|
||||||
|
|
||||||
|
from langchain.chains import LLMChain
|
||||||
|
from langchain.memory import ConversationBufferWindowMemory
|
||||||
|
from langchain.agents import LLMSingleActionAgent, AgentExecutor
|
||||||
|
from typing import AsyncIterable, Optional, List
|
||||||
|
|
||||||
|
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
|
||||||
|
from server.knowledge_base.kb_service.base import get_kb_details
|
||||||
|
from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent
|
||||||
|
from server.agent.tools_select import tools, tool_names
|
||||||
|
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
|
||||||
|
from server.chat.utils import History
|
||||||
|
from server.agent import model_container
|
||||||
|
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
|
||||||
|
|
||||||
async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||||
history: List[History] = Body([],
|
history: List[History] = Body([],
|
||||||
@ -33,7 +33,6 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
|
|||||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||||
prompt_name: str = Body("default",
|
prompt_name: str = Body("default",
|
||||||
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
|
||||||
):
|
):
|
||||||
history = [History.from_data(h) for h in history]
|
history = [History.from_data(h) for h in history]
|
||||||
|
|
||||||
@ -55,12 +54,10 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
|
|||||||
callbacks=[callback],
|
callbacks=[callback],
|
||||||
)
|
)
|
||||||
|
|
||||||
## 传入全局变量来实现agent调用
|
|
||||||
kb_list = {x["kb_name"]: x for x in get_kb_details()}
|
kb_list = {x["kb_name"]: x for x in get_kb_details()}
|
||||||
model_container.DATABASE = {name: details['kb_info'] for name, details in kb_list.items()}
|
model_container.DATABASE = {name: details['kb_info'] for name, details in kb_list.items()}
|
||||||
|
|
||||||
if Agent_MODEL:
|
if Agent_MODEL:
|
||||||
## 如果有指定使用Agent模型来完成任务
|
|
||||||
model_agent = get_ChatOpenAI(
|
model_agent = get_ChatOpenAI(
|
||||||
model_name=Agent_MODEL,
|
model_name=Agent_MODEL,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
@ -79,15 +76,11 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
|
|||||||
)
|
)
|
||||||
output_parser = CustomOutputParser()
|
output_parser = CustomOutputParser()
|
||||||
llm_chain = LLMChain(llm=model, prompt=prompt_template_agent)
|
llm_chain = LLMChain(llm=model, prompt=prompt_template_agent)
|
||||||
# 把history转成agent的memory
|
|
||||||
memory = ConversationBufferWindowMemory(k=HISTORY_LEN * 2)
|
memory = ConversationBufferWindowMemory(k=HISTORY_LEN * 2)
|
||||||
for message in history:
|
for message in history:
|
||||||
# 检查消息的角色
|
|
||||||
if message.role == 'user':
|
if message.role == 'user':
|
||||||
# 添加用户消息
|
|
||||||
memory.chat_memory.add_user_message(message.content)
|
memory.chat_memory.add_user_message(message.content)
|
||||||
else:
|
else:
|
||||||
# 添加AI消息
|
|
||||||
memory.chat_memory.add_ai_message(message.content)
|
memory.chat_memory.add_ai_message(message.content)
|
||||||
|
|
||||||
if "chatglm3" in model_container.MODEL.model_name:
|
if "chatglm3" in model_container.MODEL.model_name:
|
||||||
@ -95,7 +88,6 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
|
|||||||
llm=model,
|
llm=model,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
callback_manager=None,
|
callback_manager=None,
|
||||||
# Langchain Prompt is not constructed directly here, it is constructed inside the GLM3 agent.
|
|
||||||
prompt=prompt_template,
|
prompt=prompt_template,
|
||||||
input_variables=["input", "intermediate_steps", "history"],
|
input_variables=["input", "intermediate_steps", "history"],
|
||||||
memory=memory,
|
memory=memory,
|
||||||
@ -155,7 +147,6 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
|
|||||||
answer = ""
|
answer = ""
|
||||||
final_answer = ""
|
final_answer = ""
|
||||||
async for chunk in callback.aiter():
|
async for chunk in callback.aiter():
|
||||||
# Use server-sent-events to stream the response
|
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
if data["status"] == Status.start or data["status"] == Status.complete:
|
if data["status"] == Status.start or data["status"] == Status.complete:
|
||||||
continue
|
continue
|
||||||
@ -181,7 +172,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
|
|||||||
await task
|
await task
|
||||||
|
|
||||||
return EventSourceResponse(agent_chat_iterator(query=query,
|
return EventSourceResponse(agent_chat_iterator(query=query,
|
||||||
history=history,
|
history=history,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
prompt_name=prompt_name),
|
prompt_name=prompt_name),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,23 +1,23 @@
|
|||||||
from langchain.utilities.bing_search import BingSearchAPIWrapper
|
from langchain.utilities.bing_search import BingSearchAPIWrapper
|
||||||
from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
|
from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
|
||||||
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY,
|
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY,
|
||||||
LLM_MODELS, SEARCH_ENGINE_TOP_K, TEMPERATURE,
|
LLM_MODELS, SEARCH_ENGINE_TOP_K, TEMPERATURE, OVERLAP_SIZE)
|
||||||
TEXT_SPLITTER_NAME, OVERLAP_SIZE)
|
|
||||||
from fastapi import Body
|
|
||||||
from sse_starlette import EventSourceResponse
|
|
||||||
from fastapi.concurrency import run_in_threadpool
|
|
||||||
from server.utils import wrap_done, get_ChatOpenAI
|
|
||||||
from server.utils import BaseResponse, get_prompt_template
|
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||||
from typing import AsyncIterable
|
|
||||||
import asyncio
|
|
||||||
from langchain.prompts.chat import ChatPromptTemplate
|
from langchain.prompts.chat import ChatPromptTemplate
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
from typing import List, Optional, Dict
|
|
||||||
from server.chat.utils import History
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
from fastapi import Body
|
||||||
|
from fastapi.concurrency import run_in_threadpool
|
||||||
|
from sse_starlette import EventSourceResponse
|
||||||
|
from server.utils import wrap_done, get_ChatOpenAI
|
||||||
|
from server.utils import BaseResponse, get_prompt_template
|
||||||
|
from server.chat.utils import History
|
||||||
|
from typing import AsyncIterable
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
from typing import List, Optional, Dict
|
||||||
from strsimpy.normalized_levenshtein import NormalizedLevenshtein
|
from strsimpy.normalized_levenshtein import NormalizedLevenshtein
|
||||||
from markdownify import markdownify
|
from markdownify import markdownify
|
||||||
|
|
||||||
@ -38,11 +38,11 @@ def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def metaphor_search(
|
def metaphor_search(
|
||||||
text: str,
|
text: str,
|
||||||
result_len: int = SEARCH_ENGINE_TOP_K,
|
result_len: int = SEARCH_ENGINE_TOP_K,
|
||||||
split_result: bool = False,
|
split_result: bool = False,
|
||||||
chunk_size: int = 500,
|
chunk_size: int = 500,
|
||||||
chunk_overlap: int = OVERLAP_SIZE,
|
chunk_overlap: int = OVERLAP_SIZE,
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
from metaphor_python import Metaphor
|
from metaphor_python import Metaphor
|
||||||
|
|
||||||
@ -58,13 +58,13 @@ def metaphor_search(
|
|||||||
# metaphor 返回的内容都是长文本,需要分词再检索
|
# metaphor 返回的内容都是长文本,需要分词再检索
|
||||||
if split_result:
|
if split_result:
|
||||||
docs = [Document(page_content=x.extract,
|
docs = [Document(page_content=x.extract,
|
||||||
metadata={"link": x.url, "title": x.title})
|
metadata={"link": x.url, "title": x.title})
|
||||||
for x in contents]
|
for x in contents]
|
||||||
text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "],
|
text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "],
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
chunk_overlap=chunk_overlap)
|
chunk_overlap=chunk_overlap)
|
||||||
splitted_docs = text_splitter.split_documents(docs)
|
splitted_docs = text_splitter.split_documents(docs)
|
||||||
|
|
||||||
# 将切分好的文档放入临时向量库,重新筛选出TOP_K个文档
|
# 将切分好的文档放入临时向量库,重新筛选出TOP_K个文档
|
||||||
if len(splitted_docs) > result_len:
|
if len(splitted_docs) > result_len:
|
||||||
normal = NormalizedLevenshtein()
|
normal = NormalizedLevenshtein()
|
||||||
@ -74,13 +74,13 @@ def metaphor_search(
|
|||||||
splitted_docs = splitted_docs[:result_len]
|
splitted_docs = splitted_docs[:result_len]
|
||||||
|
|
||||||
docs = [{"snippet": x.page_content,
|
docs = [{"snippet": x.page_content,
|
||||||
"link": x.metadata["link"],
|
"link": x.metadata["link"],
|
||||||
"title": x.metadata["title"]}
|
"title": x.metadata["title"]}
|
||||||
for x in splitted_docs]
|
for x in splitted_docs]
|
||||||
else:
|
else:
|
||||||
docs = [{"snippet": x.extract,
|
docs = [{"snippet": x.extract,
|
||||||
"link": x.url,
|
"link": x.url,
|
||||||
"title": x.title}
|
"title": x.title}
|
||||||
for x in contents]
|
for x in contents]
|
||||||
|
|
||||||
return docs
|
return docs
|
||||||
@ -113,25 +113,27 @@ async def lookup_search_engine(
|
|||||||
docs = search_result2docs(results)
|
docs = search_result2docs(results)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
|
||||||
async def search_engine_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
async def search_engine_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
||||||
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
|
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
|
||||||
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
||||||
history: List[History] = Body([],
|
history: List[History] = Body([],
|
||||||
description="历史对话",
|
description="历史对话",
|
||||||
examples=[[
|
examples=[[
|
||||||
{"role": "user",
|
{"role": "user",
|
||||||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||||
{"role": "assistant",
|
{"role": "assistant",
|
||||||
"content": "虎头虎脑"}]]
|
"content": "虎头虎脑"}]]
|
||||||
),
|
),
|
||||||
stream: bool = Body(False, description="流式输出"),
|
stream: bool = Body(False, description="流式输出"),
|
||||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
max_tokens: Optional[int] = Body(None,
|
||||||
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||||
split_result: bool = Body(False, description="是否对搜索结果进行拆分(主要用于metaphor搜索引擎)")
|
prompt_name: str = Body("default",
|
||||||
):
|
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||||
|
split_result: bool = Body(False,
|
||||||
|
description="是否对搜索结果进行拆分(主要用于metaphor搜索引擎)")
|
||||||
|
):
|
||||||
if search_engine_name not in SEARCH_ENGINES.keys():
|
if search_engine_name not in SEARCH_ENGINES.keys():
|
||||||
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
||||||
|
|
||||||
@ -198,9 +200,9 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
|||||||
await task
|
await task
|
||||||
|
|
||||||
return EventSourceResponse(search_engine_chat_iterator(query=query,
|
return EventSourceResponse(search_engine_chat_iterator(query=query,
|
||||||
search_engine_name=search_engine_name,
|
search_engine_name=search_engine_name,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
history=history,
|
history=history,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
prompt_name=prompt_name),
|
prompt_name=prompt_name),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# 该文件封装了对api.py的请求,可以被不同的webui使用
|
# 该文件封装了对api.py的请求,可以被不同的webui使用
|
||||||
# 通过ApiRequest和AsyncApiRequest支持同步/异步调用
|
# 通过ApiRequest和AsyncApiRequest支持同步/异步调用
|
||||||
|
|
||||||
|
|
||||||
from typing import *
|
from typing import *
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
# 此处导入的配置为发起请求(如WEBUI)机器上的配置,主要用于为前端设置默认值。分布式部署时可以与服务器上的不同
|
# 此处导入的配置为发起请求(如WEBUI)机器上的配置,主要用于为前端设置默认值。分布式部署时可以与服务器上的不同
|
||||||
@ -27,7 +26,7 @@ from io import BytesIO
|
|||||||
from server.utils import set_httpx_config, api_address, get_httpx_client
|
from server.utils import set_httpx_config, api_address, get_httpx_client
|
||||||
|
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
from langchain_core._api import deprecated
|
||||||
|
|
||||||
set_httpx_config()
|
set_httpx_config()
|
||||||
|
|
||||||
@ -36,10 +35,11 @@ class ApiRequest:
|
|||||||
'''
|
'''
|
||||||
api.py调用的封装(同步模式),简化api调用方式
|
api.py调用的封装(同步模式),简化api调用方式
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
base_url: str = api_address(),
|
base_url: str = api_address(),
|
||||||
timeout: float = HTTPX_DEFAULT_TIMEOUT,
|
timeout: float = HTTPX_DEFAULT_TIMEOUT,
|
||||||
):
|
):
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
@ -55,12 +55,12 @@ class ApiRequest:
|
|||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
params: Union[Dict, List[Tuple], bytes] = None,
|
params: Union[Dict, List[Tuple], bytes] = None,
|
||||||
retry: int = 3,
|
retry: int = 3,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
||||||
while retry > 0:
|
while retry > 0:
|
||||||
try:
|
try:
|
||||||
@ -75,13 +75,13 @@ class ApiRequest:
|
|||||||
retry -= 1
|
retry -= 1
|
||||||
|
|
||||||
def post(
|
def post(
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
data: Dict = None,
|
data: Dict = None,
|
||||||
json: Dict = None,
|
json: Dict = None,
|
||||||
retry: int = 3,
|
retry: int = 3,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
||||||
while retry > 0:
|
while retry > 0:
|
||||||
try:
|
try:
|
||||||
@ -97,13 +97,13 @@ class ApiRequest:
|
|||||||
retry -= 1
|
retry -= 1
|
||||||
|
|
||||||
def delete(
|
def delete(
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
data: Dict = None,
|
data: Dict = None,
|
||||||
json: Dict = None,
|
json: Dict = None,
|
||||||
retry: int = 3,
|
retry: int = 3,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
||||||
while retry > 0:
|
while retry > 0:
|
||||||
try:
|
try:
|
||||||
@ -118,24 +118,25 @@ class ApiRequest:
|
|||||||
retry -= 1
|
retry -= 1
|
||||||
|
|
||||||
def _httpx_stream2generator(
|
def _httpx_stream2generator(
|
||||||
self,
|
self,
|
||||||
response: contextlib._GeneratorContextManager,
|
response: contextlib._GeneratorContextManager,
|
||||||
as_json: bool = False,
|
as_json: bool = False,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
将httpx.stream返回的GeneratorContextManager转化为普通生成器
|
将httpx.stream返回的GeneratorContextManager转化为普通生成器
|
||||||
'''
|
'''
|
||||||
|
|
||||||
async def ret_async(response, as_json):
|
async def ret_async(response, as_json):
|
||||||
try:
|
try:
|
||||||
async with response as r:
|
async with response as r:
|
||||||
async for chunk in r.aiter_text(None):
|
async for chunk in r.aiter_text(None):
|
||||||
if not chunk: # fastchat api yield empty bytes on start and end
|
if not chunk: # fastchat api yield empty bytes on start and end
|
||||||
continue
|
continue
|
||||||
if as_json:
|
if as_json:
|
||||||
try:
|
try:
|
||||||
if chunk.startswith("data: "):
|
if chunk.startswith("data: "):
|
||||||
data = json.loads(chunk[6:-2])
|
data = json.loads(chunk[6:-2])
|
||||||
elif chunk.startswith(":"): # skip sse comment line
|
elif chunk.startswith(":"): # skip sse comment line
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
@ -143,7 +144,7 @@ class ApiRequest:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。"
|
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。"
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||||
exc_info=e if log_verbose else None)
|
exc_info=e if log_verbose else None)
|
||||||
else:
|
else:
|
||||||
# print(chunk, end="", flush=True)
|
# print(chunk, end="", flush=True)
|
||||||
yield chunk
|
yield chunk
|
||||||
@ -158,20 +159,20 @@ class ApiRequest:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = f"API通信遇到错误:{e}"
|
msg = f"API通信遇到错误:{e}"
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||||
exc_info=e if log_verbose else None)
|
exc_info=e if log_verbose else None)
|
||||||
yield {"code": 500, "msg": msg}
|
yield {"code": 500, "msg": msg}
|
||||||
|
|
||||||
def ret_sync(response, as_json):
|
def ret_sync(response, as_json):
|
||||||
try:
|
try:
|
||||||
with response as r:
|
with response as r:
|
||||||
for chunk in r.iter_text(None):
|
for chunk in r.iter_text(None):
|
||||||
if not chunk: # fastchat api yield empty bytes on start and end
|
if not chunk: # fastchat api yield empty bytes on start and end
|
||||||
continue
|
continue
|
||||||
if as_json:
|
if as_json:
|
||||||
try:
|
try:
|
||||||
if chunk.startswith("data: "):
|
if chunk.startswith("data: "):
|
||||||
data = json.loads(chunk[6:-2])
|
data = json.loads(chunk[6:-2])
|
||||||
elif chunk.startswith(":"): # skip sse comment line
|
elif chunk.startswith(":"): # skip sse comment line
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
@ -179,7 +180,7 @@ class ApiRequest:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。"
|
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。"
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||||
exc_info=e if log_verbose else None)
|
exc_info=e if log_verbose else None)
|
||||||
else:
|
else:
|
||||||
# print(chunk, end="", flush=True)
|
# print(chunk, end="", flush=True)
|
||||||
yield chunk
|
yield chunk
|
||||||
@ -194,7 +195,7 @@ class ApiRequest:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = f"API通信遇到错误:{e}"
|
msg = f"API通信遇到错误:{e}"
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||||
exc_info=e if log_verbose else None)
|
exc_info=e if log_verbose else None)
|
||||||
yield {"code": 500, "msg": msg}
|
yield {"code": 500, "msg": msg}
|
||||||
|
|
||||||
if self._use_async:
|
if self._use_async:
|
||||||
@ -203,16 +204,17 @@ class ApiRequest:
|
|||||||
return ret_sync(response, as_json)
|
return ret_sync(response, as_json)
|
||||||
|
|
||||||
def _get_response_value(
|
def _get_response_value(
|
||||||
self,
|
self,
|
||||||
response: httpx.Response,
|
response: httpx.Response,
|
||||||
as_json: bool = False,
|
as_json: bool = False,
|
||||||
value_func: Callable = None,
|
value_func: Callable = None,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
转换同步或异步请求返回的响应
|
转换同步或异步请求返回的响应
|
||||||
`as_json`: 返回json
|
`as_json`: 返回json
|
||||||
`value_func`: 用户可以自定义返回值,该函数接受response或json
|
`value_func`: 用户可以自定义返回值,该函数接受response或json
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def to_json(r):
|
def to_json(r):
|
||||||
try:
|
try:
|
||||||
return r.json()
|
return r.json()
|
||||||
@ -220,7 +222,7 @@ class ApiRequest:
|
|||||||
msg = "API未能返回正确的JSON。" + str(e)
|
msg = "API未能返回正确的JSON。" + str(e)
|
||||||
if log_verbose:
|
if log_verbose:
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||||
exc_info=e if log_verbose else None)
|
exc_info=e if log_verbose else None)
|
||||||
return {"code": 500, "msg": msg, "data": None}
|
return {"code": 500, "msg": msg, "data": None}
|
||||||
|
|
||||||
if value_func is None:
|
if value_func is None:
|
||||||
@ -250,10 +252,10 @@ class ApiRequest:
|
|||||||
return self._get_response_value(response, as_json=True, value_func=lambda r: r["data"])
|
return self._get_response_value(response, as_json=True, value_func=lambda r: r["data"])
|
||||||
|
|
||||||
def get_prompt_template(
|
def get_prompt_template(
|
||||||
self,
|
self,
|
||||||
type: str = "llm_chat",
|
type: str = "llm_chat",
|
||||||
name: str = "default",
|
name: str = "default",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
data = {
|
data = {
|
||||||
"type": type,
|
"type": type,
|
||||||
@ -297,15 +299,19 @@ class ApiRequest:
|
|||||||
response = self.post("/chat/chat", json=data, stream=True, **kwargs)
|
response = self.post("/chat/chat", json=data, stream=True, **kwargs)
|
||||||
return self._httpx_stream2generator(response, as_json=True)
|
return self._httpx_stream2generator(response, as_json=True)
|
||||||
|
|
||||||
|
@deprecated(
|
||||||
|
since="0.3.0",
|
||||||
|
message="自定义Agent问答将于 Langchain-Chatchat 0.3.x重写, 0.2.x中相关功能将废弃",
|
||||||
|
removal="0.3.0")
|
||||||
def agent_chat(
|
def agent_chat(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
history: List[Dict] = [],
|
history: List[Dict] = [],
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
model: str = LLM_MODELS[0],
|
model: str = LLM_MODELS[0],
|
||||||
temperature: float = TEMPERATURE,
|
temperature: float = TEMPERATURE,
|
||||||
max_tokens: int = None,
|
max_tokens: int = None,
|
||||||
prompt_name: str = "default",
|
prompt_name: str = "default",
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/chat/agent_chat 接口
|
对应api.py/chat/agent_chat 接口
|
||||||
@ -327,17 +333,17 @@ class ApiRequest:
|
|||||||
return self._httpx_stream2generator(response, as_json=True)
|
return self._httpx_stream2generator(response, as_json=True)
|
||||||
|
|
||||||
def knowledge_base_chat(
|
def knowledge_base_chat(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||||
score_threshold: float = SCORE_THRESHOLD,
|
score_threshold: float = SCORE_THRESHOLD,
|
||||||
history: List[Dict] = [],
|
history: List[Dict] = [],
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
model: str = LLM_MODELS[0],
|
model: str = LLM_MODELS[0],
|
||||||
temperature: float = TEMPERATURE,
|
temperature: float = TEMPERATURE,
|
||||||
max_tokens: int = None,
|
max_tokens: int = None,
|
||||||
prompt_name: str = "default",
|
prompt_name: str = "default",
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/chat/knowledge_base_chat接口
|
对应api.py/chat/knowledge_base_chat接口
|
||||||
@ -366,28 +372,29 @@ class ApiRequest:
|
|||||||
return self._httpx_stream2generator(response, as_json=True)
|
return self._httpx_stream2generator(response, as_json=True)
|
||||||
|
|
||||||
def upload_temp_docs(
|
def upload_temp_docs(
|
||||||
self,
|
self,
|
||||||
files: List[Union[str, Path, bytes]],
|
files: List[Union[str, Path, bytes]],
|
||||||
knowledge_id: str = None,
|
knowledge_id: str = None,
|
||||||
chunk_size=CHUNK_SIZE,
|
chunk_size=CHUNK_SIZE,
|
||||||
chunk_overlap=OVERLAP_SIZE,
|
chunk_overlap=OVERLAP_SIZE,
|
||||||
zh_title_enhance=ZH_TITLE_ENHANCE,
|
zh_title_enhance=ZH_TITLE_ENHANCE,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/upload_tmep_docs接口
|
对应api.py/knowledge_base/upload_tmep_docs接口
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def convert_file(file, filename=None):
|
def convert_file(file, filename=None):
|
||||||
if isinstance(file, bytes): # raw bytes
|
if isinstance(file, bytes): # raw bytes
|
||||||
file = BytesIO(file)
|
file = BytesIO(file)
|
||||||
elif hasattr(file, "read"): # a file io like object
|
elif hasattr(file, "read"): # a file io like object
|
||||||
filename = filename or file.name
|
filename = filename or file.name
|
||||||
else: # a local path
|
else: # a local path
|
||||||
file = Path(file).absolute().open("rb")
|
file = Path(file).absolute().open("rb")
|
||||||
filename = filename or os.path.split(file.name)[-1]
|
filename = filename or os.path.split(file.name)[-1]
|
||||||
return filename, file
|
return filename, file
|
||||||
|
|
||||||
files = [convert_file(file) for file in files]
|
files = [convert_file(file) for file in files]
|
||||||
data={
|
data = {
|
||||||
"knowledge_id": knowledge_id,
|
"knowledge_id": knowledge_id,
|
||||||
"chunk_size": chunk_size,
|
"chunk_size": chunk_size,
|
||||||
"chunk_overlap": chunk_overlap,
|
"chunk_overlap": chunk_overlap,
|
||||||
@ -402,17 +409,17 @@ class ApiRequest:
|
|||||||
return self._get_response_value(response, as_json=True)
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
def file_chat(
|
def file_chat(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
knowledge_id: str,
|
knowledge_id: str,
|
||||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||||
score_threshold: float = SCORE_THRESHOLD,
|
score_threshold: float = SCORE_THRESHOLD,
|
||||||
history: List[Dict] = [],
|
history: List[Dict] = [],
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
model: str = LLM_MODELS[0],
|
model: str = LLM_MODELS[0],
|
||||||
temperature: float = TEMPERATURE,
|
temperature: float = TEMPERATURE,
|
||||||
max_tokens: int = None,
|
max_tokens: int = None,
|
||||||
prompt_name: str = "default",
|
prompt_name: str = "default",
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/chat/file_chat接口
|
对应api.py/chat/file_chat接口
|
||||||
@ -440,18 +447,23 @@ class ApiRequest:
|
|||||||
)
|
)
|
||||||
return self._httpx_stream2generator(response, as_json=True)
|
return self._httpx_stream2generator(response, as_json=True)
|
||||||
|
|
||||||
|
@deprecated(
|
||||||
|
since="0.3.0",
|
||||||
|
message="搜索引擎问答将于 Langchain-Chatchat 0.3.x重写, 0.2.x中相关功能将废弃",
|
||||||
|
removal="0.3.0"
|
||||||
|
)
|
||||||
def search_engine_chat(
|
def search_engine_chat(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
search_engine_name: str,
|
search_engine_name: str,
|
||||||
top_k: int = SEARCH_ENGINE_TOP_K,
|
top_k: int = SEARCH_ENGINE_TOP_K,
|
||||||
history: List[Dict] = [],
|
history: List[Dict] = [],
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
model: str = LLM_MODELS[0],
|
model: str = LLM_MODELS[0],
|
||||||
temperature: float = TEMPERATURE,
|
temperature: float = TEMPERATURE,
|
||||||
max_tokens: int = None,
|
max_tokens: int = None,
|
||||||
prompt_name: str = "default",
|
prompt_name: str = "default",
|
||||||
split_result: bool = False,
|
split_result: bool = False,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/chat/search_engine_chat接口
|
对应api.py/chat/search_engine_chat接口
|
||||||
@ -482,7 +494,7 @@ class ApiRequest:
|
|||||||
# 知识库相关操作
|
# 知识库相关操作
|
||||||
|
|
||||||
def list_knowledge_bases(
|
def list_knowledge_bases(
|
||||||
self,
|
self,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/list_knowledge_bases接口
|
对应api.py/knowledge_base/list_knowledge_bases接口
|
||||||
@ -493,10 +505,10 @@ class ApiRequest:
|
|||||||
value_func=lambda r: r.get("data", []))
|
value_func=lambda r: r.get("data", []))
|
||||||
|
|
||||||
def create_knowledge_base(
|
def create_knowledge_base(
|
||||||
self,
|
self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
vector_store_type: str = DEFAULT_VS_TYPE,
|
vector_store_type: str = DEFAULT_VS_TYPE,
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/create_knowledge_base接口
|
对应api.py/knowledge_base/create_knowledge_base接口
|
||||||
@ -514,8 +526,8 @@ class ApiRequest:
|
|||||||
return self._get_response_value(response, as_json=True)
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
def delete_knowledge_base(
|
def delete_knowledge_base(
|
||||||
self,
|
self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/delete_knowledge_base接口
|
对应api.py/knowledge_base/delete_knowledge_base接口
|
||||||
@ -527,8 +539,8 @@ class ApiRequest:
|
|||||||
return self._get_response_value(response, as_json=True)
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
def list_kb_docs(
|
def list_kb_docs(
|
||||||
self,
|
self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/list_files接口
|
对应api.py/knowledge_base/list_files接口
|
||||||
@ -542,13 +554,13 @@ class ApiRequest:
|
|||||||
value_func=lambda r: r.get("data", []))
|
value_func=lambda r: r.get("data", []))
|
||||||
|
|
||||||
def search_kb_docs(
|
def search_kb_docs(
|
||||||
self,
|
self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
query: str = "",
|
query: str = "",
|
||||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||||
score_threshold: int = SCORE_THRESHOLD,
|
score_threshold: int = SCORE_THRESHOLD,
|
||||||
file_name: str = "",
|
file_name: str = "",
|
||||||
metadata: dict = {},
|
metadata: dict = {},
|
||||||
) -> List:
|
) -> List:
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/search_docs接口
|
对应api.py/knowledge_base/search_docs接口
|
||||||
@ -569,9 +581,9 @@ class ApiRequest:
|
|||||||
return self._get_response_value(response, as_json=True)
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
def update_docs_by_id(
|
def update_docs_by_id(
|
||||||
self,
|
self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
docs: Dict[str, Dict],
|
docs: Dict[str, Dict],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/update_docs_by_id接口
|
对应api.py/knowledge_base/update_docs_by_id接口
|
||||||
@ -587,32 +599,33 @@ class ApiRequest:
|
|||||||
return self._get_response_value(response)
|
return self._get_response_value(response)
|
||||||
|
|
||||||
def upload_kb_docs(
|
def upload_kb_docs(
|
||||||
self,
|
self,
|
||||||
files: List[Union[str, Path, bytes]],
|
files: List[Union[str, Path, bytes]],
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
override: bool = False,
|
override: bool = False,
|
||||||
to_vector_store: bool = True,
|
to_vector_store: bool = True,
|
||||||
chunk_size=CHUNK_SIZE,
|
chunk_size=CHUNK_SIZE,
|
||||||
chunk_overlap=OVERLAP_SIZE,
|
chunk_overlap=OVERLAP_SIZE,
|
||||||
zh_title_enhance=ZH_TITLE_ENHANCE,
|
zh_title_enhance=ZH_TITLE_ENHANCE,
|
||||||
docs: Dict = {},
|
docs: Dict = {},
|
||||||
not_refresh_vs_cache: bool = False,
|
not_refresh_vs_cache: bool = False,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/upload_docs接口
|
对应api.py/knowledge_base/upload_docs接口
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def convert_file(file, filename=None):
|
def convert_file(file, filename=None):
|
||||||
if isinstance(file, bytes): # raw bytes
|
if isinstance(file, bytes): # raw bytes
|
||||||
file = BytesIO(file)
|
file = BytesIO(file)
|
||||||
elif hasattr(file, "read"): # a file io like object
|
elif hasattr(file, "read"): # a file io like object
|
||||||
filename = filename or file.name
|
filename = filename or file.name
|
||||||
else: # a local path
|
else: # a local path
|
||||||
file = Path(file).absolute().open("rb")
|
file = Path(file).absolute().open("rb")
|
||||||
filename = filename or os.path.split(file.name)[-1]
|
filename = filename or os.path.split(file.name)[-1]
|
||||||
return filename, file
|
return filename, file
|
||||||
|
|
||||||
files = [convert_file(file) for file in files]
|
files = [convert_file(file) for file in files]
|
||||||
data={
|
data = {
|
||||||
"knowledge_base_name": knowledge_base_name,
|
"knowledge_base_name": knowledge_base_name,
|
||||||
"override": override,
|
"override": override,
|
||||||
"to_vector_store": to_vector_store,
|
"to_vector_store": to_vector_store,
|
||||||
@ -633,11 +646,11 @@ class ApiRequest:
|
|||||||
return self._get_response_value(response, as_json=True)
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
def delete_kb_docs(
|
def delete_kb_docs(
|
||||||
self,
|
self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
file_names: List[str],
|
file_names: List[str],
|
||||||
delete_content: bool = False,
|
delete_content: bool = False,
|
||||||
not_refresh_vs_cache: bool = False,
|
not_refresh_vs_cache: bool = False,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/delete_docs接口
|
对应api.py/knowledge_base/delete_docs接口
|
||||||
@ -655,8 +668,7 @@ class ApiRequest:
|
|||||||
)
|
)
|
||||||
return self._get_response_value(response, as_json=True)
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
|
def update_kb_info(self, knowledge_base_name, kb_info):
|
||||||
def update_kb_info(self,knowledge_base_name,kb_info):
|
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/update_info接口
|
对应api.py/knowledge_base/update_info接口
|
||||||
'''
|
'''
|
||||||
@ -672,15 +684,15 @@ class ApiRequest:
|
|||||||
return self._get_response_value(response, as_json=True)
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
def update_kb_docs(
|
def update_kb_docs(
|
||||||
self,
|
self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
file_names: List[str],
|
file_names: List[str],
|
||||||
override_custom_docs: bool = False,
|
override_custom_docs: bool = False,
|
||||||
chunk_size=CHUNK_SIZE,
|
chunk_size=CHUNK_SIZE,
|
||||||
chunk_overlap=OVERLAP_SIZE,
|
chunk_overlap=OVERLAP_SIZE,
|
||||||
zh_title_enhance=ZH_TITLE_ENHANCE,
|
zh_title_enhance=ZH_TITLE_ENHANCE,
|
||||||
docs: Dict = {},
|
docs: Dict = {},
|
||||||
not_refresh_vs_cache: bool = False,
|
not_refresh_vs_cache: bool = False,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/update_docs接口
|
对应api.py/knowledge_base/update_docs接口
|
||||||
@ -706,14 +718,14 @@ class ApiRequest:
|
|||||||
return self._get_response_value(response, as_json=True)
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
def recreate_vector_store(
|
def recreate_vector_store(
|
||||||
self,
|
self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
allow_empty_kb: bool = True,
|
allow_empty_kb: bool = True,
|
||||||
vs_type: str = DEFAULT_VS_TYPE,
|
vs_type: str = DEFAULT_VS_TYPE,
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
chunk_size=CHUNK_SIZE,
|
chunk_size=CHUNK_SIZE,
|
||||||
chunk_overlap=OVERLAP_SIZE,
|
chunk_overlap=OVERLAP_SIZE,
|
||||||
zh_title_enhance=ZH_TITLE_ENHANCE,
|
zh_title_enhance=ZH_TITLE_ENHANCE,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/recreate_vector_store接口
|
对应api.py/knowledge_base/recreate_vector_store接口
|
||||||
@ -738,8 +750,8 @@ class ApiRequest:
|
|||||||
|
|
||||||
# LLM模型相关操作
|
# LLM模型相关操作
|
||||||
def list_running_models(
|
def list_running_models(
|
||||||
self,
|
self,
|
||||||
controller_address: str = None,
|
controller_address: str = None,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
获取Fastchat中正运行的模型列表
|
获取Fastchat中正运行的模型列表
|
||||||
@ -755,8 +767,7 @@ class ApiRequest:
|
|||||||
"/llm_model/list_running_models",
|
"/llm_model/list_running_models",
|
||||||
json=data,
|
json=data,
|
||||||
)
|
)
|
||||||
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", []))
|
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", []))
|
||||||
|
|
||||||
|
|
||||||
def get_default_llm_model(self, local_first: bool = True) -> Tuple[str, bool]:
|
def get_default_llm_model(self, local_first: bool = True) -> Tuple[str, bool]:
|
||||||
'''
|
'''
|
||||||
@ -764,6 +775,7 @@ class ApiRequest:
|
|||||||
当 local_first=True 时,优先返回运行中的本地模型,否则优先按LLM_MODELS配置顺序返回。
|
当 local_first=True 时,优先返回运行中的本地模型,否则优先按LLM_MODELS配置顺序返回。
|
||||||
返回类型为(model_name, is_local_model)
|
返回类型为(model_name, is_local_model)
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def ret_sync():
|
def ret_sync():
|
||||||
running_models = self.list_running_models()
|
running_models = self.list_running_models()
|
||||||
if not running_models:
|
if not running_models:
|
||||||
@ -780,7 +792,7 @@ class ApiRequest:
|
|||||||
model = m
|
model = m
|
||||||
break
|
break
|
||||||
|
|
||||||
if not model: # LLM_MODELS中配置的模型都不在running_models里
|
if not model: # LLM_MODELS中配置的模型都不在running_models里
|
||||||
model = list(running_models)[0]
|
model = list(running_models)[0]
|
||||||
is_local = not running_models[model].get("online_api")
|
is_local = not running_models[model].get("online_api")
|
||||||
return model, is_local
|
return model, is_local
|
||||||
@ -801,7 +813,7 @@ class ApiRequest:
|
|||||||
model = m
|
model = m
|
||||||
break
|
break
|
||||||
|
|
||||||
if not model: # LLM_MODELS中配置的模型都不在running_models里
|
if not model: # LLM_MODELS中配置的模型都不在running_models里
|
||||||
model = list(running_models)[0]
|
model = list(running_models)[0]
|
||||||
is_local = not running_models[model].get("online_api")
|
is_local = not running_models[model].get("online_api")
|
||||||
return model, is_local
|
return model, is_local
|
||||||
@ -812,8 +824,8 @@ class ApiRequest:
|
|||||||
return ret_sync()
|
return ret_sync()
|
||||||
|
|
||||||
def list_config_models(
|
def list_config_models(
|
||||||
self,
|
self,
|
||||||
types: List[str] = ["local", "online"],
|
types: List[str] = ["local", "online"],
|
||||||
) -> Dict[str, Dict]:
|
) -> Dict[str, Dict]:
|
||||||
'''
|
'''
|
||||||
获取服务器configs中配置的模型列表,返回形式为{"type": {model_name: config}, ...}。
|
获取服务器configs中配置的模型列表,返回形式为{"type": {model_name: config}, ...}。
|
||||||
@ -825,23 +837,23 @@ class ApiRequest:
|
|||||||
"/llm_model/list_config_models",
|
"/llm_model/list_config_models",
|
||||||
json=data,
|
json=data,
|
||||||
)
|
)
|
||||||
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {}))
|
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
|
||||||
|
|
||||||
def get_model_config(
|
def get_model_config(
|
||||||
self,
|
self,
|
||||||
model_name: str = None,
|
model_name: str = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
'''
|
'''
|
||||||
获取服务器上模型配置
|
获取服务器上模型配置
|
||||||
'''
|
'''
|
||||||
data={
|
data = {
|
||||||
"model_name": model_name,
|
"model_name": model_name,
|
||||||
}
|
}
|
||||||
response = self.post(
|
response = self.post(
|
||||||
"/llm_model/get_model_config",
|
"/llm_model/get_model_config",
|
||||||
json=data,
|
json=data,
|
||||||
)
|
)
|
||||||
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {}))
|
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
|
||||||
|
|
||||||
def list_search_engines(self) -> List[str]:
|
def list_search_engines(self) -> List[str]:
|
||||||
'''
|
'''
|
||||||
@ -850,12 +862,12 @@ class ApiRequest:
|
|||||||
response = self.post(
|
response = self.post(
|
||||||
"/server/list_search_engines",
|
"/server/list_search_engines",
|
||||||
)
|
)
|
||||||
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {}))
|
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
|
||||||
|
|
||||||
def stop_llm_model(
|
def stop_llm_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
controller_address: str = None,
|
controller_address: str = None,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
停止某个LLM模型。
|
停止某个LLM模型。
|
||||||
@ -873,10 +885,10 @@ class ApiRequest:
|
|||||||
return self._get_response_value(response, as_json=True)
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
def change_llm_model(
|
def change_llm_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
new_model_name: str,
|
new_model_name: str,
|
||||||
controller_address: str = None,
|
controller_address: str = None,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
向fastchat controller请求切换LLM模型。
|
向fastchat controller请求切换LLM模型。
|
||||||
@ -959,10 +971,10 @@ class ApiRequest:
|
|||||||
return ret_sync()
|
return ret_sync()
|
||||||
|
|
||||||
def embed_texts(
|
def embed_texts(
|
||||||
self,
|
self,
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
to_query: bool = False,
|
to_query: bool = False,
|
||||||
) -> List[List[float]]:
|
) -> List[List[float]]:
|
||||||
'''
|
'''
|
||||||
对文本进行向量化,可选模型包括本地 embed_models 和支持 embeddings 的在线模型
|
对文本进行向量化,可选模型包括本地 embed_models 和支持 embeddings 的在线模型
|
||||||
@ -979,10 +991,10 @@ class ApiRequest:
|
|||||||
return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data"))
|
return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data"))
|
||||||
|
|
||||||
def chat_feedback(
|
def chat_feedback(
|
||||||
self,
|
self,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
score: int,
|
score: int,
|
||||||
reason: str = "",
|
reason: str = "",
|
||||||
) -> int:
|
) -> int:
|
||||||
'''
|
'''
|
||||||
反馈对话评价
|
反馈对话评价
|
||||||
@ -1019,9 +1031,9 @@ def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str:
|
|||||||
return error message if error occured when requests API
|
return error message if error occured when requests API
|
||||||
'''
|
'''
|
||||||
if (isinstance(data, dict)
|
if (isinstance(data, dict)
|
||||||
and key in data
|
and key in data
|
||||||
and "code" in data
|
and "code" in data
|
||||||
and data["code"] == 200):
|
and data["code"] == 200):
|
||||||
return data[key]
|
return data[key]
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user