feat: Implement stage_csv_for_prodigy brick (#13)

* Refactor metadata validation and implement stage_csv_for_prodigy brick

* Refactor unit tests for metadata validation and add tests for Prodigy CSV brick

* Add stage_csv_for_prodigy description and example in docs

* Bump version and update changelog

* added _csv_ to function name

* update changelog line to 0.2.1-dev2

Co-authored-by: Matt Robinson <mrobinson@unstructuredai.io>
This commit is contained in:
asymness 2022-10-03 18:30:30 +05:00 committed by GitHub
parent 90d4f40da8
commit d429e9b305
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 137 additions and 26 deletions

View File

@ -1,5 +1,6 @@
## 0.2.1-dev1
## 0.2.1-dev2
* Added staging brick for CSV format for Prodigy
* Added staging brick for Prodigy
* Added text_field and id_field to stage_for_label_studio signature

View File

@ -364,7 +364,7 @@ Examples:
``stage_for_prodigy``
--------------------------
Formats outputs for use with `Prodigy <https://prodi.gy/docs/api-loaders>`_. After running ``stage_for_prodigy``, you can
Formats outputs in JSON format for use with `Prodigy <https://prodi.gy/docs/api-loaders>`_. After running ``stage_for_prodigy``, you can
write the results to a JSON file that is ready to be used with Prodigy.
Examples:
@ -383,3 +383,25 @@ Examples:
# The resulting JSON file is ready to be used with Prodigy
with open("prodigy.json", "w") as f:
json.dump(prodigy_data, f, indent=4)
``stage_csv_for_prodigy``
--------------------------
Formats outputs in CSV format for use with `Prodigy <https://prodi.gy/docs/api-loaders>`_. After running ``stage_csv_for_prodigy``, you can
write the results to a CSV file that is ready to be used with Prodigy.
Examples:
.. code:: python
from unstructured.documents.elements import Title, NarrativeText
from unstructured.staging.prodigy import stage_csv_for_prodigy
elements = [Title(text="Title"), NarrativeText(text="Narrative")]
metadata = [{"type": "title"}, {"source": "news"}]
prodigy_csv_data = stage_csv_for_prodigy(elements, metadata)
# The resulting CSV file is ready to be used with Prodigy
with open("prodigy.csv", "w") as csv_file:
csv_file.write(prodigy_csv_data)

View File

@ -1,4 +1,6 @@
import pytest
import csv
import os
import unstructured.staging.prodigy as prodigy
from unstructured.documents.elements import Title, NarrativeText
@ -24,6 +26,41 @@ def metadata_with_invalid_length():
return [{"score": 0.1}, {"category": "paragraph"}, {"type": "text"}]
@pytest.fixture
def output_csv_file(tmp_path):
return os.path.join(tmp_path, "prodigy_data.csv")
def test_validate_prodigy_metadata(elements):
validated_metadata = prodigy._validate_prodigy_metadata(elements, metadata=None)
assert len(validated_metadata) == len(elements)
assert all(not data for data in validated_metadata)
def test_validate_prodigy_metadata_with_valid_metadata(elements, valid_metadata):
validated_metadata = prodigy._validate_prodigy_metadata(elements, metadata=valid_metadata)
assert len(validated_metadata) == len(elements)
@pytest.mark.parametrize(
"invalid_metadata_fixture, exception_message",
[
("metadata_with_id", 'The key "id" is not allowed with metadata parameter at index: 1'),
(
"metadata_with_invalid_length",
"The length of metadata parameter does not match with length of elements parameter.",
),
],
)
def test_validate_prodigy_metadata_with_invalid_metadata(
elements, invalid_metadata_fixture, exception_message, request
):
invalid_metadata = request.getfixturevalue(invalid_metadata_fixture)
with pytest.raises(ValueError) as validation_exception:
prodigy._validate_prodigy_metadata(elements, invalid_metadata)
assert str(validation_exception.value) == exception_message
def test_convert_to_prodigy_data(elements):
prodigy_data = prodigy.stage_for_prodigy(elements)
@ -54,20 +91,24 @@ def test_convert_to_prodigy_data_with_valid_metadata(elements, valid_metadata):
assert prodigy_data[1]["meta"] == {"id": elements[1].id, **valid_metadata[1]}
@pytest.mark.parametrize(
"invalid_metadata_fixture, exception_message",
[
("metadata_with_id", 'The key "id" is not allowed with metadata parameter at index: 1'),
(
"metadata_with_invalid_length",
"The length of metadata parameter does not match with length of elements parameter.",
),
],
)
def test_convert_to_prodigy_data_with_invalid_metadata(
elements, invalid_metadata_fixture, exception_message, request
):
invalid_metadata = request.getfixturevalue(invalid_metadata_fixture)
with pytest.raises(ValueError) as validation_exception:
prodigy.stage_for_prodigy(elements, invalid_metadata)
assert str(validation_exception.value) == exception_message
def test_stage_csv_for_prodigy(elements, output_csv_file):
with open(output_csv_file, "w+") as csv_file:
prodigy_csv_string = prodigy.stage_csv_for_prodigy(elements)
csv_file.write(prodigy_csv_string)
fieldnames = ["text", "id"]
with open(output_csv_file, "r") as csv_file:
csv_rows = csv.DictReader(csv_file)
assert all(set(row.keys()) == set(fieldnames) for row in csv_rows)
def test_stage_csv_for_prodigy_with_metadata(elements, valid_metadata, output_csv_file):
with open(output_csv_file, "w+") as csv_file:
prodigy_csv_string = prodigy.stage_csv_for_prodigy(elements, valid_metadata)
csv_file.write(prodigy_csv_string)
fieldnames = set(["text", "id"]).union(*(data.keys() for data in valid_metadata))
fieldnames = [fieldname.lower() for fieldname in fieldnames]
with open(output_csv_file, "r") as csv_file:
csv_rows = csv.DictReader(csv_file)
assert all(set(row.keys()) == set(fieldnames) for row in csv_rows)

View File

@ -1 +1 @@
__version__ = "0.2.1-dev1" # pragma: no cover
__version__ = "0.2.1-dev2" # pragma: no cover

View File

@ -1,4 +1,6 @@
from typing import Iterable, List, Dict, Optional, Union
import io
from typing import Generator, Iterable, List, Dict, Optional, Union
import csv
from unstructured.documents.elements import Text
@ -6,15 +8,14 @@ from unstructured.documents.elements import Text
PRODIGY_TYPE = List[Dict[str, Union[str, Dict[str, str]]]]
def stage_for_prodigy(
def _validate_prodigy_metadata(
elements: List[Text],
metadata: Optional[List[Dict[str, str]]] = None,
) -> PRODIGY_TYPE:
) -> Iterable[Dict[str, str]]:
"""
Converts the document to the format required for use with Prodigy.
ref: https://prodi.gy/docs/api-loaders#input
Returns validated metadata list for Prodigy bricks.
Raises ValueError with error message if metadata is not valid.
"""
validated_metadata: Iterable[Dict[str, str]]
if metadata:
if len(metadata) != len(elements):
@ -33,6 +34,19 @@ def stage_for_prodigy(
validated_metadata = metadata
else:
validated_metadata = [dict() for _ in elements]
return validated_metadata
def stage_for_prodigy(
elements: List[Text],
metadata: Optional[List[Dict[str, str]]] = None,
) -> PRODIGY_TYPE:
"""
Converts the document to the JSON format required for use with Prodigy.
ref: https://prodi.gy/docs/api-loaders#input
"""
validated_metadata: Iterable[Dict[str, str]] = _validate_prodigy_metadata(elements, metadata)
prodigy_data: PRODIGY_TYPE = list()
for element, metadatum in zip(elements, validated_metadata):
@ -42,3 +56,36 @@ def stage_for_prodigy(
prodigy_data.append(data)
return prodigy_data
def stage_csv_for_prodigy(
elements: List[Text],
metadata: Optional[List[Dict[str, str]]] = None,
) -> str:
"""
Converts the document to the CSV format required for use with Prodigy.
ref: https://prodi.gy/docs/api-loaders#input
"""
validated_metadata: Iterable[Dict[str, str]] = _validate_prodigy_metadata(elements, metadata)
csv_fieldnames = ["text", "id"]
csv_fieldnames += list(
set().union(
*((key.lower() for key in metadata_item.keys()) for metadata_item in validated_metadata)
)
)
def _get_rows() -> Generator[Dict[str, str], None, None]:
for element, metadatum in zip(elements, validated_metadata):
metadatum = {key.lower(): value for key, value in metadatum.items()}
row_data = dict(text=element.text, **metadatum)
if isinstance(element.id, str):
row_data["id"] = element.id
yield row_data
with io.StringIO() as buffer:
csv_writer = csv.DictWriter(buffer, fieldnames=csv_fieldnames)
csv_writer.writeheader()
csv_rows = _get_rows()
csv_writer.writerows(csv_rows)
return buffer.getvalue()