fix markdown header split (#1825) (#3324)

This commit is contained in:
Sumkor 2024-03-15 07:17:53 +08:00 committed by GitHub
parent 9b5367a23b
commit 4bdb69baf3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 9 deletions

View File

@ -95,8 +95,6 @@ class KBService(ABC):
""" """
if docs: if docs:
custom_docs = True custom_docs = True
for doc in docs:
doc.metadata.setdefault("source", kb_file.filename)
else: else:
docs = kb_file.file2text() docs = kb_file.file2text()
custom_docs = False custom_docs = False
@ -105,6 +103,7 @@ class KBService(ABC):
# 将 metadata["source"] 改为相对路径 # 将 metadata["source"] 改为相对路径
for doc in docs: for doc in docs:
try: try:
doc.metadata.setdefault("source", kb_file.filename)
source = doc.metadata.get("source", "") source = doc.metadata.get("source", "")
if os.path.isabs(source): if os.path.isabs(source):
rel_path = Path(source).relative_to(self.doc_path) rel_path = Path(source).relative_to(self.doc_path)

View File

@ -14,13 +14,13 @@ import importlib
from server.text_splitter import zh_title_enhance as func_zh_title_enhance from server.text_splitter import zh_title_enhance as func_zh_title_enhance
import langchain_community.document_loaders import langchain_community.document_loaders
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.text_splitter import TextSplitter from langchain.text_splitter import TextSplitter, MarkdownHeaderTextSplitter
from pathlib import Path from pathlib import Path
from server.utils import run_in_thread_pool, run_in_process_pool from server.utils import run_in_thread_pool, run_in_process_pool
import json import json
from typing import List, Union, Dict, Tuple, Generator from typing import List, Union, Dict, Tuple, Generator
import chardet import chardet
from langchain_community.document_loaders import JSONLoader from langchain_community.document_loaders import JSONLoader, TextLoader
def validate_kb_name(knowledge_base_id: str) -> bool: def validate_kb_name(knowledge_base_id: str) -> bool:
@ -88,6 +88,7 @@ def list_files_from_folder(kb_name: str):
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html', '.htm'], LOADER_DICT = {"UnstructuredHTMLLoader": ['.html', '.htm'],
"MHTMLLoader": ['.mhtml'], "MHTMLLoader": ['.mhtml'],
"TextLoader": ['.md'],
"UnstructuredMarkdownLoader": ['.md'], "UnstructuredMarkdownLoader": ['.md'],
"JSONLoader": [".json"], "JSONLoader": [".json"],
"JSONLinesLoader": [".jsonl"], "JSONLinesLoader": [".jsonl"],
@ -199,8 +200,8 @@ def make_text_splitter(
try: try:
if splitter_name == "MarkdownHeaderTextSplitter": # MarkdownHeaderTextSplitter特殊判定 if splitter_name == "MarkdownHeaderTextSplitter": # MarkdownHeaderTextSplitter特殊判定
headers_to_split_on = text_splitter_dict[splitter_name]['headers_to_split_on'] headers_to_split_on = text_splitter_dict[splitter_name]['headers_to_split_on']
text_splitter = langchain.text_splitter.MarkdownHeaderTextSplitter( text_splitter = MarkdownHeaderTextSplitter(
headers_to_split_on=headers_to_split_on) headers_to_split_on=headers_to_split_on, strip_headers=False)
else: else:
try: ## 优先使用用户自定义的text_splitter try: ## 优先使用用户自定义的text_splitter
@ -292,7 +293,11 @@ class KnowledgeFile:
loader = get_loader(loader_name=self.document_loader_name, loader = get_loader(loader_name=self.document_loader_name,
file_path=self.filepath, file_path=self.filepath,
loader_kwargs=self.loader_kwargs) loader_kwargs=self.loader_kwargs)
self.docs = loader.load() if isinstance(loader, TextLoader):
loader.encoding = "utf8"
self.docs = loader.load()
else:
self.docs = loader.load()
return self.docs return self.docs
def docs2texts( def docs2texts(
@ -375,7 +380,6 @@ def files2docs_in_thread(
生成器返回值为 status, (kb_name, file_name, docs | error) 生成器返回值为 status, (kb_name, file_name, docs | error)
''' '''
kwargs_list = [] kwargs_list = []
for i, file in enumerate(files): for i, file in enumerate(files):
kwargs = {} kwargs = {}
@ -405,8 +409,12 @@ if __name__ == "__main__":
from pprint import pprint from pprint import pprint
kb_file = KnowledgeFile( kb_file = KnowledgeFile(
filename="/home/congyin/Code/Project_Langchain_0814/Langchain-Chatchat/knowledge_base/csv1/content/gm.csv", filename="E:\\LLM\\Data\\Test.md",
knowledge_base_name="samples") knowledge_base_name="samples")
# kb_file.text_splitter_name = "RecursiveCharacterTextSplitter" # kb_file.text_splitter_name = "RecursiveCharacterTextSplitter"
kb_file.text_splitter_name = "MarkdownHeaderTextSplitter"
docs = kb_file.file2docs() docs = kb_file.file2docs()
# pprint(docs[-1]) # pprint(docs[-1])
texts = kb_file.docs2texts(docs)
for text in texts:
print(text)