更新 langchain-openai 和 tiktoken 版本以支持 gpt4o

init_database 脚本完成后打印摘要
This commit is contained in:
liunux4odoo 2024-06-12 12:20:59 +08:00
parent 52e826d879
commit 10c43e87ac
4 changed files with 44 additions and 17 deletions

View File

@ -166,9 +166,9 @@ def main():
prune_folder_files(args.kb_name) prune_folder_files(args.kb_name)
end_time = datetime.now() end_time = datetime.now()
print(f"总计用时 {end_time-start_time}") print(f"总计用时\t{end_time-start_time}\n")
except Exception as e: except Exception as e:
logger.error(e) logger.error(e, exc_info=True)
logger.warning("Caught KeyboardInterrupt! Setting stop event...") logger.warning("Caught KeyboardInterrupt! Setting stop event...")
finally: finally:

View File

@ -1,3 +1,8 @@
from datetime import datetime
from dateutil.parser import parse
import os
from typing import Literal, List
from chatchat.configs import ( from chatchat.configs import (
DEFAULT_EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE, DEFAULT_EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
CHUNK_SIZE, OVERLAP_SIZE, logger, log_verbose CHUNK_SIZE, OVERLAP_SIZE, logger, log_verbose
@ -7,7 +12,7 @@ from chatchat.server.knowledge_base.utils import (
list_files_from_folder, files2docs_in_thread, list_files_from_folder, files2docs_in_thread,
KnowledgeFile KnowledgeFile
) )
from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory, SupportedVSType
from chatchat.server.db.models.conversation_model import ConversationModel from chatchat.server.db.models.conversation_model import ConversationModel
from chatchat.server.db.models.message_model import MessageModel from chatchat.server.db.models.message_model import MessageModel
from chatchat.server.db.repository.knowledge_file_repository import add_file_to_db # ensure Models are imported from chatchat.server.db.repository.knowledge_file_repository import add_file_to_db # ensure Models are imported
@ -15,10 +20,6 @@ from chatchat.server.db.repository.knowledge_metadata_repository import add_summ
from chatchat.server.db.base import Base, engine from chatchat.server.db.base import Base, engine
from chatchat.server.db.session import session_scope from chatchat.server.db.session import session_scope
import os
from dateutil.parser import parse
from typing import Literal, List
def create_tables(): def create_tables():
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
@ -99,22 +100,26 @@ def folder2db(
increment: create vector store and database info for local files that not existed in database only increment: create vector store and database info for local files that not existed in database only
""" """
def files2vs(kb_name: str, kb_files: List[KnowledgeFile]): def files2vs(kb_name: str, kb_files: List[KnowledgeFile]) -> List:
for success, result in files2docs_in_thread(kb_files, result = []
for success, res in files2docs_in_thread(kb_files,
chunk_size=chunk_size, chunk_size=chunk_size,
chunk_overlap=chunk_overlap, chunk_overlap=chunk_overlap,
zh_title_enhance=zh_title_enhance): zh_title_enhance=zh_title_enhance):
if success: if success:
_, filename, docs = result _, filename, docs = res
print(f"正在将 {kb_name}/{filename} 添加到向量库,共包含{len(docs)}条文档") print(f"正在将 {kb_name}/{filename} 添加到向量库,共包含{len(docs)}条文档")
kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
kb_file.splited_docs = docs kb_file.splited_docs = docs
kb.add_doc(kb_file=kb_file, not_refresh_vs_cache=True) kb.add_doc(kb_file=kb_file, not_refresh_vs_cache=True)
result.append({"kb_name": kb_name, "file": filename, "docs": docs})
else: else:
print(result) print(res)
return result
kb_names = kb_names or list_kbs_from_folder() kb_names = kb_names or list_kbs_from_folder()
for kb_name in kb_names: for kb_name in kb_names:
start = datetime.now()
kb = KBServiceFactory.get_service(kb_name, vs_type, embed_model) kb = KBServiceFactory.get_service(kb_name, vs_type, embed_model)
if not kb.exists(): if not kb.exists():
kb.create_kb() kb.create_kb()
@ -124,7 +129,7 @@ def folder2db(
kb.clear_vs() kb.clear_vs()
kb.create_kb() kb.create_kb()
kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name)) kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name))
files2vs(kb_name, kb_files) result = files2vs(kb_name, kb_files)
kb.save_vector_store() kb.save_vector_store()
# # 不做文件内容的向量化,仅将文件元信息存到数据库 # # 不做文件内容的向量化,仅将文件元信息存到数据库
# # 由于现在数据库存了很多与文本切分相关的信息,单纯存储文件信息意义不大,该功能取消。 # # 由于现在数据库存了很多与文本切分相关的信息,单纯存储文件信息意义不大,该功能取消。
@ -138,7 +143,7 @@ def folder2db(
elif mode == "update_in_db": elif mode == "update_in_db":
files = kb.list_files() files = kb.list_files()
kb_files = file_to_kbfile(kb_name, files) kb_files = file_to_kbfile(kb_name, files)
files2vs(kb_name, kb_files) result = files2vs(kb_name, kb_files)
kb.save_vector_store() kb.save_vector_store()
# 对比本地目录与数据库中的文件列表,进行增量向量化 # 对比本地目录与数据库中的文件列表,进行增量向量化
elif mode == "increment": elif mode == "increment":
@ -146,10 +151,32 @@ def folder2db(
folder_files = list_files_from_folder(kb_name) folder_files = list_files_from_folder(kb_name)
files = list(set(folder_files) - set(db_files)) files = list(set(folder_files) - set(db_files))
kb_files = file_to_kbfile(kb_name, files) kb_files = file_to_kbfile(kb_name, files)
files2vs(kb_name, kb_files) result = files2vs(kb_name, kb_files)
kb.save_vector_store() kb.save_vector_store()
else: else:
print(f"unsupported migrate mode: {mode}") print(f"unsupported migrate mode: {mode}")
end = datetime.now()
kb_path = f"知识库路径\t{kb.kb_path}\n" if kb.vs_type()==SupportedVSType.FAISS else ""
file_count = len(kb_files)
success_count = len(result)
docs_count = sum([len(x['docs']) for x in result])
print("\n" + "-" * 100)
print(
(
f"知识库名称\t{kb_name}\n"
f"知识库类型\t{kb.vs_type()}\n"
f"向量模型:\t{kb.embed_model}\n"
)
+kb_path+
(
f"文件总数量\t{file_count}\n"
f"入库文件数\t{success_count}\n"
f"知识条目数\t{docs_count}\n"
f"用时\t\t{end-start}"
)
)
print("-" * 100 + "\n")
return result
def prune_db_docs(kb_names: List[str]): def prune_db_docs(kb_names: List[str]):

View File

@ -19,7 +19,7 @@ model-providers = "^0.3.0"
langchain = { version = "0.1.17", python = ">=3.8.1,<3.12,!=3.9.7" } langchain = { version = "0.1.17", python = ">=3.8.1,<3.12,!=3.9.7" }
langchainhub = "0.1.14" langchainhub = "0.1.14"
langchain-community = "0.0.36" langchain-community = "0.0.36"
langchain-openai = { version = "0.0.5", python = ">=3.8.1,<3.12,!=3.9.7" } langchain-openai = { version = "0.0.6", python = ">=3.8.1,<3.12,!=3.9.7" }
langchain-experimental = "0.0.58" langchain-experimental = "0.0.58"
fastapi = "~0.109.2" fastapi = "~0.109.2"
sse_starlette = "~1.8.2" sse_starlette = "~1.8.2"

View File

@ -17,7 +17,7 @@ pydantic ="~2.6.4"
omegaconf = "~2.0.6" omegaconf = "~2.0.6"
# modle_runtime # modle_runtime
openai = "~1.13.3" openai = "~1.13.3"
tiktoken = ">=0.5.2" tiktoken = ">=0.7.0"
pydub = "~0.25.1" pydub = "~0.25.1"
boto3 = "~1.28.17" boto3 = "~1.28.17"
@ -41,7 +41,7 @@ syrupy = "^4.0.2"
requests-mock = "^1.11.0" requests-mock = "^1.11.0"
langchain = { version = "0.1.17", python = ">=3.8.1,<3.12,!=3.9.7" } langchain = { version = "0.1.17", python = ">=3.8.1,<3.12,!=3.9.7" }
langchain-openai = { version = "0.0.5", python = ">=3.8.1,<3.12,!=3.9.7" } langchain-openai = { version = "0.0.6", python = ">=3.8.1,<3.12,!=3.9.7" }
[tool.poetry.group.lint] [tool.poetry.group.lint]
optional = true optional = true