From 2170a2aae2ba3733e518df85fa0163b11394a91d Mon Sep 17 00:00:00 2001 From: asymness Date: Mon, 28 Nov 2022 14:41:48 +0000 Subject: [PATCH] feat: Implement Argilla staging brick (#81) * Add argilla to dependencies and run pip-compile * Implement Argilla staging brick and add unit tests * Update version and changelog * Update docs with description and usage for Argilla staging brick * Remove unused fixtures and fix typo in Argilla tests * add missing quote in docs * changelog tweak * doc tweaks Co-authored-by: Matt Robinson Co-authored-by: Matt Robinson --- CHANGELOG.md | 3 +- docs/source/bricks.rst | 26 +++++++++++ requirements/base.txt | 52 +++++++++++++++++++++- requirements/huggingface.txt | 53 ++++++++++++++++++++-- setup.py | 1 + test_unstructured/staging/test_argilla.py | 50 +++++++++++++++++++++ unstructured/__version__.py | 2 +- unstructured/staging/argilla.py | 54 +++++++++++++++++++++++ 8 files changed, 234 insertions(+), 7 deletions(-) create mode 100644 test_unstructured/staging/test_argilla.py create mode 100644 unstructured/staging/argilla.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 16ad1a47c..c345e3267 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ -## 0.3.0-dev3 +## 0.3.0-dev4 +* Implement staging brick for Argilla. * Removing the local PDF parsing code and any dependencies and tests. * Reorganizes the staging bricks in the unstructured.partition module * Allow entities to be passed into the Datasaur staging brick diff --git a/docs/source/bricks.rst b/docs/source/bricks.rst index 50ff65ea5..2d9ba3eda 100644 --- a/docs/source/bricks.rst +++ b/docs/source/bricks.rst @@ -868,6 +868,7 @@ Formats a list of ``Text`` elements as input to token based tasks in Datasaur. Example: .. code:: python + from unstructured.documents.elements import Text from unstructured.staging.datasaur import stage_for_datasaur @@ -887,9 +888,34 @@ list for any elements that do not have any entities. Example: .. code:: python + from unstructured.documents.elements import Text from unstructured.staging.datasaur import stage_for_datasaur elements = [Text("Hi my name is Matt.")] entities = [[{"text": "Matt", "type": "PER", "start_idx": 11, "end_idx": 15}]] datasaur_data = stage_for_datasaur(elements, entities) + + +``stage_for_argilla`` +-------------------------- + +Convert a list of ``Text`` elements to an `Argilla Dataset `_. +The type of Argilla dataset to be generated can be specified with `argilla_task` parameter. Currently, only ``text_classification`` +task type is supported. + + +Examples: + +.. code:: python + + import json + + from unstructured.documents.elements import Title, NarrativeText + from unstructured.staging.argilla import stage_for_argilla + + elements = [Title(text="Title"), NarrativeText(text="Narrative")] + metadata = [{"type": "title"}, {"type": "text"}] + + # The resulting Argilla dataset is ready to be used with Argilla for training or labelling, etc. + argilla_dataset = stage_for_argilla(elements, "text_classification", metadata=metadata) diff --git a/requirements/base.txt b/requirements/base.txt index f2ce772d1..1d890aecb 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -4,15 +4,65 @@ # # pip-compile --output-file=requirements/base.txt # +argilla==1.1.0 + # via unstructured (setup.py) +backoff==2.2.1 + # via argilla +certifi==2022.9.24 + # via httpx click==8.1.3 # via nltk +deprecated==1.2.13 + # via argilla +h11==0.9.0 + # via httpcore +httpcore==0.11.1 + # via httpx +httpx==0.15.5 + # via argilla +idna==3.4 + # via rfc3986 joblib==1.2.0 # via nltk lxml==4.9.1 # via unstructured (setup.py) +monotonic==1.6 + # via argilla nltk==3.7 # via unstructured (setup.py) +numpy==1.23.5 + # via + # argilla + # pandas +packaging==21.3 + # via argilla +pandas==1.5.2 + # via argilla +pydantic==1.10.2 + # via argilla +pyparsing==3.0.9 + # via packaging +python-dateutil==2.8.2 + # via pandas +pytz==2022.6 + # via pandas regex==2022.10.31 # via nltk +rfc3986[idna2008]==1.5.0 + # via httpx +six==1.16.0 + # via python-dateutil +sniffio==1.3.0 + # via + # httpcore + # httpx tqdm==4.64.1 - # via nltk + # via + # argilla + # nltk +typing-extensions==4.4.0 + # via pydantic +wrapt==1.13.3 + # via + # argilla + # deprecated diff --git a/requirements/huggingface.txt b/requirements/huggingface.txt index eb64851a3..ae79894e8 100644 --- a/requirements/huggingface.txt +++ b/requirements/huggingface.txt @@ -4,34 +4,64 @@ # # pip-compile --extra=huggingface --output-file=requirements/huggingface.txt # +argilla==1.1.0 + # via unstructured (setup.py) +backoff==2.2.1 + # via argilla certifi==2022.9.24 - # via requests + # via + # httpx + # requests charset-normalizer==2.1.1 # via requests click==8.1.3 # via nltk +deprecated==1.2.13 + # via argilla filelock==3.8.0 # via # huggingface-hub # transformers +h11==0.9.0 + # via httpcore +httpcore==0.11.1 + # via httpx +httpx==0.15.5 + # via argilla huggingface-hub==0.10.1 # via transformers idna==3.4 - # via requests + # via + # requests + # rfc3986 joblib==1.2.0 # via nltk lxml==4.9.1 # via unstructured (setup.py) +monotonic==1.6 + # via argilla nltk==3.7 # via unstructured (setup.py) numpy==1.23.4 - # via transformers + # via + # argilla + # pandas + # transformers packaging==21.3 # via + # argilla # huggingface-hub # transformers +pandas==1.5.2 + # via argilla +pydantic==1.10.2 + # via argilla pyparsing==3.0.9 # via packaging +python-dateutil==2.8.2 + # via pandas +pytz==2022.6 + # via pandas pyyaml==6.0 # via # huggingface-hub @@ -44,16 +74,31 @@ requests==2.28.1 # via # huggingface-hub # transformers +rfc3986[idna2008]==1.5.0 + # via httpx +six==1.16.0 + # via python-dateutil +sniffio==1.3.0 + # via + # httpcore + # httpx tokenizers==0.13.2 # via transformers tqdm==4.64.1 # via + # argilla # huggingface-hub # nltk # transformers transformers==4.23.1 # via unstructured (setup.py) typing-extensions==4.4.0 - # via huggingface-hub + # via + # huggingface-hub + # pydantic urllib3==1.26.12 # via requests +wrapt==1.13.3 + # via + # argilla + # deprecated diff --git a/setup.py b/setup.py index 404a7e8f3..b86d648e0 100644 --- a/setup.py +++ b/setup.py @@ -50,6 +50,7 @@ setup( install_requires=[ "lxml", "nltk", + "argilla", ], extras_require={ "huggingface": ["transformers"], diff --git a/test_unstructured/staging/test_argilla.py b/test_unstructured/staging/test_argilla.py new file mode 100644 index 000000000..46e5b9426 --- /dev/null +++ b/test_unstructured/staging/test_argilla.py @@ -0,0 +1,50 @@ +import pytest + +import argilla as rg +import unstructured.staging.argilla as argilla +from unstructured.documents.elements import Title, NarrativeText + + +@pytest.fixture +def elements(): + return [Title(text="example"), NarrativeText(text="another example")] + + +@pytest.mark.parametrize( + "task_name, dataset_type, extra_kwargs", + [ + ( + "text_classification", + rg.DatasetForTextClassification, + {"metadata": [{"type": "text1"}, {"type": "text2"}]}, + ), + ( + "text_classification", + rg.DatasetForTextClassification, + {}, + ), + ], +) +def test_stage_for_argilla(elements, task_name, dataset_type, extra_kwargs): + argilla_dataset = argilla.stage_for_argilla(elements, task_name, **extra_kwargs) + assert isinstance(argilla_dataset, dataset_type) + for record, element in zip(argilla_dataset, elements): + assert record.text == element.text + assert record.id == element.id + for kwarg in extra_kwargs: + assert getattr(record, kwarg) in extra_kwargs[kwarg] + + +@pytest.mark.parametrize( + "task_name, error, error_message, extra_kwargs", + [ + ("unknown_task", ValueError, "invalid value", {}), + ("token_classification", NotImplementedError, None, {}), + ("text2text", NotImplementedError, None, {}), + ("text_classification", ValueError, "invalid value", {"metadata": "invalid metadata"}), + ], +) +def test_invalid_stage_for_argilla(elements, task_name, error, error_message, extra_kwargs): + with pytest.raises(error) as e: + argilla.stage_for_argilla(elements, task_name, **extra_kwargs) + assert error_message in e.args[0].lower() if error_message else True diff --git a/unstructured/__version__.py b/unstructured/__version__.py index 3888d5ff8..2778512d4 100644 --- a/unstructured/__version__.py +++ b/unstructured/__version__.py @@ -1 +1 @@ -__version__ = "0.3.0-dev3" # pragma: no cover +__version__ = "0.3.0-dev4" # pragma: no cover diff --git a/unstructured/staging/argilla.py b/unstructured/staging/argilla.py new file mode 100644 index 000000000..cd1b043c9 --- /dev/null +++ b/unstructured/staging/argilla.py @@ -0,0 +1,54 @@ +from typing import List, Union +from unstructured.documents.elements import Text +import argilla +from argilla.client.models import ( + TextClassificationRecord, + TokenClassificationRecord, + Text2TextRecord, +) + + +def stage_for_argilla( + elements: List[Text], + argilla_task: str, + **record_kwargs, +) -> Union[ + argilla.DatasetForTextClassification, + argilla.DatasetForTokenClassification, + argilla.DatasetForText2Text, +]: + ARGILLA_TASKS = { + "text_classification": (TextClassificationRecord, argilla.DatasetForTextClassification), + "token_classification": (TokenClassificationRecord, argilla.DatasetForTokenClassification), + "text2text": (Text2TextRecord, argilla.DatasetForText2Text), + } + + try: + argilla_record_class, argilla_dataset_class = ARGILLA_TASKS[argilla_task] + except KeyError as e: + raise ValueError( + f'Invalid value "{e.args[0]}" specified for argilla_task. ' + "Must be one of: {', '.join(ARGILLA_TASKS.keys())}." + ) + + if argilla_task in {"token_classification", "text2text"}: + raise NotImplementedError() # TODO: Implement token_classification and text2text tasks + + for record_kwarg_key, record_kwarg_value in record_kwargs.items(): + if type(record_kwarg_value) is not list or len(record_kwarg_value) != len(elements): + raise ValueError( + f'Invalid value specified for "{record_kwarg_key}" keyword argument.' + " Must be of type list and same length as elements list." + ) + + results: List[Union[TextClassificationRecord, TokenClassificationRecord, Text2TextRecord]] = [] + + for idx, element in enumerate(elements): + element_kwargs = {kwarg: record_kwargs[kwarg][idx] for kwarg in record_kwargs} + arguments = dict(**element_kwargs, text=element.text) + if isinstance(element.id, str): + arguments["id"] = element.id + + results.append(argilla_record_class(**arguments)) + + return argilla_dataset_class(results)