fix: grab all metadata fields in convert_to_dataframe (#893)

* add all fieldnames to dataframe

* drop empty columns in convert_to_dataframe

* test for maintaining metadata

* version and changelog
This commit is contained in:
Matt Robinson 2023-07-07 16:04:35 -04:00 committed by GitHub
parent c8e6f0e141
commit f51ae45050
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 18 deletions

View File

@ -7,12 +7,14 @@
### Features
* Add metadata_filename parameter across all partition functions
* Add `metadata_filename` parameter across all partition functions
### Fixes
* Update to ensure `convert_to_datafame` grabs all of the metadata fields.
* Adjust encoding recognition threshold value in `detect_file_encoding`
* Fix KeyError when `isd_to_elements` doesn't find a type
* Fix _output_filename for local connector, allowing single files to be written correctly to the disk
* Fix `_output_filename` for local connector, allowing single files to be written correctly to the disk
* Fix for cases where an invalid encoding is extracted from an email header.

View File

@ -20,6 +20,7 @@ from unstructured.documents.elements import (
Text,
Title,
)
from unstructured.partition.email import partition_email
from unstructured.partition.text import partition_text
from unstructured.staging import base
@ -82,6 +83,25 @@ def test_convert_to_dataframe():
assert df.text.equals(expected_df.text) is True
def test_convert_to_dataframe_maintains_fields(
filename="example-docs/eml/fake-email-attachment.eml",
):
elements = partition_email(
filename=filename,
process_attachements=True,
regex_metadata={"hello": r"Hello", "punc": r"[!]"},
)
df = base.convert_to_dataframe(elements)
for element in elements:
metadata = element.metadata.to_dict()
for key in metadata:
if not key.startswith("regex_metadata"):
assert key in df.columns
assert "regex_metadata_hello" in df.columns
assert "regex_metadata_punc" in df.columns
@pytest.mark.skipif(
platform.system() == "Windows",
reason="Posix Paths are not available on Windows",

View File

@ -14,22 +14,27 @@ from unstructured.documents.elements import (
)
from unstructured.partition.common import exactly_one
def _get_metadata_table_fieldnames():
metadata_fields = list(ElementMetadata.__annotations__.keys())
metadata_fields.remove("coordinates")
metadata_fields.extend(
[
"sender",
"coordinates_points",
"coordinates_system",
"coordinates_layout_width",
"coordinates_layout_height",
],
)
return metadata_fields
TABLE_FIELDNAMES: List[str] = [
"type",
"text",
"element_id",
"coordinates_points",
"coordinates_system",
"coordinates_layout_width",
"coordinates_layout_height",
"filename",
"page_number",
"url",
"sent_from",
"sent_to",
"subject",
"sender",
]
] + _get_metadata_table_fieldnames()
def convert_to_isd(elements: List[Element]) -> List[Dict[str, Any]]:
@ -130,17 +135,28 @@ def flatten_dict(dictionary, parent_key="", separator="_"):
return flattened_dict
def _get_table_fieldnames(rows):
table_fieldnames = list(TABLE_FIELDNAMES)
for row in rows:
metadata = row["metadata"]
for key in flatten_dict(metadata):
if key.startswith("regex_metadata") and key not in table_fieldnames:
table_fieldnames.append(key)
return table_fieldnames
def convert_to_isd_csv(elements: List[Element]) -> str:
"""
Returns the representation of document elements as an Initial Structured Document (ISD)
in CSV Format.
"""
rows: List[Dict[str, Any]] = convert_to_isd(elements)
table_fieldnames = _get_table_fieldnames(rows)
# NOTE(robinson) - flatten metadata and add it to the table
for row in rows:
metadata = row.pop("metadata")
for key, value in flatten_dict(metadata).items():
if key in TABLE_FIELDNAMES:
if key in table_fieldnames:
row[key] = value
if row.get("sent_from"):
@ -149,7 +165,7 @@ def convert_to_isd_csv(elements: List[Element]) -> str:
row["sender"] = row["sender"][0]
with io.StringIO() as buffer:
csv_writer = csv.DictWriter(buffer, fieldnames=TABLE_FIELDNAMES)
csv_writer = csv.DictWriter(buffer, fieldnames=table_fieldnames)
csv_writer.writeheader()
csv_writer.writerows(rows)
return buffer.getvalue()
@ -160,7 +176,7 @@ def convert_to_csv(elements: List[Element]) -> str:
return convert_to_isd_csv(elements)
def convert_to_dataframe(elements: List[Element]) -> pd.DataFrame:
def convert_to_dataframe(elements: List[Element], drop_empty_cols: bool = True) -> pd.DataFrame:
"""Converts document elements to a pandas DataFrame. The dataframe contains the
following columns:
text: the element text
@ -168,4 +184,7 @@ def convert_to_dataframe(elements: List[Element]) -> pd.DataFrame:
"""
csv_string = convert_to_isd_csv(elements)
csv_string_io = io.StringIO(csv_string)
return pd.read_csv(csv_string_io, sep=",")
df = pd.read_csv(csv_string_io, sep=",")
if drop_empty_cols:
df.dropna(axis=1, how="all", inplace=True)
return df