mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 21:37:20 +08:00
配置的加载行为修改
This commit is contained in:
parent
c79316c7c8
commit
e6b97f13cb
@ -2,7 +2,7 @@ import sys
|
||||
sys.path.append("chatchat")
|
||||
from chatchat.server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db,
|
||||
folder2db, prune_db_docs, prune_folder_files)
|
||||
from chatchat.configs.model_config import DEFAULT_EMBEDDING_MODEL
|
||||
from chatchat.configs import DEFAULT_EMBEDDING_MODEL
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from chatchat.configs import logger
|
||||
import logging
|
||||
from chatchat.server.utils import get_tool_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelContainer:
|
||||
def __init__(self):
|
||||
|
||||
@ -10,21 +10,21 @@ from chatchat.server.utils import get_tool_config
|
||||
from .tools_registry import regist_tool, BaseToolOutput
|
||||
import openai
|
||||
|
||||
from chatchat.configs.basic_config import MEDIA_PATH
|
||||
from chatchat.configs import MEDIA_PATH
|
||||
from chatchat.server.utils import MsgType
|
||||
|
||||
|
||||
def get_image_model_config() -> dict:
|
||||
from chatchat.configs.model_config import LLM_MODEL_CONFIG, ONLINE_LLM_MODEL
|
||||
|
||||
model = LLM_MODEL_CONFIG.get("image_model")
|
||||
if model:
|
||||
name = list(model.keys())[0]
|
||||
if config := ONLINE_LLM_MODEL.get(name):
|
||||
config = {**list(model.values())[0], **config}
|
||||
config.setdefault("model_name", name)
|
||||
return config
|
||||
|
||||
# from chatchat.configs import LLM_MODEL_CONFIG, ONLINE_LLM_MODEL
|
||||
# TODO ONLINE_LLM_MODEL的配置被删除,此处业务需要修改
|
||||
# model = LLM_MODEL_CONFIG.get("image_model")
|
||||
# if model:
|
||||
# name = list(model.keys())[0]
|
||||
# if config := ONLINE_LLM_MODEL.get(name):
|
||||
# config = {**list(model.values())[0], **config}
|
||||
# config.setdefault("model_name", name)
|
||||
# return config
|
||||
pass
|
||||
|
||||
@regist_tool(title="文生图", return_direct=True)
|
||||
def text2images(
|
||||
|
||||
@ -16,9 +16,13 @@ from openai.types.file_object import FileObject
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from .api_schemas import *
|
||||
from chatchat.configs import logger, BASE_TEMP_DIR, log_verbose
|
||||
from chatchat.configs import BASE_TEMP_DIR, log_verbose
|
||||
from chatchat.server.utils import get_model_info, get_config_platforms, get_OpenAIClient
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
DEFAULT_API_CONCURRENCIES = 5 # 默认单个模型最大并发数
|
||||
model_semaphores: Dict[Tuple[str, str], asyncio.Semaphore] = {} # key: (model_name, platform)
|
||||
|
||||
@ -9,7 +9,7 @@ from starlette.responses import RedirectResponse
|
||||
import uvicorn
|
||||
|
||||
from chatchat.configs import VERSION, MEDIA_PATH, CHATCHAT_ROOT
|
||||
from chatchat.configs.server_config import OPEN_CROSS_DOMAIN
|
||||
from chatchat.configs import OPEN_CROSS_DOMAIN
|
||||
from chatchat.server.api_server.chat_routes import chat_router
|
||||
from chatchat.server.api_server.kb_routes import kb_router
|
||||
from chatchat.server.api_server.openai_routes import openai_router
|
||||
|
||||
@ -4,9 +4,11 @@ from typing import List
|
||||
|
||||
from fastapi import APIRouter, Request, Body
|
||||
|
||||
from chatchat.configs import logger
|
||||
from chatchat.server.utils import BaseResponse, get_tool, get_tool_config
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
tool_router = APIRouter(prefix="/tools", tags=["Toolkits"])
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ from langchain.chains import LLMChain
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
from chatchat.configs.model_config import LLM_MODEL_CONFIG
|
||||
from chatchat.configs import LLM_MODEL_CONFIG
|
||||
from chatchat.server.agent.agent_factory.agents_registry import agents_registry
|
||||
from chatchat.server.agent.container import container
|
||||
from chatchat.server.api_server.api_schemas import OpenAIChatOutput
|
||||
|
||||
@ -1,8 +1,13 @@
|
||||
from fastapi import Body
|
||||
from chatchat.configs import logger, log_verbose
|
||||
from chatchat.configs import log_verbose
|
||||
from chatchat.server.utils import BaseResponse
|
||||
from chatchat.server.db.repository import feedback_message_to_db
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def chat_feedback(message_id: str = Body("", max_length=32, description="聊天记录id"),
|
||||
score: int = Body(0, max=100, description="用户评分,满分100,越大表示评价越高"),
|
||||
reason: str = Body("", description="用户评分理由,比如不符合事实等")
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
from functools import lru_cache
|
||||
from chatchat.server.pydantic_v2 import BaseModel, Field
|
||||
from langchain.prompts.chat import ChatMessagePromptTemplate
|
||||
from chatchat.configs import logger, log_verbose
|
||||
from typing import List, Tuple, Dict, Union
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class History(BaseModel):
|
||||
"""
|
||||
|
||||
@ -7,7 +7,11 @@ from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedV
|
||||
from chatchat.server.knowledge_base.utils import KnowledgeFile
|
||||
from chatchat.server.utils import get_Embeddings
|
||||
from elasticsearch import Elasticsearch, BadRequestError
|
||||
from chatchat.configs import logger, kbs_config, KB_ROOT_PATH
|
||||
from chatchat.configs import kbs_config, KB_ROOT_PATH
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class ESKBService(KBService):
|
||||
|
||||
@ -3,7 +3,6 @@ from typing import List, Optional
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||
from chatchat.configs import (logger)
|
||||
from langchain.chains import StuffDocumentsChain, LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
@ -14,6 +13,10 @@ from langchain.chains.combine_documents.map_reduce import ReduceDocumentsChain,
|
||||
import sys
|
||||
import asyncio
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class SummaryAdapter:
|
||||
_OVERLAP_SIZE: int
|
||||
|
||||
@ -5,7 +5,6 @@ from chatchat.configs import (
|
||||
CHUNK_SIZE,
|
||||
OVERLAP_SIZE,
|
||||
ZH_TITLE_ENHANCE,
|
||||
logger,
|
||||
log_verbose,
|
||||
text_splitter_dict,
|
||||
TEXT_SPLITTER_NAME,
|
||||
@ -22,6 +21,10 @@ from typing import List, Union, Dict, Tuple, Generator
|
||||
import chardet
|
||||
from langchain_community.document_loaders import JSONLoader, TextLoader
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def validate_kb_name(knowledge_base_id: str) -> bool:
|
||||
# 检查是否包含预期外的字符或路径攻击关键字
|
||||
|
||||
@ -1,51 +0,0 @@
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Tuple
|
||||
)
|
||||
import sys
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import tiktoken
|
||||
|
||||
|
||||
class MinxChatOpenAI:
|
||||
|
||||
@staticmethod
|
||||
def import_tiktoken() -> Any:
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import tiktoken python package. "
|
||||
"This is needed in order to calculate get_token_ids. "
|
||||
"Please install it with `pip install tiktoken`."
|
||||
)
|
||||
return tiktoken
|
||||
|
||||
@staticmethod
|
||||
def get_encoding_model(self) -> Tuple[str, "tiktoken.Encoding"]:
|
||||
tiktoken_ = MinxChatOpenAI.import_tiktoken()
|
||||
if self.tiktoken_model_name is not None:
|
||||
model = self.tiktoken_model_name
|
||||
else:
|
||||
model = self.model_name
|
||||
if model == "gpt-3.5-turbo":
|
||||
# gpt-3.5-turbo may change over time.
|
||||
# Returning num tokens assuming gpt-3.5-turbo-0301.
|
||||
model = "gpt-3.5-turbo-0301"
|
||||
elif model == "gpt-4":
|
||||
# gpt-4 may change over time.
|
||||
# Returning num tokens assuming gpt-4-0314.
|
||||
model = "gpt-4-0314"
|
||||
# Returns the number of tokens used by a list of messages.
|
||||
try:
|
||||
encoding = tiktoken_.encoding_for_model(model)
|
||||
except Exception as e:
|
||||
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
||||
model = "cl100k_base"
|
||||
encoding = tiktoken_.get_encoding(model)
|
||||
return model, encoding
|
||||
@ -24,13 +24,15 @@ from typing import (
|
||||
Tuple,
|
||||
Literal,
|
||||
)
|
||||
import logging
|
||||
|
||||
from chatchat.configs import (logger, log_verbose, HTTPX_DEFAULT_TIMEOUT,
|
||||
from chatchat.configs import (log_verbose, HTTPX_DEFAULT_TIMEOUT,
|
||||
DEFAULT_LLM_MODEL, DEFAULT_EMBEDDING_MODEL, TEMPERATURE,
|
||||
MODEL_PLATFORMS)
|
||||
from chatchat.server.pydantic_v2 import BaseModel, Field
|
||||
from chatchat.server.minx_chat_openai import MinxChatOpenAI # TODO: still used?
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||
@ -490,7 +492,7 @@ def MakeFastAPIOffline(
|
||||
|
||||
|
||||
def api_address() -> str:
|
||||
from chatchat.configs.server_config import API_SERVER
|
||||
from chatchat.configs import API_SERVER
|
||||
|
||||
host = API_SERVER["host"]
|
||||
if host == "0.0.0.0":
|
||||
@ -500,7 +502,7 @@ def api_address() -> str:
|
||||
|
||||
|
||||
def webui_address() -> str:
|
||||
from chatchat.configs.server_config import WEBUI_SERVER
|
||||
from chatchat.configs import WEBUI_SERVER
|
||||
|
||||
host = WEBUI_SERVER["host"]
|
||||
port = WEBUI_SERVER["port"]
|
||||
|
||||
@ -13,7 +13,7 @@ from chatchat.configs import (
|
||||
ZH_TITLE_ENHANCE,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
HTTPX_DEFAULT_TIMEOUT,
|
||||
logger, log_verbose,
|
||||
log_verbose,
|
||||
)
|
||||
import httpx
|
||||
import contextlib
|
||||
@ -22,7 +22,10 @@ import os
|
||||
from io import BytesIO
|
||||
from chatchat.server.utils import set_httpx_config, api_address, get_httpx_client
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
set_httpx_config()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user