fix markdown header split (#1825) (#3324)

This commit is contained in:
Sumkor 2024-03-15 07:17:53 +08:00 committed by GitHub
parent 9b5367a23b
commit 4bdb69baf3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 9 deletions

View File

@ -95,8 +95,6 @@ class KBService(ABC):
"""
if docs:
custom_docs = True
for doc in docs:
doc.metadata.setdefault("source", kb_file.filename)
else:
docs = kb_file.file2text()
custom_docs = False
@ -105,6 +103,7 @@ class KBService(ABC):
# 将 metadata["source"] 改为相对路径
for doc in docs:
try:
doc.metadata.setdefault("source", kb_file.filename)
source = doc.metadata.get("source", "")
if os.path.isabs(source):
rel_path = Path(source).relative_to(self.doc_path)

View File

@ -14,13 +14,13 @@ import importlib
from server.text_splitter import zh_title_enhance as func_zh_title_enhance
import langchain_community.document_loaders
from langchain.docstore.document import Document
from langchain.text_splitter import TextSplitter
from langchain.text_splitter import TextSplitter, MarkdownHeaderTextSplitter
from pathlib import Path
from server.utils import run_in_thread_pool, run_in_process_pool
import json
from typing import List, Union, Dict, Tuple, Generator
import chardet
from langchain_community.document_loaders import JSONLoader
from langchain_community.document_loaders import JSONLoader, TextLoader
def validate_kb_name(knowledge_base_id: str) -> bool:
@ -88,6 +88,7 @@ def list_files_from_folder(kb_name: str):
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html', '.htm'],
"MHTMLLoader": ['.mhtml'],
"TextLoader": ['.md'],
"UnstructuredMarkdownLoader": ['.md'],
"JSONLoader": [".json"],
"JSONLinesLoader": [".jsonl"],
@ -199,8 +200,8 @@ def make_text_splitter(
try:
if splitter_name == "MarkdownHeaderTextSplitter": # MarkdownHeaderTextSplitter特殊判定
headers_to_split_on = text_splitter_dict[splitter_name]['headers_to_split_on']
text_splitter = langchain.text_splitter.MarkdownHeaderTextSplitter(
headers_to_split_on=headers_to_split_on)
text_splitter = MarkdownHeaderTextSplitter(
headers_to_split_on=headers_to_split_on, strip_headers=False)
else:
try: ## 优先使用用户自定义的text_splitter
@ -292,7 +293,11 @@ class KnowledgeFile:
loader = get_loader(loader_name=self.document_loader_name,
file_path=self.filepath,
loader_kwargs=self.loader_kwargs)
self.docs = loader.load()
if isinstance(loader, TextLoader):
loader.encoding = "utf8"
self.docs = loader.load()
else:
self.docs = loader.load()
return self.docs
def docs2texts(
@ -375,7 +380,6 @@ def files2docs_in_thread(
生成器返回值为 status, (kb_name, file_name, docs | error)
'''
kwargs_list = []
for i, file in enumerate(files):
kwargs = {}
@ -405,8 +409,12 @@ if __name__ == "__main__":
from pprint import pprint
kb_file = KnowledgeFile(
filename="/home/congyin/Code/Project_Langchain_0814/Langchain-Chatchat/knowledge_base/csv1/content/gm.csv",
filename="E:\\LLM\\Data\\Test.md",
knowledge_base_name="samples")
# kb_file.text_splitter_name = "RecursiveCharacterTextSplitter"
kb_file.text_splitter_name = "MarkdownHeaderTextSplitter"
docs = kb_file.file2docs()
# pprint(docs[-1])
texts = kb_file.docs2texts(docs)
for text in texts:
print(text)