配置的加载行为修改

This commit is contained in:
glide-the 2024-05-07 20:15:56 +08:00
parent c79316c7c8
commit e6b97f13cb
15 changed files with 60 additions and 80 deletions

View File

@ -2,7 +2,7 @@ import sys
sys.path.append("chatchat")
from chatchat.server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db,
folder2db, prune_db_docs, prune_folder_files)
from chatchat.configs.model_config import DEFAULT_EMBEDDING_MODEL
from chatchat.configs import DEFAULT_EMBEDDING_MODEL
from datetime import datetime

View File

@ -1,6 +1,8 @@
from chatchat.configs import logger
import logging
from chatchat.server.utils import get_tool_config
logger = logging.getLogger(__name__)
class ModelContainer:
def __init__(self):

View File

@ -10,21 +10,21 @@ from chatchat.server.utils import get_tool_config
from .tools_registry import regist_tool, BaseToolOutput
import openai
from chatchat.configs.basic_config import MEDIA_PATH
from chatchat.configs import MEDIA_PATH
from chatchat.server.utils import MsgType
def get_image_model_config() -> dict:
from chatchat.configs.model_config import LLM_MODEL_CONFIG, ONLINE_LLM_MODEL
model = LLM_MODEL_CONFIG.get("image_model")
if model:
name = list(model.keys())[0]
if config := ONLINE_LLM_MODEL.get(name):
config = {**list(model.values())[0], **config}
config.setdefault("model_name", name)
return config
# from chatchat.configs import LLM_MODEL_CONFIG, ONLINE_LLM_MODEL
# TODO ONLINE_LLM_MODEL的配置被删除此处业务需要修改
# model = LLM_MODEL_CONFIG.get("image_model")
# if model:
# name = list(model.keys())[0]
# if config := ONLINE_LLM_MODEL.get(name):
# config = {**list(model.values())[0], **config}
# config.setdefault("model_name", name)
# return config
pass
@regist_tool(title="文生图", return_direct=True)
def text2images(

View File

@ -16,9 +16,13 @@ from openai.types.file_object import FileObject
from sse_starlette.sse import EventSourceResponse
from .api_schemas import *
from chatchat.configs import logger, BASE_TEMP_DIR, log_verbose
from chatchat.configs import BASE_TEMP_DIR, log_verbose
from chatchat.server.utils import get_model_info, get_config_platforms, get_OpenAIClient
import logging
logger = logging.getLogger()
DEFAULT_API_CONCURRENCIES = 5 # 默认单个模型最大并发数
model_semaphores: Dict[Tuple[str, str], asyncio.Semaphore] = {} # key: (model_name, platform)

View File

@ -9,7 +9,7 @@ from starlette.responses import RedirectResponse
import uvicorn
from chatchat.configs import VERSION, MEDIA_PATH, CHATCHAT_ROOT
from chatchat.configs.server_config import OPEN_CROSS_DOMAIN
from chatchat.configs import OPEN_CROSS_DOMAIN
from chatchat.server.api_server.chat_routes import chat_router
from chatchat.server.api_server.kb_routes import kb_router
from chatchat.server.api_server.openai_routes import openai_router

View File

@ -4,9 +4,11 @@ from typing import List
from fastapi import APIRouter, Request, Body
from chatchat.configs import logger
from chatchat.server.utils import BaseResponse, get_tool, get_tool_config
import logging
logger = logging.getLogger()
tool_router = APIRouter(prefix="/tools", tags=["Toolkits"])

View File

@ -13,7 +13,7 @@ from langchain.chains import LLMChain
from langchain.prompts.chat import ChatPromptTemplate
from langchain.prompts import PromptTemplate
from chatchat.configs.model_config import LLM_MODEL_CONFIG
from chatchat.configs import LLM_MODEL_CONFIG
from chatchat.server.agent.agent_factory.agents_registry import agents_registry
from chatchat.server.agent.container import container
from chatchat.server.api_server.api_schemas import OpenAIChatOutput

View File

@ -1,8 +1,13 @@
from fastapi import Body
from chatchat.configs import logger, log_verbose
from chatchat.configs import log_verbose
from chatchat.server.utils import BaseResponse
from chatchat.server.db.repository import feedback_message_to_db
import logging
logger = logging.getLogger()
def chat_feedback(message_id: str = Body("", max_length=32, description="聊天记录id"),
score: int = Body(0, max=100, description="用户评分满分100越大表示评价越高"),
reason: str = Body("", description="用户评分理由,比如不符合事实等")

View File

@ -1,9 +1,12 @@
from functools import lru_cache
from chatchat.server.pydantic_v2 import BaseModel, Field
from langchain.prompts.chat import ChatMessagePromptTemplate
from chatchat.configs import logger, log_verbose
from typing import List, Tuple, Dict, Union
import logging
logger = logging.getLogger()
class History(BaseModel):
"""

View File

@ -7,7 +7,11 @@ from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedV
from chatchat.server.knowledge_base.utils import KnowledgeFile
from chatchat.server.utils import get_Embeddings
from elasticsearch import Elasticsearch, BadRequestError
from chatchat.configs import logger, kbs_config, KB_ROOT_PATH
from chatchat.configs import kbs_config, KB_ROOT_PATH
import logging
logger = logging.getLogger()
class ESKBService(KBService):

View File

@ -3,7 +3,6 @@ from typing import List, Optional
from langchain.schema.language_model import BaseLanguageModel
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
from chatchat.configs import (logger)
from langchain.chains import StuffDocumentsChain, LLMChain
from langchain.prompts import PromptTemplate
@ -14,6 +13,10 @@ from langchain.chains.combine_documents.map_reduce import ReduceDocumentsChain,
import sys
import asyncio
import logging
logger = logging.getLogger()
class SummaryAdapter:
_OVERLAP_SIZE: int

View File

@ -5,7 +5,6 @@ from chatchat.configs import (
CHUNK_SIZE,
OVERLAP_SIZE,
ZH_TITLE_ENHANCE,
logger,
log_verbose,
text_splitter_dict,
TEXT_SPLITTER_NAME,
@ -22,6 +21,10 @@ from typing import List, Union, Dict, Tuple, Generator
import chardet
from langchain_community.document_loaders import JSONLoader, TextLoader
import logging
logger = logging.getLogger()
def validate_kb_name(knowledge_base_id: str) -> bool:
# 检查是否包含预期外的字符或路径攻击关键字

View File

@ -1,51 +0,0 @@
from typing import (
TYPE_CHECKING,
Any,
Tuple
)
import sys
import logging
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
import tiktoken
class MinxChatOpenAI:
@staticmethod
def import_tiktoken() -> Any:
try:
import tiktoken
except ImportError:
raise ValueError(
"Could not import tiktoken python package. "
"This is needed in order to calculate get_token_ids. "
"Please install it with `pip install tiktoken`."
)
return tiktoken
@staticmethod
def get_encoding_model(self) -> Tuple[str, "tiktoken.Encoding"]:
tiktoken_ = MinxChatOpenAI.import_tiktoken()
if self.tiktoken_model_name is not None:
model = self.tiktoken_model_name
else:
model = self.model_name
if model == "gpt-3.5-turbo":
# gpt-3.5-turbo may change over time.
# Returning num tokens assuming gpt-3.5-turbo-0301.
model = "gpt-3.5-turbo-0301"
elif model == "gpt-4":
# gpt-4 may change over time.
# Returning num tokens assuming gpt-4-0314.
model = "gpt-4-0314"
# Returns the number of tokens used by a list of messages.
try:
encoding = tiktoken_.encoding_for_model(model)
except Exception as e:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken_.get_encoding(model)
return model, encoding

View File

@ -24,13 +24,15 @@ from typing import (
Tuple,
Literal,
)
import logging
from chatchat.configs import (logger, log_verbose, HTTPX_DEFAULT_TIMEOUT,
from chatchat.configs import (log_verbose, HTTPX_DEFAULT_TIMEOUT,
DEFAULT_LLM_MODEL, DEFAULT_EMBEDDING_MODEL, TEMPERATURE,
MODEL_PLATFORMS)
from chatchat.server.pydantic_v2 import BaseModel, Field
from chatchat.server.minx_chat_openai import MinxChatOpenAI # TODO: still used?
import logging
logger = logging.getLogger()
async def wrap_done(fn: Awaitable, event: asyncio.Event):
@ -490,7 +492,7 @@ def MakeFastAPIOffline(
def api_address() -> str:
from chatchat.configs.server_config import API_SERVER
from chatchat.configs import API_SERVER
host = API_SERVER["host"]
if host == "0.0.0.0":
@ -500,7 +502,7 @@ def api_address() -> str:
def webui_address() -> str:
from chatchat.configs.server_config import WEBUI_SERVER
from chatchat.configs import WEBUI_SERVER
host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"]

View File

@ -13,7 +13,7 @@ from chatchat.configs import (
ZH_TITLE_ENHANCE,
VECTOR_SEARCH_TOP_K,
HTTPX_DEFAULT_TIMEOUT,
logger, log_verbose,
log_verbose,
)
import httpx
import contextlib
@ -22,7 +22,10 @@ import os
from io import BytesIO
from chatchat.server.utils import set_httpx_config, api_address, get_httpx_client
from pprint import pprint
import logging
logger = logging.getLogger()
set_httpx_config()