diff --git a/configs/kb_config.py.example b/configs/kb_config.py.example index 6be2848f..cf9b3e68 100644 --- a/configs/kb_config.py.example +++ b/configs/kb_config.py.example @@ -55,6 +55,9 @@ METAPHOR_API_KEY = "" # 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。 ZH_TITLE_ENHANCE = False +# PDF OCR 控制:只对宽高超过页面一定比例(图片宽/页面宽,图片高/页面高)的图片进行 OCR。 +# 这样可以避免 PDF 中一些小图片的干扰,提高非扫描版 PDF 处理速度 +PDF_OCR_THRESHOLD = (0.6, 0.6) # 每个知识库的初始化介绍,用于在初始化知识库时显示和Agent调用,没写则没有介绍,不会被Agent调用。 KB_INFO = { diff --git a/document_loaders/mypdfloader.py b/document_loaders/mypdfloader.py index 51778b89..5c480cff 100644 --- a/document_loaders/mypdfloader.py +++ b/document_loaders/mypdfloader.py @@ -1,5 +1,6 @@ from typing import List from langchain.document_loaders.unstructured import UnstructuredFileLoader +from configs import PDF_OCR_THRESHOLD from document_loaders.ocr import get_ocr import tqdm @@ -15,7 +16,6 @@ class RapidOCRPDFLoader(UnstructuredFileLoader): b_unit = tqdm.tqdm(total=doc.page_count, desc="RapidOCRPDFLoader context page index: 0") for i, page in enumerate(doc): - # 更新描述 b_unit.set_description("RapidOCRPDFLoader context page index: {}".format(i)) # 立即显示进度条更新结果 @@ -24,14 +24,20 @@ class RapidOCRPDFLoader(UnstructuredFileLoader): text = page.get_text("") resp += text + "\n" - img_list = page.get_images() + img_list = page.get_image_info(xrefs=True) for img in img_list: - pix = fitz.Pixmap(doc, img[0]) - img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1) - result, _ = ocr(img_array) - if result: - ocr_result = [line[1] for line in result] - resp += "\n".join(ocr_result) + if xref := img.get("xref"): + bbox = img["bbox"] + # 检查图片尺寸是否超过设定的阈值 + if ((bbox[2] - bbox[0]) / (page.rect.width) < PDF_OCR_THRESHOLD[0] + or (bbox[3] - bbox[1]) / (page.rect.height) < PDF_OCR_THRESHOLD[1]): + continue + pix = fitz.Pixmap(doc, xref) + img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1) + result, _ = ocr(img_array) + if result: + ocr_result = [line[1] for line in result] + resp += "\n".join(ocr_result) # 更新进度 b_unit.update(1)