mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 21:37:20 +08:00
ollama 代码升级,使用openai协议
This commit is contained in:
parent
402153de09
commit
b243b3cfbc
@ -0,0 +1,60 @@
|
||||
from typing import Dict, List, Type
|
||||
|
||||
import openai
|
||||
from httpx import Timeout
|
||||
|
||||
from model_providers.core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
|
||||
|
||||
class _CommonOllama:
|
||||
def _to_credential_kwargs(self, credentials: dict) -> dict:
|
||||
"""
|
||||
Transform credentials to kwargs for model instance
|
||||
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
credentials_kwargs = {
|
||||
"openai_api_key": "Empty",
|
||||
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
||||
"max_retries": 1,
|
||||
}
|
||||
|
||||
if "openai_api_base" in credentials and credentials["openai_api_base"]:
|
||||
credentials["openai_api_base"] = credentials["openai_api_base"].rstrip("/")
|
||||
credentials_kwargs["base_url"] = credentials["openai_api_base"] + "/v1"
|
||||
|
||||
return credentials_kwargs
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError],
|
||||
InvokeServerUnavailableError: [openai.InternalServerError],
|
||||
InvokeRateLimitError: [openai.RateLimitError],
|
||||
InvokeAuthorizationError: [
|
||||
openai.AuthenticationError,
|
||||
openai.PermissionDeniedError,
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
openai.BadRequestError,
|
||||
openai.NotFoundError,
|
||||
openai.UnprocessableEntityError,
|
||||
openai.APIError,
|
||||
],
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@ -7,7 +7,7 @@ from model_providers.core.model_runtime.model_providers.__base.model_provider im
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIProvider(ModelProvider):
|
||||
class OllamaProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
|
||||
@ -11,7 +11,7 @@ help:
|
||||
en_US: How to integrate with Ollama
|
||||
zh_Hans: 如何集成 Ollama
|
||||
url:
|
||||
en_US: https://docs.dify.ai/tutorials/model-configuration/ollama
|
||||
en_US: "ollama"
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
@ -26,73 +26,13 @@ model_credential_schema:
|
||||
en_US: Enter your model name
|
||||
zh_Hans: 输入模型名称
|
||||
credential_form_schemas:
|
||||
- variable: base_url
|
||||
|
||||
- variable: openai_api_base
|
||||
label:
|
||||
zh_Hans: 基础 URL
|
||||
en_US: Base URL
|
||||
zh_Hans: API Base
|
||||
en_US: API Base
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: Ollama server 的基础 URL,例如 http://192.168.1.100:11434
|
||||
en_US: Base url of Ollama server, e.g. http://192.168.1.100:11434
|
||||
- variable: mode
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
label:
|
||||
zh_Hans: 模型类型
|
||||
en_US: Completion mode
|
||||
type: select
|
||||
required: true
|
||||
default: chat
|
||||
placeholder:
|
||||
zh_Hans: 选择对话类型
|
||||
en_US: Select completion mode
|
||||
options:
|
||||
- value: completion
|
||||
label:
|
||||
en_US: Completion
|
||||
zh_Hans: 补全
|
||||
- value: chat
|
||||
label:
|
||||
en_US: Chat
|
||||
zh_Hans: 对话
|
||||
- variable: context_size
|
||||
label:
|
||||
zh_Hans: 模型上下文长度
|
||||
en_US: Model context size
|
||||
required: true
|
||||
type: text-input
|
||||
default: '4096'
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型上下文长度
|
||||
en_US: Enter your Model context size
|
||||
- variable: max_tokens
|
||||
label:
|
||||
zh_Hans: 最大 token 上限
|
||||
en_US: Upper bound for max tokens
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
default: '4096'
|
||||
type: text-input
|
||||
required: true
|
||||
- variable: vision_support
|
||||
label:
|
||||
zh_Hans: 是否支持 Vision
|
||||
en_US: Vision support
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
default: 'false'
|
||||
type: radio
|
||||
required: false
|
||||
options:
|
||||
- value: 'true'
|
||||
label:
|
||||
en_US: "Yes"
|
||||
zh_Hans: 是
|
||||
- value: 'false'
|
||||
label:
|
||||
en_US: "No"
|
||||
zh_Hans: 否
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Base
|
||||
en_US: Enter your API Base
|
||||
@ -136,7 +136,7 @@ def init_server(logging_conf: dict, providers_file: str) -> None:
|
||||
yield f"http://127.0.0.1:20000"
|
||||
finally:
|
||||
print("")
|
||||
boot.destroy()
|
||||
# boot.destroy()
|
||||
|
||||
except SystemExit:
|
||||
|
||||
|
||||
@ -0,0 +1,10 @@
|
||||
|
||||
ollama:
|
||||
model_credential:
|
||||
- model: 'llama3'
|
||||
model_type: 'llm'
|
||||
model_credentials:
|
||||
openai_api_base: 'http://172.21.80.1:11434'
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1,34 @@
|
||||
from langchain.chains import LLMChain
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||
import pytest
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_llm(init_server: str):
|
||||
llm = ChatOpenAI(model_name="llama3", openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/ollama/v1")
|
||||
template = """Question: {question}
|
||||
|
||||
Answer: Let's think step by step."""
|
||||
|
||||
prompt = PromptTemplate.from_template(template)
|
||||
|
||||
llm_chain = LLMChain(prompt=prompt, llm=llm)
|
||||
responses = llm_chain.run("你好")
|
||||
logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_embedding(init_server: str):
|
||||
embeddings = OpenAIEmbeddings(model="text-embedding-3-large",
|
||||
openai_api_key="YOUR_API_KEY",
|
||||
openai_api_base=f"{init_server}/zhipuai/v1")
|
||||
|
||||
text = "你好"
|
||||
|
||||
query_result = embeddings.embed_query(text)
|
||||
|
||||
logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m")
|
||||
@ -1,5 +1,10 @@
|
||||
openai:
|
||||
provider_credential:
|
||||
openai_api_key: 'sk-'
|
||||
openai_organization: ''
|
||||
openai_api_base: ''
|
||||
|
||||
ollama:
|
||||
model_credential:
|
||||
- model: 'llama3'
|
||||
model_type: 'llm'
|
||||
model_credentials:
|
||||
openai_api_base: 'http://172.21.80.1:11434'
|
||||
|
||||
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_llm(init_server: str):
|
||||
llm = ChatOpenAI(openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/openai/v1")
|
||||
10
model-providers/tests/unit_tests/ollama/model_providers.yaml
Normal file
10
model-providers/tests/unit_tests/ollama/model_providers.yaml
Normal file
@ -0,0 +1,10 @@
|
||||
|
||||
ollama:
|
||||
model_credential:
|
||||
- model: 'llama3'
|
||||
model_type: 'llm'
|
||||
model_credentials:
|
||||
openai_api_base: 'http://172.21.80.1:11434'
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1,42 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
|
||||
from model_providers.core.model_manager import ModelManager
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
from model_providers.core.provider_manager import ProviderManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def test_ollama_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
|
||||
logging.config.dictConfig(logging_conf) # type: ignore
|
||||
# 读取配置文件
|
||||
cfg = OmegaConf.load(
|
||||
providers_file
|
||||
)
|
||||
# 转换配置文件
|
||||
(
|
||||
provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict,
|
||||
) = _to_custom_provide_configuration(cfg)
|
||||
# 创建模型管理器
|
||||
provider_manager = ProviderManager(
|
||||
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
|
||||
)
|
||||
|
||||
provider_model_bundle_llm = provider_manager.get_provider_model_bundle(
|
||||
provider="ollama", model_type=ModelType.LLM
|
||||
)
|
||||
provider_model_bundle_emb = provider_manager.get_provider_model_bundle(
|
||||
provider="ollama", model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
predefined_models = (
|
||||
provider_model_bundle_llm.model_type_instance.predefined_models()
|
||||
)
|
||||
|
||||
logger.info(f"predefined_models: {predefined_models}")
|
||||
Loading…
x
Reference in New Issue
Block a user