rfctr: prepare to add orig_elements serde (#2668)

**Summary**
The serialization and deserialization (serde) of
`metadata.orig_elements` will be located in `unstructured.staging.base`
alongside `elements_to_json()` and other existing serde functions.
Improve the typing, readability, and structure of that module before
adding the new serde functions for `metadata.orig_elements`.

**Reviewers:** The commits are well-groomed and are probably quicker to
review commit-by-commit than as all files-changed at once.
This commit is contained in:
Steve Canny 2024-03-20 14:27:59 -07:00 committed by GitHub
parent 6abfb8b2b3
commit 31bef433ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 278 additions and 316 deletions

View File

@ -1,4 +1,4 @@
## 0.12.7-dev7
## 0.12.7-dev8
### Enhancements

View File

View File

View File

@ -10,8 +10,7 @@ from functools import partial
import pytest
from unstructured.cleaners.core import clean_prefix
from unstructured.cleaners.translate import translate_text
from unstructured.cleaners.core import clean_bullets, clean_prefix
from unstructured.documents.coordinates import (
CoordinateSystem,
Orientation,
@ -66,13 +65,10 @@ def test_text_element_apply_cleaners():
def test_text_element_apply_multiple_cleaners():
cleaners = [
partial(clean_prefix, pattern=r"\[\d{1,2}\]"),
partial(translate_text, target_lang="ru"),
]
text_element = Text(text="[1] A Textbook on Crocodile Habitats")
cleaners = [partial(clean_prefix, pattern=r"\[\d{1,2}\]"), partial(clean_bullets)]
text_element = Text(text="[1] \u2022 A Textbook on Crocodile Habitats")
text_element.apply(*cleaners)
assert str(text_element) == "Учебник по крокодильным средам обитания"
assert str(text_element) == "A Textbook on Crocodile Habitats"
def test_apply_raises_if_func_does_not_produce_string():
@ -82,7 +78,7 @@ def test_apply_raises_if_func_does_not_produce_string():
text_element = Text(text="[1] A Textbook on Crocodile Habitats")
with pytest.raises(ValueError, match="Cleaner produced a non-string output."):
text_element.apply(bad_cleaner) # pyright: ignore[reportGeneralTypeIssues]
text_element.apply(bad_cleaner) # pyright: ignore[reportArgumentType]
@pytest.mark.parametrize(
@ -241,7 +237,7 @@ class DescribeElementMetadata:
def it_detects_unknown_constructor_args_at_both_development_time_and_runtime(self):
with pytest.raises(TypeError, match="got an unexpected keyword argument 'file_name'"):
ElementMetadata(file_name="memo.docx") # pyright: ignore[reportGeneralTypeIssues]
ElementMetadata(file_name="memo.docx") # pyright: ignore[reportCallIssue]
@pytest.mark.parametrize(
"file_path",
@ -289,9 +285,9 @@ class DescribeElementMetadata:
def it_knows_the_types_of_its_known_members_so_type_checking_support_is_available(self):
ElementMetadata(
category_depth="2", # pyright: ignore[reportGeneralTypeIssues]
file_directory=True, # pyright: ignore[reportGeneralTypeIssues]
text_as_html=42, # pyright: ignore[reportGeneralTypeIssues]
category_depth="2", # pyright: ignore[reportArgumentType]
file_directory=True, # pyright: ignore[reportArgumentType]
text_as_html=42, # pyright: ignore[reportArgumentType]
)
# -- it does not check types at runtime however (choosing to avoid validation overhead) --
@ -526,7 +522,7 @@ class DescribeElementMetadata:
def but_it_raises_on_attempt_to_update_from_a_non_ElementMetadata_object(self):
meta = ElementMetadata()
with pytest.raises(ValueError, match=r"ate\(\)' must be an instance of 'ElementMetadata'"):
meta.update({"coefficient": "0.56"}) # pyright: ignore[reportGeneralTypeIssues]
meta.update({"coefficient": "0.56"}) # pyright: ignore[reportArgumentType]
# -- It knows when it is equal to another instance -------------------------------------------

View File

View File

View File

View File

View File

@ -31,14 +31,9 @@ from unstructured.partition.text import partition_text
from unstructured.staging import base
@pytest.fixture()
def output_csv_file(tmp_path):
return os.path.join(tmp_path, "isd_data.csv")
def test_convert_to_isd():
def test_elements_to_dicts():
elements = [Title(text="Title 1"), NarrativeText(text="Narrative 1")]
isd = base.convert_to_isd(elements)
isd = base.elements_to_dicts(elements)
assert isd[0]["text"] == "Title 1"
assert isd[0]["type"] == ElementType.TITLE
@ -47,8 +42,8 @@ def test_convert_to_isd():
assert isd[1]["type"] == "NarrativeText"
def test_isd_to_elements():
isd = [
def test_elements_from_dicts():
element_dicts = [
{"text": "Blurb1", "type": "NarrativeText"},
{"text": "Blurb2", "type": "Title"},
{"text": "Blurb3", "type": "ListItem"},
@ -56,7 +51,7 @@ def test_isd_to_elements():
{"text": "No Type"},
]
elements = base.isd_to_elements(isd)
elements = base.elements_from_dicts(element_dicts)
assert elements == [
NarrativeText(text="Blurb1"),
Title(text="Blurb2"),
@ -65,13 +60,14 @@ def test_isd_to_elements():
]
def test_convert_to_csv(output_csv_file):
def test_convert_to_csv(tmp_path: str):
output_csv_path = os.path.join(tmp_path, "isd_data.csv")
elements = [Title(text="Title 1"), NarrativeText(text="Narrative 1")]
with open(output_csv_file, "w+") as csv_file:
with open(output_csv_path, "w+") as csv_file:
isd_csv_string = base.convert_to_csv(elements)
csv_file.write(isd_csv_string)
with open(output_csv_file) as csv_file:
with open(output_csv_path) as csv_file:
csv_rows = csv.DictReader(csv_file)
assert all(set(row.keys()) == set(base.TABLE_FIELDNAMES) for row in csv_rows)
@ -85,15 +81,13 @@ def test_convert_to_dataframe():
"text": ["Title 1", "Narrative 1"],
},
)
assert df.type.equals(expected_df.type) is True
assert df.text.equals(expected_df.text) is True
assert df.type.equals(expected_df.type) is True # type: ignore
assert df.text.equals(expected_df.text) is True # type: ignore
def test_convert_to_dataframe_maintains_fields(
filename="example-docs/eml/fake-email-attachment.eml",
):
def test_convert_to_dataframe_maintains_fields():
elements = partition_email(
filename=filename,
"example-docs/eml/fake-email-attachment.eml",
process_attachements=True,
regex_metadata={"hello": r"Hello", "punc": r"[!]"},
)
@ -109,10 +103,7 @@ def test_convert_to_dataframe_maintains_fields(
def test_default_pandas_dtypes():
"""
Make sure that all the values that can exist on an element have a corresponding dtype
mapped in the dict returned by get_default_pandas_dtypes()
"""
"""Ensure all element fields have a dtype in dict returned by get_default_pandas_dtypes()."""
full_element = Text(
text="some text",
element_id="123",
@ -165,8 +156,7 @@ def test_default_pandas_dtypes():
element_as_dict = full_element.to_dict()
element_as_dict.update(
base.flatten_dict(
element_as_dict.pop("metadata"),
keys_to_omit=["data_source_record_locator"],
element_as_dict.pop("metadata"), keys_to_omit=["data_source_record_locator"]
),
)
flattened_element_keys = element_as_dict.keys()
@ -180,13 +170,13 @@ def test_default_pandas_dtypes():
platform.system() == "Windows",
reason="Posix Paths are not available on Windows",
)
def test_convert_to_isd_serializes_with_posix_paths():
def test_elements_to_dicts_serializes_with_posix_paths():
metadata = ElementMetadata(filename=pathlib.PosixPath("../../fake-file.txt"))
elements = [
Title(text="Title 1", metadata=metadata),
NarrativeText(text="Narrative 1", metadata=metadata),
]
output = base.convert_to_isd(elements)
output = base.elements_to_dicts(elements)
# NOTE(robinson) - json.dumps should run without raising an exception
json.dumps(output)
@ -205,11 +195,11 @@ def test_all_elements_preserved_when_serialized():
PageBreak(text=""),
]
isd = base.convert_to_isd(elements)
assert base.convert_to_isd(base.isd_to_elements(isd)) == isd
element_dicts = base.elements_to_dicts(elements)
assert base.elements_to_dicts(base.elements_from_dicts(element_dicts)) == element_dicts
def test_serialized_deserialize_elements_to_json(tmpdir):
def test_serialized_deserialize_elements_to_json(tmpdir: str):
filename = os.path.join(tmpdir, "fake-elements.json")
metadata = ElementMetadata(filename="fake-file.txt")
elements = [
@ -229,63 +219,38 @@ def test_serialized_deserialize_elements_to_json(tmpdir):
assert elements == new_elements_filename
elements_str = base.elements_to_json(elements)
assert elements_str is not None
new_elements_text = base.elements_from_json(text=elements_str)
assert elements == new_elements_text
def test_read_and_write_json_with_encoding(
filename="example-docs/fake-text-utf-16-be.txt",
):
elements = partition_text(filename=filename)
def test_read_and_write_json_with_encoding():
elements = partition_text("example-docs/fake-text-utf-16-be.txt")
with NamedTemporaryFile() as tempfile:
base.elements_to_json(elements, filename=tempfile.name, encoding="utf-16")
new_elements_filename = base.elements_from_json(
filename=tempfile.name,
encoding="utf-16",
)
new_elements_filename = base.elements_from_json(filename=tempfile.name, encoding="utf-16")
assert elements == new_elements_filename
def test_filter_element_types_with_include_element_type(
filename="example-docs/fake-text.txt",
):
def test_filter_element_types_with_include_element_type():
element_types = [Title]
elements = partition_text(
filename=filename,
include_metadata=False,
)
elements = base.filter_element_types(
elements=elements,
include_element_types=element_types,
)
elements = partition_text("example-docs/fake-text.txt", include_metadata=False)
elements = base.filter_element_types(elements=elements, include_element_types=element_types)
for element in elements:
assert type(element) in element_types
def test_filter_element_types_with_exclude_element_type(
filename="example-docs/fake-text.txt",
):
def test_filter_element_types_with_exclude_element_type():
element_types = [Title]
elements = partition_text(
filename=filename,
include_metadata=False,
)
elements = base.filter_element_types(
elements=elements,
exclude_element_types=element_types,
)
elements = partition_text("example-docs/fake-text.txt", include_metadata=False)
elements = base.filter_element_types(elements=elements, exclude_element_types=element_types)
for element in elements:
assert type(element) not in element_types
def test_filter_element_types_with_exclude_and_include_element_type(
filename="example-docs/fake-text.txt",
):
def test_filter_element_types_with_exclude_and_include_element_type():
element_types = [Title]
elements = partition_text(
filename=filename,
include_metadata=False,
)
elements = partition_text("example-docs/fake-text.txt", include_metadata=False)
with pytest.raises(ValueError):
elements = base.filter_element_types(
elements=elements,
@ -527,13 +492,9 @@ def test_flatten_dict_flatten_list_omit_keys4():
def test_flatten_empty_dict():
"""Flattening an empty dictionary"""
dictionary = {}
expected_result = {}
assert base.flatten_dict(dictionary) == expected_result
assert base.flatten_dict({}) == {}
def test_flatten_dict_empty_lists():
"""Flattening a dictionary with empty lists"""
dictionary = {"a": [], "b": {"c": []}}
expected_result = {"a": [], "b_c": []}
assert base.flatten_dict(dictionary) == expected_result
assert base.flatten_dict({"a": [], "b": {"c": []}}) == {"a": [], "b_c": []}

View File

@ -14,7 +14,7 @@ from unstructured.ingest.interfaces import (
ReadConfig,
)
from unstructured.partition.auto import partition
from unstructured.staging.base import convert_to_dict
from unstructured.staging.base import elements_to_dicts
DIRECTORY = pathlib.Path(__file__).parent.resolve()
EXAMPLE_DOCS_DIRECTORY = os.path.join(DIRECTORY, "../..", "example-docs")
@ -108,7 +108,7 @@ def partition_test_results():
@pytest.fixture()
def partition_file_test_results(partition_test_results):
# Reusable partition_file test results, calculated only once
return convert_to_dict(partition_test_results)
return elements_to_dicts(partition_test_results)
def test_partition_file():
@ -120,9 +120,9 @@ def test_partition_file():
processor_config=ProcessorConfig(output_dir=TEST_OUTPUT_DIR),
)
test_ingest_doc._date_processed = TEST_DATE_PROCESSSED
isd_elems_raw = test_ingest_doc.partition_file(partition_config=PartitionConfig())
isd_elems = convert_to_dict(isd_elems_raw)
assert len(isd_elems)
elements = test_ingest_doc.partition_file(partition_config=PartitionConfig())
element_dicts = elements_to_dicts(elements)
assert len(element_dicts)
expected_keys = {
"element_id",
"text",
@ -139,7 +139,7 @@ def test_partition_file():
"languages",
"last_modified",
}
for elem in isd_elems:
for elem in element_dicts:
# Parent IDs are non-deterministic - remove them from the test
elem["metadata"].pop("parent_id", None)
@ -166,11 +166,11 @@ def test_process_file_fields_include_default(mocker, partition_test_results):
read_config=ReadConfig(download_dir=TEST_DOWNLOAD_DIR),
processor_config=ProcessorConfig(output_dir=TEST_OUTPUT_DIR),
)
isd_elems_raw = test_ingest_doc.partition_file(partition_config=PartitionConfig())
isd_elems = convert_to_dict(isd_elems_raw)
assert len(isd_elems)
elements = test_ingest_doc.partition_file(partition_config=PartitionConfig())
element_dicts = elements_to_dicts(elements)
assert len(element_dicts)
assert mock_partition.call_count == 1
for elem in isd_elems:
for elem in element_dicts:
# Parent IDs are non-deterministic - remove them from the test
elem["metadata"].pop("parent_id", None)

View File

@ -1 +1 @@
__version__ = "0.12.7-dev7" # pragma: no cover
__version__ = "0.12.7-dev8" # pragma: no cover

View File

@ -22,7 +22,7 @@ from unstructured.ingest.enhanced_dataclass import EnhancedDataClassJsonMixin, e
from unstructured.ingest.enhanced_dataclass.core import _asdict
from unstructured.ingest.error import PartitionError, SourceConnectionError
from unstructured.ingest.logger import logger
from unstructured.staging.base import convert_to_dict, flatten_dict
from unstructured.staging.base import elements_to_dicts, flatten_dict
A = t.TypeVar("A", bound="DataClassJsonMixin")
@ -586,12 +586,11 @@ class BaseSingleIngestDoc(BaseIngestDoc, IngestDocJsonMixin, ABC):
return None
logger.info(f"Processing {self.filename}")
isd_elems_raw = self.partition_file(partition_config=partition_config, **partition_kwargs)
isd_elems = convert_to_dict(isd_elems_raw)
elements = self.partition_file(partition_config=partition_config, **partition_kwargs)
element_dicts = elements_to_dicts(elements)
self.isd_elems_no_filename: t.List[t.Dict[str, t.Any]] = []
for elem in isd_elems:
# type: ignore
for elem in element_dicts:
if partition_config.metadata_exclude and partition_config.metadata_include:
raise ValueError(
"Arguments `--metadata-include` and `--metadata-exclude` are "

View File

@ -10,7 +10,7 @@ from unstructured.ingest.interfaces import (
)
from unstructured.ingest.logger import logger
from unstructured.ingest.pipeline.interfaces import ReformatNode
from unstructured.staging.base import convert_to_dict, elements_from_json
from unstructured.staging.base import elements_from_json, elements_to_dicts
@dataclass
@ -49,10 +49,10 @@ class Chunker(ReformatNode):
return str(json_path)
elements = elements_from_json(filename=elements_json)
chunked_elements = self.chunking_config.chunk(elements=elements)
elements_dict = convert_to_dict(chunked_elements)
element_dicts = elements_to_dicts(chunked_elements)
with open(json_path, "w", encoding="utf8") as output_f:
logger.info(f"writing chunking content to {json_path}")
json.dump(elements_dict, output_f, ensure_ascii=False, indent=2)
json.dump(element_dicts, output_f, ensure_ascii=False, indent=2)
return str(json_path)
except Exception as e:
if self.pipeline_context.raise_on_error:

View File

@ -10,7 +10,7 @@ from unstructured.ingest.interfaces import (
)
from unstructured.ingest.logger import logger
from unstructured.ingest.pipeline.interfaces import ReformatNode
from unstructured.staging.base import convert_to_dict, elements_from_json
from unstructured.staging.base import elements_from_json, elements_to_dicts
@dataclass
@ -50,10 +50,10 @@ class Embedder(ReformatNode):
elements = elements_from_json(filename=elements_json)
embedder = self.embedder_config.get_embedder()
embedded_elements = embedder.embed_documents(elements=elements)
elements_dict = convert_to_dict(embedded_elements)
element_dicts = elements_to_dicts(embedded_elements)
with open(json_path, "w", encoding="utf8") as output_f:
logger.info(f"writing embeddings content to {json_path}")
json.dump(elements_dict, output_f, ensure_ascii=False, indent=2)
json.dump(element_dicts, output_f, ensure_ascii=False, indent=2)
return str(json_path)
except Exception as e:
if self.pipeline_context.raise_on_error:

View File

@ -13,7 +13,7 @@ from unstructured_client.models import shared
from unstructured.documents.elements import Element
from unstructured.logger import logger
from unstructured.partition.common import exactly_one
from unstructured.staging.base import dict_to_elements, elements_from_json
from unstructured.staging.base import elements_from_dicts, elements_from_json
def partition_via_api(
@ -214,7 +214,7 @@ def partition_multiple_via_api(
response_list = [response_list]
for document in response_list:
documents.append(dict_to_elements(document))
documents.append(elements_from_dicts(document))
return documents
else:
raise ValueError(

View File

@ -24,7 +24,7 @@ from unstructured.partition.common import (
get_last_modified_date,
get_last_modified_date_from_file,
)
from unstructured.staging.base import dict_to_elements
from unstructured.staging.base import elements_from_dicts
@process_metadata()
@ -86,8 +86,8 @@ def partition_json(
)
try:
dict = json.loads(file_text)
elements = dict_to_elements(dict)
element_dicts = json.loads(file_text)
elements = elements_from_dicts(element_dicts)
except json.JSONDecodeError:
raise ValueError("Not a valid json")

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import csv
import io
import json
from copy import deepcopy
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Iterable, Optional, Sequence, cast
from unstructured.documents.coordinates import PixelSpace
from unstructured.documents.elements import (
@ -14,13 +16,128 @@ from unstructured.documents.elements import (
NoID,
)
from unstructured.partition.common import exactly_one
from unstructured.utils import dependency_exists, requires_dependencies
from unstructured.utils import Point, dependency_exists, requires_dependencies
if dependency_exists("pandas"):
import pandas as pd
def _get_metadata_table_fieldnames():
# ================================================================================================
# SERIALIZATION/DESERIALIZATION (SERDE) RELATED FUNCTIONS
# ================================================================================================
# These serde functions will likely relocate to `unstructured.documents.elements` since they are
# so closely related to elements and this staging "brick" is deprecated.
# ================================================================================================
# == DESERIALIZERS ===============================
def elements_from_dicts(element_dicts: Iterable[dict[str, Any]]) -> list[Element]:
"""Convert a list of element-dicts to a list of elements."""
elements: list[Element] = []
for item in element_dicts:
element_id: str = item.get("element_id", NoID())
metadata = (
ElementMetadata()
if item.get("metadata") is None
else ElementMetadata.from_dict(item["metadata"])
)
if item.get("type") in TYPE_TO_TEXT_ELEMENT_MAP:
ElementCls = TYPE_TO_TEXT_ELEMENT_MAP[item["type"]]
elements.append(ElementCls(text=item["text"], element_id=element_id, metadata=metadata))
elif item.get("type") == "CheckBox":
elements.append(
CheckBox(checked=item["checked"], element_id=element_id, metadata=metadata)
)
return elements
# -- legacy aliases for elements_from_dicts() --
isd_to_elements = elements_from_dicts
dict_to_elements = elements_from_dicts
def elements_from_json(
filename: str = "", text: str = "", encoding: str = "utf-8"
) -> list[Element]:
"""Loads a list of elements from a JSON file or a string."""
exactly_one(filename=filename, text=text)
if filename:
with open(filename, encoding=encoding) as f:
element_dicts = json.load(f)
else:
element_dicts = json.loads(text)
return elements_from_dicts(element_dicts)
# == SERIALIZERS =================================
def elements_to_dicts(elements: Iterable[Element]) -> list[dict[str, Any]]:
"""Convert document elements to element-dicts."""
return [e.to_dict() for e in elements]
# -- legacy aliases for elements_to_dicts() --
convert_to_isd = elements_to_dicts
convert_to_dict = elements_to_dicts
def elements_to_json(
elements: Iterable[Element],
filename: Optional[str] = None,
indent: int = 4,
encoding: str = "utf-8",
) -> Optional[str]:
"""Saves a list of elements to a JSON file if filename is specified.
Otherwise, return the list of elements as a string.
"""
# -- serialize `elements` as a JSON array (str) --
precision_adjusted_elements = _fix_metadata_field_precision(elements)
element_dicts = elements_to_dicts(precision_adjusted_elements)
json_str = json.dumps(element_dicts, indent=indent, sort_keys=True)
if filename is not None:
with open(filename, "w", encoding=encoding) as f:
f.write(json_str)
return None
return json_str
def _fix_metadata_field_precision(elements: Iterable[Element]) -> list[Element]:
out_elements: list[Element] = []
for element in elements:
el = deepcopy(element)
if el.metadata.coordinates:
precision = 1 if isinstance(el.metadata.coordinates.system, PixelSpace) else 2
points = el.metadata.coordinates.points
assert points is not None
rounded_points: list[Point] = []
for point in points:
x, y = point
rounded_point = (round(x, precision), round(y, precision))
rounded_points.append(rounded_point)
el.metadata.coordinates.points = tuple(rounded_points)
if el.metadata.detection_class_prob:
el.metadata.detection_class_prob = round(el.metadata.detection_class_prob, 5)
out_elements.append(el)
return out_elements
# ================================================================================================
def _get_metadata_table_fieldnames() -> list[str]:
metadata_fields = list(ElementMetadata.__annotations__.keys())
metadata_fields.remove("coordinates")
metadata_fields.extend(
@ -35,27 +152,25 @@ def _get_metadata_table_fieldnames():
return metadata_fields
TABLE_FIELDNAMES: List[str] = [
TABLE_FIELDNAMES: list[str] = [
"type",
"text",
"element_id",
] + _get_metadata_table_fieldnames()
def convert_to_text(elements: List[Element]) -> str:
"""Converts a list of elements into clean, concatenated text."""
def convert_to_text(elements: Iterable[Element]) -> str:
"""Convert elements into clean, concatenated text."""
return "\n".join([e.text for e in elements if hasattr(e, "text") and e.text])
def elements_to_text(
elements: List[Element],
filename: Optional[str] = None,
encoding: str = "utf-8",
elements: Iterable[Element], filename: Optional[str] = None, encoding: str = "utf-8"
) -> Optional[str]:
"""
Convert the text from the list of elements into clean, concatenated text.
Saves to a txt file if filename is specified.
Otherwise, return the text of the elements as a string.
"""Convert text from each of `elements` into clean, concatenated text.
Saves to a txt file if filename is specified. Otherwise, return the text of the elements as a
string.
"""
element_cct = convert_to_text(elements)
if filename is not None:
@ -66,130 +181,23 @@ def elements_to_text(
return element_cct
def convert_to_isd(elements: List[Element]) -> List[Dict[str, Any]]:
"""Represents the document elements as an Initial Structured Document (ISD)."""
isd: List[Dict[str, Any]] = []
for element in elements:
section = element.to_dict()
isd.append(section)
return isd
def convert_to_dict(elements: List[Element]) -> List[Dict[str, Any]]:
"""Converts a list of elements into a dictionary."""
return convert_to_isd(elements)
def _fix_metadata_field_precision(elements: List[Element]) -> List[Element]:
out_elements = []
for element in elements:
el = deepcopy(element)
if el.metadata.coordinates:
precision = 1 if isinstance(el.metadata.coordinates.system, PixelSpace) else 2
points = el.metadata.coordinates.points
rounded_points = []
for point in points:
x, y = point
rounded_point = (round(x, precision), round(y, precision))
rounded_points.append(rounded_point)
el.metadata.coordinates.points = tuple(rounded_points)
if el.metadata.detection_class_prob:
el.metadata.detection_class_prob = round(el.metadata.detection_class_prob, 5)
out_elements.append(el)
return out_elements
def elements_to_json(
elements: List[Element],
filename: Optional[str] = None,
indent: int = 4,
encoding: str = "utf-8",
) -> Optional[str]:
"""
Saves a list of elements to a JSON file if filename is specified.
Otherwise, return the list of elements as a string.
"""
pre_processed_elements = _fix_metadata_field_precision(elements)
element_dict = convert_to_dict(pre_processed_elements)
if filename is not None:
with open(filename, "w", encoding=encoding) as f:
json.dump(element_dict, f, indent=indent, sort_keys=True)
return None
else:
return json.dumps(element_dict, indent=indent, sort_keys=True)
def isd_to_elements(isd: List[Dict[str, Any]]) -> List[Element]:
"""Converts an Initial Structured Data (ISD) dictionary to a list of elements."""
elements: List[Element] = []
for item in isd:
element_id: str = item.get("element_id", NoID())
metadata = ElementMetadata()
_metadata_dict = item.get("metadata")
if _metadata_dict is not None:
metadata = ElementMetadata.from_dict(_metadata_dict)
if item.get("type") in TYPE_TO_TEXT_ELEMENT_MAP:
_text_class = TYPE_TO_TEXT_ELEMENT_MAP[item["type"]]
elements.append(
_text_class(
text=item["text"],
element_id=element_id,
metadata=metadata,
),
)
elif item.get("type") == "CheckBox":
elements.append(
CheckBox(
checked=item["checked"],
element_id=element_id,
metadata=metadata,
),
)
return elements
def dict_to_elements(element_dict: List[Dict[str, Any]]) -> List[Element]:
"""Converts a dictionary representation of an element list into List[Element]."""
return isd_to_elements(element_dict)
def elements_from_json(
filename: str = "",
text: str = "",
encoding: str = "utf-8",
) -> List[Element]:
"""Loads a list of elements from a JSON file or a string."""
exactly_one(filename=filename, text=text)
if filename:
with open(filename, encoding=encoding) as f:
element_dict = json.load(f)
return dict_to_elements(element_dict)
else:
element_dict = json.loads(text)
return dict_to_elements(element_dict)
def flatten_dict(
dictionary,
parent_key="",
separator="_",
flatten_lists=False,
remove_none=False,
keys_to_omit: List[str] = None,
):
"""Flattens a nested dictionary into a single level dictionary. keys_to_omit is a list of keys
that don't get flattened. If omitting a nested key, format as {parent_key}{separator}{key}.
If flatten_lists is True, then lists and tuples are flattened as well.
If remove_none is True, then None keys/values are removed from the flattened dictionary."""
dictionary: dict[str, Any],
parent_key: str = "",
separator: str = "_",
flatten_lists: bool = False,
remove_none: bool = False,
keys_to_omit: Optional[Sequence[str]] = None,
) -> dict[str, Any]:
"""Flattens a nested dictionary into a single level dictionary.
keys_to_omit is a list of keys that don't get flattened. If omitting a nested key, format as
{parent_key}{separator}{key}. If flatten_lists is True, then lists and tuples are flattened as
well. If remove_none is True, then None keys/values are removed from the flattened
dictionary.
"""
keys_to_omit = keys_to_omit if keys_to_omit else []
flattened_dict = {}
flattened_dict: dict[str, Any] = {}
for key, value in dictionary.items():
new_key = f"{parent_key}{separator}{key}" if parent_key else key
if new_key in keys_to_omit:
@ -197,12 +205,14 @@ def flatten_dict(
elif value is None and remove_none:
continue
elif isinstance(value, dict):
value = cast("dict[str, Any]", value)
flattened_dict.update(
flatten_dict(
value, new_key, separator, flatten_lists, remove_none, keys_to_omit=keys_to_omit
),
)
elif isinstance(value, (list, tuple)) and flatten_lists:
value = cast("list[Any] | tuple[Any]", value)
for index, item in enumerate(value):
flattened_dict.update(
flatten_dict(
@ -216,10 +226,11 @@ def flatten_dict(
)
else:
flattened_dict[new_key] = value
return flattened_dict
def _get_table_fieldnames(rows):
def _get_table_fieldnames(rows: list[dict[str, Any]]):
table_fieldnames = list(TABLE_FIELDNAMES)
for row in rows:
metadata = row["metadata"]
@ -229,12 +240,9 @@ def _get_table_fieldnames(rows):
return table_fieldnames
def convert_to_isd_csv(elements: List[Element]) -> str:
"""
Returns the representation of document elements as an Initial Structured Document (ISD)
in CSV Format.
"""
rows: List[Dict[str, Any]] = convert_to_isd(elements)
def convert_to_csv(elements: Iterable[Element]) -> str:
"""Convert `elements` to CSV format."""
rows: list[dict[str, Any]] = elements_to_dicts(elements)
table_fieldnames = _get_table_fieldnames(rows)
# NOTE(robinson) - flatten metadata and add it to the table
for row in rows:
@ -255,55 +263,54 @@ def convert_to_isd_csv(elements: List[Element]) -> str:
return buffer.getvalue()
def convert_to_csv(elements: List[Element]) -> str:
"""Converts a list of elements to a CSV."""
return convert_to_isd_csv(elements)
# -- legacy alias for convert_to_csv --
convert_to_isd_csv = convert_to_csv
@requires_dependencies(["pandas"])
def get_default_pandas_dtypes() -> dict:
def get_default_pandas_dtypes() -> dict[str, Any]:
return {
"text": pd.StringDtype(),
"type": pd.StringDtype(),
"element_id": pd.StringDtype(),
"filename": pd.StringDtype(), # Optional[str]
"filetype": pd.StringDtype(), # Optional[str]
"file_directory": pd.StringDtype(), # Optional[str]
"last_modified": pd.StringDtype(), # Optional[str]
"attached_to_filename": pd.StringDtype(), # Optional[str]
"parent_id": pd.StringDtype(), # Optional[str],
"text": pd.StringDtype(), # type: ignore
"type": pd.StringDtype(), # type: ignore
"element_id": pd.StringDtype(), # type: ignore
"filename": pd.StringDtype(), # Optional[str] # type: ignore
"filetype": pd.StringDtype(), # Optional[str] # type: ignore
"file_directory": pd.StringDtype(), # Optional[str] # type: ignore
"last_modified": pd.StringDtype(), # Optional[str] # type: ignore
"attached_to_filename": pd.StringDtype(), # Optional[str] # type: ignore
"parent_id": pd.StringDtype(), # Optional[str], # type: ignore
"category_depth": "Int64", # Optional[int]
"image_path": pd.StringDtype(), # Optional[str]
"languages": object, # Optional[List[str]]
"image_path": pd.StringDtype(), # Optional[str] # type: ignore
"languages": object, # Optional[list[str]]
"page_number": "Int64", # Optional[int]
"page_name": pd.StringDtype(), # Optional[str]
"url": pd.StringDtype(), # Optional[str]
"link_urls": pd.StringDtype(), # Optional[str]
"link_texts": object, # Optional[List[str]]
"page_name": pd.StringDtype(), # Optional[str] # type: ignore
"url": pd.StringDtype(), # Optional[str] # type: ignore
"link_urls": pd.StringDtype(), # Optional[str] # type: ignore
"link_texts": object, # Optional[list[str]]
"links": object,
"sent_from": object, # Optional[List[str]],
"sent_to": object, # Optional[List[str]]
"subject": pd.StringDtype(), # Optional[str]
"section": pd.StringDtype(), # Optional[str]
"header_footer_type": pd.StringDtype(), # Optional[str]
"emphasized_text_contents": object, # Optional[List[str]]
"emphasized_text_tags": object, # Optional[List[str]]
"text_as_html": pd.StringDtype(), # Optional[str]
"sent_from": object, # Optional[list[str]],
"sent_to": object, # Optional[list[str]]
"subject": pd.StringDtype(), # Optional[str] # type: ignore
"section": pd.StringDtype(), # Optional[str] # type: ignore
"header_footer_type": pd.StringDtype(), # Optional[str] # type: ignore
"emphasized_text_contents": object, # Optional[list[str]]
"emphasized_text_tags": object, # Optional[list[str]]
"text_as_html": pd.StringDtype(), # Optional[str] # type: ignore
"regex_metadata": object,
"max_characters": "Int64", # Optional[int]
"is_continuation": "boolean", # Optional[bool]
"detection_class_prob": float, # Optional[float],
"sender": pd.StringDtype(),
"sender": pd.StringDtype(), # type: ignore
"coordinates_points": object,
"coordinates_system": pd.StringDtype(),
"coordinates_system": pd.StringDtype(), # type: ignore
"coordinates_layout_width": float,
"coordinates_layout_height": float,
"data_source_url": pd.StringDtype(), # Optional[str]
"data_source_version": pd.StringDtype(), # Optional[str]
"data_source_url": pd.StringDtype(), # Optional[str] # type: ignore
"data_source_version": pd.StringDtype(), # Optional[str] # type: ignore
"data_source_record_locator": object,
"data_source_date_created": pd.StringDtype(), # Optional[str]
"data_source_date_modified": pd.StringDtype(), # Optional[str]
"data_source_date_processed": pd.StringDtype(), # Optional[str]
"data_source_date_created": pd.StringDtype(), # Optional[str] # type: ignore
"data_source_date_modified": pd.StringDtype(), # Optional[str] # type: ignore
"data_source_date_processed": pd.StringDtype(), # Optional[str] # type: ignore
"data_source_permissions_data": object,
"embeddings": object,
"regex_metadata_key": object,
@ -312,44 +319,41 @@ def get_default_pandas_dtypes() -> dict:
@requires_dependencies(["pandas"])
def convert_to_dataframe(
elements: List[Element],
drop_empty_cols: bool = True,
set_dtypes=False,
elements: Iterable[Element], drop_empty_cols: bool = True, set_dtypes: bool = False
) -> "pd.DataFrame":
"""Converts document elements to a pandas DataFrame. The dataframe contains the
following columns:
"""Convert `elements` to a pandas DataFrame.
The dataframe contains the following columns:
text: the element text
type: the text type (NarrativeText, Title, etc)
Output is pd.DataFrame
"""
elements_as_dict = convert_to_dict(elements)
for d in elements_as_dict:
element_dicts = elements_to_dicts(elements)
for d in element_dicts:
if metadata := d.pop("metadata", None):
d.update(flatten_dict(metadata, keys_to_omit=["data_source_record_locator"]))
df = pd.DataFrame.from_dict(
elements_as_dict,
)
df = pd.DataFrame.from_dict(element_dicts) # type: ignore
if set_dtypes:
dt = {k: v for k, v in get_default_pandas_dtypes().items() if k in df.columns}
df = df.astype(dt)
df = df.astype(dt) # type: ignore
if drop_empty_cols:
df.dropna(axis=1, how="all", inplace=True)
df.dropna(axis=1, how="all", inplace=True) # type: ignore
return df
def filter_element_types(
elements: List[Element],
include_element_types: Optional[List[Element]] = None,
exclude_element_types: Optional[List[Element]] = None,
) -> List[Element]:
elements: Iterable[Element],
include_element_types: Optional[Sequence[type[Element]]] = None,
exclude_element_types: Optional[Sequence[type[Element]]] = None,
) -> list[Element]:
"""Filters document elements by element type"""
exactly_one(
include_element_types=include_element_types,
exclude_element_types=exclude_element_types,
)
filtered_elements: List[Element] = []
filtered_elements: list[Element] = []
if include_element_types:
for element in elements:
if type(element) in include_element_types:
@ -364,16 +368,18 @@ def filter_element_types(
return filtered_elements
return elements
return list(elements)
def convert_to_coco(
elements: List[Element],
elements: Iterable[Element],
dataset_description: Optional[str] = None,
dataset_version: str = "1.0",
contributors: Tuple[str] = ("Unstructured Developers",),
) -> List[Dict[str, Any]]:
coco_dataset = {}
contributors: tuple[str] = ("Unstructured Developers",),
) -> dict[str, Any]:
from unstructured.documents.elements import TYPE_TO_TEXT_ELEMENT_MAP
coco_dataset: dict[str, Any] = {}
# Handle Info
coco_dataset["info"] = {
"description": (
@ -386,7 +392,7 @@ def convert_to_coco(
"contributors": ",".join(contributors),
"date_created": datetime.now().date().isoformat(),
}
elements_dict = convert_to_dict(elements)
element_dicts = elements_to_dicts(elements)
# Handle Images
images = [
{
@ -404,7 +410,7 @@ def convert_to_coco(
"file_name": el["metadata"].get("filename", ""),
"page_number": el["metadata"].get("page_number", ""),
}
for el in elements_dict
for el in element_dicts
]
images = list({tuple(sorted(d.items())): d for d in images}.values())
for index, d in enumerate(images):
@ -458,7 +464,7 @@ def convert_to_coco(
else None
),
}
for el in elements_dict
for el in element_dicts
]
coco_dataset["annotations"] = annotations
return coco_dataset