feat: Implement Argilla staging brick (#81)

* Add argilla to dependencies and run pip-compile

* Implement Argilla staging brick and add unit tests

* Update version and changelog

* Update docs with description and usage for Argilla staging brick

* Remove unused fixtures and fix typo in Argilla tests

* add missing quote in docs

* changelog tweak

* doc tweaks

Co-authored-by: Matt Robinson <mrobinson@unstructuredai.io>
Co-authored-by: Matt Robinson <mrobinson@unstructured.io>
This commit is contained in:
asymness 2022-11-28 14:41:48 +00:00 committed by GitHub
parent d6623883dc
commit 2170a2aae2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 234 additions and 7 deletions

View File

@ -1,5 +1,6 @@
## 0.3.0-dev3
## 0.3.0-dev4
* Implement staging brick for Argilla.
* Removing the local PDF parsing code and any dependencies and tests.
* Reorganizes the staging bricks in the unstructured.partition module
* Allow entities to be passed into the Datasaur staging brick

View File

@ -868,6 +868,7 @@ Formats a list of ``Text`` elements as input to token based tasks in Datasaur.
Example:
.. code:: python
from unstructured.documents.elements import Text
from unstructured.staging.datasaur import stage_for_datasaur
@ -887,9 +888,34 @@ list for any elements that do not have any entities.
Example:
.. code:: python
from unstructured.documents.elements import Text
from unstructured.staging.datasaur import stage_for_datasaur
elements = [Text("Hi my name is Matt.")]
entities = [[{"text": "Matt", "type": "PER", "start_idx": 11, "end_idx": 15}]]
datasaur_data = stage_for_datasaur(elements, entities)
``stage_for_argilla``
--------------------------
Convert a list of ``Text`` elements to an `Argilla Dataset <https://docs.argilla.io/en/latest/reference/python/python_client.html#python-ref-datasets>`_.
The type of Argilla dataset to be generated can be specified with `argilla_task` parameter. Currently, only ``text_classification``
task type is supported.
Examples:
.. code:: python
import json
from unstructured.documents.elements import Title, NarrativeText
from unstructured.staging.argilla import stage_for_argilla
elements = [Title(text="Title"), NarrativeText(text="Narrative")]
metadata = [{"type": "title"}, {"type": "text"}]
# The resulting Argilla dataset is ready to be used with Argilla for training or labelling, etc.
argilla_dataset = stage_for_argilla(elements, "text_classification", metadata=metadata)

View File

@ -4,15 +4,65 @@
#
# pip-compile --output-file=requirements/base.txt
#
argilla==1.1.0
# via unstructured (setup.py)
backoff==2.2.1
# via argilla
certifi==2022.9.24
# via httpx
click==8.1.3
# via nltk
deprecated==1.2.13
# via argilla
h11==0.9.0
# via httpcore
httpcore==0.11.1
# via httpx
httpx==0.15.5
# via argilla
idna==3.4
# via rfc3986
joblib==1.2.0
# via nltk
lxml==4.9.1
# via unstructured (setup.py)
monotonic==1.6
# via argilla
nltk==3.7
# via unstructured (setup.py)
numpy==1.23.5
# via
# argilla
# pandas
packaging==21.3
# via argilla
pandas==1.5.2
# via argilla
pydantic==1.10.2
# via argilla
pyparsing==3.0.9
# via packaging
python-dateutil==2.8.2
# via pandas
pytz==2022.6
# via pandas
regex==2022.10.31
# via nltk
rfc3986[idna2008]==1.5.0
# via httpx
six==1.16.0
# via python-dateutil
sniffio==1.3.0
# via
# httpcore
# httpx
tqdm==4.64.1
# via nltk
# via
# argilla
# nltk
typing-extensions==4.4.0
# via pydantic
wrapt==1.13.3
# via
# argilla
# deprecated

View File

@ -4,34 +4,64 @@
#
# pip-compile --extra=huggingface --output-file=requirements/huggingface.txt
#
argilla==1.1.0
# via unstructured (setup.py)
backoff==2.2.1
# via argilla
certifi==2022.9.24
# via requests
# via
# httpx
# requests
charset-normalizer==2.1.1
# via requests
click==8.1.3
# via nltk
deprecated==1.2.13
# via argilla
filelock==3.8.0
# via
# huggingface-hub
# transformers
h11==0.9.0
# via httpcore
httpcore==0.11.1
# via httpx
httpx==0.15.5
# via argilla
huggingface-hub==0.10.1
# via transformers
idna==3.4
# via requests
# via
# requests
# rfc3986
joblib==1.2.0
# via nltk
lxml==4.9.1
# via unstructured (setup.py)
monotonic==1.6
# via argilla
nltk==3.7
# via unstructured (setup.py)
numpy==1.23.4
# via transformers
# via
# argilla
# pandas
# transformers
packaging==21.3
# via
# argilla
# huggingface-hub
# transformers
pandas==1.5.2
# via argilla
pydantic==1.10.2
# via argilla
pyparsing==3.0.9
# via packaging
python-dateutil==2.8.2
# via pandas
pytz==2022.6
# via pandas
pyyaml==6.0
# via
# huggingface-hub
@ -44,16 +74,31 @@ requests==2.28.1
# via
# huggingface-hub
# transformers
rfc3986[idna2008]==1.5.0
# via httpx
six==1.16.0
# via python-dateutil
sniffio==1.3.0
# via
# httpcore
# httpx
tokenizers==0.13.2
# via transformers
tqdm==4.64.1
# via
# argilla
# huggingface-hub
# nltk
# transformers
transformers==4.23.1
# via unstructured (setup.py)
typing-extensions==4.4.0
# via huggingface-hub
# via
# huggingface-hub
# pydantic
urllib3==1.26.12
# via requests
wrapt==1.13.3
# via
# argilla
# deprecated

View File

@ -50,6 +50,7 @@ setup(
install_requires=[
"lxml",
"nltk",
"argilla",
],
extras_require={
"huggingface": ["transformers"],

View File

@ -0,0 +1,50 @@
import pytest
import argilla as rg
import unstructured.staging.argilla as argilla
from unstructured.documents.elements import Title, NarrativeText
@pytest.fixture
def elements():
return [Title(text="example"), NarrativeText(text="another example")]
@pytest.mark.parametrize(
"task_name, dataset_type, extra_kwargs",
[
(
"text_classification",
rg.DatasetForTextClassification,
{"metadata": [{"type": "text1"}, {"type": "text2"}]},
),
(
"text_classification",
rg.DatasetForTextClassification,
{},
),
],
)
def test_stage_for_argilla(elements, task_name, dataset_type, extra_kwargs):
argilla_dataset = argilla.stage_for_argilla(elements, task_name, **extra_kwargs)
assert isinstance(argilla_dataset, dataset_type)
for record, element in zip(argilla_dataset, elements):
assert record.text == element.text
assert record.id == element.id
for kwarg in extra_kwargs:
assert getattr(record, kwarg) in extra_kwargs[kwarg]
@pytest.mark.parametrize(
"task_name, error, error_message, extra_kwargs",
[
("unknown_task", ValueError, "invalid value", {}),
("token_classification", NotImplementedError, None, {}),
("text2text", NotImplementedError, None, {}),
("text_classification", ValueError, "invalid value", {"metadata": "invalid metadata"}),
],
)
def test_invalid_stage_for_argilla(elements, task_name, error, error_message, extra_kwargs):
with pytest.raises(error) as e:
argilla.stage_for_argilla(elements, task_name, **extra_kwargs)
assert error_message in e.args[0].lower() if error_message else True

View File

@ -1 +1 @@
__version__ = "0.3.0-dev3" # pragma: no cover
__version__ = "0.3.0-dev4" # pragma: no cover

View File

@ -0,0 +1,54 @@
from typing import List, Union
from unstructured.documents.elements import Text
import argilla
from argilla.client.models import (
TextClassificationRecord,
TokenClassificationRecord,
Text2TextRecord,
)
def stage_for_argilla(
elements: List[Text],
argilla_task: str,
**record_kwargs,
) -> Union[
argilla.DatasetForTextClassification,
argilla.DatasetForTokenClassification,
argilla.DatasetForText2Text,
]:
ARGILLA_TASKS = {
"text_classification": (TextClassificationRecord, argilla.DatasetForTextClassification),
"token_classification": (TokenClassificationRecord, argilla.DatasetForTokenClassification),
"text2text": (Text2TextRecord, argilla.DatasetForText2Text),
}
try:
argilla_record_class, argilla_dataset_class = ARGILLA_TASKS[argilla_task]
except KeyError as e:
raise ValueError(
f'Invalid value "{e.args[0]}" specified for argilla_task. '
"Must be one of: {', '.join(ARGILLA_TASKS.keys())}."
)
if argilla_task in {"token_classification", "text2text"}:
raise NotImplementedError() # TODO: Implement token_classification and text2text tasks
for record_kwarg_key, record_kwarg_value in record_kwargs.items():
if type(record_kwarg_value) is not list or len(record_kwarg_value) != len(elements):
raise ValueError(
f'Invalid value specified for "{record_kwarg_key}" keyword argument.'
" Must be of type list and same length as elements list."
)
results: List[Union[TextClassificationRecord, TokenClassificationRecord, Text2TextRecord]] = []
for idx, element in enumerate(elements):
element_kwargs = {kwarg: record_kwargs[kwarg][idx] for kwarg in record_kwargs}
arguments = dict(**element_kwargs, text=element.text)
if isinstance(element.id, str):
arguments["id"] = element.id
results.append(argilla_record_class(**arguments))
return argilla_dataset_class(results)