mirror of
https://github.com/Unstructured-IO/unstructured.git
synced 2025-08-15 20:27:37 +00:00
feat: Implement stage_csv_for_prodigy
brick (#13)
* Refactor metadata validation and implement stage_csv_for_prodigy brick * Refactor unit tests for metadata validation and add tests for Prodigy CSV brick * Add stage_csv_for_prodigy description and example in docs * Bump version and update changelog * added _csv_ to function name * update changelog line to 0.2.1-dev2 Co-authored-by: Matt Robinson <mrobinson@unstructuredai.io>
This commit is contained in:
parent
90d4f40da8
commit
d429e9b305
@ -1,5 +1,6 @@
|
||||
## 0.2.1-dev1
|
||||
## 0.2.1-dev2
|
||||
|
||||
* Added staging brick for CSV format for Prodigy
|
||||
* Added staging brick for Prodigy
|
||||
* Added text_field and id_field to stage_for_label_studio signature
|
||||
|
||||
|
@ -364,7 +364,7 @@ Examples:
|
||||
``stage_for_prodigy``
|
||||
--------------------------
|
||||
|
||||
Formats outputs for use with `Prodigy <https://prodi.gy/docs/api-loaders>`_. After running ``stage_for_prodigy``, you can
|
||||
Formats outputs in JSON format for use with `Prodigy <https://prodi.gy/docs/api-loaders>`_. After running ``stage_for_prodigy``, you can
|
||||
write the results to a JSON file that is ready to be used with Prodigy.
|
||||
|
||||
Examples:
|
||||
@ -383,3 +383,25 @@ Examples:
|
||||
# The resulting JSON file is ready to be used with Prodigy
|
||||
with open("prodigy.json", "w") as f:
|
||||
json.dump(prodigy_data, f, indent=4)
|
||||
|
||||
|
||||
``stage_csv_for_prodigy``
|
||||
--------------------------
|
||||
|
||||
Formats outputs in CSV format for use with `Prodigy <https://prodi.gy/docs/api-loaders>`_. After running ``stage_csv_for_prodigy``, you can
|
||||
write the results to a CSV file that is ready to be used with Prodigy.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code:: python
|
||||
|
||||
from unstructured.documents.elements import Title, NarrativeText
|
||||
from unstructured.staging.prodigy import stage_csv_for_prodigy
|
||||
|
||||
elements = [Title(text="Title"), NarrativeText(text="Narrative")]
|
||||
metadata = [{"type": "title"}, {"source": "news"}]
|
||||
prodigy_csv_data = stage_csv_for_prodigy(elements, metadata)
|
||||
|
||||
# The resulting CSV file is ready to be used with Prodigy
|
||||
with open("prodigy.csv", "w") as csv_file:
|
||||
csv_file.write(prodigy_csv_data)
|
||||
|
@ -1,4 +1,6 @@
|
||||
import pytest
|
||||
import csv
|
||||
import os
|
||||
|
||||
import unstructured.staging.prodigy as prodigy
|
||||
from unstructured.documents.elements import Title, NarrativeText
|
||||
@ -24,6 +26,41 @@ def metadata_with_invalid_length():
|
||||
return [{"score": 0.1}, {"category": "paragraph"}, {"type": "text"}]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def output_csv_file(tmp_path):
|
||||
return os.path.join(tmp_path, "prodigy_data.csv")
|
||||
|
||||
|
||||
def test_validate_prodigy_metadata(elements):
|
||||
validated_metadata = prodigy._validate_prodigy_metadata(elements, metadata=None)
|
||||
assert len(validated_metadata) == len(elements)
|
||||
assert all(not data for data in validated_metadata)
|
||||
|
||||
|
||||
def test_validate_prodigy_metadata_with_valid_metadata(elements, valid_metadata):
|
||||
validated_metadata = prodigy._validate_prodigy_metadata(elements, metadata=valid_metadata)
|
||||
assert len(validated_metadata) == len(elements)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_metadata_fixture, exception_message",
|
||||
[
|
||||
("metadata_with_id", 'The key "id" is not allowed with metadata parameter at index: 1'),
|
||||
(
|
||||
"metadata_with_invalid_length",
|
||||
"The length of metadata parameter does not match with length of elements parameter.",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_validate_prodigy_metadata_with_invalid_metadata(
|
||||
elements, invalid_metadata_fixture, exception_message, request
|
||||
):
|
||||
invalid_metadata = request.getfixturevalue(invalid_metadata_fixture)
|
||||
with pytest.raises(ValueError) as validation_exception:
|
||||
prodigy._validate_prodigy_metadata(elements, invalid_metadata)
|
||||
assert str(validation_exception.value) == exception_message
|
||||
|
||||
|
||||
def test_convert_to_prodigy_data(elements):
|
||||
prodigy_data = prodigy.stage_for_prodigy(elements)
|
||||
|
||||
@ -54,20 +91,24 @@ def test_convert_to_prodigy_data_with_valid_metadata(elements, valid_metadata):
|
||||
assert prodigy_data[1]["meta"] == {"id": elements[1].id, **valid_metadata[1]}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_metadata_fixture, exception_message",
|
||||
[
|
||||
("metadata_with_id", 'The key "id" is not allowed with metadata parameter at index: 1'),
|
||||
(
|
||||
"metadata_with_invalid_length",
|
||||
"The length of metadata parameter does not match with length of elements parameter.",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_convert_to_prodigy_data_with_invalid_metadata(
|
||||
elements, invalid_metadata_fixture, exception_message, request
|
||||
):
|
||||
invalid_metadata = request.getfixturevalue(invalid_metadata_fixture)
|
||||
with pytest.raises(ValueError) as validation_exception:
|
||||
prodigy.stage_for_prodigy(elements, invalid_metadata)
|
||||
assert str(validation_exception.value) == exception_message
|
||||
def test_stage_csv_for_prodigy(elements, output_csv_file):
|
||||
with open(output_csv_file, "w+") as csv_file:
|
||||
prodigy_csv_string = prodigy.stage_csv_for_prodigy(elements)
|
||||
csv_file.write(prodigy_csv_string)
|
||||
|
||||
fieldnames = ["text", "id"]
|
||||
with open(output_csv_file, "r") as csv_file:
|
||||
csv_rows = csv.DictReader(csv_file)
|
||||
assert all(set(row.keys()) == set(fieldnames) for row in csv_rows)
|
||||
|
||||
|
||||
def test_stage_csv_for_prodigy_with_metadata(elements, valid_metadata, output_csv_file):
|
||||
with open(output_csv_file, "w+") as csv_file:
|
||||
prodigy_csv_string = prodigy.stage_csv_for_prodigy(elements, valid_metadata)
|
||||
csv_file.write(prodigy_csv_string)
|
||||
|
||||
fieldnames = set(["text", "id"]).union(*(data.keys() for data in valid_metadata))
|
||||
fieldnames = [fieldname.lower() for fieldname in fieldnames]
|
||||
with open(output_csv_file, "r") as csv_file:
|
||||
csv_rows = csv.DictReader(csv_file)
|
||||
assert all(set(row.keys()) == set(fieldnames) for row in csv_rows)
|
||||
|
@ -1 +1 @@
|
||||
__version__ = "0.2.1-dev1" # pragma: no cover
|
||||
__version__ = "0.2.1-dev2" # pragma: no cover
|
||||
|
@ -1,4 +1,6 @@
|
||||
from typing import Iterable, List, Dict, Optional, Union
|
||||
import io
|
||||
from typing import Generator, Iterable, List, Dict, Optional, Union
|
||||
import csv
|
||||
|
||||
from unstructured.documents.elements import Text
|
||||
|
||||
@ -6,15 +8,14 @@ from unstructured.documents.elements import Text
|
||||
PRODIGY_TYPE = List[Dict[str, Union[str, Dict[str, str]]]]
|
||||
|
||||
|
||||
def stage_for_prodigy(
|
||||
def _validate_prodigy_metadata(
|
||||
elements: List[Text],
|
||||
metadata: Optional[List[Dict[str, str]]] = None,
|
||||
) -> PRODIGY_TYPE:
|
||||
) -> Iterable[Dict[str, str]]:
|
||||
"""
|
||||
Converts the document to the format required for use with Prodigy.
|
||||
ref: https://prodi.gy/docs/api-loaders#input
|
||||
Returns validated metadata list for Prodigy bricks.
|
||||
Raises ValueError with error message if metadata is not valid.
|
||||
"""
|
||||
|
||||
validated_metadata: Iterable[Dict[str, str]]
|
||||
if metadata:
|
||||
if len(metadata) != len(elements):
|
||||
@ -33,6 +34,19 @@ def stage_for_prodigy(
|
||||
validated_metadata = metadata
|
||||
else:
|
||||
validated_metadata = [dict() for _ in elements]
|
||||
return validated_metadata
|
||||
|
||||
|
||||
def stage_for_prodigy(
|
||||
elements: List[Text],
|
||||
metadata: Optional[List[Dict[str, str]]] = None,
|
||||
) -> PRODIGY_TYPE:
|
||||
"""
|
||||
Converts the document to the JSON format required for use with Prodigy.
|
||||
ref: https://prodi.gy/docs/api-loaders#input
|
||||
"""
|
||||
|
||||
validated_metadata: Iterable[Dict[str, str]] = _validate_prodigy_metadata(elements, metadata)
|
||||
|
||||
prodigy_data: PRODIGY_TYPE = list()
|
||||
for element, metadatum in zip(elements, validated_metadata):
|
||||
@ -42,3 +56,36 @@ def stage_for_prodigy(
|
||||
prodigy_data.append(data)
|
||||
|
||||
return prodigy_data
|
||||
|
||||
|
||||
def stage_csv_for_prodigy(
|
||||
elements: List[Text],
|
||||
metadata: Optional[List[Dict[str, str]]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Converts the document to the CSV format required for use with Prodigy.
|
||||
ref: https://prodi.gy/docs/api-loaders#input
|
||||
"""
|
||||
validated_metadata: Iterable[Dict[str, str]] = _validate_prodigy_metadata(elements, metadata)
|
||||
|
||||
csv_fieldnames = ["text", "id"]
|
||||
csv_fieldnames += list(
|
||||
set().union(
|
||||
*((key.lower() for key in metadata_item.keys()) for metadata_item in validated_metadata)
|
||||
)
|
||||
)
|
||||
|
||||
def _get_rows() -> Generator[Dict[str, str], None, None]:
|
||||
for element, metadatum in zip(elements, validated_metadata):
|
||||
metadatum = {key.lower(): value for key, value in metadatum.items()}
|
||||
row_data = dict(text=element.text, **metadatum)
|
||||
if isinstance(element.id, str):
|
||||
row_data["id"] = element.id
|
||||
yield row_data
|
||||
|
||||
with io.StringIO() as buffer:
|
||||
csv_writer = csv.DictWriter(buffer, fieldnames=csv_fieldnames)
|
||||
csv_writer.writeheader()
|
||||
csv_rows = _get_rows()
|
||||
csv_writer.writerows(csv_rows)
|
||||
return buffer.getvalue()
|
||||
|
Loading…
x
Reference in New Issue
Block a user