更新 langchain-openai 和 tiktoken 版本以支持 gpt4o

init_database 脚本完成后打印摘要
This commit is contained in:
liunux4odoo 2024-06-12 12:20:59 +08:00
parent 52e826d879
commit 10c43e87ac
4 changed files with 44 additions and 17 deletions

View File

@ -166,9 +166,9 @@ def main():
prune_folder_files(args.kb_name)
end_time = datetime.now()
print(f"总计用时 {end_time-start_time}")
print(f"总计用时\t{end_time-start_time}\n")
except Exception as e:
logger.error(e)
logger.error(e, exc_info=True)
logger.warning("Caught KeyboardInterrupt! Setting stop event...")
finally:

View File

@ -1,3 +1,8 @@
from datetime import datetime
from dateutil.parser import parse
import os
from typing import Literal, List
from chatchat.configs import (
DEFAULT_EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
CHUNK_SIZE, OVERLAP_SIZE, logger, log_verbose
@ -7,7 +12,7 @@ from chatchat.server.knowledge_base.utils import (
list_files_from_folder, files2docs_in_thread,
KnowledgeFile
)
from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory
from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory, SupportedVSType
from chatchat.server.db.models.conversation_model import ConversationModel
from chatchat.server.db.models.message_model import MessageModel
from chatchat.server.db.repository.knowledge_file_repository import add_file_to_db # ensure Models are imported
@ -15,10 +20,6 @@ from chatchat.server.db.repository.knowledge_metadata_repository import add_summ
from chatchat.server.db.base import Base, engine
from chatchat.server.db.session import session_scope
import os
from dateutil.parser import parse
from typing import Literal, List
def create_tables():
Base.metadata.create_all(bind=engine)
@ -99,22 +100,26 @@ def folder2db(
increment: create vector store and database info for local files that not existed in database only
"""
def files2vs(kb_name: str, kb_files: List[KnowledgeFile]):
for success, result in files2docs_in_thread(kb_files,
def files2vs(kb_name: str, kb_files: List[KnowledgeFile]) -> List:
result = []
for success, res in files2docs_in_thread(kb_files,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
zh_title_enhance=zh_title_enhance):
if success:
_, filename, docs = result
_, filename, docs = res
print(f"正在将 {kb_name}/{filename} 添加到向量库,共包含{len(docs)}条文档")
kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
kb_file.splited_docs = docs
kb.add_doc(kb_file=kb_file, not_refresh_vs_cache=True)
result.append({"kb_name": kb_name, "file": filename, "docs": docs})
else:
print(result)
print(res)
return result
kb_names = kb_names or list_kbs_from_folder()
for kb_name in kb_names:
start = datetime.now()
kb = KBServiceFactory.get_service(kb_name, vs_type, embed_model)
if not kb.exists():
kb.create_kb()
@ -124,7 +129,7 @@ def folder2db(
kb.clear_vs()
kb.create_kb()
kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name))
files2vs(kb_name, kb_files)
result = files2vs(kb_name, kb_files)
kb.save_vector_store()
# # 不做文件内容的向量化,仅将文件元信息存到数据库
# # 由于现在数据库存了很多与文本切分相关的信息,单纯存储文件信息意义不大,该功能取消。
@ -138,7 +143,7 @@ def folder2db(
elif mode == "update_in_db":
files = kb.list_files()
kb_files = file_to_kbfile(kb_name, files)
files2vs(kb_name, kb_files)
result = files2vs(kb_name, kb_files)
kb.save_vector_store()
# 对比本地目录与数据库中的文件列表,进行增量向量化
elif mode == "increment":
@ -146,10 +151,32 @@ def folder2db(
folder_files = list_files_from_folder(kb_name)
files = list(set(folder_files) - set(db_files))
kb_files = file_to_kbfile(kb_name, files)
files2vs(kb_name, kb_files)
result = files2vs(kb_name, kb_files)
kb.save_vector_store()
else:
print(f"unsupported migrate mode: {mode}")
end = datetime.now()
kb_path = f"知识库路径\t{kb.kb_path}\n" if kb.vs_type()==SupportedVSType.FAISS else ""
file_count = len(kb_files)
success_count = len(result)
docs_count = sum([len(x['docs']) for x in result])
print("\n" + "-" * 100)
print(
(
f"知识库名称\t{kb_name}\n"
f"知识库类型\t{kb.vs_type()}\n"
f"向量模型:\t{kb.embed_model}\n"
)
+kb_path+
(
f"文件总数量\t{file_count}\n"
f"入库文件数\t{success_count}\n"
f"知识条目数\t{docs_count}\n"
f"用时\t\t{end-start}"
)
)
print("-" * 100 + "\n")
return result
def prune_db_docs(kb_names: List[str]):

View File

@ -19,7 +19,7 @@ model-providers = "^0.3.0"
langchain = { version = "0.1.17", python = ">=3.8.1,<3.12,!=3.9.7" }
langchainhub = "0.1.14"
langchain-community = "0.0.36"
langchain-openai = { version = "0.0.5", python = ">=3.8.1,<3.12,!=3.9.7" }
langchain-openai = { version = "0.0.6", python = ">=3.8.1,<3.12,!=3.9.7" }
langchain-experimental = "0.0.58"
fastapi = "~0.109.2"
sse_starlette = "~1.8.2"

View File

@ -17,7 +17,7 @@ pydantic ="~2.6.4"
omegaconf = "~2.0.6"
# modle_runtime
openai = "~1.13.3"
tiktoken = ">=0.5.2"
tiktoken = ">=0.7.0"
pydub = "~0.25.1"
boto3 = "~1.28.17"
@ -41,7 +41,7 @@ syrupy = "^4.0.2"
requests-mock = "^1.11.0"
langchain = { version = "0.1.17", python = ">=3.8.1,<3.12,!=3.9.7" }
langchain-openai = { version = "0.0.5", python = ">=3.8.1,<3.12,!=3.9.7" }
langchain-openai = { version = "0.0.6", python = ">=3.8.1,<3.12,!=3.9.7" }
[tool.poetry.group.lint]
optional = true