mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-07 15:38:27 +08:00
【功能新增】增强对PPT、DOC知识库文件的OCR识别 (#2013)
* 【功能新增】增强对PPT、DOC文件的OCR识别 * 【功能新增】增强对PPT、DOC文件的OCR识别 * Update mydocloader.py --------- Co-authored-by: zR <2448370773@qq.com>
This commit is contained in:
parent
e615932e7e
commit
75ff268e88
@ -1,2 +1,4 @@
|
|||||||
from .mypdfloader import RapidOCRPDFLoader
|
from .mypdfloader import RapidOCRPDFLoader
|
||||||
from .myimgloader import RapidOCRLoader
|
from .myimgloader import RapidOCRLoader
|
||||||
|
from .mydocloader import RapidOCRDocLoader
|
||||||
|
from .mypptloader import RapidOCRPPTLoader
|
||||||
|
|||||||
71
document_loaders/mydocloader.py
Normal file
71
document_loaders/mydocloader.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||||||
|
from typing import List
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class RapidOCRDocLoader(UnstructuredFileLoader):
|
||||||
|
def _get_elements(self) -> List:
|
||||||
|
def doc2text(filepath):
|
||||||
|
from docx.table import _Cell, Table
|
||||||
|
from docx.oxml.table import CT_Tbl
|
||||||
|
from docx.oxml.text.paragraph import CT_P
|
||||||
|
from docx.text.paragraph import Paragraph
|
||||||
|
from docx import Document, ImagePart
|
||||||
|
from PIL import Image
|
||||||
|
from io import BytesIO
|
||||||
|
import numpy as np
|
||||||
|
from rapidocr_onnxruntime import RapidOCR
|
||||||
|
ocr = RapidOCR()
|
||||||
|
doc = Document(filepath)
|
||||||
|
resp = ""
|
||||||
|
|
||||||
|
def iter_block_items(parent):
|
||||||
|
from docx.document import Document
|
||||||
|
if isinstance(parent, Document):
|
||||||
|
parent_elm = parent.element.body
|
||||||
|
elif isinstance(parent, _Cell):
|
||||||
|
parent_elm = parent._tc
|
||||||
|
else:
|
||||||
|
raise ValueError("RapidOCRDocLoader parse fail")
|
||||||
|
|
||||||
|
for child in parent_elm.iterchildren():
|
||||||
|
if isinstance(child, CT_P):
|
||||||
|
yield Paragraph(child, parent)
|
||||||
|
elif isinstance(child, CT_Tbl):
|
||||||
|
yield Table(child, parent)
|
||||||
|
|
||||||
|
b_unit = tqdm.tqdm(total=len(doc.paragraphs)+len(doc.tables),
|
||||||
|
desc="RapidOCRDocLoader block index: 0")
|
||||||
|
for i, block in enumerate(iter_block_items(doc)):
|
||||||
|
b_unit.set_description(
|
||||||
|
"RapidOCRDocLoader block index: {}".format(i))
|
||||||
|
b_unit.refresh()
|
||||||
|
if isinstance(block, Paragraph):
|
||||||
|
resp += block.text.strip() + "\n"
|
||||||
|
images = block._element.xpath('.//pic:pic') # 获取所有图片
|
||||||
|
for image in images:
|
||||||
|
for img_id in image.xpath('.//a:blip/@r:embed'): # 获取图片id
|
||||||
|
part = doc.part.related_parts[img_id] # 根据图片id获取对应的图片
|
||||||
|
if isinstance(part, ImagePart):
|
||||||
|
image = Image.open(BytesIO(part._blob))
|
||||||
|
result, _ = ocr(np.array(image))
|
||||||
|
if result:
|
||||||
|
ocr_result = [line[1] for line in result]
|
||||||
|
resp += "\n".join(ocr_result)
|
||||||
|
elif isinstance(block, Table):
|
||||||
|
for row in block.rows:
|
||||||
|
for cell in row.cells:
|
||||||
|
for paragraph in cell.paragraphs:
|
||||||
|
resp += paragraph.text.strip() + "\n"
|
||||||
|
b_unit.update(1)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
text = doc2text(self.file_path)
|
||||||
|
from unstructured.partition.text import partition_text
|
||||||
|
return partition_text(text=text, **self.unstructured_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
loader = RapidOCRDocLoader(file_path="../tests/samples/ocr_test.docx")
|
||||||
|
docs = loader.load()
|
||||||
|
print(docs)
|
||||||
59
document_loaders/mypptloader.py
Normal file
59
document_loaders/mypptloader.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||||||
|
from typing import List
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class RapidOCRPPTLoader(UnstructuredFileLoader):
|
||||||
|
def _get_elements(self) -> List:
|
||||||
|
def ppt2text(filepath):
|
||||||
|
from pptx import Presentation
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
from io import BytesIO
|
||||||
|
from rapidocr_onnxruntime import RapidOCR
|
||||||
|
ocr = RapidOCR()
|
||||||
|
prs = Presentation(filepath)
|
||||||
|
resp = ""
|
||||||
|
|
||||||
|
def extract_text(shape):
|
||||||
|
nonlocal resp
|
||||||
|
if shape.has_text_frame:
|
||||||
|
resp += shape.text.strip() + "\n"
|
||||||
|
if shape.has_table:
|
||||||
|
for row in shape.table.rows:
|
||||||
|
for cell in row.cells:
|
||||||
|
for paragraph in cell.text_frame.paragraphs:
|
||||||
|
resp += paragraph.text.strip() + "\n"
|
||||||
|
if shape.shape_type == 13: # 13 表示图片
|
||||||
|
image = Image.open(BytesIO(shape.image.blob))
|
||||||
|
result, _ = ocr(np.array(image))
|
||||||
|
if result:
|
||||||
|
ocr_result = [line[1] for line in result]
|
||||||
|
resp += "\n".join(ocr_result)
|
||||||
|
elif shape.shape_type == 6: # 6 表示组合
|
||||||
|
for child_shape in shape.shapes:
|
||||||
|
extract_text(child_shape)
|
||||||
|
|
||||||
|
b_unit = tqdm.tqdm(total=len(prs.slides),
|
||||||
|
desc="RapidOCRPPTLoader slide index: 1")
|
||||||
|
# 遍历所有幻灯片
|
||||||
|
for slide_number, slide in enumerate(prs.slides, start=1):
|
||||||
|
b_unit.set_description(
|
||||||
|
"RapidOCRPPTLoader slide index: {}".format(slide_number))
|
||||||
|
b_unit.refresh()
|
||||||
|
sorted_shapes = sorted(slide.shapes,
|
||||||
|
key=lambda x: (x.top, x.left)) # 从上到下、从左到右遍历
|
||||||
|
for shape in sorted_shapes:
|
||||||
|
extract_text(shape)
|
||||||
|
b_unit.update(1)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
text = ppt2text(self.file_path)
|
||||||
|
from unstructured.partition.text import partition_text
|
||||||
|
return partition_text(text=text, **self.unstructured_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
loader = RapidOCRPPTLoader(file_path="../tests/samples/ocr_test.pptx")
|
||||||
|
docs = loader.load()
|
||||||
|
print(docs)
|
||||||
@ -91,9 +91,14 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
|
|||||||
"JSONLoader": [".json"],
|
"JSONLoader": [".json"],
|
||||||
"JSONLinesLoader": [".jsonl"],
|
"JSONLinesLoader": [".jsonl"],
|
||||||
"CSVLoader": [".csv"],
|
"CSVLoader": [".csv"],
|
||||||
# "FilteredCSVLoader": [".csv"], # 需要自己指定,目前还没有支持
|
# "FilteredCSVLoader": [".csv"], 如果使用自定义分割csv
|
||||||
"RapidOCRPDFLoader": [".pdf"],
|
"RapidOCRPDFLoader": [".pdf"],
|
||||||
|
"RapidOCRDocLoader": ['.docx', '.doc'],
|
||||||
|
"RapidOCRPPTLoader": ['.ppt', '.pptx', ],
|
||||||
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
|
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
|
||||||
|
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
|
||||||
|
'.rtf', '.txt', '.xml',
|
||||||
|
'.epub', '.odt','.tsv'],
|
||||||
"UnstructuredEmailLoader": ['.eml', '.msg'],
|
"UnstructuredEmailLoader": ['.eml', '.msg'],
|
||||||
"UnstructuredEPubLoader": ['.epub'],
|
"UnstructuredEPubLoader": ['.epub'],
|
||||||
"UnstructuredExcelLoader": ['.xlsx', '.xls', '.xlsd'],
|
"UnstructuredExcelLoader": ['.xlsx', '.xls', '.xlsd'],
|
||||||
@ -109,7 +114,6 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
|
|||||||
"UnstructuredXMLLoader": ['.xml'],
|
"UnstructuredXMLLoader": ['.xml'],
|
||||||
"UnstructuredPowerPointLoader": ['.ppt', '.pptx'],
|
"UnstructuredPowerPointLoader": ['.ppt', '.pptx'],
|
||||||
"EverNoteLoader": ['.enex'],
|
"EverNoteLoader": ['.enex'],
|
||||||
"UnstructuredFileLoader": ['.txt'],
|
|
||||||
}
|
}
|
||||||
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
|
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
|
||||||
|
|
||||||
@ -141,15 +145,14 @@ def get_LoaderClass(file_extension):
|
|||||||
if file_extension in extensions:
|
if file_extension in extensions:
|
||||||
return LoaderClass
|
return LoaderClass
|
||||||
|
|
||||||
|
|
||||||
# 把一些向量化共用逻辑从KnowledgeFile抽取出来,等langchain支持内存文件的时候,可以将非磁盘文件向量化
|
|
||||||
def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
|
def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
|
||||||
'''
|
'''
|
||||||
根据loader_name和文件路径或内容返回文档加载器。
|
根据loader_name和文件路径或内容返回文档加载器。
|
||||||
'''
|
'''
|
||||||
loader_kwargs = loader_kwargs or {}
|
loader_kwargs = loader_kwargs or {}
|
||||||
try:
|
try:
|
||||||
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader","FilteredCSVLoader"]:
|
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader", "FilteredCSVLoader",
|
||||||
|
"RapidOCRDocLoader", "RapidOCRPPTLoader"]:
|
||||||
document_loaders_module = importlib.import_module('document_loaders')
|
document_loaders_module = importlib.import_module('document_loaders')
|
||||||
else:
|
else:
|
||||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||||
|
|||||||
BIN
tests/samples/ocr_test.docx
Normal file
BIN
tests/samples/ocr_test.docx
Normal file
Binary file not shown.
BIN
tests/samples/ocr_test.pptx
Normal file
BIN
tests/samples/ocr_test.pptx
Normal file
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user