自定义模型测试

This commit is contained in:
glide-the 2024-05-19 19:54:00 +08:00
parent e7b4624687
commit 6075872ceb
11 changed files with 210 additions and 30 deletions

View File

@ -185,9 +185,17 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
provider=provider, model_type=model_type
)
)
# 获取预定义模型
llm_models.extend(
provider_model_bundle.model_type_instance.predefined_models()
)
# 获取自定义模型
for model in provider_model_bundle.configuration.custom_configuration.models:
llm_models.append(provider_model_bundle.model_type_instance.get_model_schema(
model=model.model,
credentials=model.credentials,
))
except Exception as e:
logger.error(
f"Error while fetching models for provider: {provider}, model_type: {model_type}"

View File

@ -0,0 +1,15 @@
xinference:
model_credential:
- model: 'chatglm3-6b'
model_type: 'llm'
model_credentials:
server_url: 'http://127.0.0.1:9997/'
model_uid: 'chatglm3-6b'
- model: 'chatglm31-6b'
model_type: 'llm'
model_credentials:
server_url: 'http://127.0.0.1:9997/'
model_uid: 'chatglm3-6b'

View File

@ -0,0 +1,37 @@
from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
import pytest
import logging
logger = logging.getLogger(__name__)
@pytest.mark.requires("xinference_client")
def test_llm(init_server: str):
llm = ChatOpenAI(
model_name="glm-4",
openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/xinference/v1")
template = """Question: {question}
Answer: Let's think step by step."""
prompt = PromptTemplate.from_template(template)
llm_chain = LLMChain(prompt=prompt, llm=llm)
responses = llm_chain.run("你好")
logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")
@pytest.mark.requires("xinference-client")
def test_embedding(init_server: str):
embeddings = OpenAIEmbeddings(model="text_embedding",
openai_api_key="YOUR_API_KEY",
openai_api_base=f"{init_server}/xinference/v1")
text = "你好"
query_result = embeddings.embed_query(text)
logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m")

View File

@ -1,18 +1,19 @@
import asyncio
import logging
from typing import List
import pytest
from omegaconf import OmegaConf
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.entities.model_entities import ModelType, AIModelEntity
from model_providers.core.provider_manager import ProviderManager
logger = logging.getLogger(__name__)
def test_ollama_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
logging.config.dictConfig(logging_conf) # type: ignore
# 读取配置文件
cfg = OmegaConf.load(
@ -32,8 +33,17 @@ def test_ollama_provider_manager_models(logging_conf: dict, providers_file: str)
provider_model_bundle_llm = provider_manager.get_provider_model_bundle(
provider="deepseek", model_type=ModelType.LLM
)
predefined_models = (
llm_models: List[AIModelEntity] = []
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema(
model=model.model,
credentials=model.credentials,
))
# 获取预定义模型
llm_models.extend(
provider_model_bundle_llm.model_type_instance.predefined_models()
)
logger.info(f"predefined_models: {predefined_models}")
logger.info(f"predefined_models: {llm_models}")

View File

@ -12,7 +12,7 @@ from model_providers.core.provider_manager import ProviderManager
logger = logging.getLogger(__name__)
def test_ollama_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
logging.config.dictConfig(logging_conf) # type: ignore
# 读取配置文件
cfg = OmegaConf.load(
@ -32,11 +32,18 @@ def test_ollama_provider_manager_models(logging_conf: dict, providers_file: str)
provider_model_bundle_llm = provider_manager.get_provider_model_bundle(
provider="ollama", model_type=ModelType.LLM
)
provider_model_bundle_emb = provider_manager.get_provider_model_bundle(
provider="ollama", model_type=ModelType.TEXT_EMBEDDING
)
predefined_models = (
llm_models: List[AIModelEntity] = []
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema(
model=model.model,
credentials=model.credentials,
))
# 获取预定义模型
llm_models.extend(
provider_model_bundle_llm.model_type_instance.predefined_models()
)
logger.info(f"predefined_models: {predefined_models}")
logger.info(f"predefined_models: {llm_models}")

View File

@ -17,17 +17,3 @@ openai:
openai_api_key: 'sk-'
openai_organization: ''
openai_api_base: ''
xinference:
model_credential:
- model: 'chatglm3-6b'
model_type: 'llm'
model_credentials:
server_url: 'http://127.0.0.1:9997/'
model_uid: 'chatglm3-6b'
zhipuai:
provider_credential:
api_key: 'd4fa0690b6dfa205204cae2e12aa6fb6.1'

View File

@ -32,11 +32,17 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
provider_model_bundle_llm = provider_manager.get_provider_model_bundle(
provider="openai", model_type=ModelType.LLM
)
provider_model_bundle_emb = provider_manager.get_provider_model_bundle(
provider="openai", model_type=ModelType.TEXT_EMBEDDING
)
predefined_models = (
provider_model_bundle_emb.model_type_instance.predefined_models()
llm_models: List[AIModelEntity] = []
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema(
model=model.model,
credentials=model.credentials,
))
# 获取预定义模型
llm_models.extend(
provider_model_bundle_llm.model_type_instance.predefined_models()
)
logger.info(f"predefined_models: {predefined_models}")
logger.info(f"predefined_models: {llm_models}")

View File

@ -0,0 +1,9 @@
xinference:
model_credential:
- model: 'chatglm3-6b'
model_type: 'llm'
model_credentials:
server_url: 'http://127.0.0.1:9997/'
model_uid: 'chatglm3-6b'

View File

@ -0,0 +1,48 @@
import asyncio
import logging
import pytest
from omegaconf import OmegaConf
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.provider_manager import ProviderManager
logger = logging.getLogger(__name__)
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
logging.config.dictConfig(logging_conf) # type: ignore
# 读取配置文件
cfg = OmegaConf.load(
providers_file
)
# 转换配置文件
(
provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict,
) = _to_custom_provide_configuration(cfg)
# 创建模型管理器
provider_manager = ProviderManager(
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
)
provider_model_bundle_llm = provider_manager.get_provider_model_bundle(
provider="xinference", model_type=ModelType.LLM
)
llm_models: List[AIModelEntity] = []
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema(
model=model.model,
credentials=model.credentials,
))
# 获取预定义模型
llm_models.extend(
provider_model_bundle_llm.model_type_instance.predefined_models()
)
logger.info(f"predefined_models: {llm_models}")

View File

@ -0,0 +1,6 @@
zhipuai:
provider_credential:
api_key: 'd4fa0690b6dfa205204cae2e12aa6fb6.1'

View File

@ -0,0 +1,48 @@
import asyncio
import logging
import pytest
from omegaconf import OmegaConf
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.provider_manager import ProviderManager
logger = logging.getLogger(__name__)
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
logging.config.dictConfig(logging_conf) # type: ignore
# 读取配置文件
cfg = OmegaConf.load(
providers_file
)
# 转换配置文件
(
provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict,
) = _to_custom_provide_configuration(cfg)
# 创建模型管理器
provider_manager = ProviderManager(
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
)
provider_model_bundle_llm = provider_manager.get_provider_model_bundle(
provider="zhipuai", model_type=ModelType.LLM
)
llm_models: List[AIModelEntity] = []
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema(
model=model.model,
credentials=model.credentials,
))
# 获取预定义模型
llm_models.extend(
provider_model_bundle_llm.model_type_instance.predefined_models()
)
logger.info(f"predefined_models: {llm_models}")