diff --git a/document_loaders/__init__.py b/document_loaders/__init__.py index a4d6b28d..88cfeae8 100644 --- a/document_loaders/__init__.py +++ b/document_loaders/__init__.py @@ -1,2 +1,4 @@ from .mypdfloader import RapidOCRPDFLoader -from .myimgloader import RapidOCRLoader \ No newline at end of file +from .myimgloader import RapidOCRLoader +from .mydocloader import RapidOCRDocLoader +from .mypptloader import RapidOCRPPTLoader diff --git a/document_loaders/mydocloader.py b/document_loaders/mydocloader.py new file mode 100644 index 00000000..7f5462a2 --- /dev/null +++ b/document_loaders/mydocloader.py @@ -0,0 +1,71 @@ +from langchain.document_loaders.unstructured import UnstructuredFileLoader +from typing import List +import tqdm + + +class RapidOCRDocLoader(UnstructuredFileLoader): + def _get_elements(self) -> List: + def doc2text(filepath): + from docx.table import _Cell, Table + from docx.oxml.table import CT_Tbl + from docx.oxml.text.paragraph import CT_P + from docx.text.paragraph import Paragraph + from docx import Document, ImagePart + from PIL import Image + from io import BytesIO + import numpy as np + from rapidocr_onnxruntime import RapidOCR + ocr = RapidOCR() + doc = Document(filepath) + resp = "" + + def iter_block_items(parent): + from docx.document import Document + if isinstance(parent, Document): + parent_elm = parent.element.body + elif isinstance(parent, _Cell): + parent_elm = parent._tc + else: + raise ValueError("RapidOCRDocLoader parse fail") + + for child in parent_elm.iterchildren(): + if isinstance(child, CT_P): + yield Paragraph(child, parent) + elif isinstance(child, CT_Tbl): + yield Table(child, parent) + + b_unit = tqdm.tqdm(total=len(doc.paragraphs)+len(doc.tables), + desc="RapidOCRDocLoader block index: 0") + for i, block in enumerate(iter_block_items(doc)): + b_unit.set_description( + "RapidOCRDocLoader block index: {}".format(i)) + b_unit.refresh() + if isinstance(block, Paragraph): + resp += block.text.strip() + "\n" + images = block._element.xpath('.//pic:pic') # 获取所有图片 + for image in images: + for img_id in image.xpath('.//a:blip/@r:embed'): # 获取图片id + part = doc.part.related_parts[img_id] # 根据图片id获取对应的图片 + if isinstance(part, ImagePart): + image = Image.open(BytesIO(part._blob)) + result, _ = ocr(np.array(image)) + if result: + ocr_result = [line[1] for line in result] + resp += "\n".join(ocr_result) + elif isinstance(block, Table): + for row in block.rows: + for cell in row.cells: + for paragraph in cell.paragraphs: + resp += paragraph.text.strip() + "\n" + b_unit.update(1) + return resp + + text = doc2text(self.file_path) + from unstructured.partition.text import partition_text + return partition_text(text=text, **self.unstructured_kwargs) + + +if __name__ == '__main__': + loader = RapidOCRDocLoader(file_path="../tests/samples/ocr_test.docx") + docs = loader.load() + print(docs) diff --git a/document_loaders/mypptloader.py b/document_loaders/mypptloader.py new file mode 100644 index 00000000..f14d0728 --- /dev/null +++ b/document_loaders/mypptloader.py @@ -0,0 +1,59 @@ +from langchain.document_loaders.unstructured import UnstructuredFileLoader +from typing import List +import tqdm + + +class RapidOCRPPTLoader(UnstructuredFileLoader): + def _get_elements(self) -> List: + def ppt2text(filepath): + from pptx import Presentation + from PIL import Image + import numpy as np + from io import BytesIO + from rapidocr_onnxruntime import RapidOCR + ocr = RapidOCR() + prs = Presentation(filepath) + resp = "" + + def extract_text(shape): + nonlocal resp + if shape.has_text_frame: + resp += shape.text.strip() + "\n" + if shape.has_table: + for row in shape.table.rows: + for cell in row.cells: + for paragraph in cell.text_frame.paragraphs: + resp += paragraph.text.strip() + "\n" + if shape.shape_type == 13: # 13 表示图片 + image = Image.open(BytesIO(shape.image.blob)) + result, _ = ocr(np.array(image)) + if result: + ocr_result = [line[1] for line in result] + resp += "\n".join(ocr_result) + elif shape.shape_type == 6: # 6 表示组合 + for child_shape in shape.shapes: + extract_text(child_shape) + + b_unit = tqdm.tqdm(total=len(prs.slides), + desc="RapidOCRPPTLoader slide index: 1") + # 遍历所有幻灯片 + for slide_number, slide in enumerate(prs.slides, start=1): + b_unit.set_description( + "RapidOCRPPTLoader slide index: {}".format(slide_number)) + b_unit.refresh() + sorted_shapes = sorted(slide.shapes, + key=lambda x: (x.top, x.left)) # 从上到下、从左到右遍历 + for shape in sorted_shapes: + extract_text(shape) + b_unit.update(1) + return resp + + text = ppt2text(self.file_path) + from unstructured.partition.text import partition_text + return partition_text(text=text, **self.unstructured_kwargs) + + +if __name__ == '__main__': + loader = RapidOCRPPTLoader(file_path="../tests/samples/ocr_test.pptx") + docs = loader.load() + print(docs) diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 85d0f0c5..24e9b479 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -91,9 +91,14 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'], "JSONLoader": [".json"], "JSONLinesLoader": [".jsonl"], "CSVLoader": [".csv"], - # "FilteredCSVLoader": [".csv"], # 需要自己指定,目前还没有支持 + # "FilteredCSVLoader": [".csv"], 如果使用自定义分割csv "RapidOCRPDFLoader": [".pdf"], + "RapidOCRDocLoader": ['.docx', '.doc'], + "RapidOCRPPTLoader": ['.ppt', '.pptx', ], "RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'], + "UnstructuredFileLoader": ['.eml', '.msg', '.rst', + '.rtf', '.txt', '.xml', + '.epub', '.odt','.tsv'], "UnstructuredEmailLoader": ['.eml', '.msg'], "UnstructuredEPubLoader": ['.epub'], "UnstructuredExcelLoader": ['.xlsx', '.xls', '.xlsd'], @@ -109,7 +114,6 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'], "UnstructuredXMLLoader": ['.xml'], "UnstructuredPowerPointLoader": ['.ppt', '.pptx'], "EverNoteLoader": ['.enex'], - "UnstructuredFileLoader": ['.txt'], } SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist] @@ -141,15 +145,14 @@ def get_LoaderClass(file_extension): if file_extension in extensions: return LoaderClass - -# 把一些向量化共用逻辑从KnowledgeFile抽取出来,等langchain支持内存文件的时候,可以将非磁盘文件向量化 def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None): ''' 根据loader_name和文件路径或内容返回文档加载器。 ''' loader_kwargs = loader_kwargs or {} try: - if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader","FilteredCSVLoader"]: + if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader", "FilteredCSVLoader", + "RapidOCRDocLoader", "RapidOCRPPTLoader"]: document_loaders_module = importlib.import_module('document_loaders') else: document_loaders_module = importlib.import_module('langchain.document_loaders') diff --git a/tests/samples/ocr_test.docx b/tests/samples/ocr_test.docx new file mode 100644 index 00000000..25039c5e Binary files /dev/null and b/tests/samples/ocr_test.docx differ diff --git a/tests/samples/ocr_test.pptx b/tests/samples/ocr_test.pptx new file mode 100644 index 00000000..1d27f66f Binary files /dev/null and b/tests/samples/ocr_test.pptx differ