mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 13:23:16 +08:00
不同平台兼容测试用例
This commit is contained in:
parent
99a2be6970
commit
f30bdf01fc
@ -44,8 +44,8 @@ def init_server(model_platforms_shard: Dict,
|
||||
|
||||
provider_platforms = init_provider_platforms(boot.provider_manager.provider_manager)
|
||||
model_platforms_shard['provider_platforms'] = provider_platforms
|
||||
|
||||
boot.serve(logging_conf=logging_conf)
|
||||
boot.logging_conf(logging_conf=logging_conf)
|
||||
boot.run()
|
||||
|
||||
async def pool_join_thread():
|
||||
await boot.join()
|
||||
|
||||
@ -36,7 +36,8 @@ if __name__ == "__main__":
|
||||
.build()
|
||||
)
|
||||
boot.set_app_event(started_event=None)
|
||||
boot.serve(logging_conf=logging_conf)
|
||||
boot.logging_conf(logging_conf=logging_conf)
|
||||
boot.run()
|
||||
|
||||
async def pool_join_thread():
|
||||
await boot.join()
|
||||
|
||||
@ -18,6 +18,7 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
import uvicorn
|
||||
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sse_starlette import EventSourceResponse
|
||||
@ -67,8 +68,13 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
self._port = port
|
||||
self._router = APIRouter()
|
||||
self._app = FastAPI()
|
||||
self._logging_conf = None
|
||||
self._server = None
|
||||
self._server_thread = None
|
||||
|
||||
def logging_conf(self,logging_conf: Optional[dict] = None):
|
||||
self._logging_conf = logging_conf
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg=None):
|
||||
host = cfg.get("host", "127.0.0.1")
|
||||
@ -79,7 +85,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
)
|
||||
return cls(host=host, port=port)
|
||||
|
||||
def serve(self, logging_conf: Optional[dict] = None):
|
||||
def run(self):
|
||||
self._app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
@ -125,18 +131,29 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
self._app.include_router(self._router)
|
||||
|
||||
config = Config(
|
||||
app=self._app, host=self._host, port=self._port, log_config=logging_conf
|
||||
app=self._app, host=self._host, port=self._port, log_config=self._logging_conf
|
||||
)
|
||||
server = Server(config)
|
||||
self._server = Server(config)
|
||||
|
||||
def run_server():
|
||||
server.run()
|
||||
self._server.shutdown_timeout = 2 # 设置为2秒
|
||||
|
||||
self._server.run()
|
||||
|
||||
self._server_thread = threading.Thread(target=run_server)
|
||||
self._server_thread.start()
|
||||
|
||||
async def join(self):
|
||||
await self._server_thread.join()
|
||||
def destroy(self):
|
||||
|
||||
logger.info("Shutting down server")
|
||||
self._server.should_exit = True # 设置退出标志
|
||||
self._server.shutdown() # 停止服务器
|
||||
|
||||
self.join()
|
||||
|
||||
def join(self):
|
||||
self._server_thread.join()
|
||||
|
||||
|
||||
def set_app_event(self, started_event: mp.Event = None):
|
||||
@self._app.on_event("startup")
|
||||
@ -273,10 +290,11 @@ def run(
|
||||
cfg=cfg.get("run_openai_api", {})
|
||||
)
|
||||
api.set_app_event(started_event=started_event)
|
||||
api.serve(logging_conf=logging_conf)
|
||||
api.logging_conf(logging_conf=logging_conf)
|
||||
api.run()
|
||||
|
||||
async def pool_join_thread():
|
||||
await api.join()
|
||||
api.join()
|
||||
|
||||
asyncio.run(pool_join_thread())
|
||||
except SystemExit:
|
||||
|
||||
@ -24,6 +24,11 @@ class _CommonZhipuaiAI:
|
||||
else credentials["zhipuai_api_key"]
|
||||
if "zhipuai_api_key" in credentials
|
||||
else None,
|
||||
"api_base": credentials["api_base"]
|
||||
if "api_key" in credentials
|
||||
else credentials["zhipuai_api_base"]
|
||||
if "zhipuai_api_base" in credentials
|
||||
else None,
|
||||
}
|
||||
|
||||
return credentials_kwargs
|
||||
|
||||
@ -195,7 +195,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
if stop:
|
||||
extra_model_kwargs["stop"] = stop
|
||||
|
||||
client = ZhipuAI(api_key=credentials_kwargs["api_key"])
|
||||
client = ZhipuAI(base_url=credentials_kwargs["api_base"],
|
||||
api_key=credentials_kwargs["api_key"])
|
||||
|
||||
if len(prompt_messages) == 0:
|
||||
raise ValueError("At least one message is required")
|
||||
|
||||
@ -43,7 +43,8 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
|
||||
:return: embeddings result
|
||||
"""
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = ZhipuAI(api_key=credentials_kwargs["api_key"])
|
||||
client = ZhipuAI(base_url=credentials_kwargs["api_base"],
|
||||
api_key=credentials_kwargs["api_key"])
|
||||
|
||||
embeddings, embedding_used_tokens = self.embed_documents(model, client, texts)
|
||||
|
||||
@ -84,7 +85,8 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
|
||||
try:
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = ZhipuAI(api_key=credentials_kwargs["api_key"])
|
||||
client = ZhipuAI(base_url=credentials_kwargs["api_base"],
|
||||
api_key=credentials_kwargs["api_key"])
|
||||
|
||||
# call embedding model
|
||||
self.embed_documents(
|
||||
|
||||
@ -29,3 +29,12 @@ provider_credential_schema:
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 APIKey
|
||||
en_US: Enter your APIKey
|
||||
- variable: api_base
|
||||
label:
|
||||
zh_Hans: API Base
|
||||
en_US: API Base
|
||||
type: text-input
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Base
|
||||
en_US: Enter your API Base
|
||||
@ -34,8 +34,8 @@ pytest-asyncio = "^0.21.1"
|
||||
grandalf = "^0.8"
|
||||
pytest-profiling = "^1.7.0"
|
||||
responses = "^0.25.0"
|
||||
|
||||
|
||||
langchain = "0.1.5"
|
||||
langchain-openai = "0.0.5"
|
||||
|
||||
[tool.poetry.group.lint]
|
||||
optional = true
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Dict, List, Sequence
|
||||
import pytest
|
||||
from pytest import Config, Function, Parser
|
||||
|
||||
from model_providers import BootstrapWebBuilder
|
||||
from model_providers.core.utils.utils import (
|
||||
get_config_dict,
|
||||
get_log_file,
|
||||
@ -102,3 +103,41 @@ def logging_conf() -> dict:
|
||||
122,
|
||||
111,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def providers_file(request) -> str:
|
||||
from pathlib import Path
|
||||
import os
|
||||
# 当前执行目录
|
||||
# 获取当前测试文件的路径
|
||||
test_file_path = Path(str(request.fspath)).parent
|
||||
print("test_file_path:",test_file_path)
|
||||
return os.path.join(test_file_path,"model_providers.yaml")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.mark.requires("fastapi")
|
||||
def init_server(logging_conf: dict, providers_file: str) -> None:
|
||||
try:
|
||||
boot = (
|
||||
BootstrapWebBuilder()
|
||||
.model_providers_cfg_path(
|
||||
model_providers_cfg_path=providers_file
|
||||
)
|
||||
.host(host="127.0.0.1")
|
||||
.port(port=20000)
|
||||
.build()
|
||||
)
|
||||
boot.set_app_event(started_event=None)
|
||||
boot.logging_conf(logging_conf=logging_conf)
|
||||
boot.run()
|
||||
|
||||
try:
|
||||
yield f"http://127.0.0.1:20000"
|
||||
finally:
|
||||
print("")
|
||||
boot.destroy()
|
||||
|
||||
except SystemExit:
|
||||
|
||||
raise
|
||||
@ -0,0 +1,5 @@
|
||||
openai:
|
||||
provider_credential:
|
||||
openai_api_key: 'sk-'
|
||||
openai_organization: ''
|
||||
openai_api_base: ''
|
||||
@ -0,0 +1,22 @@
|
||||
from langchain.chains import LLMChain
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_openai import ChatOpenAI
|
||||
import pytest
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_llm(init_server: str):
|
||||
llm = ChatOpenAI(openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/openai/v1")
|
||||
template = """Question: {question}
|
||||
|
||||
Answer: Let's think step by step."""
|
||||
|
||||
prompt = PromptTemplate.from_template(template)
|
||||
|
||||
llm_chain = LLMChain(prompt=prompt, llm=llm)
|
||||
responses = llm_chain.run("你好")
|
||||
logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")
|
||||
|
||||
|
||||
@ -1,33 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from model_providers import BootstrapWebBuilder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.requires("fastapi")
|
||||
def test_init_server(logging_conf: dict) -> None:
|
||||
try:
|
||||
boot = (
|
||||
BootstrapWebBuilder()
|
||||
.model_providers_cfg_path(
|
||||
model_providers_cfg_path="/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
|
||||
"/model_providers.yaml"
|
||||
)
|
||||
.host(host="127.0.0.1")
|
||||
.port(port=20000)
|
||||
.build()
|
||||
)
|
||||
boot.set_app_event(started_event=None)
|
||||
boot.serve(logging_conf=logging_conf)
|
||||
|
||||
async def pool_join_thread():
|
||||
await boot.join()
|
||||
|
||||
asyncio.run(pool_join_thread())
|
||||
except SystemExit:
|
||||
logger.info("SystemExit raised, exiting")
|
||||
raise
|
||||
@ -0,0 +1,4 @@
|
||||
zhipuai:
|
||||
provider_credential:
|
||||
api_key: 'e6a98ef1c54484c2afeac1ae8cef93ef.1'
|
||||
api_base: 'https://test.bigmodel.cn/stage-api/paas/v4'
|
||||
@ -0,0 +1,42 @@
|
||||
from langchain.chains import LLMChain
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_openai import ChatOpenAI
|
||||
import pytest
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@pytest.mark.requires("zhipuai")
|
||||
def test_llm(init_server: str):
|
||||
llm = ChatOpenAI(
|
||||
|
||||
model_name="glm-4",
|
||||
openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/zhipuai/v1")
|
||||
template = """Question: {question}
|
||||
|
||||
Answer: Let's think step by step."""
|
||||
|
||||
prompt = PromptTemplate.from_template(template)
|
||||
|
||||
llm_chain = LLMChain(prompt=prompt, llm=llm)
|
||||
responses = llm_chain.run("你好")
|
||||
logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")
|
||||
|
||||
|
||||
@pytest.mark.requires("zhipuai")
|
||||
def test_embedding(init_server: str):
|
||||
llm = ChatOpenAI(
|
||||
|
||||
model_name="glm-4",
|
||||
openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/zhipuai/v1")
|
||||
template = """Question: {question}
|
||||
|
||||
Answer: Let's think step by step."""
|
||||
|
||||
prompt = PromptTemplate.from_template(template)
|
||||
|
||||
llm_chain = LLMChain(prompt=prompt, llm=llm)
|
||||
responses = llm_chain.run("你好")
|
||||
logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user