feat: allow passing down of ocr agent and table agent (#3954)

This PR allows passing down both `ocr_agent` and `table_ocr_agent` as
parameters to specify the `OCRAgent` class for the page and tables, if
any, respectively. Both are default to using `tesseract`, consistent
with the present default behavior.

We used to rely on env variables to specify the agents but os env can be
changed during runtime outside of the caller's control. This method of
passing down the variables ensures that specification is independent of
env changes.

## testing

Using `example-docs/img/layout-parser-paper-with-table.jpg` and run
partition with two different settings. Note that this test requires
`paddleocr` extra.

```python
from unstructured.partition.auto import partition
from unstructured.partition.utils.constants import OCR_AGENT_TESSERACT, OCR_AGENT_PADDLE
elements = partition(f, strategy="hi_res", skip_infer_table_types=[], ocr_agent=OCR_AGENT_TESSERACT, table_ocr_agent=OCR_AGENT_PADDLE)
elements_alt = partition(f, strategy="hi_res", skip_infer_table_types=[], ocr_agent=OCR_AGENT_PADDLE, table_ocr_agent=OCR_AGENT_TESSERACT)
```

we should see both finish and slight differences in the table element's
text attribute.
This commit is contained in:
Yao You 2025-03-11 11:36:31 -05:00 committed by GitHub
parent 0001a33dba
commit 8759b0aac9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 122 additions and 18 deletions

View File

@ -218,6 +218,7 @@ jobs:
sudo apt-get install -y tesseract-ocr tesseract-ocr-kor
tesseract --version
make install-${{ matrix.extra }}
[[ ${{ matrix.extra }} == "pdf-image" ]] && make install-paddleocr
make test-extra-${{ matrix.extra }} CI=true
setup_ingest:

View File

@ -1,8 +1,9 @@
## 0.16.26-dev2
## 0.16.26-dev3
### Enhancements
- **Add support for images in html partitioner** `<img>` tags will now be parsed as `Image` elements. When `extract_image_block_types` includes `Image` and `extract_image_block_to_payload`=True then the `image_base64` will be included for images that specify the base64 data (rather than url) as the source.
- **Use kwargs instead of env to specify `ocr_agent` and `table_ocr_agent`** for `hi_res` strategy.
### Features

View File

@ -22,7 +22,7 @@ install-base: install-base-pip-packages install-nltk-models
install: install-base-pip-packages install-dev install-nltk-models install-test install-huggingface install-all-docs
.PHONY: install-ci
install-ci: install-base-pip-packages install-nltk-models install-huggingface install-all-docs install-test install-pandoc
install-ci: install-base-pip-packages install-nltk-models install-huggingface install-all-docs install-test install-pandoc install-paddleocr
.PHONY: install-base-ci
install-base-ci: install-base-pip-packages install-nltk-models install-test install-pandoc
@ -80,6 +80,10 @@ install-odt:
install-pypandoc:
${PYTHON} -m pip install -r requirements/extra-pandoc.txt
.PHONY: install-paddleocr
install-paddleocr:
${PYTHON} -m pip install -r requirements/extra-paddleocr.txt
.PHONY: install-markdown
install-markdown:
${PYTHON} -m pip install -r requirements/extra-markdown.txt

View File

@ -1,6 +1,6 @@
from collections import namedtuple
from typing import Optional
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import numpy as np
import pandas as pd
@ -10,7 +10,7 @@ from bs4 import BeautifulSoup, Tag
from pdf2image.exceptions import PDFPageCountError
from PIL import Image, UnidentifiedImageError
from unstructured_inference.inference.elements import EmbeddedTextRegion, TextRegion, TextRegions
from unstructured_inference.inference.layout import DocumentLayout
from unstructured_inference.inference.layout import DocumentLayout, PageLayout
from unstructured_inference.inference.layoutelement import (
LayoutElement,
LayoutElements,
@ -25,6 +25,8 @@ from unstructured.partition.pdf_image.pdf_image_utils import (
)
from unstructured.partition.utils.config import env_config
from unstructured.partition.utils.constants import (
OCR_AGENT_PADDLE,
OCR_AGENT_TESSERACT,
Source,
)
from unstructured.partition.utils.ocr_models.google_vision_ocr import OCRAgentGoogleVision
@ -66,12 +68,10 @@ def test_process_file_with_ocr_invalid_filename(is_image):
)
def test_supplement_page_layout_with_ocr_invalid_ocr(monkeypatch):
monkeypatch.setenv("OCR_AGENT", "invalid_ocr")
def test_supplement_page_layout_with_ocr_invalid_ocr():
with pytest.raises(ValueError):
_ = ocr.supplement_page_layout_with_ocr(
page_layout=None,
image=None,
page_layout=None, image=None, ocr_agent="invliad_ocr"
)
@ -610,3 +610,53 @@ def test_hocr_to_dataframe_when_no_prediction_empty_df():
assert "width" in df.columns
assert "text" in df.columns
assert "text" in df.columns
@pytest.fixture
def mock_page(mock_ocr_layout, mock_layout):
mock_page = MagicMock(PageLayout)
mock_page.elements_array = mock_layout
return mock_page
def test_supplement_layout_with_ocr(mocker, mock_page):
from unstructured.partition.pdf_image.ocr import OCRAgent
mocker.patch.object(OCRAgent, "get_layout_from_image", return_value=mock_ocr_layout)
spy = mocker.spy(OCRAgent, "get_instance")
ocr.supplement_page_layout_with_ocr(
mock_page,
Image.new("RGB", (100, 100)),
infer_table_structure=True,
ocr_agent=OCR_AGENT_TESSERACT,
ocr_languages="eng",
table_ocr_agent=OCR_AGENT_PADDLE,
)
assert spy.call_args_list[0][1] == {"language": "eng", "ocr_agent_module": OCR_AGENT_TESSERACT}
assert spy.call_args_list[1][1] == {"language": "en", "ocr_agent_module": OCR_AGENT_PADDLE}
def test_pass_down_agents(mocker, mock_page):
from unstructured.partition.pdf_image.ocr import OCRAgent, PILImage
mocker.patch.object(OCRAgent, "get_layout_from_image", return_value=mock_ocr_layout)
mocker.patch.object(PILImage, "open", return_value=Image.new("RGB", (100, 100)))
spy = mocker.spy(OCRAgent, "get_instance")
doc = MagicMock(DocumentLayout)
doc.pages = [mock_page]
ocr.process_file_with_ocr(
"foo",
doc,
[],
infer_table_structure=True,
is_image=True,
ocr_agent=OCR_AGENT_PADDLE,
ocr_languages="eng",
table_ocr_agent=OCR_AGENT_TESSERACT,
)
assert spy.call_args_list[0][1] == {"language": "en", "ocr_agent_module": OCR_AGENT_PADDLE}
assert spy.call_args_list[1][1] == {"language": "eng", "ocr_agent_module": OCR_AGENT_TESSERACT}

View File

@ -39,6 +39,8 @@ from unstructured.partition import pdf, strategies
from unstructured.partition.pdf_image import ocr, pdfminer_processing
from unstructured.partition.pdf_image.pdfminer_processing import get_uris_from_annots
from unstructured.partition.utils.constants import (
OCR_AGENT_PADDLE,
OCR_AGENT_TESSERACT,
SORT_MODE_BASIC,
SORT_MODE_DONT,
SORT_MODE_XY_CUT,
@ -1585,3 +1587,20 @@ def test_partition_pdf_with_password(
file=spooled_temp_file, strategy=strategy, password="password"
)
_test(result)
def test_partition_pdf_with_specified_ocr_agents(mocker):
from unstructured.partition.pdf_image.ocr import OCRAgent
spy = mocker.spy(OCRAgent, "get_instance")
pdf.partition_pdf(
filename=example_doc_path("pdf/layout-parser-paper-with-table.pdf"),
strategy=PartitionStrategy.HI_RES,
infer_table_structure=True,
ocr_agent=OCR_AGENT_TESSERACT,
table_ocr_agent=OCR_AGENT_PADDLE,
)
assert spy.call_args_list[0][1] == {"language": "eng", "ocr_agent_module": OCR_AGENT_TESSERACT}
assert spy.call_args_list[1][1] == {"language": "en", "ocr_agent_module": OCR_AGENT_PADDLE}

View File

@ -1 +1 @@
__version__ = "0.16.26-dev2" # pragma: no cover
__version__ = "0.16.26-dev3" # pragma: no cover

View File

@ -54,7 +54,6 @@ from unstructured.partition.common.common import (
from unstructured.partition.common.lang import (
check_language_args,
prepare_languages_for_tesseract,
tesseract_to_paddle_language,
)
from unstructured.partition.common.metadata import get_last_modified_date
from unstructured.partition.pdf_image.analysis.layout_dump import (
@ -88,7 +87,7 @@ from unstructured.partition.strategies import determine_pdf_or_image_strategy, v
from unstructured.partition.text import element_from_text
from unstructured.partition.utils.config import env_config
from unstructured.partition.utils.constants import (
OCR_AGENT_PADDLE,
OCR_AGENT_TESSERACT,
SORT_MODE_BASIC,
SORT_MODE_DONT,
SORT_MODE_XY_CUT,
@ -273,6 +272,8 @@ def partition_pdf_or_image(
pdfminer_char_margin: Optional[float] = None,
pdfminer_line_overlap: Optional[float] = None,
pdfminer_word_margin: Optional[float] = 0.185,
ocr_agent: str = OCR_AGENT_TESSERACT,
table_ocr_agent: str = OCR_AGENT_TESSERACT,
**kwargs: Any,
) -> list[Element]:
"""Parses a pdf or image document into a list of interpreted elements."""
@ -332,8 +333,6 @@ def partition_pdf_or_image(
file.seek(0)
ocr_languages = prepare_languages_for_tesseract(languages)
if env_config.OCR_AGENT == OCR_AGENT_PADDLE:
ocr_languages = tesseract_to_paddle_language(ocr_languages)
if strategy == PartitionStrategy.HI_RES:
# NOTE(robinson): Catches a UserWarning that occurs when detection is called
@ -359,6 +358,8 @@ def partition_pdf_or_image(
form_extraction_skip_tables=form_extraction_skip_tables,
password=password,
pdfminer_config=pdfminer_config,
ocr_agent=ocr_agent,
table_ocr_agent=table_ocr_agent,
**kwargs,
)
out_elements = _process_uncategorized_text_elements(elements)
@ -609,6 +610,8 @@ def _partition_pdf_or_image_local(
pdf_hi_res_max_pages: Optional[int] = None,
password: Optional[str] = None,
pdfminer_config: Optional[PDFMinerConfig] = None,
ocr_agent: str = OCR_AGENT_TESSERACT,
table_ocr_agent: str = OCR_AGENT_TESSERACT,
**kwargs: Any,
) -> list[Element]:
"""Partition using package installed locally"""
@ -690,11 +693,13 @@ def _partition_pdf_or_image_local(
extracted_layout=extracted_layout,
is_image=is_image,
infer_table_structure=infer_table_structure,
ocr_agent=ocr_agent,
ocr_languages=ocr_languages,
ocr_mode=ocr_mode,
pdf_image_dpi=pdf_image_dpi,
ocr_layout_dumper=ocr_layout_dumper,
password=password,
table_ocr_agent=table_ocr_agent,
)
else:
inferred_document_layout = process_data_with_model(
@ -749,11 +754,13 @@ def _partition_pdf_or_image_local(
extracted_layout=extracted_layout,
is_image=is_image,
infer_table_structure=infer_table_structure,
ocr_agent=ocr_agent,
ocr_languages=ocr_languages,
ocr_mode=ocr_mode,
pdf_image_dpi=pdf_image_dpi,
ocr_layout_dumper=ocr_layout_dumper,
password=password,
table_ocr_agent=table_ocr_agent,
)
# vectorization of the data structure ends here

View File

@ -14,6 +14,7 @@ from PIL import ImageSequence
from unstructured.documents.elements import ElementType
from unstructured.metrics.table.table_formats import SimpleTableCell
from unstructured.partition.common.lang import tesseract_to_paddle_language
from unstructured.partition.pdf_image.analysis.layout_dump import OCRLayoutDumper
from unstructured.partition.pdf_image.pdf_image_utils import valid_text
from unstructured.partition.pdf_image.pdfminer_processing import (
@ -21,7 +22,7 @@ from unstructured.partition.pdf_image.pdfminer_processing import (
bboxes1_is_almost_subregion_of_bboxes2,
)
from unstructured.partition.utils.config import env_config
from unstructured.partition.utils.constants import OCRMode
from unstructured.partition.utils.constants import OCR_AGENT_PADDLE, OCR_AGENT_TESSERACT, OCRMode
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
from unstructured.utils import requires_dependencies
@ -38,11 +39,13 @@ def process_data_with_ocr(
extracted_layout: List[List["TextRegion"]],
is_image: bool = False,
infer_table_structure: bool = False,
ocr_agent: str = OCR_AGENT_TESSERACT,
ocr_languages: str = "eng",
ocr_mode: str = OCRMode.FULL_PAGE.value,
pdf_image_dpi: int = 200,
ocr_layout_dumper: Optional[OCRLayoutDumper] = None,
password: Optional[str] = None,
table_ocr_agent: str = OCR_AGENT_TESSERACT,
) -> "DocumentLayout":
"""
Process OCR data from a given data and supplement the output DocumentLayout
@ -86,11 +89,13 @@ def process_data_with_ocr(
extracted_layout=extracted_layout,
is_image=is_image,
infer_table_structure=infer_table_structure,
ocr_agent=ocr_agent,
ocr_languages=ocr_languages,
ocr_mode=ocr_mode,
pdf_image_dpi=pdf_image_dpi,
ocr_layout_dumper=ocr_layout_dumper,
password=password,
table_ocr_agent=table_ocr_agent,
)
return merged_layouts
@ -103,11 +108,13 @@ def process_file_with_ocr(
extracted_layout: List[TextRegions],
is_image: bool = False,
infer_table_structure: bool = False,
ocr_agent: str = OCR_AGENT_TESSERACT,
ocr_languages: str = "eng",
ocr_mode: str = OCRMode.FULL_PAGE.value,
pdf_image_dpi: int = 200,
ocr_layout_dumper: Optional[OCRLayoutDumper] = None,
password: Optional[str] = None,
table_ocr_agent: str = OCR_AGENT_TESSERACT,
) -> "DocumentLayout":
"""
Process OCR data from a given file and supplement the output DocumentLayout
@ -154,10 +161,12 @@ def process_file_with_ocr(
page_layout=out_layout.pages[i],
image=image,
infer_table_structure=infer_table_structure,
ocr_agent=ocr_agent,
ocr_languages=ocr_languages,
ocr_mode=ocr_mode,
extracted_regions=extracted_regions,
ocr_layout_dumper=ocr_layout_dumper,
table_ocr_agent=table_ocr_agent,
)
merged_page_layouts.append(merged_page_layout)
return DocumentLayout.from_pages(merged_page_layouts)
@ -178,10 +187,12 @@ def process_file_with_ocr(
page_layout=out_layout.pages[i],
image=image,
infer_table_structure=infer_table_structure,
ocr_agent=ocr_agent,
ocr_languages=ocr_languages,
ocr_mode=ocr_mode,
extracted_regions=extracted_regions,
ocr_layout_dumper=ocr_layout_dumper,
table_ocr_agent=table_ocr_agent,
)
merged_page_layouts.append(merged_page_layout)
return DocumentLayout.from_pages(merged_page_layouts)
@ -197,10 +208,12 @@ def supplement_page_layout_with_ocr(
page_layout: "PageLayout",
image: PILImage.Image,
infer_table_structure: bool = False,
ocr_agent: str = OCR_AGENT_TESSERACT,
ocr_languages: str = "eng",
ocr_mode: str = OCRMode.FULL_PAGE.value,
extracted_regions: Optional[TextRegions] = None,
ocr_layout_dumper: Optional[OCRLayoutDumper] = None,
table_ocr_agent: str = OCR_AGENT_TESSERACT,
) -> "PageLayout":
"""
Supplement an PageLayout with OCR results depending on OCR mode.
@ -210,9 +223,12 @@ def supplement_page_layout_with_ocr(
with no text and add text from OCR to each element.
"""
ocr_agent = OCRAgent.get_agent(language=ocr_languages)
language = ocr_languages
if ocr_agent == OCR_AGENT_PADDLE:
language = tesseract_to_paddle_language(ocr_languages)
_ocr_agent = OCRAgent.get_instance(ocr_agent_module=ocr_agent, language=language)
if ocr_mode == OCRMode.FULL_PAGE.value:
ocr_layout = ocr_agent.get_layout_from_image(image)
ocr_layout = _ocr_agent.get_layout_from_image(image)
if ocr_layout_dumper:
ocr_layout_dumper.add_ocred_page(ocr_layout.as_list())
page_layout.elements_array = merge_out_layout_with_ocr_layout(
@ -236,7 +252,7 @@ def supplement_page_layout_with_ocr(
)
# Note(yuming): instead of getting OCR layout, we just need
# the text extraced from OCR for individual elements
text_from_ocr = ocr_agent.get_text_from_image(cropped_image)
text_from_ocr = _ocr_agent.get_text_from_image(cropped_image)
page_layout.elements_array.texts[i] = text_from_ocr
else:
raise ValueError(
@ -246,6 +262,12 @@ def supplement_page_layout_with_ocr(
# Note(yuming): use the OCR data from entire page OCR for table extraction
if infer_table_structure:
language = ocr_languages
if table_ocr_agent == OCR_AGENT_PADDLE:
language = tesseract_to_paddle_language(ocr_languages)
_table_ocr_agent = OCRAgent.get_instance(
ocr_agent_module=table_ocr_agent, language=language
)
from unstructured_inference.models import tables
tables.load_agent()
@ -256,7 +278,7 @@ def supplement_page_layout_with_ocr(
elements=page_layout.elements_array,
image=image,
tables_agent=tables.tables_agent,
ocr_agent=ocr_agent,
ocr_agent=_table_ocr_agent,
extracted_regions=extracted_regions,
)
page_layout.elements = page_layout.elements_array.as_list()