From f639fbc19e99ee3cafe1802f36115a43c0faaf62 Mon Sep 17 00:00:00 2001 From: ceerrep Date: Tue, 25 Feb 2025 14:11:39 +0800 Subject: [PATCH] feat: basic api key support --- ktransformers/server/api/openai/endpoints/chat.py | 3 +++ ktransformers/server/args.py | 1 + ktransformers/server/config/config.py | 1 + 3 files changed, 5 insertions(+) diff --git a/ktransformers/server/api/openai/endpoints/chat.py b/ktransformers/server/api/openai/endpoints/chat.py index f84538a..4e91279 100644 --- a/ktransformers/server/api/openai/endpoints/chat.py +++ b/ktransformers/server/api/openai/endpoints/chat.py @@ -25,6 +25,9 @@ async def chat_completion(request:Request,create:ChatCompletionCreate): input_message = [json.loads(m.model_dump_json()) for m in create.messages] + if Config().api_key != '': + assert request.headers.get('Authorization', '').split()[-1] == Config().api_key + if create.stream: async def inner(): chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time())) diff --git a/ktransformers/server/args.py b/ktransformers/server/args.py index a9df65b..82bde07 100644 --- a/ktransformers/server/args.py +++ b/ktransformers/server/args.py @@ -10,6 +10,7 @@ class ArgumentParser: parser = argparse.ArgumentParser(prog="kvcache.ai", description="Ktransformers") parser.add_argument("--host", type=str, default=self.cfg.server_ip) parser.add_argument("--port", type=int, default=self.cfg.server_port) + parser.add_argument("--api_key", type=str, default=self.cfg.api_key) parser.add_argument("--ssl_keyfile", type=str) parser.add_argument("--ssl_certfile", type=str) parser.add_argument("--web", type=bool, default=self.cfg.mount_web) diff --git a/ktransformers/server/config/config.py b/ktransformers/server/config/config.py index 7dc9921..73b92f9 100644 --- a/ktransformers/server/config/config.py +++ b/ktransformers/server/config/config.py @@ -69,6 +69,7 @@ class Config(metaclass=Singleton): self.server: dict = cfg.get("server", {}) self.server_ip = self.server.get("ip", "0.0.0.0") self.server_port = self.server.get("port", 9016) + self.api_key = self.server.get("api_key", "") # db configs self.db_configs: dict = cfg.get("db", {})