diff --git a/document_loaders/mypdfloader.py b/document_loaders/mypdfloader.py index 51778b89..2d65b09f 100644 --- a/document_loaders/mypdfloader.py +++ b/document_loaders/mypdfloader.py @@ -1,11 +1,38 @@ from typing import List from langchain.document_loaders.unstructured import UnstructuredFileLoader -from document_loaders.ocr import get_ocr +from ocr import get_ocr +import cv2 +from PIL import Image +import numpy as np import tqdm class RapidOCRPDFLoader(UnstructuredFileLoader): def _get_elements(self) -> List: + def rotate_img(img, angle): + ''' + img --image + angle --rotation angle + return--rotated img + ''' + + h, w = img.shape[:2] + rotate_center = (w/2, h/2) + #获取旋转矩阵 + # 参数1为旋转中心点; + # 参数2为旋转角度,正值-逆时针旋转;负值-顺时针旋转 + # 参数3为各向同性的比例因子,1.0原图,2.0变成原来的2倍,0.5变成原来的0.5倍 + M = cv2.getRotationMatrix2D(rotate_center, angle, 1.0) + #计算图像新边界 + new_w = int(h * np.abs(M[0, 1]) + w * np.abs(M[0, 0])) + new_h = int(h * np.abs(M[0, 0]) + w * np.abs(M[0, 1])) + #调整旋转矩阵以考虑平移 + M[0, 2] += (new_w - w) / 2 + M[1, 2] += (new_h - h) / 2 + + rotated_img = cv2.warpAffine(img, M, (new_w, new_h)) + return rotated_img + def pdf2text(filepath): import fitz # pyMuPDF里面的fitz包,不要与pip install fitz混淆 import numpy as np @@ -27,8 +54,18 @@ class RapidOCRPDFLoader(UnstructuredFileLoader): img_list = page.get_images() for img in img_list: pix = fitz.Pixmap(doc, img[0]) - img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1) + pix = fitz.Pixmap(doc, img[0]) + samples = pix.samples + if int(page.rotation)!=0: #如果Page有旋转角度,则旋转图片 + img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1) + tmp_img = Image.fromarray(img_array); + ori_img = cv2.cvtColor(np.array(tmp_img),cv2.COLOR_RGB2BGR) + rot_img = rotate_img(img=ori_img, angle=360-page.rotation) + img_array = cv2.cvtColor(rot_img, cv2.COLOR_RGB2BGR) + else: + img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1) result, _ = ocr(img_array) + if result: ocr_result = [line[1] for line in result] resp += "\n".join(ocr_result) @@ -43,6 +80,6 @@ class RapidOCRPDFLoader(UnstructuredFileLoader): if __name__ == "__main__": - loader = RapidOCRPDFLoader(file_path="../tests/samples/ocr_test.pdf") + loader = RapidOCRPDFLoader(file_path="/Users/tonysong/Desktop/test.pdf") docs = loader.load() print(docs)