gemini api 修复调用

This commit is contained in:
zR 2024-01-22 13:14:13 +08:00
parent b6d2bc71ce
commit 17803cb7c1
4 changed files with 58 additions and 60 deletions

View File

@ -92,11 +92,10 @@ FSCHAT_MODEL_WORKERS = {
# 'disable_log_requests': False # 'disable_log_requests': False
}, },
# 可以如下示例方式更改默认配置 "Qwen-1_8B-Chat": {
# "Qwen-1_8B-Chat": { # 使用default中的IP和端口 "device": "cpu",
# "device": "cpu", },
# }, "chatglm3-6b": {
"chatglm3-6b": { # 使用default中的IP和端口
"device": "cuda", "device": "cuda",
}, },
@ -129,7 +128,7 @@ FSCHAT_MODEL_WORKERS = {
"port": 21009, "port": 21009,
}, },
"gemini-api": { "gemini-api": {
"port": 21012, "port": 21010,
}, },
} }

View File

@ -13,9 +13,9 @@ def list_kbs():
def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
vector_store_type: str = Body("faiss"), vector_store_type: str = Body("faiss"),
embed_model: str = Body(EMBEDDING_MODEL), embed_model: str = Body(EMBEDDING_MODEL),
) -> BaseResponse: ) -> BaseResponse:
# Create selected knowledge base # Create selected knowledge base
if not validate_kb_name(knowledge_base_name): if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me") return BaseResponse(code=403, msg="Don't attack me")
@ -39,8 +39,8 @@ def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
def delete_kb( def delete_kb(
knowledge_base_name: str = Body(..., examples=["samples"]) knowledge_base_name: str = Body(..., examples=["samples"])
) -> BaseResponse: ) -> BaseResponse:
# Delete selected knowledge base # Delete selected knowledge base
if not validate_kb_name(knowledge_base_name): if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me") return BaseResponse(code=403, msg="Don't attack me")

View File

@ -3,7 +3,7 @@ from fastchat.conversation import Conversation
from server.model_workers.base import * from server.model_workers.base import *
from server.utils import get_httpx_client from server.utils import get_httpx_client
from fastchat import conversation as conv from fastchat import conversation as conv
import json,httpx import json, httpx
from typing import List, Dict from typing import List, Dict
from configs import logger, log_verbose from configs import logger, log_verbose
@ -14,14 +14,14 @@ class GeminiWorker(ApiModelWorker):
*, *,
controller_addr: str = None, controller_addr: str = None,
worker_addr: str = None, worker_addr: str = None,
model_names: List[str] = ["Gemini-api"], model_names: List[str] = ["gemini-api"],
**kwargs, **kwargs,
): ):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 4096) kwargs.setdefault("context_len", 4096)
super().__init__(**kwargs) super().__init__(**kwargs)
def create_gemini_messages(self,messages) -> json: def create_gemini_messages(self, messages) -> json:
has_history = any(msg['role'] == 'assistant' for msg in messages) has_history = any(msg['role'] == 'assistant' for msg in messages)
gemini_msg = [] gemini_msg = []
@ -46,7 +46,7 @@ class GeminiWorker(ApiModelWorker):
def do_chat(self, params: ApiChatParams) -> Dict: def do_chat(self, params: ApiChatParams) -> Dict:
params.load_config(self.model_names[0]) params.load_config(self.model_names[0])
data = self.create_gemini_messages(messages=params.messages) data = self.create_gemini_messages(messages=params.messages)
generationConfig=dict( generationConfig = dict(
temperature=params.temperature, temperature=params.temperature,
topK=1, topK=1,
topP=1, topP=1,
@ -55,7 +55,7 @@ class GeminiWorker(ApiModelWorker):
) )
data['generationConfig'] = generationConfig data['generationConfig'] = generationConfig
url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent"+ '?key=' + params.api_key url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent" + '?key=' + params.api_key
headers = { headers = {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
} }
@ -67,7 +67,7 @@ class GeminiWorker(ApiModelWorker):
text = "" text = ""
json_string = "" json_string = ""
timeout = httpx.Timeout(60.0) timeout = httpx.Timeout(60.0)
client=get_httpx_client(timeout=timeout) client = get_httpx_client(timeout=timeout)
with client.stream("POST", url, headers=headers, json=data) as response: with client.stream("POST", url, headers=headers, json=data) as response:
for line in response.iter_lines(): for line in response.iter_lines():
line = line.strip() line = line.strip()

View File

@ -12,7 +12,6 @@ from server.knowledge_base.utils import LOADER_DICT
import uuid import uuid
from typing import List, Dict from typing import List, Dict
chat_box = ChatBox( chat_box = ChatBox(
assistant_avatar=os.path.join( assistant_avatar=os.path.join(
"img", "img",
@ -138,11 +137,11 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
st.toast(text) st.toast(text)
dialogue_modes = ["LLM 对话", dialogue_modes = ["LLM 对话",
"知识库问答", "知识库问答",
"文件对话", "文件对话",
"搜索引擎问答", "搜索引擎问答",
"自定义Agent问答", "自定义Agent问答",
] ]
dialogue_mode = st.selectbox("请选择对话模式:", dialogue_mode = st.selectbox("请选择对话模式:",
dialogue_modes, dialogue_modes,
index=0, index=0,
@ -166,9 +165,9 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
available_models = [] available_models = []
config_models = api.list_config_models() config_models = api.list_config_models()
if not is_lite: if not is_lite:
for k, v in config_models.get("local", {}).items(): # 列出配置了有效本地路径的模型 for k, v in config_models.get("local", {}).items(): # 列出配置了有效本地路径的模型
if (v.get("model_path_exists") if (v.get("model_path_exists")
and k not in running_models): and k not in running_models):
available_models.append(k) available_models.append(k)
for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型 for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型
if not v.get("provider") and k not in running_models: if not v.get("provider") and k not in running_models:
@ -250,14 +249,14 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
elif dialogue_mode == "文件对话": elif dialogue_mode == "文件对话":
with st.expander("文件对话配置", True): with st.expander("文件对话配置", True):
files = st.file_uploader("上传知识文件:", files = st.file_uploader("上传知识文件:",
[i for ls in LOADER_DICT.values() for i in ls], [i for ls in LOADER_DICT.values() for i in ls],
accept_multiple_files=True, accept_multiple_files=True,
) )
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K) kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
## Bge 模型会超过1 ## Bge 模型会超过1
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01) score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
if st.button("开始上传", disabled=len(files)==0): if st.button("开始上传", disabled=len(files) == 0):
st.session_state["file_chat_id"] = upload_temp_docs(files, api) st.session_state["file_chat_id"] = upload_temp_docs(files, api)
elif dialogue_mode == "搜索引擎问答": elif dialogue_mode == "搜索引擎问答":
search_engine_list = api.list_search_engines() search_engine_list = api.list_search_engines()
@ -279,9 +278,9 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
chat_input_placeholder = "请输入对话内容换行请使用Shift+Enter。输入/help查看自定义命令 " chat_input_placeholder = "请输入对话内容换行请使用Shift+Enter。输入/help查看自定义命令 "
def on_feedback( def on_feedback(
feedback, feedback,
message_id: str = "", message_id: str = "",
history_index: int = -1, history_index: int = -1,
): ):
reason = feedback["text"] reason = feedback["text"]
score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index) score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index)
@ -296,7 +295,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
} }
if prompt := st.chat_input(chat_input_placeholder, key="prompt"): if prompt := st.chat_input(chat_input_placeholder, key="prompt"):
if parse_command(text=prompt, modal=modal): # 用户输入自定义命令 if parse_command(text=prompt, modal=modal): # 用户输入自定义命令
st.rerun() st.rerun()
else: else:
history = get_messages_history(history_len) history = get_messages_history(history_len)
@ -306,11 +305,11 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
text = "" text = ""
message_id = "" message_id = ""
r = api.chat_chat(prompt, r = api.chat_chat(prompt,
history=history, history=history,
conversation_id=conversation_id, conversation_id=conversation_id,
model=llm_model, model=llm_model,
prompt_name=prompt_template_name, prompt_name=prompt_template_name,
temperature=temperature) temperature=temperature)
for t in r: for t in r:
if error_msg := check_error_msg(t): # check whether error occured if error_msg := check_error_msg(t): # check whether error occured
st.error(error_msg) st.error(error_msg)
@ -321,12 +320,12 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
metadata = { metadata = {
"message_id": message_id, "message_id": message_id,
} }
chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标 chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标
chat_box.show_feedback(**feedback_kwargs, chat_box.show_feedback(**feedback_kwargs,
key=message_id, key=message_id,
on_submit=on_feedback, on_submit=on_feedback,
kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1}) kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
elif dialogue_mode == "自定义Agent问答": elif dialogue_mode == "自定义Agent问答":
if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL): if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):
@ -373,13 +372,13 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
]) ])
text = "" text = ""
for d in api.knowledge_base_chat(prompt, for d in api.knowledge_base_chat(prompt,
knowledge_base_name=selected_kb, knowledge_base_name=selected_kb,
top_k=kb_top_k, top_k=kb_top_k,
score_threshold=score_threshold, score_threshold=score_threshold,
history=history, history=history,
model=llm_model, model=llm_model,
prompt_name=prompt_template_name, prompt_name=prompt_template_name,
temperature=temperature): temperature=temperature):
if error_msg := check_error_msg(d): # check whether error occured if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg) st.error(error_msg)
elif chunk := d.get("answer"): elif chunk := d.get("answer"):
@ -397,13 +396,13 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
]) ])
text = "" text = ""
for d in api.file_chat(prompt, for d in api.file_chat(prompt,
knowledge_id=st.session_state["file_chat_id"], knowledge_id=st.session_state["file_chat_id"],
top_k=kb_top_k, top_k=kb_top_k,
score_threshold=score_threshold, score_threshold=score_threshold,
history=history, history=history,
model=llm_model, model=llm_model,
prompt_name=prompt_template_name, prompt_name=prompt_template_name,
temperature=temperature): temperature=temperature):
if error_msg := check_error_msg(d): # check whether error occured if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg) st.error(error_msg)
elif chunk := d.get("answer"): elif chunk := d.get("answer"):