mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-31 03:03:22 +08:00
自定义模型测试
This commit is contained in:
parent
e7b4624687
commit
6075872ceb
@ -185,9 +185,17 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
provider=provider, model_type=model_type
|
||||
)
|
||||
)
|
||||
# 获取预定义模型
|
||||
llm_models.extend(
|
||||
provider_model_bundle.model_type_instance.predefined_models()
|
||||
)
|
||||
# 获取自定义模型
|
||||
for model in provider_model_bundle.configuration.custom_configuration.models:
|
||||
|
||||
llm_models.append(provider_model_bundle.model_type_instance.get_model_schema(
|
||||
model=model.model,
|
||||
credentials=model.credentials,
|
||||
))
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error while fetching models for provider: {provider}, model_type: {model_type}"
|
||||
|
||||
@ -0,0 +1,15 @@
|
||||
|
||||
|
||||
xinference:
|
||||
model_credential:
|
||||
- model: 'chatglm3-6b'
|
||||
model_type: 'llm'
|
||||
model_credentials:
|
||||
server_url: 'http://127.0.0.1:9997/'
|
||||
model_uid: 'chatglm3-6b'
|
||||
|
||||
- model: 'chatglm31-6b'
|
||||
model_type: 'llm'
|
||||
model_credentials:
|
||||
server_url: 'http://127.0.0.1:9997/'
|
||||
model_uid: 'chatglm3-6b'
|
||||
@ -0,0 +1,37 @@
|
||||
from langchain.chains import LLMChain
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||
import pytest
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.requires("xinference_client")
|
||||
def test_llm(init_server: str):
|
||||
llm = ChatOpenAI(
|
||||
|
||||
model_name="glm-4",
|
||||
openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/xinference/v1")
|
||||
template = """Question: {question}
|
||||
|
||||
Answer: Let's think step by step."""
|
||||
|
||||
prompt = PromptTemplate.from_template(template)
|
||||
|
||||
llm_chain = LLMChain(prompt=prompt, llm=llm)
|
||||
responses = llm_chain.run("你好")
|
||||
logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")
|
||||
|
||||
|
||||
@pytest.mark.requires("xinference-client")
|
||||
def test_embedding(init_server: str):
|
||||
embeddings = OpenAIEmbeddings(model="text_embedding",
|
||||
openai_api_key="YOUR_API_KEY",
|
||||
openai_api_base=f"{init_server}/xinference/v1")
|
||||
|
||||
text = "你好"
|
||||
|
||||
query_result = embeddings.embed_query(text)
|
||||
|
||||
logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m")
|
||||
@ -1,18 +1,19 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
|
||||
from model_providers.core.model_manager import ModelManager
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType, AIModelEntity
|
||||
from model_providers.core.provider_manager import ProviderManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def test_ollama_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
|
||||
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
|
||||
logging.config.dictConfig(logging_conf) # type: ignore
|
||||
# 读取配置文件
|
||||
cfg = OmegaConf.load(
|
||||
@ -32,8 +33,17 @@ def test_ollama_provider_manager_models(logging_conf: dict, providers_file: str)
|
||||
provider_model_bundle_llm = provider_manager.get_provider_model_bundle(
|
||||
provider="deepseek", model_type=ModelType.LLM
|
||||
)
|
||||
predefined_models = (
|
||||
llm_models: List[AIModelEntity] = []
|
||||
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
|
||||
|
||||
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema(
|
||||
model=model.model,
|
||||
credentials=model.credentials,
|
||||
))
|
||||
|
||||
# 获取预定义模型
|
||||
llm_models.extend(
|
||||
provider_model_bundle_llm.model_type_instance.predefined_models()
|
||||
)
|
||||
|
||||
logger.info(f"predefined_models: {predefined_models}")
|
||||
logger.info(f"predefined_models: {llm_models}")
|
||||
|
||||
@ -12,7 +12,7 @@ from model_providers.core.provider_manager import ProviderManager
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def test_ollama_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
|
||||
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
|
||||
logging.config.dictConfig(logging_conf) # type: ignore
|
||||
# 读取配置文件
|
||||
cfg = OmegaConf.load(
|
||||
@ -32,11 +32,18 @@ def test_ollama_provider_manager_models(logging_conf: dict, providers_file: str)
|
||||
provider_model_bundle_llm = provider_manager.get_provider_model_bundle(
|
||||
provider="ollama", model_type=ModelType.LLM
|
||||
)
|
||||
provider_model_bundle_emb = provider_manager.get_provider_model_bundle(
|
||||
provider="ollama", model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
predefined_models = (
|
||||
llm_models: List[AIModelEntity] = []
|
||||
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
|
||||
|
||||
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema(
|
||||
model=model.model,
|
||||
credentials=model.credentials,
|
||||
))
|
||||
|
||||
# 获取预定义模型
|
||||
llm_models.extend(
|
||||
provider_model_bundle_llm.model_type_instance.predefined_models()
|
||||
)
|
||||
|
||||
logger.info(f"predefined_models: {predefined_models}")
|
||||
logger.info(f"predefined_models: {llm_models}")
|
||||
|
||||
|
||||
@ -17,17 +17,3 @@ openai:
|
||||
openai_api_key: 'sk-'
|
||||
openai_organization: ''
|
||||
openai_api_base: ''
|
||||
|
||||
xinference:
|
||||
model_credential:
|
||||
- model: 'chatglm3-6b'
|
||||
model_type: 'llm'
|
||||
model_credentials:
|
||||
server_url: 'http://127.0.0.1:9997/'
|
||||
model_uid: 'chatglm3-6b'
|
||||
|
||||
|
||||
zhipuai:
|
||||
|
||||
provider_credential:
|
||||
api_key: 'd4fa0690b6dfa205204cae2e12aa6fb6.1'
|
||||
@ -32,11 +32,17 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
|
||||
provider_model_bundle_llm = provider_manager.get_provider_model_bundle(
|
||||
provider="openai", model_type=ModelType.LLM
|
||||
)
|
||||
provider_model_bundle_emb = provider_manager.get_provider_model_bundle(
|
||||
provider="openai", model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
predefined_models = (
|
||||
provider_model_bundle_emb.model_type_instance.predefined_models()
|
||||
llm_models: List[AIModelEntity] = []
|
||||
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
|
||||
|
||||
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema(
|
||||
model=model.model,
|
||||
credentials=model.credentials,
|
||||
))
|
||||
|
||||
# 获取预定义模型
|
||||
llm_models.extend(
|
||||
provider_model_bundle_llm.model_type_instance.predefined_models()
|
||||
)
|
||||
|
||||
logger.info(f"predefined_models: {predefined_models}")
|
||||
logger.info(f"predefined_models: {llm_models}")
|
||||
|
||||
@ -0,0 +1,9 @@
|
||||
|
||||
|
||||
xinference:
|
||||
model_credential:
|
||||
- model: 'chatglm3-6b'
|
||||
model_type: 'llm'
|
||||
model_credentials:
|
||||
server_url: 'http://127.0.0.1:9997/'
|
||||
model_uid: 'chatglm3-6b'
|
||||
@ -0,0 +1,48 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
|
||||
from model_providers.core.model_manager import ModelManager
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
from model_providers.core.provider_manager import ProviderManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
|
||||
logging.config.dictConfig(logging_conf) # type: ignore
|
||||
# 读取配置文件
|
||||
cfg = OmegaConf.load(
|
||||
providers_file
|
||||
)
|
||||
# 转换配置文件
|
||||
(
|
||||
provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict,
|
||||
) = _to_custom_provide_configuration(cfg)
|
||||
# 创建模型管理器
|
||||
provider_manager = ProviderManager(
|
||||
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
|
||||
)
|
||||
|
||||
provider_model_bundle_llm = provider_manager.get_provider_model_bundle(
|
||||
provider="xinference", model_type=ModelType.LLM
|
||||
)
|
||||
llm_models: List[AIModelEntity] = []
|
||||
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
|
||||
|
||||
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema(
|
||||
model=model.model,
|
||||
credentials=model.credentials,
|
||||
))
|
||||
|
||||
# 获取预定义模型
|
||||
llm_models.extend(
|
||||
provider_model_bundle_llm.model_type_instance.predefined_models()
|
||||
)
|
||||
|
||||
logger.info(f"predefined_models: {llm_models}")
|
||||
@ -0,0 +1,6 @@
|
||||
|
||||
|
||||
zhipuai:
|
||||
|
||||
provider_credential:
|
||||
api_key: 'd4fa0690b6dfa205204cae2e12aa6fb6.1'
|
||||
@ -0,0 +1,48 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
|
||||
from model_providers.core.model_manager import ModelManager
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
from model_providers.core.provider_manager import ProviderManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
|
||||
logging.config.dictConfig(logging_conf) # type: ignore
|
||||
# 读取配置文件
|
||||
cfg = OmegaConf.load(
|
||||
providers_file
|
||||
)
|
||||
# 转换配置文件
|
||||
(
|
||||
provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict,
|
||||
) = _to_custom_provide_configuration(cfg)
|
||||
# 创建模型管理器
|
||||
provider_manager = ProviderManager(
|
||||
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
|
||||
)
|
||||
|
||||
provider_model_bundle_llm = provider_manager.get_provider_model_bundle(
|
||||
provider="zhipuai", model_type=ModelType.LLM
|
||||
)
|
||||
llm_models: List[AIModelEntity] = []
|
||||
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
|
||||
|
||||
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema(
|
||||
model=model.model,
|
||||
credentials=model.credentials,
|
||||
))
|
||||
|
||||
# 获取预定义模型
|
||||
llm_models.extend(
|
||||
provider_model_bundle_llm.model_type_instance.predefined_models()
|
||||
)
|
||||
|
||||
logger.info(f"predefined_models: {llm_models}")
|
||||
Loading…
x
Reference in New Issue
Block a user