diff --git a/chatchat-server/chatchat/model_loaders/init_server.py b/chatchat-server/chatchat/model_loaders/init_server.py index ef909af6..78e50965 100644 --- a/chatchat-server/chatchat/model_loaders/init_server.py +++ b/chatchat-server/chatchat/model_loaders/init_server.py @@ -44,8 +44,8 @@ def init_server(model_platforms_shard: Dict, provider_platforms = init_provider_platforms(boot.provider_manager.provider_manager) model_platforms_shard['provider_platforms'] = provider_platforms - - boot.serve(logging_conf=logging_conf) + boot.logging_conf(logging_conf=logging_conf) + boot.run() async def pool_join_thread(): await boot.join() diff --git a/model-providers/model_providers/__main__.py b/model-providers/model_providers/__main__.py index 3f9918fe..fa4797fc 100644 --- a/model-providers/model_providers/__main__.py +++ b/model-providers/model_providers/__main__.py @@ -36,7 +36,8 @@ if __name__ == "__main__": .build() ) boot.set_app_event(started_event=None) - boot.serve(logging_conf=logging_conf) + boot.logging_conf(logging_conf=logging_conf) + boot.run() async def pool_join_thread(): await boot.join() diff --git a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py index e7784a83..514d8467 100644 --- a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py +++ b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py @@ -18,6 +18,7 @@ from typing import ( cast, ) +import uvicorn from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status from fastapi.middleware.cors import CORSMiddleware from sse_starlette import EventSourceResponse @@ -67,8 +68,13 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): self._port = port self._router = APIRouter() self._app = FastAPI() + self._logging_conf = None + self._server = None self._server_thread = None + def logging_conf(self,logging_conf: Optional[dict] = None): + self._logging_conf = logging_conf + @classmethod def from_config(cls, cfg=None): host = cfg.get("host", "127.0.0.1") @@ -79,7 +85,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): ) return cls(host=host, port=port) - def serve(self, logging_conf: Optional[dict] = None): + def run(self): self._app.add_middleware( CORSMiddleware, allow_origins=["*"], @@ -125,18 +131,29 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): self._app.include_router(self._router) config = Config( - app=self._app, host=self._host, port=self._port, log_config=logging_conf + app=self._app, host=self._host, port=self._port, log_config=self._logging_conf ) - server = Server(config) + self._server = Server(config) def run_server(): - server.run() + self._server.shutdown_timeout = 2 # 设置为2秒 + + self._server.run() self._server_thread = threading.Thread(target=run_server) self._server_thread.start() - async def join(self): - await self._server_thread.join() + def destroy(self): + + logger.info("Shutting down server") + self._server.should_exit = True # 设置退出标志 + self._server.shutdown() # 停止服务器 + + self.join() + + def join(self): + self._server_thread.join() + def set_app_event(self, started_event: mp.Event = None): @self._app.on_event("startup") @@ -273,10 +290,11 @@ def run( cfg=cfg.get("run_openai_api", {}) ) api.set_app_event(started_event=started_event) - api.serve(logging_conf=logging_conf) + api.logging_conf(logging_conf=logging_conf) + api.run() async def pool_join_thread(): - await api.join() + api.join() asyncio.run(pool_join_thread()) except SystemExit: diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/_common.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/_common.py index 4f7d4be3..4b27cb9b 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/_common.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/_common.py @@ -24,6 +24,11 @@ class _CommonZhipuaiAI: else credentials["zhipuai_api_key"] if "zhipuai_api_key" in credentials else None, + "api_base": credentials["api_base"] + if "api_key" in credentials + else credentials["zhipuai_api_base"] + if "zhipuai_api_base" in credentials + else None, } return credentials_kwargs diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/llm/llm.py index d0ec7987..b987e6fc 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -195,7 +195,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): if stop: extra_model_kwargs["stop"] = stop - client = ZhipuAI(api_key=credentials_kwargs["api_key"]) + client = ZhipuAI(base_url=credentials_kwargs["api_base"], + api_key=credentials_kwargs["api_key"]) if len(prompt_messages) == 0: raise ValueError("At least one message is required") diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py index cd870b11..9e12b53c 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py @@ -43,7 +43,8 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): :return: embeddings result """ credentials_kwargs = self._to_credential_kwargs(credentials) - client = ZhipuAI(api_key=credentials_kwargs["api_key"]) + client = ZhipuAI(base_url=credentials_kwargs["api_base"], + api_key=credentials_kwargs["api_key"]) embeddings, embedding_used_tokens = self.embed_documents(model, client, texts) @@ -84,7 +85,8 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): try: # transform credentials to kwargs for model instance credentials_kwargs = self._to_credential_kwargs(credentials) - client = ZhipuAI(api_key=credentials_kwargs["api_key"]) + client = ZhipuAI(base_url=credentials_kwargs["api_base"], + api_key=credentials_kwargs["api_key"]) # call embedding model self.embed_documents( diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai.yaml b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai.yaml index 303a5491..c4e526ac 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai.yaml +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai.yaml @@ -29,3 +29,12 @@ provider_credential_schema: placeholder: zh_Hans: 在此输入您的 APIKey en_US: Enter your APIKey + - variable: api_base + label: + zh_Hans: API Base + en_US: API Base + type: text-input + required: false + placeholder: + zh_Hans: 在此输入您的 API Base + en_US: Enter your API Base \ No newline at end of file diff --git a/model-providers/pyproject.toml b/model-providers/pyproject.toml index 859006de..ba4d923d 100644 --- a/model-providers/pyproject.toml +++ b/model-providers/pyproject.toml @@ -34,8 +34,8 @@ pytest-asyncio = "^0.21.1" grandalf = "^0.8" pytest-profiling = "^1.7.0" responses = "^0.25.0" - - +langchain = "0.1.5" +langchain-openai = "0.0.5" [tool.poetry.group.lint] optional = true diff --git a/model-providers/tests/server_unit_test/conftest.py b/model-providers/tests/conftest.py similarity index 77% rename from model-providers/tests/server_unit_test/conftest.py rename to model-providers/tests/conftest.py index eea02a65..a4508b81 100644 --- a/model-providers/tests/server_unit_test/conftest.py +++ b/model-providers/tests/conftest.py @@ -6,6 +6,7 @@ from typing import Dict, List, Sequence import pytest from pytest import Config, Function, Parser +from model_providers import BootstrapWebBuilder from model_providers.core.utils.utils import ( get_config_dict, get_log_file, @@ -102,3 +103,41 @@ def logging_conf() -> dict: 122, 111, ) + +@pytest.fixture +def providers_file(request) -> str: + from pathlib import Path + import os + # 当前执行目录 + # 获取当前测试文件的路径 + test_file_path = Path(str(request.fspath)).parent + print("test_file_path:",test_file_path) + return os.path.join(test_file_path,"model_providers.yaml") + + +@pytest.fixture +@pytest.mark.requires("fastapi") +def init_server(logging_conf: dict, providers_file: str) -> None: + try: + boot = ( + BootstrapWebBuilder() + .model_providers_cfg_path( + model_providers_cfg_path=providers_file + ) + .host(host="127.0.0.1") + .port(port=20000) + .build() + ) + boot.set_app_event(started_event=None) + boot.logging_conf(logging_conf=logging_conf) + boot.run() + + try: + yield f"http://127.0.0.1:20000" + finally: + print("") + boot.destroy() + + except SystemExit: + + raise diff --git a/model-providers/tests/openai_providers_test/model_providers.yaml b/model-providers/tests/openai_providers_test/model_providers.yaml new file mode 100644 index 00000000..b98d2924 --- /dev/null +++ b/model-providers/tests/openai_providers_test/model_providers.yaml @@ -0,0 +1,5 @@ +openai: + provider_credential: + openai_api_key: 'sk-' + openai_organization: '' + openai_api_base: '' diff --git a/model-providers/tests/openai_providers_test/test_openai_service.py b/model-providers/tests/openai_providers_test/test_openai_service.py new file mode 100644 index 00000000..6cdb2731 --- /dev/null +++ b/model-providers/tests/openai_providers_test/test_openai_service.py @@ -0,0 +1,22 @@ +from langchain.chains import LLMChain +from langchain_core.prompts import PromptTemplate +from langchain_openai import ChatOpenAI +import pytest +import logging + +logger = logging.getLogger(__name__) + +@pytest.mark.requires("openai") +def test_llm(init_server: str): + llm = ChatOpenAI(openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/openai/v1") + template = """Question: {question} + + Answer: Let's think step by step.""" + + prompt = PromptTemplate.from_template(template) + + llm_chain = LLMChain(prompt=prompt, llm=llm) + responses = llm_chain.run("你好") + logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m") + + diff --git a/model-providers/tests/server_unit_test/test_init_server.py b/model-providers/tests/server_unit_test/test_init_server.py deleted file mode 100644 index 96210b89..00000000 --- a/model-providers/tests/server_unit_test/test_init_server.py +++ /dev/null @@ -1,33 +0,0 @@ -import asyncio -import logging - -import pytest - -from model_providers import BootstrapWebBuilder - -logger = logging.getLogger(__name__) - - -@pytest.mark.requires("fastapi") -def test_init_server(logging_conf: dict) -> None: - try: - boot = ( - BootstrapWebBuilder() - .model_providers_cfg_path( - model_providers_cfg_path="/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers" - "/model_providers.yaml" - ) - .host(host="127.0.0.1") - .port(port=20000) - .build() - ) - boot.set_app_event(started_event=None) - boot.serve(logging_conf=logging_conf) - - async def pool_join_thread(): - await boot.join() - - asyncio.run(pool_join_thread()) - except SystemExit: - logger.info("SystemExit raised, exiting") - raise diff --git a/model-providers/tests/zhipuai_providers_test/model_providers.yaml b/model-providers/tests/zhipuai_providers_test/model_providers.yaml new file mode 100644 index 00000000..dd9a6fa7 --- /dev/null +++ b/model-providers/tests/zhipuai_providers_test/model_providers.yaml @@ -0,0 +1,4 @@ +zhipuai: + provider_credential: + api_key: 'e6a98ef1c54484c2afeac1ae8cef93ef.1' + api_base: 'https://test.bigmodel.cn/stage-api/paas/v4' \ No newline at end of file diff --git a/model-providers/tests/zhipuai_providers_test/test_zhipuai_service.py b/model-providers/tests/zhipuai_providers_test/test_zhipuai_service.py new file mode 100644 index 00000000..94992e3b --- /dev/null +++ b/model-providers/tests/zhipuai_providers_test/test_zhipuai_service.py @@ -0,0 +1,42 @@ +from langchain.chains import LLMChain +from langchain_core.prompts import PromptTemplate +from langchain_openai import ChatOpenAI +import pytest +import logging + +logger = logging.getLogger(__name__) + +@pytest.mark.requires("zhipuai") +def test_llm(init_server: str): + llm = ChatOpenAI( + + model_name="glm-4", + openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/zhipuai/v1") + template = """Question: {question} + + Answer: Let's think step by step.""" + + prompt = PromptTemplate.from_template(template) + + llm_chain = LLMChain(prompt=prompt, llm=llm) + responses = llm_chain.run("你好") + logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m") + + +@pytest.mark.requires("zhipuai") +def test_embedding(init_server: str): + llm = ChatOpenAI( + + model_name="glm-4", + openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/zhipuai/v1") + template = """Question: {question} + + Answer: Let's think step by step.""" + + prompt = PromptTemplate.from_template(template) + + llm_chain = LLMChain(prompt=prompt, llm=llm) + responses = llm_chain.run("你好") + logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m") + +