From 03eb5e9d2ec5720b6872eb21c0b3715a94faaecb Mon Sep 17 00:00:00 2001 From: yhfgyyf <49861869+yhfgyyf@users.noreply.github.com> Date: Fri, 12 Jan 2024 10:16:31 +0800 Subject: [PATCH] Gemini api (#2630) * Gemini-pro api * Update gemini.py * Update gemini.py --- configs/model_config.py.example | 6 ++ configs/server_config.py.example | 3 + server/model_workers/__init__.py | 1 + server/model_workers/gemini.py | 124 +++++++++++++++++++++++++++++++ 4 files changed, 134 insertions(+) create mode 100644 server/model_workers/gemini.py diff --git a/configs/model_config.py.example b/configs/model_config.py.example index b203e933..9be56953 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -56,6 +56,12 @@ ONLINE_LLM_MODEL = { "openai_proxy": "", }, + # 获取api_key请前往https://makersuite.google.com/或者google cloud,使用前先确认网络正常,使用代理请在项目启动(python startup.py -a)环境内设置https_proxy环境变量 + "gemini-api": { + "api_key": "", + "provider": "GeminiWorker", + }, + # 具体注册及api key获取请前往 http://open.bigmodel.cn "zhipu-api": { "api_key": "", diff --git a/configs/server_config.py.example b/configs/server_config.py.example index 7fa0c412..2f51c3ad 100644 --- a/configs/server_config.py.example +++ b/configs/server_config.py.example @@ -128,6 +128,9 @@ FSCHAT_MODEL_WORKERS = { "tiangong-api": { "port": 21009, }, + "gemini-api": { + "port": 21012, + }, } # fastchat multi model worker server diff --git a/server/model_workers/__init__.py b/server/model_workers/__init__.py index 67c9aa63..d0320f41 100644 --- a/server/model_workers/__init__.py +++ b/server/model_workers/__init__.py @@ -8,3 +8,4 @@ from .qwen import QwenWorker from .baichuan import BaiChuanWorker from .azure import AzureWorker from .tiangong import TianGongWorker +from .gemini import GeminiWorker \ No newline at end of file diff --git a/server/model_workers/gemini.py b/server/model_workers/gemini.py new file mode 100644 index 00000000..46130212 --- /dev/null +++ b/server/model_workers/gemini.py @@ -0,0 +1,124 @@ +import sys +from fastchat.conversation import Conversation +from server.model_workers.base import * +from server.utils import get_httpx_client +from fastchat import conversation as conv +import json,httpx +from typing import List, Dict +from configs import logger, log_verbose + + +class GeminiWorker(ApiModelWorker): + def __init__( + self, + *, + controller_addr: str = None, + worker_addr: str = None, + model_names: List[str] = ["Gemini-api"], + **kwargs, + ): + kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) + kwargs.setdefault("context_len", 4096) #TODO 16K模型需要改成16384 + super().__init__(**kwargs) + + def create_gemini_messages(self,messages) -> json: + has_history = any(msg['role'] == 'assistant' for msg in messages) + gemini_msg = [] + + for msg in messages: + role = msg['role'] + content = msg['content'] + if role == 'system': + continue + if has_history: + if role == 'assistant': + role = "model" + transformed_msg = {"role": role, "parts": [{"text": content}]} + else: + if role == 'user': + transformed_msg = {"parts": [{"text": content}]} + + gemini_msg.append(transformed_msg) + + msg = dict(contents=gemini_msg) + return msg + + def do_chat(self, params: ApiChatParams) -> Dict: + params.load_config(self.model_names[0]) + data = self.create_gemini_messages(messages=params.messages) + generationConfig=dict( + temperature = params.temperature, + topK = 1, + topP = 1, + maxOutputTokens = 4096, + stopSequences=[] + ) + + data['generationConfig'] = generationConfig + url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent"+ '?key=' + params.api_key + headers = { + 'Content-Type': 'application/json', + } + if log_verbose: + logger.info(f'{self.__class__.__name__}:url: {url}') + logger.info(f'{self.__class__.__name__}:headers: {headers}') + logger.info(f'{self.__class__.__name__}:data: {data}') + + text = "" + json_string = "" + timeout = httpx.Timeout(60.0) + client=get_httpx_client(timeout=timeout) + with client.stream("POST", url, headers=headers, json=data) as response: + for line in response.iter_lines(): + line = line.strip() + if not line or "[DONE]" in line: + continue + + json_string += line + + try: + resp = json.loads(json_string) + if 'candidates' in resp: + for candidate in resp['candidates']: + content = candidate.get('content', {}) + parts = content.get('parts', []) + for part in parts: + if 'text' in part: + text += part['text'] + yield { + "error_code": 0, + "text": text + } + print(text) + except json.JSONDecodeError as e: + print("Failed to decode JSON:", e) + print("Invalid JSON string:", json_string) + + def get_embeddings(self, params): + # TODO: 支持embeddings + print("embedding") + print(params) + + def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: + return conv.Conversation( + name=self.model_names[0], + system_message="You are a helpful, respectful and honest assistant.", + messages=[], + roles=["user", "assistant"], + sep="\n### ", + stop_str="###", + ) + + +if __name__ == "__main__": + import uvicorn + from server.utils import MakeFastAPIOffline + from fastchat.serve.base_model_worker import app + + worker = GeminiWorker( + controller_addr="http://127.0.0.1:20001", + worker_addr="http://127.0.0.1:21012", + ) + sys.modules["fastchat.serve.model_worker"].worker = worker + MakeFastAPIOffline(app) + uvicorn.run(app, port=21012)