不同平台兼容测试用例

This commit is contained in:
glide-the 2024-04-16 14:19:36 +08:00
parent 99a2be6970
commit f30bdf01fc
14 changed files with 164 additions and 49 deletions

View File

@ -44,8 +44,8 @@ def init_server(model_platforms_shard: Dict,
provider_platforms = init_provider_platforms(boot.provider_manager.provider_manager)
model_platforms_shard['provider_platforms'] = provider_platforms
boot.serve(logging_conf=logging_conf)
boot.logging_conf(logging_conf=logging_conf)
boot.run()
async def pool_join_thread():
await boot.join()

View File

@ -36,7 +36,8 @@ if __name__ == "__main__":
.build()
)
boot.set_app_event(started_event=None)
boot.serve(logging_conf=logging_conf)
boot.logging_conf(logging_conf=logging_conf)
boot.run()
async def pool_join_thread():
await boot.join()

View File

@ -18,6 +18,7 @@ from typing import (
cast,
)
import uvicorn
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status
from fastapi.middleware.cors import CORSMiddleware
from sse_starlette import EventSourceResponse
@ -67,8 +68,13 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
self._port = port
self._router = APIRouter()
self._app = FastAPI()
self._logging_conf = None
self._server = None
self._server_thread = None
def logging_conf(self,logging_conf: Optional[dict] = None):
self._logging_conf = logging_conf
@classmethod
def from_config(cls, cfg=None):
host = cfg.get("host", "127.0.0.1")
@ -79,7 +85,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
)
return cls(host=host, port=port)
def serve(self, logging_conf: Optional[dict] = None):
def run(self):
self._app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
@ -125,18 +131,29 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
self._app.include_router(self._router)
config = Config(
app=self._app, host=self._host, port=self._port, log_config=logging_conf
app=self._app, host=self._host, port=self._port, log_config=self._logging_conf
)
server = Server(config)
self._server = Server(config)
def run_server():
server.run()
self._server.shutdown_timeout = 2 # 设置为2秒
self._server.run()
self._server_thread = threading.Thread(target=run_server)
self._server_thread.start()
async def join(self):
await self._server_thread.join()
def destroy(self):
logger.info("Shutting down server")
self._server.should_exit = True # 设置退出标志
self._server.shutdown() # 停止服务器
self.join()
def join(self):
self._server_thread.join()
def set_app_event(self, started_event: mp.Event = None):
@self._app.on_event("startup")
@ -273,10 +290,11 @@ def run(
cfg=cfg.get("run_openai_api", {})
)
api.set_app_event(started_event=started_event)
api.serve(logging_conf=logging_conf)
api.logging_conf(logging_conf=logging_conf)
api.run()
async def pool_join_thread():
await api.join()
api.join()
asyncio.run(pool_join_thread())
except SystemExit:

View File

@ -24,6 +24,11 @@ class _CommonZhipuaiAI:
else credentials["zhipuai_api_key"]
if "zhipuai_api_key" in credentials
else None,
"api_base": credentials["api_base"]
if "api_key" in credentials
else credentials["zhipuai_api_base"]
if "zhipuai_api_base" in credentials
else None,
}
return credentials_kwargs

View File

@ -195,7 +195,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
if stop:
extra_model_kwargs["stop"] = stop
client = ZhipuAI(api_key=credentials_kwargs["api_key"])
client = ZhipuAI(base_url=credentials_kwargs["api_base"],
api_key=credentials_kwargs["api_key"])
if len(prompt_messages) == 0:
raise ValueError("At least one message is required")

View File

@ -43,7 +43,8 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
:return: embeddings result
"""
credentials_kwargs = self._to_credential_kwargs(credentials)
client = ZhipuAI(api_key=credentials_kwargs["api_key"])
client = ZhipuAI(base_url=credentials_kwargs["api_base"],
api_key=credentials_kwargs["api_key"])
embeddings, embedding_used_tokens = self.embed_documents(model, client, texts)
@ -84,7 +85,8 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
try:
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
client = ZhipuAI(api_key=credentials_kwargs["api_key"])
client = ZhipuAI(base_url=credentials_kwargs["api_base"],
api_key=credentials_kwargs["api_key"])
# call embedding model
self.embed_documents(

View File

@ -29,3 +29,12 @@ provider_credential_schema:
placeholder:
zh_Hans: 在此输入您的 APIKey
en_US: Enter your APIKey
- variable: api_base
label:
zh_Hans: API Base
en_US: API Base
type: text-input
required: false
placeholder:
zh_Hans: 在此输入您的 API Base
en_US: Enter your API Base

View File

@ -34,8 +34,8 @@ pytest-asyncio = "^0.21.1"
grandalf = "^0.8"
pytest-profiling = "^1.7.0"
responses = "^0.25.0"
langchain = "0.1.5"
langchain-openai = "0.0.5"
[tool.poetry.group.lint]
optional = true

View File

@ -6,6 +6,7 @@ from typing import Dict, List, Sequence
import pytest
from pytest import Config, Function, Parser
from model_providers import BootstrapWebBuilder
from model_providers.core.utils.utils import (
get_config_dict,
get_log_file,
@ -102,3 +103,41 @@ def logging_conf() -> dict:
122,
111,
)
@pytest.fixture
def providers_file(request) -> str:
from pathlib import Path
import os
# 当前执行目录
# 获取当前测试文件的路径
test_file_path = Path(str(request.fspath)).parent
print("test_file_path:",test_file_path)
return os.path.join(test_file_path,"model_providers.yaml")
@pytest.fixture
@pytest.mark.requires("fastapi")
def init_server(logging_conf: dict, providers_file: str) -> None:
try:
boot = (
BootstrapWebBuilder()
.model_providers_cfg_path(
model_providers_cfg_path=providers_file
)
.host(host="127.0.0.1")
.port(port=20000)
.build()
)
boot.set_app_event(started_event=None)
boot.logging_conf(logging_conf=logging_conf)
boot.run()
try:
yield f"http://127.0.0.1:20000"
finally:
print("")
boot.destroy()
except SystemExit:
raise

View File

@ -0,0 +1,5 @@
openai:
provider_credential:
openai_api_key: 'sk-'
openai_organization: ''
openai_api_base: ''

View File

@ -0,0 +1,22 @@
from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
import pytest
import logging
logger = logging.getLogger(__name__)
@pytest.mark.requires("openai")
def test_llm(init_server: str):
llm = ChatOpenAI(openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/openai/v1")
template = """Question: {question}
Answer: Let's think step by step."""
prompt = PromptTemplate.from_template(template)
llm_chain = LLMChain(prompt=prompt, llm=llm)
responses = llm_chain.run("你好")
logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")

View File

@ -1,33 +0,0 @@
import asyncio
import logging
import pytest
from model_providers import BootstrapWebBuilder
logger = logging.getLogger(__name__)
@pytest.mark.requires("fastapi")
def test_init_server(logging_conf: dict) -> None:
try:
boot = (
BootstrapWebBuilder()
.model_providers_cfg_path(
model_providers_cfg_path="/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
"/model_providers.yaml"
)
.host(host="127.0.0.1")
.port(port=20000)
.build()
)
boot.set_app_event(started_event=None)
boot.serve(logging_conf=logging_conf)
async def pool_join_thread():
await boot.join()
asyncio.run(pool_join_thread())
except SystemExit:
logger.info("SystemExit raised, exiting")
raise

View File

@ -0,0 +1,4 @@
zhipuai:
provider_credential:
api_key: 'e6a98ef1c54484c2afeac1ae8cef93ef.1'
api_base: 'https://test.bigmodel.cn/stage-api/paas/v4'

View File

@ -0,0 +1,42 @@
from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
import pytest
import logging
logger = logging.getLogger(__name__)
@pytest.mark.requires("zhipuai")
def test_llm(init_server: str):
llm = ChatOpenAI(
model_name="glm-4",
openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/zhipuai/v1")
template = """Question: {question}
Answer: Let's think step by step."""
prompt = PromptTemplate.from_template(template)
llm_chain = LLMChain(prompt=prompt, llm=llm)
responses = llm_chain.run("你好")
logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")
@pytest.mark.requires("zhipuai")
def test_embedding(init_server: str):
llm = ChatOpenAI(
model_name="glm-4",
openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/zhipuai/v1")
template = """Question: {question}
Answer: Let's think step by step."""
prompt = PromptTemplate.from_template(template)
llm_chain = LLMChain(prompt=prompt, llm=llm)
responses = llm_chain.run("你好")
logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")