from collections import namedtuple from typing import Optional from unittest.mock import MagicMock, patch import numpy as np import pandas as pd import pytest import unstructured_pytesseract from lxml import etree from pdf2image.exceptions import PDFPageCountError from PIL import Image, UnidentifiedImageError from unstructured_inference.inference.elements import EmbeddedTextRegion, TextRegion, TextRegions from unstructured_inference.inference.layout import DocumentLayout, PageLayout from unstructured_inference.inference.layoutelement import ( LayoutElement, LayoutElements, ) from test_unstructured.unit_utils import example_doc_path from unstructured.documents.elements import ElementType from unstructured.partition.pdf_image import ocr from unstructured.partition.pdf_image.pdf_image_utils import ( convert_pdf_to_images, pad_element_bboxes, ) from unstructured.partition.utils.config import env_config from unstructured.partition.utils.constants import ( OCR_AGENT_PADDLE, OCR_AGENT_TESSERACT, Source, ) from unstructured.partition.utils.ocr_models.google_vision_ocr import OCRAgentGoogleVision from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent from unstructured.partition.utils.ocr_models.paddle_ocr import OCRAgentPaddle from unstructured.partition.utils.ocr_models.tesseract_ocr import ( OCRAgentTesseract, zoom_image, ) @pytest.mark.parametrize( ("is_image", "expected_error"), [ (True, UnidentifiedImageError), (False, PDFPageCountError), ], ) def test_process_data_with_ocr_invalid_file(is_image, expected_error): invalid_data = b"i am not a valid file" with pytest.raises(expected_error): _ = ocr.process_data_with_ocr( data=invalid_data, is_image=is_image, out_layout=DocumentLayout(), extracted_layout=[], ) @pytest.mark.parametrize("is_image", [True, False]) def test_process_file_with_ocr_invalid_filename(is_image): invalid_filename = "i am not a valid file name" with pytest.raises(FileNotFoundError): _ = ocr.process_file_with_ocr( filename=invalid_filename, is_image=is_image, out_layout=DocumentLayout(), extracted_layout=[], ) def test_supplement_page_layout_with_ocr_invalid_ocr(): with pytest.raises(ValueError): _ = ocr.supplement_page_layout_with_ocr( page_layout=None, image=None, ocr_agent="invliad_ocr" ) def test_get_ocr_layout_from_image_tesseract(monkeypatch): monkeypatch.setattr( OCRAgentTesseract, "image_to_data_with_character_confidence_filter", lambda *args, **kwargs: pd.DataFrame( { "left": [10, 20, 30, 0], "top": [5, 15, 25, 0], "width": [15, 25, 35, 0], "height": [10, 20, 30, 0], "text": ["Hello", "World", "!", ""], }, ), ) image = Image.new("RGB", (100, 100)) ocr_agent = OCRAgentTesseract() ocr_layout = ocr_agent.get_layout_from_image(image) expected_layout = TextRegions( element_coords=np.array([[10.0, 5, 25, 15], [20, 15, 45, 35], [30, 25, 65, 55]]), texts=np.array(["Hello", "World", "!"]), sources=np.array([Source.OCR_TESSERACT] * 3), ) assert ocr_layout.texts.tolist() == expected_layout.texts.tolist() np.testing.assert_array_equal(ocr_layout.element_coords, expected_layout.element_coords) np.testing.assert_array_equal(ocr_layout.sources, expected_layout.sources) def mock_ocr(*args, **kwargs): return [ [ ( [(10, 5), (25, 5), (25, 15), (10, 15)], ["Hello"], ), ], [ ( [(20, 15), (45, 15), (45, 35), (20, 35)], ["World"], ), ], [ ( [(30, 25), (65, 25), (65, 55), (30, 55)], ["!"], ), ], [ ( [(0, 0), (0, 0), (0, 0), (0, 0)], [""], ), ], ] def monkeypatch_load_agent(*args): class MockAgent: def __init__(self): self.ocr = mock_ocr return MockAgent() def test_get_ocr_layout_from_image_paddle(monkeypatch): monkeypatch.setattr( OCRAgentPaddle, "load_agent", monkeypatch_load_agent, ) image = Image.new("RGB", (100, 100)) ocr_layout = OCRAgentPaddle().get_layout_from_image(image) expected_layout = TextRegions( element_coords=np.array([[10.0, 5, 25, 15], [20, 15, 45, 35], [30, 25, 65, 55]]), texts=np.array(["Hello", "World", "!"]), sources=np.array([Source.OCR_PADDLE] * 3), ) assert ocr_layout.texts.tolist() == expected_layout.texts.tolist() np.testing.assert_array_equal(ocr_layout.element_coords, expected_layout.element_coords) np.testing.assert_array_equal(ocr_layout.sources, expected_layout.sources) def test_get_ocr_text_from_image_tesseract(monkeypatch): monkeypatch.setattr( unstructured_pytesseract, "image_to_string", lambda *args, **kwargs: "Hello World", ) image = Image.new("RGB", (100, 100)) ocr_agent = OCRAgentTesseract() ocr_text = ocr_agent.get_text_from_image(image) assert ocr_text == "Hello World" def test_get_ocr_text_from_image_paddle(monkeypatch): monkeypatch.setattr( OCRAgentPaddle, "load_agent", monkeypatch_load_agent, ) image = Image.new("RGB", (100, 100)) ocr_agent = OCRAgentPaddle() ocr_text = ocr_agent.get_text_from_image(image) assert ocr_text == "Hello\n\nWorld\n\n!" @pytest.fixture() def google_vision_text_annotation(): from google.cloud.vision import ( Block, BoundingPoly, Page, Paragraph, Symbol, TextAnnotation, Vertex, Word, ) breaks = TextAnnotation.DetectedBreak.BreakType symbols_hello = [Symbol(text=c) for c in "Hello"] + [ Symbol( property=TextAnnotation.TextProperty( detected_break=TextAnnotation.DetectedBreak(type_=breaks.SPACE) ) ) ] symbols_world = [Symbol(text=c) for c in "World!"] + [ Symbol( property=TextAnnotation.TextProperty( detected_break=TextAnnotation.DetectedBreak(type_=breaks.LINE_BREAK) ) ) ] words = [Word(symbols=symbols_hello), Word(symbols=symbols_world)] bounding_box = BoundingPoly( vertices=[Vertex(x=0, y=0), Vertex(x=0, y=10), Vertex(x=10, y=10), Vertex(x=10, y=0)] ) paragraphs = [Paragraph(words=words, bounding_box=bounding_box)] blocks = [Block(paragraphs=paragraphs)] pages = [Page(blocks=blocks)] return TextAnnotation(text="Hello World!", pages=pages) @pytest.fixture() def google_vision_client(google_vision_text_annotation): Response = namedtuple("Response", "full_text_annotation") class FakeGoogleVisionClient: def document_text_detection(self, image, image_context): return Response(full_text_annotation=google_vision_text_annotation) class OCRAgentFakeGoogleVision(OCRAgentGoogleVision): def __init__(self, language: Optional[str] = None): self.client = FakeGoogleVisionClient() self.language = language return OCRAgentFakeGoogleVision() def test_get_ocr_from_image_google_vision(google_vision_client): image = Image.new("RGB", (100, 100)) ocr_agent = google_vision_client ocr_text = ocr_agent.get_text_from_image(image) assert ocr_text == "Hello World!" def test_get_layout_from_image_google_vision(google_vision_client): image = Image.new("RGB", (100, 100)) ocr_agent = google_vision_client regions = ocr_agent.get_layout_from_image(image) assert len(regions) == 1 assert regions.texts[0] == "Hello World!" assert all(source == Source.OCR_GOOGLEVISION for source in regions.sources) assert regions.x1[0] == 0 assert regions.y1[0] == 0 assert regions.x2[0] == 10 assert regions.y2[0] == 10 def test_get_layout_elements_from_image_google_vision(google_vision_client): image = Image.new("RGB", (100, 100)) ocr_agent = google_vision_client layout_elements = ocr_agent.get_layout_elements_from_image(image) assert len(layout_elements) == 1 @pytest.fixture() def mock_ocr_regions(): return TextRegions.from_list( [ EmbeddedTextRegion.from_coords(10, 10, 90, 90, text="0", source=None), EmbeddedTextRegion.from_coords(200, 200, 300, 300, text="1", source=None), EmbeddedTextRegion.from_coords(500, 320, 600, 350, text="3", source=None), ] ) @pytest.fixture() def mock_out_layout(mock_embedded_text_regions): return LayoutElements.from_list( [ LayoutElement( text="", source=None, type="Text", bbox=r.bbox, ) for r in mock_embedded_text_regions ] ) def test_aggregate_ocr_text_by_block(): expected = "A Unified Toolkit" ocr_layout = [ TextRegion.from_coords(0, 0, 20, 20, "A"), TextRegion.from_coords(50, 50, 150, 150, "Unified"), TextRegion.from_coords(150, 150, 300, 250, "Toolkit"), TextRegion.from_coords(200, 250, 300, 350, "Deep"), ] region = TextRegion.from_coords(0, 0, 250, 350, "") text = ocr.aggregate_ocr_text_by_block(ocr_layout, region, 0.5) assert text == expected @pytest.mark.parametrize("zoom", [1, 0.1, 5, -1, 0]) def test_zoom_image(zoom): image = Image.new("RGB", (100, 100)) width, height = image.size new_image = zoom_image(image, zoom) new_w, new_h = new_image.size if zoom <= 0: zoom = 1 assert new_w == np.round(width * zoom, 0) assert new_h == np.round(height * zoom, 0) @pytest.fixture() def mock_layout(mock_embedded_text_regions): return LayoutElements.from_list( [ LayoutElement(text=r.text, type=ElementType.UNCATEGORIZED_TEXT, bbox=r.bbox) for r in mock_embedded_text_regions ] ) def test_supplement_layout_with_ocr_elements(mock_layout, mock_ocr_regions): ocr_elements = [ LayoutElement(text=r.text, source=None, type=ElementType.UNCATEGORIZED_TEXT, bbox=r.bbox) for r in mock_ocr_regions.as_list() ] final_layout = ocr.supplement_layout_with_ocr_elements(mock_layout, mock_ocr_regions).as_list() # Check if the final layout contains the original layout elements for element in mock_layout.as_list(): assert element in final_layout # Check if the final layout contains the OCR-derived elements assert any(ocr_element in final_layout for ocr_element in ocr_elements) # Check if the OCR-derived elements that are subregions of layout elements are removed for element in mock_layout.as_list(): for ocr_element in ocr_elements: if ocr_element.bbox.is_almost_subregion_of( element.bbox, env_config.OCR_LAYOUT_SUBREGION_THRESHOLD, ): assert ocr_element not in final_layout def test_merge_out_layout_with_ocr_layout(mock_out_layout, mock_ocr_regions): ocr_elements = [ LayoutElement(text=r.text, source=None, type=ElementType.UNCATEGORIZED_TEXT, bbox=r.bbox) for r in mock_ocr_regions.as_list() ] input_layout_elements = mock_out_layout.as_list() final_layout = ocr.merge_out_layout_with_ocr_layout( mock_out_layout, mock_ocr_regions, ).as_list() # Check if the out layout's text attribute is updated with aggregated OCR text assert final_layout[0].text == mock_ocr_regions.texts[2] # Check if the final layout contains both original elements and OCR-derived elements # The first element's text is modified by the ocr regions so it won't be the same as the input assert all(element in final_layout for element in input_layout_elements[1:]) assert final_layout[0].bbox == input_layout_elements[0].bbox assert any(element in final_layout for element in ocr_elements) @pytest.mark.parametrize( ("padding", "expected_bbox"), [ (5, (5, 15, 35, 45)), (-3, (13, 23, 27, 37)), (2.5, (7.5, 17.5, 32.5, 42.5)), (-1.5, (11.5, 21.5, 28.5, 38.5)), ], ) def test_pad_element_bboxes(padding, expected_bbox): element = LayoutElement.from_coords( x1=10, y1=20, x2=30, y2=40, text="", source=None, type=ElementType.UNCATEGORIZED_TEXT, ) expected_original_element_bbox = (10, 20, 30, 40) padded_element = pad_element_bboxes(element, padding) padded_element_bbox = ( padded_element.bbox.x1, padded_element.bbox.y1, padded_element.bbox.x2, padded_element.bbox.y2, ) assert padded_element_bbox == expected_bbox # make sure the original element has not changed original_element_bbox = (element.bbox.x1, element.bbox.y1, element.bbox.x2, element.bbox.y2) assert original_element_bbox == expected_original_element_bbox @pytest.fixture() def table_element(): table = LayoutElement.from_coords(x1=10, y1=20, x2=50, y2=70, text="I am a table", type="Table") return table @pytest.fixture() def mock_ocr_layout(): return TextRegions.from_list( [ TextRegion.from_coords(x1=15, y1=25, x2=35, y2=45, text="Token1"), TextRegion.from_coords(x1=40, y1=30, x2=45, y2=50, text="Token2"), ] ) def test_supplement_element_with_table_extraction(): from unstructured_inference.models import tables tables.load_agent() image = next(convert_pdf_to_images(example_doc_path("pdf/single_table.pdf"))) elements = LayoutElements( element_coords=np.array([[215.00109863, 731.89996338, 1470.07739258, 972.83129883]]), texts=np.array(["foo"]), sources=np.array(["yolox_sg"]), element_class_ids=np.array([0]), element_class_id_map={0: "Table"}, ) supplemented = ocr.supplement_element_with_table_extraction( elements=elements, image=image, tables_agent=tables.tables_agent, ocr_agent=ocr.OCRAgent.get_agent(language="eng"), ) assert supplemented.text_as_html[0].startswith("") def test_get_table_tokens(mock_ocr_layout): with patch.object(OCRAgentTesseract, "get_layout_from_image", return_value=mock_ocr_layout): ocr_agent = OCRAgent.get_agent(language="eng") table_tokens = ocr.get_table_tokens(table_element_image=None, ocr_agent=ocr_agent) expected_tokens = [ { "bbox": [15, 25, 35, 45], "text": "Token1", "span_num": 0, "line_num": 0, "block_num": 0, }, { "bbox": [40, 30, 45, 50], "text": "Token2", "span_num": 1, "line_num": 0, "block_num": 0, }, ] assert table_tokens == expected_tokens def test_auto_zoom_not_exceed_tesseract_limit(monkeypatch): monkeypatch.setenv("TESSERACT_MIN_TEXT_HEIGHT", "1000") monkeypatch.setenv("TESSERACT_OPTIMUM_TEXT_HEIGHT", "100000") monkeypatch.setattr( OCRAgentTesseract, "image_to_data_with_character_confidence_filter", lambda *args, **kwargs: pd.DataFrame( { "left": [10, 20, 30, 0], "top": [5, 15, 25, 0], "width": [15, 25, 35, 0], "height": [10, 20, 30, 0], "text": ["Hello", "World", "!", ""], }, ), ) image = Image.new("RGB", (1000, 1000)) ocr_agent = OCRAgentTesseract() # tests that the code can run instead of oom and OCR results make sense assert ocr_agent.get_layout_from_image(image).texts.tolist() == [ "Hello", "World", "!", ] def test_merge_out_layout_with_cid_code(mock_out_layout, mock_ocr_regions): # the code should ignore this invalid text and use ocr region's text mock_out_layout.texts = mock_out_layout.texts.astype(object) mock_out_layout.texts[0] = "(cid:10)(cid:5)?" ocr_elements = [ LayoutElement(text=r.text, source=None, type=ElementType.UNCATEGORIZED_TEXT, bbox=r.bbox) for r in mock_ocr_regions.as_list() ] input_layout_elements = mock_out_layout.as_list() # TODO (yao): refactor the tests to check the array data structure directly instead of # converting them into lists first (this includes other tests in this file) final_layout = ocr.merge_out_layout_with_ocr_layout(mock_out_layout, mock_ocr_regions).as_list() # Check if the out layout's text attribute is updated with aggregated OCR text assert final_layout[0].text == mock_ocr_regions.texts[2] # Check if the final layout contains both original elements and OCR-derived elements assert all(element in final_layout for element in input_layout_elements[1:]) assert any(element in final_layout for element in ocr_elements) def _create_hocr_word_span( characters: list[tuple[str, str]], word_bbox: tuple[int, int, int, int], namespace_map: dict ) -> etree.Element: word_span = [ '\n', ( f"" ), ] for char, x_conf in characters: word_span.append( f"{char}" ) word_span.append("") word_span.append("") root = etree.fromstring("\n".join(word_span)) return root def test_extract_word_from_hocr(): characters = [ ("w", "99.0"), ("o", "98.5"), ("r", "97.5"), ("d", "96.0"), ("!", "50.0"), ("@", "45.0"), ] word_bbox = (10, 9, 70, 22) agent = OCRAgentTesseract() word_span = _create_hocr_word_span(characters, word_bbox, agent.hocr_namespace) text = agent.extract_word_from_hocr(word_span, 0.0) assert text == "word!@" text = agent.extract_word_from_hocr(word_span, 0.960) assert text == "word" text = agent.extract_word_from_hocr(word_span, 0.990) assert text == "w" text = agent.extract_word_from_hocr(word_span, 0.999) assert text == "" def test_hocr_to_dataframe(): characters = [ ("w", "99.0"), ("o", "98.5"), ("r", "97.5"), ("d", "96.0"), ("!", "50.0"), ("@", "45.0"), ] word_bbox = (10, 9, 70, 22) agent = OCRAgentTesseract() hocr = etree.tostring(_create_hocr_word_span(characters, word_bbox, agent.hocr_namespace)) df = agent.hocr_to_dataframe(hocr=hocr, character_confidence_threshold=0.960) assert df.shape == (1, 5) assert df["left"].iloc[0] == 10 assert df["top"].iloc[0] == 9 assert df["width"].iloc[0] == 60 assert df["height"].iloc[0] == 13 assert df["text"].iloc[0] == "word" def test_hocr_to_dataframe_when_no_prediction_empty_df(): df = OCRAgentTesseract().hocr_to_dataframe(hocr="") assert df.shape == (0, 5) assert "left" in df.columns assert "top" in df.columns assert "width" in df.columns assert "height" in df.columns assert "text" in df.columns @pytest.fixture def mock_page(mock_ocr_layout, mock_layout): mock_page = MagicMock(PageLayout) mock_page.elements_array = mock_layout return mock_page def test_supplement_layout_with_ocr(mock_ocr_get_instance, mocker, mock_page): from unstructured.partition.pdf_image.ocr import OCRAgent mocker.patch.object(OCRAgent, "get_layout_from_image", return_value=mock_ocr_layout) ocr.supplement_page_layout_with_ocr( mock_page, Image.new("RGB", (100, 100)), infer_table_structure=True, ocr_agent=OCR_AGENT_TESSERACT, ocr_languages="eng", table_ocr_agent=OCR_AGENT_PADDLE, ) assert mock_ocr_get_instance.call_args_list[0][1] == { "language": "eng", "ocr_agent_module": OCR_AGENT_TESSERACT, } assert mock_ocr_get_instance.call_args_list[1][1] == { "language": "en", "ocr_agent_module": OCR_AGENT_PADDLE, } def test_pass_down_agents(mock_ocr_get_instance, mocker, mock_page): from unstructured.partition.pdf_image.ocr import OCRAgent, PILImage mocker.patch.object(OCRAgent, "get_layout_from_image", return_value=mock_ocr_layout) mocker.patch.object(PILImage, "open", return_value=Image.new("RGB", (100, 100))) doc = MagicMock(DocumentLayout) doc.pages = [mock_page] ocr.process_file_with_ocr( "foo", doc, [], infer_table_structure=True, is_image=True, ocr_agent=OCR_AGENT_PADDLE, ocr_languages="eng", table_ocr_agent=OCR_AGENT_TESSERACT, ) assert mock_ocr_get_instance.call_args_list[0][1] == { "language": "en", "ocr_agent_module": OCR_AGENT_PADDLE, } assert mock_ocr_get_instance.call_args_list[1][1] == { "language": "eng", "ocr_agent_module": OCR_AGENT_TESSERACT, }