mirror of
https://github.com/Unstructured-IO/unstructured.git
synced 2025-12-24 13:44:05 +00:00
feat: Implement Argilla staging brick (#81)
* Add argilla to dependencies and run pip-compile * Implement Argilla staging brick and add unit tests * Update version and changelog * Update docs with description and usage for Argilla staging brick * Remove unused fixtures and fix typo in Argilla tests * add missing quote in docs * changelog tweak * doc tweaks Co-authored-by: Matt Robinson <mrobinson@unstructuredai.io> Co-authored-by: Matt Robinson <mrobinson@unstructured.io>
This commit is contained in:
parent
d6623883dc
commit
2170a2aae2
@ -1,5 +1,6 @@
|
||||
## 0.3.0-dev3
|
||||
## 0.3.0-dev4
|
||||
|
||||
* Implement staging brick for Argilla.
|
||||
* Removing the local PDF parsing code and any dependencies and tests.
|
||||
* Reorganizes the staging bricks in the unstructured.partition module
|
||||
* Allow entities to be passed into the Datasaur staging brick
|
||||
|
||||
@ -868,6 +868,7 @@ Formats a list of ``Text`` elements as input to token based tasks in Datasaur.
|
||||
Example:
|
||||
|
||||
.. code:: python
|
||||
|
||||
from unstructured.documents.elements import Text
|
||||
from unstructured.staging.datasaur import stage_for_datasaur
|
||||
|
||||
@ -887,9 +888,34 @@ list for any elements that do not have any entities.
|
||||
Example:
|
||||
|
||||
.. code:: python
|
||||
|
||||
from unstructured.documents.elements import Text
|
||||
from unstructured.staging.datasaur import stage_for_datasaur
|
||||
|
||||
elements = [Text("Hi my name is Matt.")]
|
||||
entities = [[{"text": "Matt", "type": "PER", "start_idx": 11, "end_idx": 15}]]
|
||||
datasaur_data = stage_for_datasaur(elements, entities)
|
||||
|
||||
|
||||
``stage_for_argilla``
|
||||
--------------------------
|
||||
|
||||
Convert a list of ``Text`` elements to an `Argilla Dataset <https://docs.argilla.io/en/latest/reference/python/python_client.html#python-ref-datasets>`_.
|
||||
The type of Argilla dataset to be generated can be specified with `argilla_task` parameter. Currently, only ``text_classification``
|
||||
task type is supported.
|
||||
|
||||
|
||||
Examples:
|
||||
|
||||
.. code:: python
|
||||
|
||||
import json
|
||||
|
||||
from unstructured.documents.elements import Title, NarrativeText
|
||||
from unstructured.staging.argilla import stage_for_argilla
|
||||
|
||||
elements = [Title(text="Title"), NarrativeText(text="Narrative")]
|
||||
metadata = [{"type": "title"}, {"type": "text"}]
|
||||
|
||||
# The resulting Argilla dataset is ready to be used with Argilla for training or labelling, etc.
|
||||
argilla_dataset = stage_for_argilla(elements, "text_classification", metadata=metadata)
|
||||
|
||||
@ -4,15 +4,65 @@
|
||||
#
|
||||
# pip-compile --output-file=requirements/base.txt
|
||||
#
|
||||
argilla==1.1.0
|
||||
# via unstructured (setup.py)
|
||||
backoff==2.2.1
|
||||
# via argilla
|
||||
certifi==2022.9.24
|
||||
# via httpx
|
||||
click==8.1.3
|
||||
# via nltk
|
||||
deprecated==1.2.13
|
||||
# via argilla
|
||||
h11==0.9.0
|
||||
# via httpcore
|
||||
httpcore==0.11.1
|
||||
# via httpx
|
||||
httpx==0.15.5
|
||||
# via argilla
|
||||
idna==3.4
|
||||
# via rfc3986
|
||||
joblib==1.2.0
|
||||
# via nltk
|
||||
lxml==4.9.1
|
||||
# via unstructured (setup.py)
|
||||
monotonic==1.6
|
||||
# via argilla
|
||||
nltk==3.7
|
||||
# via unstructured (setup.py)
|
||||
numpy==1.23.5
|
||||
# via
|
||||
# argilla
|
||||
# pandas
|
||||
packaging==21.3
|
||||
# via argilla
|
||||
pandas==1.5.2
|
||||
# via argilla
|
||||
pydantic==1.10.2
|
||||
# via argilla
|
||||
pyparsing==3.0.9
|
||||
# via packaging
|
||||
python-dateutil==2.8.2
|
||||
# via pandas
|
||||
pytz==2022.6
|
||||
# via pandas
|
||||
regex==2022.10.31
|
||||
# via nltk
|
||||
rfc3986[idna2008]==1.5.0
|
||||
# via httpx
|
||||
six==1.16.0
|
||||
# via python-dateutil
|
||||
sniffio==1.3.0
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
tqdm==4.64.1
|
||||
# via nltk
|
||||
# via
|
||||
# argilla
|
||||
# nltk
|
||||
typing-extensions==4.4.0
|
||||
# via pydantic
|
||||
wrapt==1.13.3
|
||||
# via
|
||||
# argilla
|
||||
# deprecated
|
||||
|
||||
@ -4,34 +4,64 @@
|
||||
#
|
||||
# pip-compile --extra=huggingface --output-file=requirements/huggingface.txt
|
||||
#
|
||||
argilla==1.1.0
|
||||
# via unstructured (setup.py)
|
||||
backoff==2.2.1
|
||||
# via argilla
|
||||
certifi==2022.9.24
|
||||
# via requests
|
||||
# via
|
||||
# httpx
|
||||
# requests
|
||||
charset-normalizer==2.1.1
|
||||
# via requests
|
||||
click==8.1.3
|
||||
# via nltk
|
||||
deprecated==1.2.13
|
||||
# via argilla
|
||||
filelock==3.8.0
|
||||
# via
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
h11==0.9.0
|
||||
# via httpcore
|
||||
httpcore==0.11.1
|
||||
# via httpx
|
||||
httpx==0.15.5
|
||||
# via argilla
|
||||
huggingface-hub==0.10.1
|
||||
# via transformers
|
||||
idna==3.4
|
||||
# via requests
|
||||
# via
|
||||
# requests
|
||||
# rfc3986
|
||||
joblib==1.2.0
|
||||
# via nltk
|
||||
lxml==4.9.1
|
||||
# via unstructured (setup.py)
|
||||
monotonic==1.6
|
||||
# via argilla
|
||||
nltk==3.7
|
||||
# via unstructured (setup.py)
|
||||
numpy==1.23.4
|
||||
# via transformers
|
||||
# via
|
||||
# argilla
|
||||
# pandas
|
||||
# transformers
|
||||
packaging==21.3
|
||||
# via
|
||||
# argilla
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
pandas==1.5.2
|
||||
# via argilla
|
||||
pydantic==1.10.2
|
||||
# via argilla
|
||||
pyparsing==3.0.9
|
||||
# via packaging
|
||||
python-dateutil==2.8.2
|
||||
# via pandas
|
||||
pytz==2022.6
|
||||
# via pandas
|
||||
pyyaml==6.0
|
||||
# via
|
||||
# huggingface-hub
|
||||
@ -44,16 +74,31 @@ requests==2.28.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
rfc3986[idna2008]==1.5.0
|
||||
# via httpx
|
||||
six==1.16.0
|
||||
# via python-dateutil
|
||||
sniffio==1.3.0
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
tokenizers==0.13.2
|
||||
# via transformers
|
||||
tqdm==4.64.1
|
||||
# via
|
||||
# argilla
|
||||
# huggingface-hub
|
||||
# nltk
|
||||
# transformers
|
||||
transformers==4.23.1
|
||||
# via unstructured (setup.py)
|
||||
typing-extensions==4.4.0
|
||||
# via huggingface-hub
|
||||
# via
|
||||
# huggingface-hub
|
||||
# pydantic
|
||||
urllib3==1.26.12
|
||||
# via requests
|
||||
wrapt==1.13.3
|
||||
# via
|
||||
# argilla
|
||||
# deprecated
|
||||
|
||||
1
setup.py
1
setup.py
@ -50,6 +50,7 @@ setup(
|
||||
install_requires=[
|
||||
"lxml",
|
||||
"nltk",
|
||||
"argilla",
|
||||
],
|
||||
extras_require={
|
||||
"huggingface": ["transformers"],
|
||||
|
||||
50
test_unstructured/staging/test_argilla.py
Normal file
50
test_unstructured/staging/test_argilla.py
Normal file
@ -0,0 +1,50 @@
|
||||
import pytest
|
||||
|
||||
import argilla as rg
|
||||
import unstructured.staging.argilla as argilla
|
||||
from unstructured.documents.elements import Title, NarrativeText
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def elements():
|
||||
return [Title(text="example"), NarrativeText(text="another example")]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"task_name, dataset_type, extra_kwargs",
|
||||
[
|
||||
(
|
||||
"text_classification",
|
||||
rg.DatasetForTextClassification,
|
||||
{"metadata": [{"type": "text1"}, {"type": "text2"}]},
|
||||
),
|
||||
(
|
||||
"text_classification",
|
||||
rg.DatasetForTextClassification,
|
||||
{},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_stage_for_argilla(elements, task_name, dataset_type, extra_kwargs):
|
||||
argilla_dataset = argilla.stage_for_argilla(elements, task_name, **extra_kwargs)
|
||||
assert isinstance(argilla_dataset, dataset_type)
|
||||
for record, element in zip(argilla_dataset, elements):
|
||||
assert record.text == element.text
|
||||
assert record.id == element.id
|
||||
for kwarg in extra_kwargs:
|
||||
assert getattr(record, kwarg) in extra_kwargs[kwarg]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"task_name, error, error_message, extra_kwargs",
|
||||
[
|
||||
("unknown_task", ValueError, "invalid value", {}),
|
||||
("token_classification", NotImplementedError, None, {}),
|
||||
("text2text", NotImplementedError, None, {}),
|
||||
("text_classification", ValueError, "invalid value", {"metadata": "invalid metadata"}),
|
||||
],
|
||||
)
|
||||
def test_invalid_stage_for_argilla(elements, task_name, error, error_message, extra_kwargs):
|
||||
with pytest.raises(error) as e:
|
||||
argilla.stage_for_argilla(elements, task_name, **extra_kwargs)
|
||||
assert error_message in e.args[0].lower() if error_message else True
|
||||
@ -1 +1 @@
|
||||
__version__ = "0.3.0-dev3" # pragma: no cover
|
||||
__version__ = "0.3.0-dev4" # pragma: no cover
|
||||
|
||||
54
unstructured/staging/argilla.py
Normal file
54
unstructured/staging/argilla.py
Normal file
@ -0,0 +1,54 @@
|
||||
from typing import List, Union
|
||||
from unstructured.documents.elements import Text
|
||||
import argilla
|
||||
from argilla.client.models import (
|
||||
TextClassificationRecord,
|
||||
TokenClassificationRecord,
|
||||
Text2TextRecord,
|
||||
)
|
||||
|
||||
|
||||
def stage_for_argilla(
|
||||
elements: List[Text],
|
||||
argilla_task: str,
|
||||
**record_kwargs,
|
||||
) -> Union[
|
||||
argilla.DatasetForTextClassification,
|
||||
argilla.DatasetForTokenClassification,
|
||||
argilla.DatasetForText2Text,
|
||||
]:
|
||||
ARGILLA_TASKS = {
|
||||
"text_classification": (TextClassificationRecord, argilla.DatasetForTextClassification),
|
||||
"token_classification": (TokenClassificationRecord, argilla.DatasetForTokenClassification),
|
||||
"text2text": (Text2TextRecord, argilla.DatasetForText2Text),
|
||||
}
|
||||
|
||||
try:
|
||||
argilla_record_class, argilla_dataset_class = ARGILLA_TASKS[argilla_task]
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
f'Invalid value "{e.args[0]}" specified for argilla_task. '
|
||||
"Must be one of: {', '.join(ARGILLA_TASKS.keys())}."
|
||||
)
|
||||
|
||||
if argilla_task in {"token_classification", "text2text"}:
|
||||
raise NotImplementedError() # TODO: Implement token_classification and text2text tasks
|
||||
|
||||
for record_kwarg_key, record_kwarg_value in record_kwargs.items():
|
||||
if type(record_kwarg_value) is not list or len(record_kwarg_value) != len(elements):
|
||||
raise ValueError(
|
||||
f'Invalid value specified for "{record_kwarg_key}" keyword argument.'
|
||||
" Must be of type list and same length as elements list."
|
||||
)
|
||||
|
||||
results: List[Union[TextClassificationRecord, TokenClassificationRecord, Text2TextRecord]] = []
|
||||
|
||||
for idx, element in enumerate(elements):
|
||||
element_kwargs = {kwarg: record_kwargs[kwarg][idx] for kwarg in record_kwargs}
|
||||
arguments = dict(**element_kwargs, text=element.text)
|
||||
if isinstance(element.id, str):
|
||||
arguments["id"] = element.id
|
||||
|
||||
results.append(argilla_record_class(**arguments))
|
||||
|
||||
return argilla_dataset_class(results)
|
||||
Loading…
x
Reference in New Issue
Block a user