feat: element types extension (#2700)

This PR adds some new element types that can be used especially by
pdf/image parition.
This commit is contained in:
Pawel Kmiecik 2024-04-04 09:49:55 +02:00 committed by GitHub
parent 1ce60f2bba
commit 63fc2a1061
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 89 additions and 42 deletions

View File

@ -4,6 +4,8 @@
### Features ### Features
* **Add a set of new `ElementType`s to extend future element types**
### Fixes ### Fixes
* **Fix `partition_html()` swallowing some paragraphs**. The `partition_html()` only considers elements with limited depth to avoid becoming the text representation of a giant div. This fix increases the limit value. * **Fix `partition_html()` swallowing some paragraphs**. The `partition_html()` only considers elements with limited depth to avoid becoming the text representation of a giant div. This fix increases the limit value.

View File

@ -18,6 +18,7 @@ from unstructured.documents.coordinates import (
) )
from unstructured.documents.elements import ( from unstructured.documents.elements import (
UUID, UUID,
CheckBox,
ConsolidationStrategy, ConsolidationStrategy,
CoordinatesMetadata, CoordinatesMetadata,
DataSourceMetadata, DataSourceMetadata,
@ -72,6 +73,14 @@ def test_text_element_apply_multiple_cleaners():
assert str(text_element) == "A Textbook on Crocodile Habitats" assert str(text_element) == "A Textbook on Crocodile Habitats"
def test_non_text_elements_are_serializable_to_text():
element = CheckBox()
assert hasattr(element, "text")
assert element.text is not None
assert element.text == ""
assert str(element) == ""
def test_apply_raises_if_func_does_not_produce_string(): def test_apply_raises_if_func_does_not_produce_string():
def bad_cleaner(s: str): def bad_cleaner(s: str):
return 1 return 1

View File

@ -14,6 +14,7 @@ from unstructured_inference.inference.layoutelement import LayoutElement
from unstructured.documents.coordinates import PixelSpace from unstructured.documents.coordinates import PixelSpace
from unstructured.documents.elements import ( from unstructured.documents.elements import (
TYPE_TO_TEXT_ELEMENT_MAP,
CheckBox, CheckBox,
CoordinatesMetadata, CoordinatesMetadata,
ElementMetadata, ElementMetadata,
@ -207,30 +208,48 @@ def test_normalize_layout_element_layout_element_narrative_text():
) )
def test_normalize_layout_element_checked_box(): @pytest.mark.parametrize(
("element_type", "expected_element_class"),
TYPE_TO_TEXT_ELEMENT_MAP.items(),
)
def test_normalize_layout_element_layout_element_maps_to_appropriate_text_element(
element_type: str,
expected_element_class: type[Text],
):
layout_element = LayoutElement.from_coords( layout_element = LayoutElement.from_coords(
type="Checked", type=element_type,
x1=1, x1=1,
y1=2, y1=2,
x2=3, x2=3,
y2=4, y2=4,
text="", text="Some lovely text",
) )
coordinate_system = PixelSpace(width=10, height=20) coordinate_system = PixelSpace(width=10, height=20)
element = common.normalize_layout_element( element = common.normalize_layout_element(
layout_element, layout_element,
coordinate_system=coordinate_system, coordinate_system=coordinate_system,
) )
assert element == CheckBox( assert element == expected_element_class(
checked=True, text="Some lovely text",
coordinates=((1, 2), (1, 4), (3, 4), (3, 2)), coordinates=((1, 2), (1, 4), (3, 4), (3, 2)),
coordinate_system=coordinate_system, coordinate_system=coordinate_system,
) )
def test_normalize_layout_element_unchecked_box(): @pytest.mark.parametrize(
("element_type", "expected_checked"),
[
(ElementType.CHECK_BOX_UNCHECKED, False),
(ElementType.CHECK_BOX_CHECKED, True),
(ElementType.RADIO_BUTTON_UNCHECKED, False),
(ElementType.RADIO_BUTTON_CHECKED, True),
(ElementType.CHECKED, True),
(ElementType.UNCHECKED, False),
],
)
def test_normalize_layout_element_checkable(element_type: str, expected_checked: bool):
layout_element = LayoutElement.from_coords( layout_element = LayoutElement.from_coords(
type="Unchecked", type=element_type,
x1=1, x1=1,
y1=2, y1=2,
x2=3, x2=3,
@ -242,8 +261,9 @@ def test_normalize_layout_element_unchecked_box():
layout_element, layout_element,
coordinate_system=coordinate_system, coordinate_system=coordinate_system,
) )
assert isinstance(element, CheckBox)
assert element == CheckBox( assert element == CheckBox(
checked=False, checked=expected_checked,
coordinates=((1, 2), (1, 4), (3, 4), (3, 2)), coordinates=((1, 2), (1, 4), (3, 4), (3, 2)),
coordinate_system=coordinate_system, coordinate_system=coordinate_system,
) )

View File

@ -597,6 +597,7 @@ class ElementType:
UNCATEGORIZED_TEXT = "UncategorizedText" UNCATEGORIZED_TEXT = "UncategorizedText"
NARRATIVE_TEXT = "NarrativeText" NARRATIVE_TEXT = "NarrativeText"
BULLETED_TEXT = "BulletedText" BULLETED_TEXT = "BulletedText"
PARAGRAPH = "Paragraph"
ABSTRACT = "Abstract" ABSTRACT = "Abstract"
THREADING = "Threading" THREADING = "Threading"
FORM = "Form" FORM = "Form"
@ -614,6 +615,10 @@ class ElementType:
LIST_ITEM_OTHER = "List-item" LIST_ITEM_OTHER = "List-item"
CHECKED = "Checked" CHECKED = "Checked"
UNCHECKED = "Unchecked" UNCHECKED = "Unchecked"
CHECK_BOX_CHECKED = "CheckBoxChecked"
CHECK_BOX_UNCHECKED = "CheckBoxUnchecked"
RADIO_BUTTON_CHECKED = "RadioButtonChecked"
RADIO_BUTTON_UNCHECKED = "RadioButtonUnchecked"
ADDRESS = "Address" ADDRESS = "Address"
EMAIL_ADDRESS = "EmailAddress" EMAIL_ADDRESS = "EmailAddress"
PAGE_BREAK = "PageBreak" PAGE_BREAK = "PageBreak"
@ -627,6 +632,8 @@ class ElementType:
FOOTER = "Footer" FOOTER = "Footer"
FOOTNOTE = "Footnote" FOOTNOTE = "Footnote"
PAGE_FOOTER = "Page-footer" PAGE_FOOTER = "Page-footer"
PAGE_NUMBER = "PageNumber"
CODE_SNIPPET = "CodeSnippet"
@classmethod @classmethod
def to_dict(cls): def to_dict(cls):
@ -707,6 +714,9 @@ class Element(abc.ABC):
return new_coordinates return new_coordinates
def __str__(self):
return self.text
class CheckBox(Element): class CheckBox(Element):
"""A checkbox with an attribute indicating whether its checked or not. """A checkbox with an attribute indicating whether its checked or not.
@ -798,9 +808,6 @@ class Text(Element):
), ),
) )
def __str__(self):
return self.text
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""Serialize to JSON-compatible (str keys) dict.""" """Serialize to JSON-compatible (str keys) dict."""
out = super().to_dict() out = super().to_dict()
@ -912,6 +919,18 @@ class Footer(Text):
category = "Footer" category = "Footer"
class CodeSnippet(Text):
"""An element for capturing code snippets."""
category = "CodeSnippet"
class PageNumber(Text):
"""An element for capturing page numbers."""
category = "PageNumber"
TYPE_TO_TEXT_ELEMENT_MAP: dict[str, type[Text]] = { TYPE_TO_TEXT_ELEMENT_MAP: dict[str, type[Text]] = {
ElementType.TITLE: Title, ElementType.TITLE: Title,
ElementType.SECTION_HEADER: Title, ElementType.SECTION_HEADER: Title,
@ -922,6 +941,7 @@ TYPE_TO_TEXT_ELEMENT_MAP: dict[str, type[Text]] = {
ElementType.COMPOSITE_ELEMENT: Text, ElementType.COMPOSITE_ELEMENT: Text,
ElementType.TEXT: NarrativeText, ElementType.TEXT: NarrativeText,
ElementType.NARRATIVE_TEXT: NarrativeText, ElementType.NARRATIVE_TEXT: NarrativeText,
ElementType.PARAGRAPH: NarrativeText,
# this mapping favors ensures yolox produces backward compatible categories # this mapping favors ensures yolox produces backward compatible categories
ElementType.ABSTRACT: NarrativeText, ElementType.ABSTRACT: NarrativeText,
ElementType.THREADING: NarrativeText, ElementType.THREADING: NarrativeText,
@ -946,4 +966,6 @@ TYPE_TO_TEXT_ELEMENT_MAP: dict[str, type[Text]] = {
ElementType.EMAIL_ADDRESS: EmailAddress, ElementType.EMAIL_ADDRESS: EmailAddress,
ElementType.FORMULA: Formula, ElementType.FORMULA: Formula,
ElementType.PAGE_BREAK: PageBreak, ElementType.PAGE_BREAK: PageBreak,
ElementType.CODE_SNIPPET: CodeSnippet,
ElementType.PAGE_NUMBER: PageNumber,
} }

View File

@ -39,7 +39,6 @@ if TYPE_CHECKING:
from unstructured_inference.inference.layout import DocumentLayout, PageLayout from unstructured_inference.inference.layout import DocumentLayout, PageLayout
from unstructured_inference.inference.layoutelement import LayoutElement from unstructured_inference.inference.layoutelement import LayoutElement
HIERARCHY_RULE_SET = { HIERARCHY_RULE_SET = {
"Title": [ "Title": [
"Text", "Text",
@ -132,22 +131,22 @@ def normalize_layout_element(
class_prob_metadata = ElementMetadata(detection_class_prob=float(prob)) # type: ignore class_prob_metadata = ElementMetadata(detection_class_prob=float(prob)) # type: ignore
else: else:
class_prob_metadata = ElementMetadata() class_prob_metadata = ElementMetadata()
common_kwargs = {
"coordinates": coordinates,
"coordinate_system": coordinate_system,
"metadata": class_prob_metadata,
"detection_origin": origin,
}
if element_type == ElementType.LIST: if element_type == ElementType.LIST:
if infer_list_items: if infer_list_items:
return layout_list_to_list_items( return layout_list_to_list_items(
text, text,
coordinates=coordinates, **common_kwargs,
coordinate_system=coordinate_system,
metadata=class_prob_metadata,
detection_origin=origin,
) )
else: else:
return ListItem( return ListItem(
text=text, text=text,
coordinates=coordinates, **common_kwargs,
coordinate_system=coordinate_system,
metadata=class_prob_metadata,
detection_origin=origin,
) )
elif element_type in TYPE_TO_TEXT_ELEMENT_MAP: elif element_type in TYPE_TO_TEXT_ELEMENT_MAP:
@ -155,39 +154,34 @@ def normalize_layout_element(
_element_class = TYPE_TO_TEXT_ELEMENT_MAP[element_type] _element_class = TYPE_TO_TEXT_ELEMENT_MAP[element_type]
_element_class = _element_class( _element_class = _element_class(
text=text, text=text,
coordinates=coordinates, **common_kwargs,
coordinate_system=coordinate_system,
metadata=class_prob_metadata,
detection_origin=origin,
) )
if element_type == ElementType.HEADLINE: if element_type == ElementType.HEADLINE:
_element_class.metadata.category_depth = 1 _element_class.metadata.category_depth = 1
elif element_type == ElementType.SUB_HEADLINE: elif element_type == ElementType.SUB_HEADLINE:
_element_class.metadata.category_depth = 2 _element_class.metadata.category_depth = 2
return _element_class return _element_class
elif element_type == ElementType.CHECKED: elif element_type in [
ElementType.CHECK_BOX_CHECKED,
ElementType.CHECK_BOX_UNCHECKED,
ElementType.RADIO_BUTTON_CHECKED,
ElementType.RADIO_BUTTON_UNCHECKED,
ElementType.CHECKED,
ElementType.UNCHECKED,
]:
checked = element_type in [
ElementType.CHECK_BOX_CHECKED,
ElementType.RADIO_BUTTON_CHECKED,
ElementType.CHECKED,
]
return CheckBox( return CheckBox(
checked=True, checked=checked,
coordinates=coordinates, **common_kwargs,
coordinate_system=coordinate_system,
metadata=class_prob_metadata,
detection_origin=origin,
)
elif element_type == ElementType.UNCHECKED:
return CheckBox(
checked=False,
coordinates=coordinates,
coordinate_system=coordinate_system,
metadata=class_prob_metadata,
detection_origin=origin,
) )
else: else:
return Text( return Text(
text=text, text=text,
coordinates=coordinates, **common_kwargs,
coordinate_system=coordinate_system,
metadata=class_prob_metadata,
detection_origin=origin,
) )