mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-07 15:38:27 +08:00
gemini api 修复调用
This commit is contained in:
parent
b6d2bc71ce
commit
17803cb7c1
@ -92,11 +92,10 @@ FSCHAT_MODEL_WORKERS = {
|
|||||||
# 'disable_log_requests': False
|
# 'disable_log_requests': False
|
||||||
|
|
||||||
},
|
},
|
||||||
# 可以如下示例方式更改默认配置
|
"Qwen-1_8B-Chat": {
|
||||||
# "Qwen-1_8B-Chat": { # 使用default中的IP和端口
|
"device": "cpu",
|
||||||
# "device": "cpu",
|
},
|
||||||
# },
|
"chatglm3-6b": {
|
||||||
"chatglm3-6b": { # 使用default中的IP和端口
|
|
||||||
"device": "cuda",
|
"device": "cuda",
|
||||||
},
|
},
|
||||||
|
|
||||||
@ -129,7 +128,7 @@ FSCHAT_MODEL_WORKERS = {
|
|||||||
"port": 21009,
|
"port": 21009,
|
||||||
},
|
},
|
||||||
"gemini-api": {
|
"gemini-api": {
|
||||||
"port": 21012,
|
"port": 21010,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -13,9 +13,9 @@ def list_kbs():
|
|||||||
|
|
||||||
|
|
||||||
def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||||
vector_store_type: str = Body("faiss"),
|
vector_store_type: str = Body("faiss"),
|
||||||
embed_model: str = Body(EMBEDDING_MODEL),
|
embed_model: str = Body(EMBEDDING_MODEL),
|
||||||
) -> BaseResponse:
|
) -> BaseResponse:
|
||||||
# Create selected knowledge base
|
# Create selected knowledge base
|
||||||
if not validate_kb_name(knowledge_base_name):
|
if not validate_kb_name(knowledge_base_name):
|
||||||
return BaseResponse(code=403, msg="Don't attack me")
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
@ -39,8 +39,8 @@ def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
|||||||
|
|
||||||
|
|
||||||
def delete_kb(
|
def delete_kb(
|
||||||
knowledge_base_name: str = Body(..., examples=["samples"])
|
knowledge_base_name: str = Body(..., examples=["samples"])
|
||||||
) -> BaseResponse:
|
) -> BaseResponse:
|
||||||
# Delete selected knowledge base
|
# Delete selected knowledge base
|
||||||
if not validate_kb_name(knowledge_base_name):
|
if not validate_kb_name(knowledge_base_name):
|
||||||
return BaseResponse(code=403, msg="Don't attack me")
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from fastchat.conversation import Conversation
|
|||||||
from server.model_workers.base import *
|
from server.model_workers.base import *
|
||||||
from server.utils import get_httpx_client
|
from server.utils import get_httpx_client
|
||||||
from fastchat import conversation as conv
|
from fastchat import conversation as conv
|
||||||
import json,httpx
|
import json, httpx
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
from configs import logger, log_verbose
|
from configs import logger, log_verbose
|
||||||
|
|
||||||
@ -14,14 +14,14 @@ class GeminiWorker(ApiModelWorker):
|
|||||||
*,
|
*,
|
||||||
controller_addr: str = None,
|
controller_addr: str = None,
|
||||||
worker_addr: str = None,
|
worker_addr: str = None,
|
||||||
model_names: List[str] = ["Gemini-api"],
|
model_names: List[str] = ["gemini-api"],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||||
kwargs.setdefault("context_len", 4096)
|
kwargs.setdefault("context_len", 4096)
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def create_gemini_messages(self,messages) -> json:
|
def create_gemini_messages(self, messages) -> json:
|
||||||
has_history = any(msg['role'] == 'assistant' for msg in messages)
|
has_history = any(msg['role'] == 'assistant' for msg in messages)
|
||||||
gemini_msg = []
|
gemini_msg = []
|
||||||
|
|
||||||
@ -46,7 +46,7 @@ class GeminiWorker(ApiModelWorker):
|
|||||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||||
params.load_config(self.model_names[0])
|
params.load_config(self.model_names[0])
|
||||||
data = self.create_gemini_messages(messages=params.messages)
|
data = self.create_gemini_messages(messages=params.messages)
|
||||||
generationConfig=dict(
|
generationConfig = dict(
|
||||||
temperature=params.temperature,
|
temperature=params.temperature,
|
||||||
topK=1,
|
topK=1,
|
||||||
topP=1,
|
topP=1,
|
||||||
@ -55,7 +55,7 @@ class GeminiWorker(ApiModelWorker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
data['generationConfig'] = generationConfig
|
data['generationConfig'] = generationConfig
|
||||||
url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent"+ '?key=' + params.api_key
|
url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent" + '?key=' + params.api_key
|
||||||
headers = {
|
headers = {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
}
|
}
|
||||||
@ -67,7 +67,7 @@ class GeminiWorker(ApiModelWorker):
|
|||||||
text = ""
|
text = ""
|
||||||
json_string = ""
|
json_string = ""
|
||||||
timeout = httpx.Timeout(60.0)
|
timeout = httpx.Timeout(60.0)
|
||||||
client=get_httpx_client(timeout=timeout)
|
client = get_httpx_client(timeout=timeout)
|
||||||
with client.stream("POST", url, headers=headers, json=data) as response:
|
with client.stream("POST", url, headers=headers, json=data) as response:
|
||||||
for line in response.iter_lines():
|
for line in response.iter_lines():
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
|
|||||||
@ -12,7 +12,6 @@ from server.knowledge_base.utils import LOADER_DICT
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
|
|
||||||
chat_box = ChatBox(
|
chat_box = ChatBox(
|
||||||
assistant_avatar=os.path.join(
|
assistant_avatar=os.path.join(
|
||||||
"img",
|
"img",
|
||||||
@ -138,11 +137,11 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
st.toast(text)
|
st.toast(text)
|
||||||
|
|
||||||
dialogue_modes = ["LLM 对话",
|
dialogue_modes = ["LLM 对话",
|
||||||
"知识库问答",
|
"知识库问答",
|
||||||
"文件对话",
|
"文件对话",
|
||||||
"搜索引擎问答",
|
"搜索引擎问答",
|
||||||
"自定义Agent问答",
|
"自定义Agent问答",
|
||||||
]
|
]
|
||||||
dialogue_mode = st.selectbox("请选择对话模式:",
|
dialogue_mode = st.selectbox("请选择对话模式:",
|
||||||
dialogue_modes,
|
dialogue_modes,
|
||||||
index=0,
|
index=0,
|
||||||
@ -166,9 +165,9 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
available_models = []
|
available_models = []
|
||||||
config_models = api.list_config_models()
|
config_models = api.list_config_models()
|
||||||
if not is_lite:
|
if not is_lite:
|
||||||
for k, v in config_models.get("local", {}).items(): # 列出配置了有效本地路径的模型
|
for k, v in config_models.get("local", {}).items(): # 列出配置了有效本地路径的模型
|
||||||
if (v.get("model_path_exists")
|
if (v.get("model_path_exists")
|
||||||
and k not in running_models):
|
and k not in running_models):
|
||||||
available_models.append(k)
|
available_models.append(k)
|
||||||
for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型
|
for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型
|
||||||
if not v.get("provider") and k not in running_models:
|
if not v.get("provider") and k not in running_models:
|
||||||
@ -250,14 +249,14 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
elif dialogue_mode == "文件对话":
|
elif dialogue_mode == "文件对话":
|
||||||
with st.expander("文件对话配置", True):
|
with st.expander("文件对话配置", True):
|
||||||
files = st.file_uploader("上传知识文件:",
|
files = st.file_uploader("上传知识文件:",
|
||||||
[i for ls in LOADER_DICT.values() for i in ls],
|
[i for ls in LOADER_DICT.values() for i in ls],
|
||||||
accept_multiple_files=True,
|
accept_multiple_files=True,
|
||||||
)
|
)
|
||||||
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
|
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
|
||||||
|
|
||||||
## Bge 模型会超过1
|
## Bge 模型会超过1
|
||||||
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
|
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
|
||||||
if st.button("开始上传", disabled=len(files)==0):
|
if st.button("开始上传", disabled=len(files) == 0):
|
||||||
st.session_state["file_chat_id"] = upload_temp_docs(files, api)
|
st.session_state["file_chat_id"] = upload_temp_docs(files, api)
|
||||||
elif dialogue_mode == "搜索引擎问答":
|
elif dialogue_mode == "搜索引擎问答":
|
||||||
search_engine_list = api.list_search_engines()
|
search_engine_list = api.list_search_engines()
|
||||||
@ -279,9 +278,9 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter。输入/help查看自定义命令 "
|
chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter。输入/help查看自定义命令 "
|
||||||
|
|
||||||
def on_feedback(
|
def on_feedback(
|
||||||
feedback,
|
feedback,
|
||||||
message_id: str = "",
|
message_id: str = "",
|
||||||
history_index: int = -1,
|
history_index: int = -1,
|
||||||
):
|
):
|
||||||
reason = feedback["text"]
|
reason = feedback["text"]
|
||||||
score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index)
|
score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index)
|
||||||
@ -296,7 +295,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if prompt := st.chat_input(chat_input_placeholder, key="prompt"):
|
if prompt := st.chat_input(chat_input_placeholder, key="prompt"):
|
||||||
if parse_command(text=prompt, modal=modal): # 用户输入自定义命令
|
if parse_command(text=prompt, modal=modal): # 用户输入自定义命令
|
||||||
st.rerun()
|
st.rerun()
|
||||||
else:
|
else:
|
||||||
history = get_messages_history(history_len)
|
history = get_messages_history(history_len)
|
||||||
@ -306,11 +305,11 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
text = ""
|
text = ""
|
||||||
message_id = ""
|
message_id = ""
|
||||||
r = api.chat_chat(prompt,
|
r = api.chat_chat(prompt,
|
||||||
history=history,
|
history=history,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
model=llm_model,
|
model=llm_model,
|
||||||
prompt_name=prompt_template_name,
|
prompt_name=prompt_template_name,
|
||||||
temperature=temperature)
|
temperature=temperature)
|
||||||
for t in r:
|
for t in r:
|
||||||
if error_msg := check_error_msg(t): # check whether error occured
|
if error_msg := check_error_msg(t): # check whether error occured
|
||||||
st.error(error_msg)
|
st.error(error_msg)
|
||||||
@ -321,12 +320,12 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
}
|
}
|
||||||
chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标
|
chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标
|
||||||
chat_box.show_feedback(**feedback_kwargs,
|
chat_box.show_feedback(**feedback_kwargs,
|
||||||
key=message_id,
|
key=message_id,
|
||||||
on_submit=on_feedback,
|
on_submit=on_feedback,
|
||||||
kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
|
kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
|
||||||
|
|
||||||
elif dialogue_mode == "自定义Agent问答":
|
elif dialogue_mode == "自定义Agent问答":
|
||||||
if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):
|
if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):
|
||||||
@ -373,13 +372,13 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
])
|
])
|
||||||
text = ""
|
text = ""
|
||||||
for d in api.knowledge_base_chat(prompt,
|
for d in api.knowledge_base_chat(prompt,
|
||||||
knowledge_base_name=selected_kb,
|
knowledge_base_name=selected_kb,
|
||||||
top_k=kb_top_k,
|
top_k=kb_top_k,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
history=history,
|
history=history,
|
||||||
model=llm_model,
|
model=llm_model,
|
||||||
prompt_name=prompt_template_name,
|
prompt_name=prompt_template_name,
|
||||||
temperature=temperature):
|
temperature=temperature):
|
||||||
if error_msg := check_error_msg(d): # check whether error occured
|
if error_msg := check_error_msg(d): # check whether error occured
|
||||||
st.error(error_msg)
|
st.error(error_msg)
|
||||||
elif chunk := d.get("answer"):
|
elif chunk := d.get("answer"):
|
||||||
@ -397,13 +396,13 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
])
|
])
|
||||||
text = ""
|
text = ""
|
||||||
for d in api.file_chat(prompt,
|
for d in api.file_chat(prompt,
|
||||||
knowledge_id=st.session_state["file_chat_id"],
|
knowledge_id=st.session_state["file_chat_id"],
|
||||||
top_k=kb_top_k,
|
top_k=kb_top_k,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
history=history,
|
history=history,
|
||||||
model=llm_model,
|
model=llm_model,
|
||||||
prompt_name=prompt_template_name,
|
prompt_name=prompt_template_name,
|
||||||
temperature=temperature):
|
temperature=temperature):
|
||||||
if error_msg := check_error_msg(d): # check whether error occured
|
if error_msg := check_error_msg(d): # check whether error occured
|
||||||
st.error(error_msg)
|
st.error(error_msg)
|
||||||
elif chunk := d.get("answer"):
|
elif chunk := d.get("answer"):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user