chore: add hi_res_model_name kwarg (#2289)

Closes #2160 

Explicitly adds `hi_res_model_name` as kwarg to relevant functions and
notes that `model_name` is to be deprecated.

Testing:
```
from unstructured.partition.auto import partition
filename = "example-docs/DA-1p.pdf"
elements = partition(filename, strategy="hi_res", hi_res_model_name="yolox")
```

---------

Co-authored-by: cragwolfe <crag@unstructured.io>
Co-authored-by: Steve Canny <stcanny@gmail.com>
Co-authored-by: Christine Straub <christinemstraub@gmail.com>
Co-authored-by: Yao You <yao@unstructured.io>
Co-authored-by: Yao You <theyaoyou@gmail.com>
This commit is contained in:
John 2023-12-22 09:06:54 -06:00 committed by GitHub
parent 093a11d058
commit 5c0043aa7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 109 additions and 10 deletions

View File

@ -24,6 +24,8 @@
### Fixes ### Fixes
* **Enable --fields argument omission for elasticsearch connector** Solves two bugs where removing the optional parameter --fields broke the connector due to an integer processing error and using an elasticsearch config for a destination connector resulted in a serialization issue when optional parameter --fields was not provided. * **Enable --fields argument omission for elasticsearch connector** Solves two bugs where removing the optional parameter --fields broke the connector due to an integer processing error and using an elasticsearch config for a destination connector resulted in a serialization issue when optional parameter --fields was not provided.
* **Add hi_res_model_name** Adds kwarg to relevant functions and add comments that model_name is to be deprecated.
## 0.11.5 ## 0.11.5
### Enhancements ### Enhancements

View File

@ -536,6 +536,18 @@ def test_partition_image_uses_model_name():
assert mockpartition.call_args.kwargs["model_name"] assert mockpartition.call_args.kwargs["model_name"]
def test_partition_image_uses_hi_res_model_name():
with mock.patch.object(
pdf,
"_partition_pdf_or_image_local",
) as mockpartition:
image.partition_image("example-docs/layout-parser-paper-fast.jpg", hi_res_model_name="test")
print(mockpartition.call_args)
assert "model_name" not in mockpartition.call_args.kwargs
assert "hi_res_model_name" in mockpartition.call_args.kwargs
assert mockpartition.call_args.kwargs["hi_res_model_name"] == "test"
@pytest.mark.parametrize( @pytest.mark.parametrize(
("ocr_mode", "idx_title_element"), ("ocr_mode", "idx_title_element"),
[ [

View File

@ -215,6 +215,40 @@ def test_partition_pdf_with_model_name(
assert mock_process.call_args[1]["model_name"] == "checkbox" assert mock_process.call_args[1]["model_name"] == "checkbox"
def test_partition_pdf_with_hi_res_model_name(
monkeypatch,
filename=example_doc_path("layout-parser-paper-fast.pdf"),
):
monkeypatch.setattr(pdf, "extractable_elements", lambda *args, **kwargs: [])
with mock.patch.object(
layout,
"process_file_with_model",
mock.MagicMock(),
) as mock_process:
pdf.partition_pdf(
filename=filename, strategy=PartitionStrategy.HI_RES, hi_res_model_name="checkbox"
)
# unstructured-ingest uses `model_name` instead of `hi_res_model_name`
assert mock_process.call_args[1]["model_name"] == "checkbox"
def test_partition_pdf_or_image_with_hi_res_model_name(
monkeypatch,
filename=example_doc_path("layout-parser-paper-fast.pdf"),
):
monkeypatch.setattr(pdf, "extractable_elements", lambda *args, **kwargs: [])
with mock.patch.object(
layout,
"process_file_with_model",
mock.MagicMock(),
) as mock_process:
pdf.partition_pdf_or_image(
filename=filename, strategy=PartitionStrategy.HI_RES, hi_res_model_name="checkbox"
)
# unstructured-ingest uses `model_name` instead of `hi_res_model_name`
assert mock_process.call_args[1]["model_name"] == "checkbox"
def test_partition_pdf_with_auto_strategy( def test_partition_pdf_with_auto_strategy(
filename=example_doc_path("layout-parser-paper-fast.pdf"), filename=example_doc_path("layout-parser-paper-fast.pdf"),
): ):
@ -798,6 +832,22 @@ def test_partition_pdf_uses_model_name():
assert mockpartition.call_args.kwargs["model_name"] assert mockpartition.call_args.kwargs["model_name"]
def test_partition_pdf_uses_hi_res_model_name():
with mock.patch.object(
pdf,
"_partition_pdf_or_image_local",
) as mockpartition:
pdf.partition_pdf(
example_doc_path("layout-parser-paper-fast.pdf"),
hi_res_model_name="test",
strategy=PartitionStrategy.HI_RES,
)
mockpartition.assert_called_once()
assert "hi_res_model_name" in mockpartition.call_args.kwargs
assert mockpartition.call_args.kwargs["hi_res_model_name"]
def test_partition_pdf_word_bbox_not_char( def test_partition_pdf_word_bbox_not_char(
filename=example_doc_path("interface-config-guide-p93.pdf"), filename=example_doc_path("interface-config-guide-p93.pdf"),
): ):
@ -863,6 +913,18 @@ def test_partition_model_name_default_to_None():
pytest.fail("partition_pdf() raised AttributeError unexpectedly!") pytest.fail("partition_pdf() raised AttributeError unexpectedly!")
def test_partition_hi_res_model_name_default_to_None():
filename = example_doc_path("DA-1p.pdf")
try:
pdf.partition_pdf(
filename=filename,
strategy=PartitionStrategy.HI_RES,
hi_res_model_name=None,
)
except AttributeError:
pytest.fail("partition_pdf() raised AttributeError unexpectedly!")
@pytest.mark.parametrize( @pytest.mark.parametrize(
("strategy", "ocr_func"), ("strategy", "ocr_func"),
[ [

View File

@ -356,6 +356,7 @@ def test_auto_partition_pdf_with_fast_strategy(monkeypatch):
image_output_dir_path=ANY, image_output_dir_path=ANY,
strategy=PartitionStrategy.FAST, strategy=PartitionStrategy.FAST,
languages=None, languages=None,
hi_res_model_name=None,
) )

View File

@ -142,6 +142,8 @@ def partition(
data_source_metadata: Optional[DataSourceMetadata] = None, data_source_metadata: Optional[DataSourceMetadata] = None,
metadata_filename: Optional[str] = None, metadata_filename: Optional[str] = None,
request_timeout: Optional[int] = None, request_timeout: Optional[int] = None,
hi_res_model_name: Optional[str] = None,
model_name: Optional[str] = None, # to be deprecated
**kwargs, **kwargs,
): ):
"""Partitions a document into its constituent elements. Will use libmagic to determine """Partitions a document into its constituent elements. Will use libmagic to determine
@ -202,6 +204,11 @@ def partition(
request_timeout request_timeout
The timeout for the HTTP request if URL is set. Defaults to None meaning no timeout and The timeout for the HTTP request if URL is set. Defaults to None meaning no timeout and
requests will block indefinitely. requests will block indefinitely.
hi_res_model_name
The layout detection model used when partitioning strategy is set to `hi_res`.
model_name
The layout detection model used when partitioning strategy is set to `hi_res`. To be
deprecated in favor of `hi_res_model_name`.
""" """
exactly_one(file=file, filename=filename, url=url) exactly_one(file=file, filename=filename, url=url)
@ -391,6 +398,7 @@ def partition(
languages=languages, languages=languages,
extract_images_in_pdf=pdf_extract_images, extract_images_in_pdf=pdf_extract_images,
image_output_dir_path=pdf_image_output_dir_path, image_output_dir_path=pdf_image_output_dir_path,
hi_res_model_name=hi_res_model_name or model_name,
**kwargs, **kwargs,
) )
elif (filetype == FileType.PNG) or (filetype == FileType.JPG) or (filetype == FileType.TIFF): elif (filetype == FileType.PNG) or (filetype == FileType.JPG) or (filetype == FileType.TIFF):
@ -402,6 +410,7 @@ def partition(
infer_table_structure=infer_table_structure, infer_table_structure=infer_table_structure,
strategy=strategy, strategy=strategy,
languages=languages, languages=languages,
hi_res_model_name=hi_res_model_name or model_name,
**kwargs, **kwargs,
) )
elif filetype == FileType.TXT: elif filetype == FileType.TXT:

View File

@ -25,6 +25,7 @@ def partition_image(
strategy: str = PartitionStrategy.HI_RES, strategy: str = PartitionStrategy.HI_RES,
metadata_last_modified: Optional[str] = None, metadata_last_modified: Optional[str] = None,
chunking_strategy: Optional[str] = None, chunking_strategy: Optional[str] = None,
hi_res_model_name: Optional[str] = None,
**kwargs, **kwargs,
) -> List[Element]: ) -> List[Element]:
"""Parses an image into a list of interpreted elements. """Parses an image into a list of interpreted elements.
@ -55,6 +56,8 @@ def partition_image(
The default strategy is `hi_res`. The default strategy is `hi_res`.
metadata_last_modified metadata_last_modified
The last modified date for the document. The last modified date for the document.
hi_res_model_name
The layout detection model used when partitioning strategy is set to `hi_res`.
""" """
exactly_one(filename=filename, file=file) exactly_one(filename=filename, file=file)
@ -89,5 +92,6 @@ def partition_image(
languages=languages, languages=languages,
strategy=strategy, strategy=strategy,
metadata_last_modified=metadata_last_modified, metadata_last_modified=metadata_last_modified,
hi_res_model_name=hi_res_model_name,
**kwargs, **kwargs,
) )

View File

@ -143,6 +143,7 @@ def partition_pdf(
extract_images_in_pdf: bool = False, extract_images_in_pdf: bool = False,
extract_element_types: Optional[List[str]] = None, extract_element_types: Optional[List[str]] = None,
image_output_dir_path: Optional[str] = None, image_output_dir_path: Optional[str] = None,
hi_res_model_name: Optional[str] = None,
**kwargs, **kwargs,
) -> List[Element]: ) -> List[Element]:
"""Parses a pdf document into a list of interpreted elements. """Parses a pdf document into a list of interpreted elements.
@ -182,6 +183,8 @@ def partition_pdf(
image_output_dir_path image_output_dir_path
Only applicable if `strategy=hi_res`. Only applicable if `strategy=hi_res`.
The path for saving images when using `extract_images_in_pdf` or `extract_element_types`. The path for saving images when using `extract_images_in_pdf` or `extract_element_types`.
hi_res_model_name
The layout detection model used when partitioning strategy is set to `hi_res`.
""" """
exactly_one(filename=filename, file=file) exactly_one(filename=filename, file=file)
@ -199,6 +202,7 @@ def partition_pdf(
extract_images_in_pdf=extract_images_in_pdf, extract_images_in_pdf=extract_images_in_pdf,
extract_element_types=extract_element_types, extract_element_types=extract_element_types,
image_output_dir_path=image_output_dir_path, image_output_dir_path=image_output_dir_path,
hi_res_model_name=hi_res_model_name,
**kwargs, **kwargs,
) )
@ -244,13 +248,14 @@ def _partition_pdf_or_image_local(
include_page_breaks: bool = False, include_page_breaks: bool = False,
languages: Optional[List[str]] = None, languages: Optional[List[str]] = None,
ocr_mode: str = OCRMode.FULL_PAGE.value, ocr_mode: str = OCRMode.FULL_PAGE.value,
model_name: Optional[str] = None, model_name: Optional[str] = None, # to be deprecated in favor of `hi_res_model_name`
metadata_last_modified: Optional[str] = None, metadata_last_modified: Optional[str] = None,
pdf_text_extractable: bool = False, pdf_text_extractable: bool = False,
extract_images_in_pdf: bool = False, extract_images_in_pdf: bool = False,
extract_element_types: Optional[List[str]] = None, extract_element_types: Optional[List[str]] = None,
image_output_dir_path: Optional[str] = None, image_output_dir_path: Optional[str] = None,
pdf_image_dpi: Optional[int] = None, pdf_image_dpi: Optional[int] = None,
hi_res_model_name: Optional[str] = None,
analysis: bool = False, analysis: bool = False,
analyzed_image_output_dir_path: Optional[str] = None, analyzed_image_output_dir_path: Optional[str] = None,
**kwargs, **kwargs,
@ -275,10 +280,12 @@ def _partition_pdf_or_image_local(
ocr_languages = prepare_languages_for_tesseract(languages) ocr_languages = prepare_languages_for_tesseract(languages)
model_name = model_name or default_hi_res_model(infer_table_structure) hi_res_model_name = (
hi_res_model_name or model_name or default_hi_res_model(infer_table_structure)
)
if pdf_image_dpi is None: if pdf_image_dpi is None:
pdf_image_dpi = 300 if model_name == "chipper" else 200 pdf_image_dpi = 300 if hi_res_model_name == "chipper" else 200
if (pdf_image_dpi < 300) and (model_name == "chipper"): if (pdf_image_dpi < 300) and (hi_res_model_name == "chipper"):
logger.warning( logger.warning(
"The Chipper model performs better when images are rendered with DPI >= 300 " "The Chipper model performs better when images are rendered with DPI >= 300 "
f"(currently {pdf_image_dpi}).", f"(currently {pdf_image_dpi}).",
@ -288,7 +295,7 @@ def _partition_pdf_or_image_local(
inferred_document_layout = process_file_with_model( inferred_document_layout = process_file_with_model(
filename, filename,
is_image=is_image, is_image=is_image,
model_name=model_name, model_name=hi_res_model_name,
pdf_image_dpi=pdf_image_dpi, pdf_image_dpi=pdf_image_dpi,
) )
@ -314,7 +321,7 @@ def _partition_pdf_or_image_local(
extracted_layout=extracted_layout, extracted_layout=extracted_layout,
) )
if model_name.startswith("chipper"): if hi_res_model_name.startswith("chipper"):
# NOTE(alan): We shouldn't do OCR with chipper # NOTE(alan): We shouldn't do OCR with chipper
final_document_layout = merged_document_layout final_document_layout = merged_document_layout
else: else:
@ -331,7 +338,7 @@ def _partition_pdf_or_image_local(
inferred_document_layout = process_data_with_model( inferred_document_layout = process_data_with_model(
file, file,
is_image=is_image, is_image=is_image,
model_name=model_name, model_name=hi_res_model_name,
pdf_image_dpi=pdf_image_dpi, pdf_image_dpi=pdf_image_dpi,
) )
if hasattr(file, "seek"): if hasattr(file, "seek"):
@ -347,7 +354,7 @@ def _partition_pdf_or_image_local(
extracted_layout=extracted_layout, extracted_layout=extracted_layout,
) )
if model_name.startswith("chipper"): if hi_res_model_name.startswith("chipper"):
# NOTE(alan): We shouldn't do OCR with chipper # NOTE(alan): We shouldn't do OCR with chipper
final_document_layout = merged_document_layout final_document_layout = merged_document_layout
else: else:
@ -364,7 +371,7 @@ def _partition_pdf_or_image_local(
) )
# NOTE(alan): starting with v2, chipper sorts the elements itself. # NOTE(alan): starting with v2, chipper sorts the elements itself.
if model_name == "chipper": if hi_res_model_name == "chipper":
kwargs["sort_mode"] = SORT_MODE_DONT kwargs["sort_mode"] = SORT_MODE_DONT
final_document_layout = clean_pdfminer_inner_elements(final_document_layout) final_document_layout = clean_pdfminer_inner_elements(final_document_layout)
@ -434,7 +441,7 @@ def _partition_pdf_or_image_local(
).strip() ).strip()
# NOTE(alan): with chipper there are parent elements with no text we don't want to # NOTE(alan): with chipper there are parent elements with no text we don't want to
# filter those out and leave the children orphaned. # filter those out and leave the children orphaned.
if el.text or isinstance(el, PageBreak) or model_name.startswith("chipper"): if el.text or isinstance(el, PageBreak) or hi_res_model_name.startswith("chipper"):
out_elements.append(cast(Element, el)) out_elements.append(cast(Element, el))
return out_elements return out_elements
@ -453,6 +460,7 @@ def partition_pdf_or_image(
extract_images_in_pdf: bool = False, extract_images_in_pdf: bool = False,
extract_element_types: Optional[List[str]] = None, extract_element_types: Optional[List[str]] = None,
image_output_dir_path: Optional[str] = None, image_output_dir_path: Optional[str] = None,
hi_res_model_name: Optional[str] = None,
**kwargs, **kwargs,
) -> List[Element]: ) -> List[Element]:
"""Parses a pdf or image document into a list of interpreted elements.""" """Parses a pdf or image document into a list of interpreted elements."""
@ -514,6 +522,7 @@ def partition_pdf_or_image(
extract_images_in_pdf=extract_images_in_pdf, extract_images_in_pdf=extract_images_in_pdf,
extract_element_types=extract_element_types, extract_element_types=extract_element_types,
image_output_dir_path=image_output_dir_path, image_output_dir_path=image_output_dir_path,
hi_res_model_name=hi_res_model_name,
**kwargs, **kwargs,
) )
out_elements = _process_uncategorized_text_elements(elements) out_elements = _process_uncategorized_text_elements(elements)