diff --git a/CHANGELOG.md b/CHANGELOG.md index 2addd26cd..7fa5c51c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ -## 0.2.1-dev1 +## 0.2.1-dev2 +* Added staging brick for CSV format for Prodigy * Added staging brick for Prodigy * Added text_field and id_field to stage_for_label_studio signature diff --git a/docs/source/bricks.rst b/docs/source/bricks.rst index bb18a6b38..fe2477299 100644 --- a/docs/source/bricks.rst +++ b/docs/source/bricks.rst @@ -364,7 +364,7 @@ Examples: ``stage_for_prodigy`` -------------------------- -Formats outputs for use with `Prodigy `_. After running ``stage_for_prodigy``, you can +Formats outputs in JSON format for use with `Prodigy `_. After running ``stage_for_prodigy``, you can write the results to a JSON file that is ready to be used with Prodigy. Examples: @@ -383,3 +383,25 @@ Examples: # The resulting JSON file is ready to be used with Prodigy with open("prodigy.json", "w") as f: json.dump(prodigy_data, f, indent=4) + + +``stage_csv_for_prodigy`` +-------------------------- + +Formats outputs in CSV format for use with `Prodigy `_. After running ``stage_csv_for_prodigy``, you can +write the results to a CSV file that is ready to be used with Prodigy. + +Examples: + +.. code:: python + + from unstructured.documents.elements import Title, NarrativeText + from unstructured.staging.prodigy import stage_csv_for_prodigy + + elements = [Title(text="Title"), NarrativeText(text="Narrative")] + metadata = [{"type": "title"}, {"source": "news"}] + prodigy_csv_data = stage_csv_for_prodigy(elements, metadata) + + # The resulting CSV file is ready to be used with Prodigy + with open("prodigy.csv", "w") as csv_file: + csv_file.write(prodigy_csv_data) diff --git a/test_unstructured/staging/test_prodigy.py b/test_unstructured/staging/test_prodigy.py index a7db379c3..ba4d1e155 100644 --- a/test_unstructured/staging/test_prodigy.py +++ b/test_unstructured/staging/test_prodigy.py @@ -1,4 +1,6 @@ import pytest +import csv +import os import unstructured.staging.prodigy as prodigy from unstructured.documents.elements import Title, NarrativeText @@ -24,6 +26,41 @@ def metadata_with_invalid_length(): return [{"score": 0.1}, {"category": "paragraph"}, {"type": "text"}] +@pytest.fixture +def output_csv_file(tmp_path): + return os.path.join(tmp_path, "prodigy_data.csv") + + +def test_validate_prodigy_metadata(elements): + validated_metadata = prodigy._validate_prodigy_metadata(elements, metadata=None) + assert len(validated_metadata) == len(elements) + assert all(not data for data in validated_metadata) + + +def test_validate_prodigy_metadata_with_valid_metadata(elements, valid_metadata): + validated_metadata = prodigy._validate_prodigy_metadata(elements, metadata=valid_metadata) + assert len(validated_metadata) == len(elements) + + +@pytest.mark.parametrize( + "invalid_metadata_fixture, exception_message", + [ + ("metadata_with_id", 'The key "id" is not allowed with metadata parameter at index: 1'), + ( + "metadata_with_invalid_length", + "The length of metadata parameter does not match with length of elements parameter.", + ), + ], +) +def test_validate_prodigy_metadata_with_invalid_metadata( + elements, invalid_metadata_fixture, exception_message, request +): + invalid_metadata = request.getfixturevalue(invalid_metadata_fixture) + with pytest.raises(ValueError) as validation_exception: + prodigy._validate_prodigy_metadata(elements, invalid_metadata) + assert str(validation_exception.value) == exception_message + + def test_convert_to_prodigy_data(elements): prodigy_data = prodigy.stage_for_prodigy(elements) @@ -54,20 +91,24 @@ def test_convert_to_prodigy_data_with_valid_metadata(elements, valid_metadata): assert prodigy_data[1]["meta"] == {"id": elements[1].id, **valid_metadata[1]} -@pytest.mark.parametrize( - "invalid_metadata_fixture, exception_message", - [ - ("metadata_with_id", 'The key "id" is not allowed with metadata parameter at index: 1'), - ( - "metadata_with_invalid_length", - "The length of metadata parameter does not match with length of elements parameter.", - ), - ], -) -def test_convert_to_prodigy_data_with_invalid_metadata( - elements, invalid_metadata_fixture, exception_message, request -): - invalid_metadata = request.getfixturevalue(invalid_metadata_fixture) - with pytest.raises(ValueError) as validation_exception: - prodigy.stage_for_prodigy(elements, invalid_metadata) - assert str(validation_exception.value) == exception_message +def test_stage_csv_for_prodigy(elements, output_csv_file): + with open(output_csv_file, "w+") as csv_file: + prodigy_csv_string = prodigy.stage_csv_for_prodigy(elements) + csv_file.write(prodigy_csv_string) + + fieldnames = ["text", "id"] + with open(output_csv_file, "r") as csv_file: + csv_rows = csv.DictReader(csv_file) + assert all(set(row.keys()) == set(fieldnames) for row in csv_rows) + + +def test_stage_csv_for_prodigy_with_metadata(elements, valid_metadata, output_csv_file): + with open(output_csv_file, "w+") as csv_file: + prodigy_csv_string = prodigy.stage_csv_for_prodigy(elements, valid_metadata) + csv_file.write(prodigy_csv_string) + + fieldnames = set(["text", "id"]).union(*(data.keys() for data in valid_metadata)) + fieldnames = [fieldname.lower() for fieldname in fieldnames] + with open(output_csv_file, "r") as csv_file: + csv_rows = csv.DictReader(csv_file) + assert all(set(row.keys()) == set(fieldnames) for row in csv_rows) diff --git a/unstructured/__version__.py b/unstructured/__version__.py index 6aa693e8d..20b14eae9 100644 --- a/unstructured/__version__.py +++ b/unstructured/__version__.py @@ -1 +1 @@ -__version__ = "0.2.1-dev1" # pragma: no cover +__version__ = "0.2.1-dev2" # pragma: no cover diff --git a/unstructured/staging/prodigy.py b/unstructured/staging/prodigy.py index b7eaf8f02..fe1de1211 100644 --- a/unstructured/staging/prodigy.py +++ b/unstructured/staging/prodigy.py @@ -1,4 +1,6 @@ -from typing import Iterable, List, Dict, Optional, Union +import io +from typing import Generator, Iterable, List, Dict, Optional, Union +import csv from unstructured.documents.elements import Text @@ -6,15 +8,14 @@ from unstructured.documents.elements import Text PRODIGY_TYPE = List[Dict[str, Union[str, Dict[str, str]]]] -def stage_for_prodigy( +def _validate_prodigy_metadata( elements: List[Text], metadata: Optional[List[Dict[str, str]]] = None, -) -> PRODIGY_TYPE: +) -> Iterable[Dict[str, str]]: """ - Converts the document to the format required for use with Prodigy. - ref: https://prodi.gy/docs/api-loaders#input + Returns validated metadata list for Prodigy bricks. + Raises ValueError with error message if metadata is not valid. """ - validated_metadata: Iterable[Dict[str, str]] if metadata: if len(metadata) != len(elements): @@ -33,6 +34,19 @@ def stage_for_prodigy( validated_metadata = metadata else: validated_metadata = [dict() for _ in elements] + return validated_metadata + + +def stage_for_prodigy( + elements: List[Text], + metadata: Optional[List[Dict[str, str]]] = None, +) -> PRODIGY_TYPE: + """ + Converts the document to the JSON format required for use with Prodigy. + ref: https://prodi.gy/docs/api-loaders#input + """ + + validated_metadata: Iterable[Dict[str, str]] = _validate_prodigy_metadata(elements, metadata) prodigy_data: PRODIGY_TYPE = list() for element, metadatum in zip(elements, validated_metadata): @@ -42,3 +56,36 @@ def stage_for_prodigy( prodigy_data.append(data) return prodigy_data + + +def stage_csv_for_prodigy( + elements: List[Text], + metadata: Optional[List[Dict[str, str]]] = None, +) -> str: + """ + Converts the document to the CSV format required for use with Prodigy. + ref: https://prodi.gy/docs/api-loaders#input + """ + validated_metadata: Iterable[Dict[str, str]] = _validate_prodigy_metadata(elements, metadata) + + csv_fieldnames = ["text", "id"] + csv_fieldnames += list( + set().union( + *((key.lower() for key in metadata_item.keys()) for metadata_item in validated_metadata) + ) + ) + + def _get_rows() -> Generator[Dict[str, str], None, None]: + for element, metadatum in zip(elements, validated_metadata): + metadatum = {key.lower(): value for key, value in metadatum.items()} + row_data = dict(text=element.text, **metadatum) + if isinstance(element.id, str): + row_data["id"] = element.id + yield row_data + + with io.StringIO() as buffer: + csv_writer = csv.DictWriter(buffer, fieldnames=csv_fieldnames) + csv_writer.writeheader() + csv_rows = _get_rows() + csv_writer.writerows(csv_rows) + return buffer.getvalue()