mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-26 16:53:36 +08:00
支持deepseek客户端
This commit is contained in:
parent
b243b3cfbc
commit
5bd8a4ed8e
File diff suppressed because it is too large
Load Diff
|
After Width: | Height: | Size: 89 KiB |
File diff suppressed because it is too large
Load Diff
|
After Width: | Height: | Size: 89 KiB |
@ -0,0 +1,59 @@
|
||||
from typing import Dict, List, Type
|
||||
|
||||
import openai
|
||||
from httpx import Timeout
|
||||
|
||||
from model_providers.core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
|
||||
|
||||
class _CommonDeepseek:
|
||||
def _to_credential_kwargs(self, credentials: dict) -> dict:
|
||||
"""
|
||||
Transform credentials to kwargs for model instance
|
||||
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
credentials_kwargs = {
|
||||
"api_key": credentials["api_key"],
|
||||
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
||||
"max_retries": 1,
|
||||
}
|
||||
|
||||
if "base_url" in credentials and credentials["base_url"]:
|
||||
credentials_kwargs["base_url"] = credentials["base_url"]
|
||||
|
||||
return credentials_kwargs
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError],
|
||||
InvokeServerUnavailableError: [openai.InternalServerError],
|
||||
InvokeRateLimitError: [openai.RateLimitError],
|
||||
InvokeAuthorizationError: [
|
||||
openai.AuthenticationError,
|
||||
openai.PermissionDeniedError,
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
openai.BadRequestError,
|
||||
openai.NotFoundError,
|
||||
openai.UnprocessableEntityError,
|
||||
openai.APIError,
|
||||
],
|
||||
}
|
||||
@ -0,0 +1,18 @@
|
||||
import logging
|
||||
|
||||
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
|
||||
ModelProvider,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeepseekProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
pass
|
||||
@ -0,0 +1,44 @@
|
||||
provider: deepseek
|
||||
label:
|
||||
en_US: Deepseek
|
||||
icon_large:
|
||||
en_US: icon_l_en.svg
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
background: "#F9FAFB"
|
||||
help:
|
||||
title:
|
||||
en_US: How to integrate with Deepseek
|
||||
zh_Hans: 如何集成 Deepseek
|
||||
url:
|
||||
en_US: "deepseek"
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
en_US: Model Name
|
||||
zh_Hans: 模型名称
|
||||
placeholder:
|
||||
en_US: Enter your model name
|
||||
zh_Hans: 输入模型名称
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: base_url
|
||||
label:
|
||||
zh_Hans: API Base
|
||||
en_US: API Base
|
||||
type: text-input
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Base
|
||||
en_US: Enter your API Base
|
||||
File diff suppressed because it is too large
Load Diff
@ -27,9 +27,8 @@ class _CommonOllama:
|
||||
"max_retries": 1,
|
||||
}
|
||||
|
||||
if "openai_api_base" in credentials and credentials["openai_api_base"]:
|
||||
credentials["openai_api_base"] = credentials["openai_api_base"].rstrip("/")
|
||||
credentials_kwargs["base_url"] = credentials["openai_api_base"] + "/v1"
|
||||
if "base_url" in credentials and credentials["base_url"]:
|
||||
credentials_kwargs["base_url"] = credentials["base_url"]
|
||||
|
||||
return credentials_kwargs
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import List, Optional, Union, cast
|
||||
from decimal import Decimal
|
||||
|
||||
import tiktoken
|
||||
from openai import OpenAI, Stream
|
||||
@ -39,7 +40,7 @@ from model_providers.core.model_runtime.entities.model_entities import (
|
||||
FetchFrom,
|
||||
I18nObject,
|
||||
ModelType,
|
||||
PriceConfig,
|
||||
PriceConfig, ModelFeature, ModelPropertyKey, DefaultParameterName, ParameterRule, ParameterType,
|
||||
)
|
||||
from model_providers.core.model_runtime.errors.validate import (
|
||||
CredentialsValidateFailedError,
|
||||
@ -1116,47 +1117,223 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
||||
return num_tokens
|
||||
|
||||
def get_customizable_model_schema(
|
||||
self, model: str, credentials: dict
|
||||
self, model: str, credentials: dict
|
||||
) -> AIModelEntity:
|
||||
"""
|
||||
OpenAI supports fine-tuning of their models. This method returns the schema of the base model
|
||||
but renamed to the fine-tuned model name.
|
||||
Get customizable model schema.
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
|
||||
:return: model schema
|
||||
"""
|
||||
if not model.startswith("ft:"):
|
||||
base_model = model
|
||||
else:
|
||||
# get base_model
|
||||
base_model = model.split(":")[1]
|
||||
extras = {}
|
||||
|
||||
# get model schema
|
||||
models = self.predefined_models()
|
||||
model_map = {model.model: model for model in models}
|
||||
if base_model not in model_map:
|
||||
raise ValueError(f"Base model {base_model} not found")
|
||||
|
||||
base_model_schema = model_map[base_model]
|
||||
|
||||
base_model_schema_features = base_model_schema.features or []
|
||||
base_model_schema_model_properties = base_model_schema.model_properties or {}
|
||||
base_model_schema_parameters_rules = base_model_schema.parameter_rules or []
|
||||
if "vision_support" in credentials and credentials["vision_support"] == "true":
|
||||
extras["features"] = [ModelFeature.VISION]
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(zh_Hans=model, en_US=model),
|
||||
model_type=ModelType.LLM,
|
||||
features=[feature for feature in base_model_schema_features],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
key: property
|
||||
for key, property in base_model_schema_model_properties.items()
|
||||
ModelPropertyKey.MODE: credentials.get("mode"),
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(
|
||||
credentials.get("context_size", 4096)
|
||||
),
|
||||
},
|
||||
parameter_rules=[rule for rule in base_model_schema_parameters_rules],
|
||||
pricing=base_model_schema.pricing,
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.TEMPERATURE.value,
|
||||
use_template=DefaultParameterName.TEMPERATURE.value,
|
||||
label=I18nObject(en_US="Temperature"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(
|
||||
en_US="The temperature of the model. "
|
||||
"Increasing the temperature will make the model answer "
|
||||
"more creatively. (Default: 0.8)"
|
||||
),
|
||||
default=0.8,
|
||||
min=0,
|
||||
max=2,
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.TOP_P.value,
|
||||
use_template=DefaultParameterName.TOP_P.value,
|
||||
label=I18nObject(en_US="Top P"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(
|
||||
en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to "
|
||||
"more diverse text, while a lower value (e.g., 0.5) will generate more "
|
||||
"focused and conservative text. (Default: 0.9)"
|
||||
),
|
||||
default=0.9,
|
||||
min=0,
|
||||
max=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_k",
|
||||
label=I18nObject(en_US="Top K"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(
|
||||
en_US="Reduces the probability of generating nonsense. "
|
||||
"A higher value (e.g. 100) will give more diverse answers, "
|
||||
"while a lower value (e.g. 10) will be more conservative. (Default: 40)"
|
||||
),
|
||||
default=40,
|
||||
min=1,
|
||||
max=100,
|
||||
),
|
||||
ParameterRule(
|
||||
name="repeat_penalty",
|
||||
label=I18nObject(en_US="Repeat Penalty"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(
|
||||
en_US="Sets how strongly to penalize repetitions. "
|
||||
"A higher value (e.g., 1.5) will penalize repetitions more strongly, "
|
||||
"while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"
|
||||
),
|
||||
default=1.1,
|
||||
min=-2,
|
||||
max=2,
|
||||
),
|
||||
ParameterRule(
|
||||
name="num_predict",
|
||||
use_template="max_tokens",
|
||||
label=I18nObject(en_US="Num Predict"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(
|
||||
en_US="Maximum number of tokens to predict when generating text. "
|
||||
"(Default: 128, -1 = infinite generation, -2 = fill context)"
|
||||
),
|
||||
default=128,
|
||||
min=-2,
|
||||
max=int(credentials.get("max_tokens", 4096)),
|
||||
),
|
||||
ParameterRule(
|
||||
name="mirostat",
|
||||
label=I18nObject(en_US="Mirostat sampling"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(
|
||||
en_US="Enable Mirostat sampling for controlling perplexity. "
|
||||
"(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"
|
||||
),
|
||||
default=0,
|
||||
min=0,
|
||||
max=2,
|
||||
),
|
||||
ParameterRule(
|
||||
name="mirostat_eta",
|
||||
label=I18nObject(en_US="Mirostat Eta"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(
|
||||
en_US="Influences how quickly the algorithm responds to feedback from "
|
||||
"the generated text. A lower learning rate will result in slower adjustments, "
|
||||
"while a higher learning rate will make the algorithm more responsive. "
|
||||
"(Default: 0.1)"
|
||||
),
|
||||
default=0.1,
|
||||
precision=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name="mirostat_tau",
|
||||
label=I18nObject(en_US="Mirostat Tau"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(
|
||||
en_US="Controls the balance between coherence and diversity of the output. "
|
||||
"A lower value will result in more focused and coherent text. (Default: 5.0)"
|
||||
),
|
||||
default=5.0,
|
||||
precision=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name="num_ctx",
|
||||
label=I18nObject(en_US="Size of context window"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(
|
||||
en_US="Sets the size of the context window used to generate the next token. "
|
||||
"(Default: 2048)"
|
||||
),
|
||||
default=2048,
|
||||
min=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name="num_gpu",
|
||||
label=I18nObject(en_US="Num GPU"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(
|
||||
en_US="The number of layers to send to the GPU(s). "
|
||||
"On macOS it defaults to 1 to enable metal support, 0 to disable."
|
||||
),
|
||||
default=1,
|
||||
min=0,
|
||||
max=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name="num_thread",
|
||||
label=I18nObject(en_US="Num Thread"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(
|
||||
en_US="Sets the number of threads to use during computation. "
|
||||
"By default, Ollama will detect this for optimal performance. "
|
||||
"It is recommended to set this value to the number of physical CPU cores "
|
||||
"your system has (as opposed to the logical number of cores)."
|
||||
),
|
||||
min=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name="repeat_last_n",
|
||||
label=I18nObject(en_US="Repeat last N"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(
|
||||
en_US="Sets how far back for the model to look back to prevent repetition. "
|
||||
"(Default: 64, 0 = disabled, -1 = num_ctx)"
|
||||
),
|
||||
default=64,
|
||||
min=-1,
|
||||
),
|
||||
ParameterRule(
|
||||
name="tfs_z",
|
||||
label=I18nObject(en_US="TFS Z"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(
|
||||
en_US="Tail free sampling is used to reduce the impact of less probable tokens "
|
||||
"from the output. A higher value (e.g., 2.0) will reduce the impact more, "
|
||||
"while a value of 1.0 disables this setting. (default: 1)"
|
||||
),
|
||||
default=1,
|
||||
precision=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name="seed",
|
||||
label=I18nObject(en_US="Seed"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(
|
||||
en_US="Sets the random number seed to use for generation. Setting this to "
|
||||
"a specific number will make the model generate the same text for "
|
||||
"the same prompt. (Default: 0)"
|
||||
),
|
||||
default=0,
|
||||
),
|
||||
ParameterRule(
|
||||
name="format",
|
||||
label=I18nObject(en_US="Format"),
|
||||
type=ParameterType.STRING,
|
||||
help=I18nObject(
|
||||
en_US="the format to return a response in."
|
||||
" Currently the only accepted value is json."
|
||||
),
|
||||
options=["json"],
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=Decimal(credentials.get("input_price", 0)),
|
||||
output=Decimal(credentials.get("output_price", 0)),
|
||||
unit=Decimal(credentials.get("unit", 0)),
|
||||
currency=credentials.get("currency", "USD"),
|
||||
),
|
||||
**extras,
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
@ -27,7 +27,7 @@ model_credential_schema:
|
||||
zh_Hans: 输入模型名称
|
||||
credential_form_schemas:
|
||||
|
||||
- variable: openai_api_base
|
||||
- variable: base_url
|
||||
label:
|
||||
zh_Hans: API Base
|
||||
en_US: API Base
|
||||
@ -35,4 +35,4 @@ model_credential_schema:
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Base
|
||||
en_US: Enter your API Base
|
||||
en_US: Enter your API Base
|
||||
|
||||
@ -0,0 +1,11 @@
|
||||
|
||||
deepseek:
|
||||
model_credential:
|
||||
- model: 'deepseek-chat'
|
||||
model_type: 'llm'
|
||||
model_credentials:
|
||||
base_url: 'https://api.deepseek.com'
|
||||
api_key: 'sk-dcb625fcbc1e497d80b7b9493b51d758'
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1,21 @@
|
||||
from langchain.chains import LLMChain
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||
import pytest
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_llm(init_server: str):
|
||||
llm = ChatOpenAI(model_name="deepseek-chat", openai_api_key="sk-dcb625fcbc1e497d80b7b9493b51d758", openai_api_base=f"{init_server}/deepseek/v1")
|
||||
template = """Question: {question}
|
||||
|
||||
Answer: Let's think step by step."""
|
||||
|
||||
prompt = PromptTemplate.from_template(template)
|
||||
|
||||
llm_chain = LLMChain(prompt=prompt, llm=llm)
|
||||
responses = llm_chain.run("你好")
|
||||
logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")
|
||||
@ -4,7 +4,7 @@ ollama:
|
||||
- model: 'llama3'
|
||||
model_type: 'llm'
|
||||
model_credentials:
|
||||
openai_api_base: 'http://172.21.80.1:11434'
|
||||
base_url: 'http://172.21.80.1:11434/v1'
|
||||
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1,11 @@
|
||||
|
||||
deepseek:
|
||||
model_credential:
|
||||
- model: 'deepseek-chat'
|
||||
model_type: 'llm'
|
||||
model_credentials:
|
||||
base_url: 'https://api.deepseek.com'
|
||||
api_key: 'sk-dcb625fcbc1e497d80b7b9493b51d758'
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1,39 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
|
||||
from model_providers.core.model_manager import ModelManager
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
from model_providers.core.provider_manager import ProviderManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def test_ollama_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
|
||||
logging.config.dictConfig(logging_conf) # type: ignore
|
||||
# 读取配置文件
|
||||
cfg = OmegaConf.load(
|
||||
providers_file
|
||||
)
|
||||
# 转换配置文件
|
||||
(
|
||||
provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict,
|
||||
) = _to_custom_provide_configuration(cfg)
|
||||
# 创建模型管理器
|
||||
provider_manager = ProviderManager(
|
||||
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
|
||||
)
|
||||
|
||||
provider_model_bundle_llm = provider_manager.get_provider_model_bundle(
|
||||
provider="deepseek", model_type=ModelType.LLM
|
||||
)
|
||||
predefined_models = (
|
||||
provider_model_bundle_llm.model_type_instance.predefined_models()
|
||||
)
|
||||
|
||||
logger.info(f"predefined_models: {predefined_models}")
|
||||
@ -4,7 +4,7 @@ ollama:
|
||||
- model: 'llama3'
|
||||
model_type: 'llm'
|
||||
model_credentials:
|
||||
openai_api_base: 'http://172.21.80.1:11434'
|
||||
base_url: 'http://172.21.80.1:11434/v1'
|
||||
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user