mirror of
https://github.com/Unstructured-IO/unstructured.git
synced 2025-06-27 02:30:08 +00:00
feat: allow passing down of ocr agent and table agent (#3954)
This PR allows passing down both `ocr_agent` and `table_ocr_agent` as parameters to specify the `OCRAgent` class for the page and tables, if any, respectively. Both are default to using `tesseract`, consistent with the present default behavior. We used to rely on env variables to specify the agents but os env can be changed during runtime outside of the caller's control. This method of passing down the variables ensures that specification is independent of env changes. ## testing Using `example-docs/img/layout-parser-paper-with-table.jpg` and run partition with two different settings. Note that this test requires `paddleocr` extra. ```python from unstructured.partition.auto import partition from unstructured.partition.utils.constants import OCR_AGENT_TESSERACT, OCR_AGENT_PADDLE elements = partition(f, strategy="hi_res", skip_infer_table_types=[], ocr_agent=OCR_AGENT_TESSERACT, table_ocr_agent=OCR_AGENT_PADDLE) elements_alt = partition(f, strategy="hi_res", skip_infer_table_types=[], ocr_agent=OCR_AGENT_PADDLE, table_ocr_agent=OCR_AGENT_TESSERACT) ``` we should see both finish and slight differences in the table element's text attribute.
This commit is contained in:
parent
0001a33dba
commit
8759b0aac9
1
.github/workflows/ci.yml
vendored
1
.github/workflows/ci.yml
vendored
@ -218,6 +218,7 @@ jobs:
|
||||
sudo apt-get install -y tesseract-ocr tesseract-ocr-kor
|
||||
tesseract --version
|
||||
make install-${{ matrix.extra }}
|
||||
[[ ${{ matrix.extra }} == "pdf-image" ]] && make install-paddleocr
|
||||
make test-extra-${{ matrix.extra }} CI=true
|
||||
|
||||
setup_ingest:
|
||||
|
@ -1,8 +1,9 @@
|
||||
## 0.16.26-dev2
|
||||
## 0.16.26-dev3
|
||||
|
||||
### Enhancements
|
||||
|
||||
- **Add support for images in html partitioner** `<img>` tags will now be parsed as `Image` elements. When `extract_image_block_types` includes `Image` and `extract_image_block_to_payload`=True then the `image_base64` will be included for images that specify the base64 data (rather than url) as the source.
|
||||
- **Use kwargs instead of env to specify `ocr_agent` and `table_ocr_agent`** for `hi_res` strategy.
|
||||
|
||||
### Features
|
||||
|
||||
|
6
Makefile
6
Makefile
@ -22,7 +22,7 @@ install-base: install-base-pip-packages install-nltk-models
|
||||
install: install-base-pip-packages install-dev install-nltk-models install-test install-huggingface install-all-docs
|
||||
|
||||
.PHONY: install-ci
|
||||
install-ci: install-base-pip-packages install-nltk-models install-huggingface install-all-docs install-test install-pandoc
|
||||
install-ci: install-base-pip-packages install-nltk-models install-huggingface install-all-docs install-test install-pandoc install-paddleocr
|
||||
|
||||
.PHONY: install-base-ci
|
||||
install-base-ci: install-base-pip-packages install-nltk-models install-test install-pandoc
|
||||
@ -80,6 +80,10 @@ install-odt:
|
||||
install-pypandoc:
|
||||
${PYTHON} -m pip install -r requirements/extra-pandoc.txt
|
||||
|
||||
.PHONY: install-paddleocr
|
||||
install-paddleocr:
|
||||
${PYTHON} -m pip install -r requirements/extra-paddleocr.txt
|
||||
|
||||
.PHONY: install-markdown
|
||||
install-markdown:
|
||||
${PYTHON} -m pip install -r requirements/extra-markdown.txt
|
||||
|
@ -1,6 +1,6 @@
|
||||
from collections import namedtuple
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -10,7 +10,7 @@ from bs4 import BeautifulSoup, Tag
|
||||
from pdf2image.exceptions import PDFPageCountError
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
from unstructured_inference.inference.elements import EmbeddedTextRegion, TextRegion, TextRegions
|
||||
from unstructured_inference.inference.layout import DocumentLayout
|
||||
from unstructured_inference.inference.layout import DocumentLayout, PageLayout
|
||||
from unstructured_inference.inference.layoutelement import (
|
||||
LayoutElement,
|
||||
LayoutElements,
|
||||
@ -25,6 +25,8 @@ from unstructured.partition.pdf_image.pdf_image_utils import (
|
||||
)
|
||||
from unstructured.partition.utils.config import env_config
|
||||
from unstructured.partition.utils.constants import (
|
||||
OCR_AGENT_PADDLE,
|
||||
OCR_AGENT_TESSERACT,
|
||||
Source,
|
||||
)
|
||||
from unstructured.partition.utils.ocr_models.google_vision_ocr import OCRAgentGoogleVision
|
||||
@ -66,12 +68,10 @@ def test_process_file_with_ocr_invalid_filename(is_image):
|
||||
)
|
||||
|
||||
|
||||
def test_supplement_page_layout_with_ocr_invalid_ocr(monkeypatch):
|
||||
monkeypatch.setenv("OCR_AGENT", "invalid_ocr")
|
||||
def test_supplement_page_layout_with_ocr_invalid_ocr():
|
||||
with pytest.raises(ValueError):
|
||||
_ = ocr.supplement_page_layout_with_ocr(
|
||||
page_layout=None,
|
||||
image=None,
|
||||
page_layout=None, image=None, ocr_agent="invliad_ocr"
|
||||
)
|
||||
|
||||
|
||||
@ -610,3 +610,53 @@ def test_hocr_to_dataframe_when_no_prediction_empty_df():
|
||||
assert "width" in df.columns
|
||||
assert "text" in df.columns
|
||||
assert "text" in df.columns
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_page(mock_ocr_layout, mock_layout):
|
||||
mock_page = MagicMock(PageLayout)
|
||||
mock_page.elements_array = mock_layout
|
||||
return mock_page
|
||||
|
||||
|
||||
def test_supplement_layout_with_ocr(mocker, mock_page):
|
||||
from unstructured.partition.pdf_image.ocr import OCRAgent
|
||||
|
||||
mocker.patch.object(OCRAgent, "get_layout_from_image", return_value=mock_ocr_layout)
|
||||
spy = mocker.spy(OCRAgent, "get_instance")
|
||||
|
||||
ocr.supplement_page_layout_with_ocr(
|
||||
mock_page,
|
||||
Image.new("RGB", (100, 100)),
|
||||
infer_table_structure=True,
|
||||
ocr_agent=OCR_AGENT_TESSERACT,
|
||||
ocr_languages="eng",
|
||||
table_ocr_agent=OCR_AGENT_PADDLE,
|
||||
)
|
||||
|
||||
assert spy.call_args_list[0][1] == {"language": "eng", "ocr_agent_module": OCR_AGENT_TESSERACT}
|
||||
assert spy.call_args_list[1][1] == {"language": "en", "ocr_agent_module": OCR_AGENT_PADDLE}
|
||||
|
||||
|
||||
def test_pass_down_agents(mocker, mock_page):
|
||||
from unstructured.partition.pdf_image.ocr import OCRAgent, PILImage
|
||||
|
||||
mocker.patch.object(OCRAgent, "get_layout_from_image", return_value=mock_ocr_layout)
|
||||
mocker.patch.object(PILImage, "open", return_value=Image.new("RGB", (100, 100)))
|
||||
spy = mocker.spy(OCRAgent, "get_instance")
|
||||
doc = MagicMock(DocumentLayout)
|
||||
doc.pages = [mock_page]
|
||||
|
||||
ocr.process_file_with_ocr(
|
||||
"foo",
|
||||
doc,
|
||||
[],
|
||||
infer_table_structure=True,
|
||||
is_image=True,
|
||||
ocr_agent=OCR_AGENT_PADDLE,
|
||||
ocr_languages="eng",
|
||||
table_ocr_agent=OCR_AGENT_TESSERACT,
|
||||
)
|
||||
|
||||
assert spy.call_args_list[0][1] == {"language": "en", "ocr_agent_module": OCR_AGENT_PADDLE}
|
||||
assert spy.call_args_list[1][1] == {"language": "eng", "ocr_agent_module": OCR_AGENT_TESSERACT}
|
||||
|
@ -39,6 +39,8 @@ from unstructured.partition import pdf, strategies
|
||||
from unstructured.partition.pdf_image import ocr, pdfminer_processing
|
||||
from unstructured.partition.pdf_image.pdfminer_processing import get_uris_from_annots
|
||||
from unstructured.partition.utils.constants import (
|
||||
OCR_AGENT_PADDLE,
|
||||
OCR_AGENT_TESSERACT,
|
||||
SORT_MODE_BASIC,
|
||||
SORT_MODE_DONT,
|
||||
SORT_MODE_XY_CUT,
|
||||
@ -1585,3 +1587,20 @@ def test_partition_pdf_with_password(
|
||||
file=spooled_temp_file, strategy=strategy, password="password"
|
||||
)
|
||||
_test(result)
|
||||
|
||||
|
||||
def test_partition_pdf_with_specified_ocr_agents(mocker):
|
||||
from unstructured.partition.pdf_image.ocr import OCRAgent
|
||||
|
||||
spy = mocker.spy(OCRAgent, "get_instance")
|
||||
|
||||
pdf.partition_pdf(
|
||||
filename=example_doc_path("pdf/layout-parser-paper-with-table.pdf"),
|
||||
strategy=PartitionStrategy.HI_RES,
|
||||
infer_table_structure=True,
|
||||
ocr_agent=OCR_AGENT_TESSERACT,
|
||||
table_ocr_agent=OCR_AGENT_PADDLE,
|
||||
)
|
||||
|
||||
assert spy.call_args_list[0][1] == {"language": "eng", "ocr_agent_module": OCR_AGENT_TESSERACT}
|
||||
assert spy.call_args_list[1][1] == {"language": "en", "ocr_agent_module": OCR_AGENT_PADDLE}
|
||||
|
@ -1 +1 @@
|
||||
__version__ = "0.16.26-dev2" # pragma: no cover
|
||||
__version__ = "0.16.26-dev3" # pragma: no cover
|
||||
|
@ -54,7 +54,6 @@ from unstructured.partition.common.common import (
|
||||
from unstructured.partition.common.lang import (
|
||||
check_language_args,
|
||||
prepare_languages_for_tesseract,
|
||||
tesseract_to_paddle_language,
|
||||
)
|
||||
from unstructured.partition.common.metadata import get_last_modified_date
|
||||
from unstructured.partition.pdf_image.analysis.layout_dump import (
|
||||
@ -88,7 +87,7 @@ from unstructured.partition.strategies import determine_pdf_or_image_strategy, v
|
||||
from unstructured.partition.text import element_from_text
|
||||
from unstructured.partition.utils.config import env_config
|
||||
from unstructured.partition.utils.constants import (
|
||||
OCR_AGENT_PADDLE,
|
||||
OCR_AGENT_TESSERACT,
|
||||
SORT_MODE_BASIC,
|
||||
SORT_MODE_DONT,
|
||||
SORT_MODE_XY_CUT,
|
||||
@ -273,6 +272,8 @@ def partition_pdf_or_image(
|
||||
pdfminer_char_margin: Optional[float] = None,
|
||||
pdfminer_line_overlap: Optional[float] = None,
|
||||
pdfminer_word_margin: Optional[float] = 0.185,
|
||||
ocr_agent: str = OCR_AGENT_TESSERACT,
|
||||
table_ocr_agent: str = OCR_AGENT_TESSERACT,
|
||||
**kwargs: Any,
|
||||
) -> list[Element]:
|
||||
"""Parses a pdf or image document into a list of interpreted elements."""
|
||||
@ -332,8 +333,6 @@ def partition_pdf_or_image(
|
||||
file.seek(0)
|
||||
|
||||
ocr_languages = prepare_languages_for_tesseract(languages)
|
||||
if env_config.OCR_AGENT == OCR_AGENT_PADDLE:
|
||||
ocr_languages = tesseract_to_paddle_language(ocr_languages)
|
||||
|
||||
if strategy == PartitionStrategy.HI_RES:
|
||||
# NOTE(robinson): Catches a UserWarning that occurs when detection is called
|
||||
@ -359,6 +358,8 @@ def partition_pdf_or_image(
|
||||
form_extraction_skip_tables=form_extraction_skip_tables,
|
||||
password=password,
|
||||
pdfminer_config=pdfminer_config,
|
||||
ocr_agent=ocr_agent,
|
||||
table_ocr_agent=table_ocr_agent,
|
||||
**kwargs,
|
||||
)
|
||||
out_elements = _process_uncategorized_text_elements(elements)
|
||||
@ -609,6 +610,8 @@ def _partition_pdf_or_image_local(
|
||||
pdf_hi_res_max_pages: Optional[int] = None,
|
||||
password: Optional[str] = None,
|
||||
pdfminer_config: Optional[PDFMinerConfig] = None,
|
||||
ocr_agent: str = OCR_AGENT_TESSERACT,
|
||||
table_ocr_agent: str = OCR_AGENT_TESSERACT,
|
||||
**kwargs: Any,
|
||||
) -> list[Element]:
|
||||
"""Partition using package installed locally"""
|
||||
@ -690,11 +693,13 @@ def _partition_pdf_or_image_local(
|
||||
extracted_layout=extracted_layout,
|
||||
is_image=is_image,
|
||||
infer_table_structure=infer_table_structure,
|
||||
ocr_agent=ocr_agent,
|
||||
ocr_languages=ocr_languages,
|
||||
ocr_mode=ocr_mode,
|
||||
pdf_image_dpi=pdf_image_dpi,
|
||||
ocr_layout_dumper=ocr_layout_dumper,
|
||||
password=password,
|
||||
table_ocr_agent=table_ocr_agent,
|
||||
)
|
||||
else:
|
||||
inferred_document_layout = process_data_with_model(
|
||||
@ -749,11 +754,13 @@ def _partition_pdf_or_image_local(
|
||||
extracted_layout=extracted_layout,
|
||||
is_image=is_image,
|
||||
infer_table_structure=infer_table_structure,
|
||||
ocr_agent=ocr_agent,
|
||||
ocr_languages=ocr_languages,
|
||||
ocr_mode=ocr_mode,
|
||||
pdf_image_dpi=pdf_image_dpi,
|
||||
ocr_layout_dumper=ocr_layout_dumper,
|
||||
password=password,
|
||||
table_ocr_agent=table_ocr_agent,
|
||||
)
|
||||
|
||||
# vectorization of the data structure ends here
|
||||
|
@ -14,6 +14,7 @@ from PIL import ImageSequence
|
||||
|
||||
from unstructured.documents.elements import ElementType
|
||||
from unstructured.metrics.table.table_formats import SimpleTableCell
|
||||
from unstructured.partition.common.lang import tesseract_to_paddle_language
|
||||
from unstructured.partition.pdf_image.analysis.layout_dump import OCRLayoutDumper
|
||||
from unstructured.partition.pdf_image.pdf_image_utils import valid_text
|
||||
from unstructured.partition.pdf_image.pdfminer_processing import (
|
||||
@ -21,7 +22,7 @@ from unstructured.partition.pdf_image.pdfminer_processing import (
|
||||
bboxes1_is_almost_subregion_of_bboxes2,
|
||||
)
|
||||
from unstructured.partition.utils.config import env_config
|
||||
from unstructured.partition.utils.constants import OCRMode
|
||||
from unstructured.partition.utils.constants import OCR_AGENT_PADDLE, OCR_AGENT_TESSERACT, OCRMode
|
||||
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
|
||||
from unstructured.utils import requires_dependencies
|
||||
|
||||
@ -38,11 +39,13 @@ def process_data_with_ocr(
|
||||
extracted_layout: List[List["TextRegion"]],
|
||||
is_image: bool = False,
|
||||
infer_table_structure: bool = False,
|
||||
ocr_agent: str = OCR_AGENT_TESSERACT,
|
||||
ocr_languages: str = "eng",
|
||||
ocr_mode: str = OCRMode.FULL_PAGE.value,
|
||||
pdf_image_dpi: int = 200,
|
||||
ocr_layout_dumper: Optional[OCRLayoutDumper] = None,
|
||||
password: Optional[str] = None,
|
||||
table_ocr_agent: str = OCR_AGENT_TESSERACT,
|
||||
) -> "DocumentLayout":
|
||||
"""
|
||||
Process OCR data from a given data and supplement the output DocumentLayout
|
||||
@ -86,11 +89,13 @@ def process_data_with_ocr(
|
||||
extracted_layout=extracted_layout,
|
||||
is_image=is_image,
|
||||
infer_table_structure=infer_table_structure,
|
||||
ocr_agent=ocr_agent,
|
||||
ocr_languages=ocr_languages,
|
||||
ocr_mode=ocr_mode,
|
||||
pdf_image_dpi=pdf_image_dpi,
|
||||
ocr_layout_dumper=ocr_layout_dumper,
|
||||
password=password,
|
||||
table_ocr_agent=table_ocr_agent,
|
||||
)
|
||||
|
||||
return merged_layouts
|
||||
@ -103,11 +108,13 @@ def process_file_with_ocr(
|
||||
extracted_layout: List[TextRegions],
|
||||
is_image: bool = False,
|
||||
infer_table_structure: bool = False,
|
||||
ocr_agent: str = OCR_AGENT_TESSERACT,
|
||||
ocr_languages: str = "eng",
|
||||
ocr_mode: str = OCRMode.FULL_PAGE.value,
|
||||
pdf_image_dpi: int = 200,
|
||||
ocr_layout_dumper: Optional[OCRLayoutDumper] = None,
|
||||
password: Optional[str] = None,
|
||||
table_ocr_agent: str = OCR_AGENT_TESSERACT,
|
||||
) -> "DocumentLayout":
|
||||
"""
|
||||
Process OCR data from a given file and supplement the output DocumentLayout
|
||||
@ -154,10 +161,12 @@ def process_file_with_ocr(
|
||||
page_layout=out_layout.pages[i],
|
||||
image=image,
|
||||
infer_table_structure=infer_table_structure,
|
||||
ocr_agent=ocr_agent,
|
||||
ocr_languages=ocr_languages,
|
||||
ocr_mode=ocr_mode,
|
||||
extracted_regions=extracted_regions,
|
||||
ocr_layout_dumper=ocr_layout_dumper,
|
||||
table_ocr_agent=table_ocr_agent,
|
||||
)
|
||||
merged_page_layouts.append(merged_page_layout)
|
||||
return DocumentLayout.from_pages(merged_page_layouts)
|
||||
@ -178,10 +187,12 @@ def process_file_with_ocr(
|
||||
page_layout=out_layout.pages[i],
|
||||
image=image,
|
||||
infer_table_structure=infer_table_structure,
|
||||
ocr_agent=ocr_agent,
|
||||
ocr_languages=ocr_languages,
|
||||
ocr_mode=ocr_mode,
|
||||
extracted_regions=extracted_regions,
|
||||
ocr_layout_dumper=ocr_layout_dumper,
|
||||
table_ocr_agent=table_ocr_agent,
|
||||
)
|
||||
merged_page_layouts.append(merged_page_layout)
|
||||
return DocumentLayout.from_pages(merged_page_layouts)
|
||||
@ -197,10 +208,12 @@ def supplement_page_layout_with_ocr(
|
||||
page_layout: "PageLayout",
|
||||
image: PILImage.Image,
|
||||
infer_table_structure: bool = False,
|
||||
ocr_agent: str = OCR_AGENT_TESSERACT,
|
||||
ocr_languages: str = "eng",
|
||||
ocr_mode: str = OCRMode.FULL_PAGE.value,
|
||||
extracted_regions: Optional[TextRegions] = None,
|
||||
ocr_layout_dumper: Optional[OCRLayoutDumper] = None,
|
||||
table_ocr_agent: str = OCR_AGENT_TESSERACT,
|
||||
) -> "PageLayout":
|
||||
"""
|
||||
Supplement an PageLayout with OCR results depending on OCR mode.
|
||||
@ -210,9 +223,12 @@ def supplement_page_layout_with_ocr(
|
||||
with no text and add text from OCR to each element.
|
||||
"""
|
||||
|
||||
ocr_agent = OCRAgent.get_agent(language=ocr_languages)
|
||||
language = ocr_languages
|
||||
if ocr_agent == OCR_AGENT_PADDLE:
|
||||
language = tesseract_to_paddle_language(ocr_languages)
|
||||
_ocr_agent = OCRAgent.get_instance(ocr_agent_module=ocr_agent, language=language)
|
||||
if ocr_mode == OCRMode.FULL_PAGE.value:
|
||||
ocr_layout = ocr_agent.get_layout_from_image(image)
|
||||
ocr_layout = _ocr_agent.get_layout_from_image(image)
|
||||
if ocr_layout_dumper:
|
||||
ocr_layout_dumper.add_ocred_page(ocr_layout.as_list())
|
||||
page_layout.elements_array = merge_out_layout_with_ocr_layout(
|
||||
@ -236,7 +252,7 @@ def supplement_page_layout_with_ocr(
|
||||
)
|
||||
# Note(yuming): instead of getting OCR layout, we just need
|
||||
# the text extraced from OCR for individual elements
|
||||
text_from_ocr = ocr_agent.get_text_from_image(cropped_image)
|
||||
text_from_ocr = _ocr_agent.get_text_from_image(cropped_image)
|
||||
page_layout.elements_array.texts[i] = text_from_ocr
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -246,6 +262,12 @@ def supplement_page_layout_with_ocr(
|
||||
|
||||
# Note(yuming): use the OCR data from entire page OCR for table extraction
|
||||
if infer_table_structure:
|
||||
language = ocr_languages
|
||||
if table_ocr_agent == OCR_AGENT_PADDLE:
|
||||
language = tesseract_to_paddle_language(ocr_languages)
|
||||
_table_ocr_agent = OCRAgent.get_instance(
|
||||
ocr_agent_module=table_ocr_agent, language=language
|
||||
)
|
||||
from unstructured_inference.models import tables
|
||||
|
||||
tables.load_agent()
|
||||
@ -256,7 +278,7 @@ def supplement_page_layout_with_ocr(
|
||||
elements=page_layout.elements_array,
|
||||
image=image,
|
||||
tables_agent=tables.tables_agent,
|
||||
ocr_agent=ocr_agent,
|
||||
ocr_agent=_table_ocr_agent,
|
||||
extracted_regions=extracted_regions,
|
||||
)
|
||||
page_layout.elements = page_layout.elements_array.as_list()
|
||||
|
Loading…
x
Reference in New Issue
Block a user