From 448e79576c848a542c342eaf372647d0db32e079 Mon Sep 17 00:00:00 2001 From: Art Jiang Date: Mon, 15 Apr 2024 17:32:48 -0700 Subject: [PATCH] Adding Claude 3 API support (#3340) --------- Co-authored-by: Jiang, Fengyi --- configs/model_config.py.example | 11 +++ configs/server_config.py.example | 3 + server/model_workers/__init__.py | 3 +- server/model_workers/claude.py | 130 +++++++++++++++++++++++++++++++ 4 files changed, 146 insertions(+), 1 deletion(-) create mode 100644 server/model_workers/claude.py diff --git a/configs/model_config.py.example b/configs/model_config.py.example index c42066d3..a0753663 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -125,6 +125,17 @@ ONLINE_LLM_MODEL = { "api_key": "", "provider": "GeminiWorker", } + # Claude API : https://www.anthropic.com/api + # Available models: + # Claude 3 Opus: claude-3-opus-20240229 + # Claude 3 Sonnet claude-3-sonnet-20240229 + # Claude 3 Haiku claude-3-haiku-20240307 + "claude-api": { + "api_key": "", + "version": "2023-06-01", + "model_name":"claude-3-opus-20240229", + "provider": "ClaudeWorker", + } } diff --git a/configs/server_config.py.example b/configs/server_config.py.example index 9bbb8b49..56b63680 100644 --- a/configs/server_config.py.example +++ b/configs/server_config.py.example @@ -127,6 +127,9 @@ FSCHAT_MODEL_WORKERS = { "gemini-api": { "port": 21010, }, + "claude-api": { + "port": 21011, + }, } FSCHAT_CONTROLLER = { diff --git a/server/model_workers/__init__.py b/server/model_workers/__init__.py index d0320f41..6991fd0b 100644 --- a/server/model_workers/__init__.py +++ b/server/model_workers/__init__.py @@ -8,4 +8,5 @@ from .qwen import QwenWorker from .baichuan import BaiChuanWorker from .azure import AzureWorker from .tiangong import TianGongWorker -from .gemini import GeminiWorker \ No newline at end of file +from .gemini import GeminiWorker +from .claude import ClaudeWorker \ No newline at end of file diff --git a/server/model_workers/claude.py b/server/model_workers/claude.py new file mode 100644 index 00000000..8661548a --- /dev/null +++ b/server/model_workers/claude.py @@ -0,0 +1,130 @@ +import sys +from fastchat.conversation import Conversation +from server.model_workers.base import * +from server.utils import get_httpx_client +import json, httpx +from typing import List, Dict +from configs import logger, log_verbose +import uvicorn + + +class ClaudeWorker(ApiModelWorker): + def __init__( + self, + *, + controller_addr: str = None, + worker_addr: str = None, + model_names: List[str] = ["claude-api"], + version: str = "2023-06-01", + + **kwargs, + ): + kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) + kwargs.setdefault("context_len", 1024) + super().__init__(**kwargs) + self.version = version + + def create_claude_messages(self, params: ApiChatParams) -> json: + has_history = any(msg['role'] == 'assistant' for msg in params.messages) + claude_msg = { + "model": params.model_name, + "max_tokens": params.context_len, + "messages": [] + } + + for msg in params.messages: + role = msg['role'] + content = msg['content'] + if role == 'system': + continue + # Adjusting for history presence + if has_history and role == 'assistant': + role = "model" + claude_msg["messages"].append({"role": role, "content": content}) + + return claude_msg + + def do_chat(self, params: ApiChatParams) -> Dict: + data = self.create_claude_messages(params) + url = "https://api.anthropic.com/v1/messages" + headers = { + 'anthropic-version': '2023-06-01', + 'anthropic-beta': 'messages-2023-12-15', + 'Content-Type': 'application/json', + 'x-api-key': params.api_key, + } + if log_verbose: + logger.info(f'{self.__class__.__name__}:url: {url}') + logger.info(f'{self.__class__.__name__}:headers: {headers}') + logger.info(f'{self.__class__.__name__}:data: {data}') + + text = "" + json_string = "" + timeout = httpx.Timeout(60.0) + client = get_httpx_client(timeout=timeout) + client = get_httpx_client() + with client.stream("POST", url, headers=headers, json=data) as response: + for line in response.iter_lines(): + line = line.strip() + if not line: + continue + json_string += line + + try: + event_data = json.loads(line) + event_type = event_data.get("type") + if event_type == "content_block_delta": + delta_text = event_data.get("delta", {}).get("text", "") + text += delta_text + elif event_type == "message_stop": + # Message is complete, yield the result + yield { + "error_code": 0, + "text": text + } + text = "" + else: + logger.error(f"Failed to get response: {response.text}") + yield { + "error_code": response.status_code, + "text": "Failed to communicate with Claude API." + } + + except json.JSONDecodeError as e: + print("Failed to decode JSON:", e) + print("Invalid JSON string:", json_string) + + def get_embeddings(self, params): + # Implement embedding retrieval if necessary + print("embedding") + print(params) + + def make_conv_template(self, conv_template: List[Dict[str, str]] = None, model_path: str = None) -> Conversation: + if conv_template is None: + conv_template = [ + {"role": "user", "content": "Hello there."}, + {"role": "assistant", "content": "Hi, I'm Claude. How can I help you?"}, + {"role": "user", "content": "Can you explain LLMs in plain English?"} + ] + return Conversation( + name=self.model_names[0], + system_message="You are Claude, a helpful, respectful, and honest assistant.", + messages=conv_template, + roles=["user", "assistant"], + sep="\n### ", + stop_str="###", + ) + + +if __name__ == "__main__": + + from server.utils import MakeFastAPIOffline + from fastchat.serve.base_model_worker import app + + worker = ClaudeWorker( + controller_addr="http://127.0.0.1:20001", + worker_addr="http://127.0.0.1:21011", + ) + sys.modules["fastchat.serve.model_worker"].worker = worker + MakeFastAPIOffline(app) + uvicorn.run(app, port=21011) \ No newline at end of file