ollama 代码升级,使用openai协议

This commit is contained in:
glide-the 2024-05-09 17:35:58 +08:00
parent 402153de09
commit b243b3cfbc
11 changed files with 1049 additions and 516 deletions

View File

@ -0,0 +1,60 @@
from typing import Dict, List, Type
import openai
from httpx import Timeout
from model_providers.core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
class _CommonOllama:
def _to_credential_kwargs(self, credentials: dict) -> dict:
"""
Transform credentials to kwargs for model instance
:param credentials:
:return:
"""
credentials_kwargs = {
"openai_api_key": "Empty",
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
"max_retries": 1,
}
if "openai_api_base" in credentials and credentials["openai_api_base"]:
credentials["openai_api_base"] = credentials["openai_api_base"].rstrip("/")
credentials_kwargs["base_url"] = credentials["openai_api_base"] + "/v1"
return credentials_kwargs
@property
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError],
InvokeServerUnavailableError: [openai.InternalServerError],
InvokeRateLimitError: [openai.RateLimitError],
InvokeAuthorizationError: [
openai.AuthenticationError,
openai.PermissionDeniedError,
],
InvokeBadRequestError: [
openai.BadRequestError,
openai.NotFoundError,
openai.UnprocessableEntityError,
openai.APIError,
],
}

View File

@ -7,7 +7,7 @@ from model_providers.core.model_runtime.model_providers.__base.model_provider im
logger = logging.getLogger(__name__)
class OpenAIProvider(ModelProvider):
class OllamaProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials

View File

@ -11,7 +11,7 @@ help:
en_US: How to integrate with Ollama
zh_Hans: 如何集成 Ollama
url:
en_US: https://docs.dify.ai/tutorials/model-configuration/ollama
en_US: "ollama"
supported_model_types:
- llm
- text-embedding
@ -26,73 +26,13 @@ model_credential_schema:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: base_url
- variable: openai_api_base
label:
zh_Hans: 基础 URL
en_US: Base URL
zh_Hans: API Base
en_US: API Base
type: text-input
required: true
placeholder:
zh_Hans: Ollama server 的基础 URL例如 http://192.168.1.100:11434
en_US: Base url of Ollama server, e.g. http://192.168.1.100:11434
- variable: mode
show_on:
- variable: __model_type
value: llm
label:
zh_Hans: 模型类型
en_US: Completion mode
type: select
required: true
default: chat
placeholder:
zh_Hans: 选择对话类型
en_US: Select completion mode
options:
- value: completion
label:
en_US: Completion
zh_Hans: 补全
- value: chat
label:
en_US: Chat
zh_Hans: 对话
- variable: context_size
label:
zh_Hans: 模型上下文长度
en_US: Model context size
required: true
type: text-input
default: '4096'
placeholder:
zh_Hans: 在此输入您的模型上下文长度
en_US: Enter your Model context size
- variable: max_tokens
label:
zh_Hans: 最大 token 上限
en_US: Upper bound for max tokens
show_on:
- variable: __model_type
value: llm
default: '4096'
type: text-input
required: true
- variable: vision_support
label:
zh_Hans: 是否支持 Vision
en_US: Vision support
show_on:
- variable: __model_type
value: llm
default: 'false'
type: radio
required: false
options:
- value: 'true'
label:
en_US: "Yes"
zh_Hans:
- value: 'false'
label:
en_US: "No"
zh_Hans:
placeholder:
zh_Hans: 在此输入您的 API Base
en_US: Enter your API Base

View File

@ -136,7 +136,7 @@ def init_server(logging_conf: dict, providers_file: str) -> None:
yield f"http://127.0.0.1:20000"
finally:
print("")
boot.destroy()
# boot.destroy()
except SystemExit:

View File

@ -0,0 +1,10 @@
ollama:
model_credential:
- model: 'llama3'
model_type: 'llm'
model_credentials:
openai_api_base: 'http://172.21.80.1:11434'

View File

@ -0,0 +1,34 @@
from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
import pytest
import logging
logger = logging.getLogger(__name__)
@pytest.mark.requires("openai")
def test_llm(init_server: str):
llm = ChatOpenAI(model_name="llama3", openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/ollama/v1")
template = """Question: {question}
Answer: Let's think step by step."""
prompt = PromptTemplate.from_template(template)
llm_chain = LLMChain(prompt=prompt, llm=llm)
responses = llm_chain.run("你好")
logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")
@pytest.mark.requires("openai")
def test_embedding(init_server: str):
embeddings = OpenAIEmbeddings(model="text-embedding-3-large",
openai_api_key="YOUR_API_KEY",
openai_api_base=f"{init_server}/zhipuai/v1")
text = "你好"
query_result = embeddings.embed_query(text)
logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m")

View File

@ -1,5 +1,10 @@
openai:
provider_credential:
openai_api_key: 'sk-'
openai_organization: ''
openai_api_base: ''
ollama:
model_credential:
- model: 'llama3'
model_type: 'llm'
model_credentials:
openai_api_base: 'http://172.21.80.1:11434'

View File

@ -6,6 +6,7 @@ import logging
logger = logging.getLogger(__name__)
@pytest.mark.requires("openai")
def test_llm(init_server: str):
llm = ChatOpenAI(openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/openai/v1")

View File

@ -0,0 +1,10 @@
ollama:
model_credential:
- model: 'llama3'
model_type: 'llm'
model_credentials:
openai_api_base: 'http://172.21.80.1:11434'

View File

@ -0,0 +1,42 @@
import asyncio
import logging
import pytest
from omegaconf import OmegaConf
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.provider_manager import ProviderManager
logger = logging.getLogger(__name__)
def test_ollama_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
logging.config.dictConfig(logging_conf) # type: ignore
# 读取配置文件
cfg = OmegaConf.load(
providers_file
)
# 转换配置文件
(
provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict,
) = _to_custom_provide_configuration(cfg)
# 创建模型管理器
provider_manager = ProviderManager(
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
)
provider_model_bundle_llm = provider_manager.get_provider_model_bundle(
provider="ollama", model_type=ModelType.LLM
)
provider_model_bundle_emb = provider_manager.get_provider_model_bundle(
provider="ollama", model_type=ModelType.TEXT_EMBEDDING
)
predefined_models = (
provider_model_bundle_llm.model_type_instance.predefined_models()
)
logger.info(f"predefined_models: {predefined_models}")