mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-06 23:15:53 +08:00
ollama 代码升级,使用openai协议
This commit is contained in:
parent
402153de09
commit
b243b3cfbc
@ -0,0 +1,60 @@
|
|||||||
|
from typing import Dict, List, Type
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from httpx import Timeout
|
||||||
|
|
||||||
|
from model_providers.core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeRateLimitError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _CommonOllama:
|
||||||
|
def _to_credential_kwargs(self, credentials: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Transform credentials to kwargs for model instance
|
||||||
|
|
||||||
|
:param credentials:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
credentials_kwargs = {
|
||||||
|
"openai_api_key": "Empty",
|
||||||
|
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
||||||
|
"max_retries": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
if "openai_api_base" in credentials and credentials["openai_api_base"]:
|
||||||
|
credentials["openai_api_base"] = credentials["openai_api_base"].rstrip("/")
|
||||||
|
credentials_kwargs["base_url"] = credentials["openai_api_base"] + "/v1"
|
||||||
|
|
||||||
|
return credentials_kwargs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||||
|
"""
|
||||||
|
Map model invoke error to unified error
|
||||||
|
The key is the error type thrown to the caller
|
||||||
|
The value is the error type thrown by the model,
|
||||||
|
which needs to be converted into a unified error type for the caller.
|
||||||
|
|
||||||
|
:return: Invoke error mapping
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError],
|
||||||
|
InvokeServerUnavailableError: [openai.InternalServerError],
|
||||||
|
InvokeRateLimitError: [openai.RateLimitError],
|
||||||
|
InvokeAuthorizationError: [
|
||||||
|
openai.AuthenticationError,
|
||||||
|
openai.PermissionDeniedError,
|
||||||
|
],
|
||||||
|
InvokeBadRequestError: [
|
||||||
|
openai.BadRequestError,
|
||||||
|
openai.NotFoundError,
|
||||||
|
openai.UnprocessableEntityError,
|
||||||
|
openai.APIError,
|
||||||
|
],
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@ -7,7 +7,7 @@ from model_providers.core.model_runtime.model_providers.__base.model_provider im
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIProvider(ModelProvider):
|
class OllamaProvider(ModelProvider):
|
||||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
"""
|
"""
|
||||||
Validate provider credentials
|
Validate provider credentials
|
||||||
|
|||||||
@ -11,7 +11,7 @@ help:
|
|||||||
en_US: How to integrate with Ollama
|
en_US: How to integrate with Ollama
|
||||||
zh_Hans: 如何集成 Ollama
|
zh_Hans: 如何集成 Ollama
|
||||||
url:
|
url:
|
||||||
en_US: https://docs.dify.ai/tutorials/model-configuration/ollama
|
en_US: "ollama"
|
||||||
supported_model_types:
|
supported_model_types:
|
||||||
- llm
|
- llm
|
||||||
- text-embedding
|
- text-embedding
|
||||||
@ -26,73 +26,13 @@ model_credential_schema:
|
|||||||
en_US: Enter your model name
|
en_US: Enter your model name
|
||||||
zh_Hans: 输入模型名称
|
zh_Hans: 输入模型名称
|
||||||
credential_form_schemas:
|
credential_form_schemas:
|
||||||
- variable: base_url
|
|
||||||
|
- variable: openai_api_base
|
||||||
label:
|
label:
|
||||||
zh_Hans: 基础 URL
|
zh_Hans: API Base
|
||||||
en_US: Base URL
|
en_US: API Base
|
||||||
type: text-input
|
type: text-input
|
||||||
required: true
|
|
||||||
placeholder:
|
|
||||||
zh_Hans: Ollama server 的基础 URL,例如 http://192.168.1.100:11434
|
|
||||||
en_US: Base url of Ollama server, e.g. http://192.168.1.100:11434
|
|
||||||
- variable: mode
|
|
||||||
show_on:
|
|
||||||
- variable: __model_type
|
|
||||||
value: llm
|
|
||||||
label:
|
|
||||||
zh_Hans: 模型类型
|
|
||||||
en_US: Completion mode
|
|
||||||
type: select
|
|
||||||
required: true
|
|
||||||
default: chat
|
|
||||||
placeholder:
|
|
||||||
zh_Hans: 选择对话类型
|
|
||||||
en_US: Select completion mode
|
|
||||||
options:
|
|
||||||
- value: completion
|
|
||||||
label:
|
|
||||||
en_US: Completion
|
|
||||||
zh_Hans: 补全
|
|
||||||
- value: chat
|
|
||||||
label:
|
|
||||||
en_US: Chat
|
|
||||||
zh_Hans: 对话
|
|
||||||
- variable: context_size
|
|
||||||
label:
|
|
||||||
zh_Hans: 模型上下文长度
|
|
||||||
en_US: Model context size
|
|
||||||
required: true
|
|
||||||
type: text-input
|
|
||||||
default: '4096'
|
|
||||||
placeholder:
|
|
||||||
zh_Hans: 在此输入您的模型上下文长度
|
|
||||||
en_US: Enter your Model context size
|
|
||||||
- variable: max_tokens
|
|
||||||
label:
|
|
||||||
zh_Hans: 最大 token 上限
|
|
||||||
en_US: Upper bound for max tokens
|
|
||||||
show_on:
|
|
||||||
- variable: __model_type
|
|
||||||
value: llm
|
|
||||||
default: '4096'
|
|
||||||
type: text-input
|
|
||||||
required: true
|
|
||||||
- variable: vision_support
|
|
||||||
label:
|
|
||||||
zh_Hans: 是否支持 Vision
|
|
||||||
en_US: Vision support
|
|
||||||
show_on:
|
|
||||||
- variable: __model_type
|
|
||||||
value: llm
|
|
||||||
default: 'false'
|
|
||||||
type: radio
|
|
||||||
required: false
|
required: false
|
||||||
options:
|
placeholder:
|
||||||
- value: 'true'
|
zh_Hans: 在此输入您的 API Base
|
||||||
label:
|
en_US: Enter your API Base
|
||||||
en_US: "Yes"
|
|
||||||
zh_Hans: 是
|
|
||||||
- value: 'false'
|
|
||||||
label:
|
|
||||||
en_US: "No"
|
|
||||||
zh_Hans: 否
|
|
||||||
@ -136,7 +136,7 @@ def init_server(logging_conf: dict, providers_file: str) -> None:
|
|||||||
yield f"http://127.0.0.1:20000"
|
yield f"http://127.0.0.1:20000"
|
||||||
finally:
|
finally:
|
||||||
print("")
|
print("")
|
||||||
boot.destroy()
|
# boot.destroy()
|
||||||
|
|
||||||
except SystemExit:
|
except SystemExit:
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,10 @@
|
|||||||
|
|
||||||
|
ollama:
|
||||||
|
model_credential:
|
||||||
|
- model: 'llama3'
|
||||||
|
model_type: 'llm'
|
||||||
|
model_credentials:
|
||||||
|
openai_api_base: 'http://172.21.80.1:11434'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -0,0 +1,34 @@
|
|||||||
|
from langchain.chains import LLMChain
|
||||||
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||||
|
import pytest
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_llm(init_server: str):
|
||||||
|
llm = ChatOpenAI(model_name="llama3", openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/ollama/v1")
|
||||||
|
template = """Question: {question}
|
||||||
|
|
||||||
|
Answer: Let's think step by step."""
|
||||||
|
|
||||||
|
prompt = PromptTemplate.from_template(template)
|
||||||
|
|
||||||
|
llm_chain = LLMChain(prompt=prompt, llm=llm)
|
||||||
|
responses = llm_chain.run("你好")
|
||||||
|
logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_embedding(init_server: str):
|
||||||
|
embeddings = OpenAIEmbeddings(model="text-embedding-3-large",
|
||||||
|
openai_api_key="YOUR_API_KEY",
|
||||||
|
openai_api_base=f"{init_server}/zhipuai/v1")
|
||||||
|
|
||||||
|
text = "你好"
|
||||||
|
|
||||||
|
query_result = embeddings.embed_query(text)
|
||||||
|
|
||||||
|
logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m")
|
||||||
@ -1,5 +1,10 @@
|
|||||||
openai:
|
|
||||||
provider_credential:
|
ollama:
|
||||||
openai_api_key: 'sk-'
|
model_credential:
|
||||||
openai_organization: ''
|
- model: 'llama3'
|
||||||
openai_api_base: ''
|
model_type: 'llm'
|
||||||
|
model_credentials:
|
||||||
|
openai_api_base: 'http://172.21.80.1:11434'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import logging
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("openai")
|
@pytest.mark.requires("openai")
|
||||||
def test_llm(init_server: str):
|
def test_llm(init_server: str):
|
||||||
llm = ChatOpenAI(openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/openai/v1")
|
llm = ChatOpenAI(openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/openai/v1")
|
||||||
10
model-providers/tests/unit_tests/ollama/model_providers.yaml
Normal file
10
model-providers/tests/unit_tests/ollama/model_providers.yaml
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
|
||||||
|
ollama:
|
||||||
|
model_credential:
|
||||||
|
- model: 'llama3'
|
||||||
|
model_type: 'llm'
|
||||||
|
model_credentials:
|
||||||
|
openai_api_base: 'http://172.21.80.1:11434'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -0,0 +1,42 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
|
||||||
|
from model_providers.core.model_manager import ModelManager
|
||||||
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from model_providers.core.provider_manager import ProviderManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ollama_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
|
||||||
|
logging.config.dictConfig(logging_conf) # type: ignore
|
||||||
|
# 读取配置文件
|
||||||
|
cfg = OmegaConf.load(
|
||||||
|
providers_file
|
||||||
|
)
|
||||||
|
# 转换配置文件
|
||||||
|
(
|
||||||
|
provider_name_to_provider_records_dict,
|
||||||
|
provider_name_to_provider_model_records_dict,
|
||||||
|
) = _to_custom_provide_configuration(cfg)
|
||||||
|
# 创建模型管理器
|
||||||
|
provider_manager = ProviderManager(
|
||||||
|
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
|
||||||
|
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_model_bundle_llm = provider_manager.get_provider_model_bundle(
|
||||||
|
provider="ollama", model_type=ModelType.LLM
|
||||||
|
)
|
||||||
|
provider_model_bundle_emb = provider_manager.get_provider_model_bundle(
|
||||||
|
provider="ollama", model_type=ModelType.TEXT_EMBEDDING
|
||||||
|
)
|
||||||
|
predefined_models = (
|
||||||
|
provider_model_bundle_llm.model_type_instance.predefined_models()
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"predefined_models: {predefined_models}")
|
||||||
Loading…
x
Reference in New Issue
Block a user