gemini api 修复调用

This commit is contained in:
zR 2024-01-22 13:14:13 +08:00
parent b6d2bc71ce
commit 17803cb7c1
4 changed files with 58 additions and 60 deletions

View File

@ -92,11 +92,10 @@ FSCHAT_MODEL_WORKERS = {
# 'disable_log_requests': False # 'disable_log_requests': False
}, },
# 可以如下示例方式更改默认配置 "Qwen-1_8B-Chat": {
# "Qwen-1_8B-Chat": { # 使用default中的IP和端口 "device": "cpu",
# "device": "cpu", },
# }, "chatglm3-6b": {
"chatglm3-6b": { # 使用default中的IP和端口
"device": "cuda", "device": "cuda",
}, },
@ -129,7 +128,7 @@ FSCHAT_MODEL_WORKERS = {
"port": 21009, "port": 21009,
}, },
"gemini-api": { "gemini-api": {
"port": 21012, "port": 21010,
}, },
} }

View File

@ -40,7 +40,7 @@ def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
def delete_kb( def delete_kb(
knowledge_base_name: str = Body(..., examples=["samples"]) knowledge_base_name: str = Body(..., examples=["samples"])
) -> BaseResponse: ) -> BaseResponse:
# Delete selected knowledge base # Delete selected knowledge base
if not validate_kb_name(knowledge_base_name): if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me") return BaseResponse(code=403, msg="Don't attack me")

View File

@ -3,7 +3,7 @@ from fastchat.conversation import Conversation
from server.model_workers.base import * from server.model_workers.base import *
from server.utils import get_httpx_client from server.utils import get_httpx_client
from fastchat import conversation as conv from fastchat import conversation as conv
import json,httpx import json, httpx
from typing import List, Dict from typing import List, Dict
from configs import logger, log_verbose from configs import logger, log_verbose
@ -14,14 +14,14 @@ class GeminiWorker(ApiModelWorker):
*, *,
controller_addr: str = None, controller_addr: str = None,
worker_addr: str = None, worker_addr: str = None,
model_names: List[str] = ["Gemini-api"], model_names: List[str] = ["gemini-api"],
**kwargs, **kwargs,
): ):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 4096) kwargs.setdefault("context_len", 4096)
super().__init__(**kwargs) super().__init__(**kwargs)
def create_gemini_messages(self,messages) -> json: def create_gemini_messages(self, messages) -> json:
has_history = any(msg['role'] == 'assistant' for msg in messages) has_history = any(msg['role'] == 'assistant' for msg in messages)
gemini_msg = [] gemini_msg = []
@ -46,7 +46,7 @@ class GeminiWorker(ApiModelWorker):
def do_chat(self, params: ApiChatParams) -> Dict: def do_chat(self, params: ApiChatParams) -> Dict:
params.load_config(self.model_names[0]) params.load_config(self.model_names[0])
data = self.create_gemini_messages(messages=params.messages) data = self.create_gemini_messages(messages=params.messages)
generationConfig=dict( generationConfig = dict(
temperature=params.temperature, temperature=params.temperature,
topK=1, topK=1,
topP=1, topP=1,
@ -55,7 +55,7 @@ class GeminiWorker(ApiModelWorker):
) )
data['generationConfig'] = generationConfig data['generationConfig'] = generationConfig
url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent"+ '?key=' + params.api_key url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent" + '?key=' + params.api_key
headers = { headers = {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
} }
@ -67,7 +67,7 @@ class GeminiWorker(ApiModelWorker):
text = "" text = ""
json_string = "" json_string = ""
timeout = httpx.Timeout(60.0) timeout = httpx.Timeout(60.0)
client=get_httpx_client(timeout=timeout) client = get_httpx_client(timeout=timeout)
with client.stream("POST", url, headers=headers, json=data) as response: with client.stream("POST", url, headers=headers, json=data) as response:
for line in response.iter_lines(): for line in response.iter_lines():
line = line.strip() line = line.strip()

View File

@ -12,7 +12,6 @@ from server.knowledge_base.utils import LOADER_DICT
import uuid import uuid
from typing import List, Dict from typing import List, Dict
chat_box = ChatBox( chat_box = ChatBox(
assistant_avatar=os.path.join( assistant_avatar=os.path.join(
"img", "img",
@ -257,7 +256,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
## Bge 模型会超过1 ## Bge 模型会超过1
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01) score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
if st.button("开始上传", disabled=len(files)==0): if st.button("开始上传", disabled=len(files) == 0):
st.session_state["file_chat_id"] = upload_temp_docs(files, api) st.session_state["file_chat_id"] = upload_temp_docs(files, api)
elif dialogue_mode == "搜索引擎问答": elif dialogue_mode == "搜索引擎问答":
search_engine_list = api.list_search_engines() search_engine_list = api.list_search_engines()