Merge pull request #2752 from zRzRzRzRzRzRzR/dev

gemini API 修复
This commit is contained in:
zR 2024-01-22 13:15:06 +08:00 committed by GitHub
commit 54e5b41647
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 58 additions and 60 deletions

View File

@ -92,11 +92,10 @@ FSCHAT_MODEL_WORKERS = {
# 'disable_log_requests': False
},
# 可以如下示例方式更改默认配置
# "Qwen-1_8B-Chat": { # 使用default中的IP和端口
# "device": "cpu",
# },
"chatglm3-6b": { # 使用default中的IP和端口
"Qwen-1_8B-Chat": {
"device": "cpu",
},
"chatglm3-6b": {
"device": "cuda",
},
@ -129,7 +128,7 @@ FSCHAT_MODEL_WORKERS = {
"port": 21009,
},
"gemini-api": {
"port": 21012,
"port": 21010,
},
}

View File

@ -40,7 +40,7 @@ def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
def delete_kb(
knowledge_base_name: str = Body(..., examples=["samples"])
) -> BaseResponse:
) -> BaseResponse:
# Delete selected knowledge base
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")

View File

@ -3,7 +3,7 @@ from fastchat.conversation import Conversation
from server.model_workers.base import *
from server.utils import get_httpx_client
from fastchat import conversation as conv
import json,httpx
import json, httpx
from typing import List, Dict
from configs import logger, log_verbose
@ -14,14 +14,14 @@ class GeminiWorker(ApiModelWorker):
*,
controller_addr: str = None,
worker_addr: str = None,
model_names: List[str] = ["Gemini-api"],
model_names: List[str] = ["gemini-api"],
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 4096)
super().__init__(**kwargs)
def create_gemini_messages(self,messages) -> json:
def create_gemini_messages(self, messages) -> json:
has_history = any(msg['role'] == 'assistant' for msg in messages)
gemini_msg = []
@ -46,7 +46,7 @@ class GeminiWorker(ApiModelWorker):
def do_chat(self, params: ApiChatParams) -> Dict:
params.load_config(self.model_names[0])
data = self.create_gemini_messages(messages=params.messages)
generationConfig=dict(
generationConfig = dict(
temperature=params.temperature,
topK=1,
topP=1,
@ -55,7 +55,7 @@ class GeminiWorker(ApiModelWorker):
)
data['generationConfig'] = generationConfig
url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent"+ '?key=' + params.api_key
url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent" + '?key=' + params.api_key
headers = {
'Content-Type': 'application/json',
}
@ -67,7 +67,7 @@ class GeminiWorker(ApiModelWorker):
text = ""
json_string = ""
timeout = httpx.Timeout(60.0)
client=get_httpx_client(timeout=timeout)
client = get_httpx_client(timeout=timeout)
with client.stream("POST", url, headers=headers, json=data) as response:
for line in response.iter_lines():
line = line.strip()

View File

@ -12,7 +12,6 @@ from server.knowledge_base.utils import LOADER_DICT
import uuid
from typing import List, Dict
chat_box = ChatBox(
assistant_avatar=os.path.join(
"img",
@ -257,7 +256,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
## Bge 模型会超过1
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
if st.button("开始上传", disabled=len(files)==0):
if st.button("开始上传", disabled=len(files) == 0):
st.session_state["file_chat_id"] = upload_temp_docs(files, api)
elif dialogue_mode == "搜索引擎问答":
search_engine_list = api.list_search_engines()