feat: Adds support for zero-shot document classification (#7669) (#8193)

* feat: adds support for zero short document classification (#7669)

Also, supports multi-label classification

* pytests for zero shot document classification

* release note

* added licence info to py scripts

* updated the format of licence info

* Added doc string and example code

* added review points highlighted in the PR

* feat: adds support for zero short document classification (#7669)

Also, supports multi-label classification

* pytests for zero shot document classification

* release note

* added licence info to py scripts

* updated the format of licence info

* Added doc string and example code

* added review points highlighted in the PR

* Applied suggestions from doc string review

Co-authored-by: Daria Fokina <daria.f93@gmail.com>

* fixed pytest for init

* added output type

* added test for pipeline (de-) serialization

---------

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
Co-authored-by: Daria Fokina <daria.f93@gmail.com>
This commit is contained in:
jpatra72 2024-09-10 11:00:05 +02:00 committed by GitHub
parent da49e782e2
commit b126c14e51
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 420 additions and 1 deletions

View File

@ -3,5 +3,6 @@
# SPDX-License-Identifier: Apache-2.0
from haystack.components.classifiers.document_language_classifier import DocumentLanguageClassifier
from haystack.components.classifiers.zero_shot_document_classifier import TransformersZeroShotDocumentClassifier
__all__ = ["DocumentLanguageClassifier"]
__all__ = ["DocumentLanguageClassifier", "TransformersZeroShotDocumentClassifier"]

View File

@ -0,0 +1,246 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.lazy_imports import LazyImport
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_pipeline_kwargs, serialize_hf_model_kwargs
logger = logging.getLogger(__name__)
with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import:
from transformers import pipeline
@component
class TransformersZeroShotDocumentClassifier:
"""
Performs zero-shot classification of documents based on given labels and adds the predicted label to their metadata.
The component uses a Hugging Face pipeline for zero-shot classification.
Provide the model and the set of labels to be used for categorization during initialization.
Additionally, you can configure the component to allow multiple labels to be true.
Classification is run on the document's content field by default. If you want it to run on another field, set the
`classification_field` to one of the document's metadata fields.
Available models for the task of zero-shot-classification include:
- `valhalla/distilbart-mnli-12-3`
- `cross-encoder/nli-distilroberta-base`
- `cross-encoder/nli-deberta-v3-xsmall`
### Usage example
The following is a pipeline that classifies documents based on predefined classification labels
retrieved from a search pipeline:
```python
from haystack import Document
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.core.pipeline import Pipeline
from haystack.components.classifiers import TransformersZeroShotDocumentClassifier
documents = [Document(id="0", content="Today was a nice day!"),
Document(id="1", content="Yesterday was a bad day!")]
document_store = InMemoryDocumentStore()
retriever = InMemoryBM25Retriever(document_store=document_store)
document_classifier = TransformersZeroShotDocumentClassifier(
model="cross-encoder/nli-deberta-v3-xsmall",
labels=["positive", "negative"],
)
document_store.write_documents(documents)
pipeline = Pipeline()
pipeline.add_component(instance=retriever, name="retriever")
pipeline.add_component(instance=document_classifier, name="document_classifier")
pipeline.connect("retriever", "document_classifier")
queries = ["How was your day today?", "How was your day yesterday?"]
expected_predictions = ["positive", "negative"]
for idx, query in enumerate(queries):
result = pipeline.run({"retriever": {"query": query, "top_k": 1}})
assert result["document_classifier"]["documents"][0].to_dict()["id"] == str(idx)
assert (result["document_classifier"]["documents"][0].to_dict()["classification"]["label"]
== expected_predictions[idx])
```
"""
def __init__(
self,
model: str,
labels: List[str],
multi_label: bool = False,
classification_field: Optional[str] = None,
device: Optional[ComponentDevice] = None,
token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Initializes the TransformersZeroShotDocumentClassifier.
See the Hugging Face [website](https://huggingface.co/models?pipeline_tag=zero-shot-classification&sort=downloads&search=nli)
for the full list of zero-shot classification models (NLI) models.
:param model:
The name or path of a Hugging Face model for zero shot document classification.
:param labels:
The set of possible class labels to classify each document into, for example,
["positive", "negative"]. The labels depend on the selected model.
:param multi_label:
Whether or not multiple candidate labels can be true.
If `False`, the scores are normalized such that
the sum of the label likelihoods for each sequence is 1. If `True`, the labels are considered
independent and probabilities are normalized for each candidate by doing a softmax of the entailment
score vs. the contradiction score.
:param classification_field:
Name of document's meta field to be used for classification.
If not set, `Document.content` is used by default.
:param device:
The device on which the model is loaded. If `None`, the default device is automatically
selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
:param token:
The Hugging Face token to use as HTTP bearer authorization.
Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
:param huggingface_pipeline_kwargs:
Dictionary containing keyword arguments used to initialize the
Hugging Face pipeline for text classification.
"""
torch_and_transformers_import.check()
self.classification_field = classification_field
self.token = token
self.labels = labels
self.multi_label = multi_label
component.set_output_types(self, **{label: str for label in labels})
huggingface_pipeline_kwargs = resolve_hf_pipeline_kwargs(
huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {},
model=model,
task="zero-shot-classification",
supported_tasks=["zero-shot-classification"],
device=device,
token=token,
)
self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
self.pipeline = None
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
if isinstance(self.huggingface_pipeline_kwargs["model"], str):
return {"model": self.huggingface_pipeline_kwargs["model"]}
return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
def warm_up(self):
"""
Initializes the component.
"""
if self.pipeline is None:
self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)
def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
serialization_dict = default_to_dict(
self,
labels=self.labels,
model=self.huggingface_pipeline_kwargs["model"],
huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
token=self.token.to_dict() if self.token else None,
)
huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
huggingface_pipeline_kwargs.pop("token", None)
serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
return serialization_dict
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "TransformersZeroShotDocumentClassifier":
"""
Deserializes the component from a dictionary.
:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
if data["init_parameters"].get("huggingface_pipeline_kwargs") is not None:
deserialize_hf_model_kwargs(data["init_parameters"]["huggingface_pipeline_kwargs"])
return default_from_dict(cls, data)
@component.output_types(documents=List[Document])
def run(self, documents: List[Document], batch_size: int = 1):
"""
Classifies the documents based on the provided labels and adds them to their metadata.
The classification results are stored in the `classification` dict within
each document's metadata. If `multi_label` is set to `True`, the scores for each label are available under
the `details` key within the `classification` dictionary.
:param documents:
Documents to process.
:param batch_size:
Batch size used for processing the content in each document.
:returns:
A dictionary with the following key:
- `documents`: A list of documents with an added metadata field called `classification`.
"""
if self.pipeline is None:
raise RuntimeError(
"The component TransformerZeroShotDocumentClassifier wasn't warmed up. "
"Run 'warm_up()' before calling 'run()'."
)
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
raise TypeError(
"DocumentLanguageClassifier expects a list of documents as input. "
"In case you want to classify a text, please use the TextLanguageClassifier."
)
invalid_doc_ids = []
for doc in documents:
if self.classification_field is not None and self.classification_field not in doc.meta:
invalid_doc_ids.append(doc.id)
if invalid_doc_ids:
raise ValueError(
f"The following documents do not have the classification field '{self.classification_field}': "
f"{', '.join(invalid_doc_ids)}"
)
texts = [
doc.content if self.classification_field is None else doc.meta[self.classification_field]
for doc in documents
]
predictions = self.pipeline(texts, self.labels, multi_label=self.multi_label, batch_size=batch_size)
for prediction, document in zip(predictions, documents):
formatted_prediction = {
"label": prediction["labels"][0],
"score": prediction["scores"][0],
"details": dict(zip(prediction["labels"], prediction["scores"])),
}
document.meta["classification"] = formatted_prediction
return {"documents": documents}

View File

@ -0,0 +1,7 @@
---
highlights: >
Adds support for zero shot document classification
features:
- |
Adds support for zero shot document classification. This allows you to classify documents into user-defined
classes (binary and multi-label classification) using pre-trained models from huggingface.

View File

@ -0,0 +1,165 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import pytest
from unittest.mock import patch
from haystack import Document, Pipeline
from haystack.components.classifiers import TransformersZeroShotDocumentClassifier
from haystack.components.retrievers import InMemoryBM25Retriever
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.utils import ComponentDevice, Secret
class TestTransformersZeroShotDocumentClassifier:
def test_init(self):
component = TransformersZeroShotDocumentClassifier(
model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
)
assert component.labels == ["positive", "negative"]
assert component.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
assert component.multi_label is False
assert component.pipeline is None
assert component.classification_field is None
def test_to_dict(self):
component = TransformersZeroShotDocumentClassifier(
model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
)
component_dict = component.to_dict()
assert component_dict == {
"type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier",
"init_parameters": {
"model": "cross-encoder/nli-deberta-v3-xsmall",
"labels": ["positive", "negative"],
"token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
"huggingface_pipeline_kwargs": {
"model": "cross-encoder/nli-deberta-v3-xsmall",
"device": ComponentDevice.resolve_device(None).to_hf(),
"task": "zero-shot-classification",
},
},
}
def test_from_dict(self, monkeypatch):
monkeypatch.delenv("HF_API_TOKEN", raising=False)
data = {
"type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier",
"init_parameters": {
"model": "cross-encoder/nli-deberta-v3-xsmall",
"labels": ["positive", "negative"],
"token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
"huggingface_pipeline_kwargs": {
"model": "cross-encoder/nli-deberta-v3-xsmall",
"device": ComponentDevice.resolve_device(None).to_hf(),
"task": "zero-shot-classification",
},
},
}
component = TransformersZeroShotDocumentClassifier.from_dict(data)
assert component.labels == ["positive", "negative"]
assert component.pipeline is None
assert component.token == Secret.from_dict(
{"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"}
)
assert component.huggingface_pipeline_kwargs == {
"model": "cross-encoder/nli-deberta-v3-xsmall",
"device": ComponentDevice.resolve_device(None).to_hf(),
"task": "zero-shot-classification",
"token": None,
}
def test_from_dict_no_default_parameters(self, monkeypatch):
monkeypatch.delenv("HF_API_TOKEN", raising=False)
data = {
"type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier",
"init_parameters": {"model": "cross-encoder/nli-deberta-v3-xsmall", "labels": ["positive", "negative"]},
}
component = TransformersZeroShotDocumentClassifier.from_dict(data)
assert component.labels == ["positive", "negative"]
assert component.pipeline is None
assert component.token == Secret.from_dict(
{"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"}
)
assert component.huggingface_pipeline_kwargs == {
"model": "cross-encoder/nli-deberta-v3-xsmall",
"device": ComponentDevice.resolve_device(None).to_hf(),
"task": "zero-shot-classification",
"token": None,
}
@patch("haystack.components.classifiers.zero_shot_document_classifier.pipeline")
def test_warm_up(self, hf_pipeline_mock):
component = TransformersZeroShotDocumentClassifier(
model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
)
component.warm_up()
assert component.pipeline is not None
def test_run_fails_without_warm_up(self):
component = TransformersZeroShotDocumentClassifier(
model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
)
positive_documents = [Document(content="That's good. I like it.")]
with pytest.raises(RuntimeError):
component.run(documents=positive_documents)
@patch("haystack.components.classifiers.zero_shot_document_classifier.pipeline")
def test_run_fails_with_non_document_input(self, hf_pipeline_mock):
hf_pipeline_mock.return_value = " "
component = TransformersZeroShotDocumentClassifier(
model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
)
component.warm_up()
text_list = ["That's good. I like it.", "That's bad. I don't like it."]
with pytest.raises(TypeError):
component.run(documents=text_list)
@patch("haystack.components.classifiers.zero_shot_document_classifier.pipeline")
def test_run_unit(self, hf_pipeline_mock):
hf_pipeline_mock.return_value = [
{"sequence": "That's good. I like it.", "labels": ["positive", "negative"], "scores": [0.99, 0.01]},
{"sequence": "That's bad. I don't like it.", "labels": ["negative", "positive"], "scores": [0.99, 0.01]},
]
component = TransformersZeroShotDocumentClassifier(
model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
)
component.pipeline = hf_pipeline_mock
positive_document = Document(content="That's good. I like it.")
negative_document = Document(content="That's bad. I don't like it.")
result = component.run(documents=[positive_document, negative_document])
assert component.pipeline is not None
assert result["documents"][0].to_dict()["classification"]["label"] == "positive"
assert result["documents"][1].to_dict()["classification"]["label"] == "negative"
@pytest.mark.integration
def test_run(self):
component = TransformersZeroShotDocumentClassifier(
model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
)
component.warm_up()
positive_document = Document(content="That's good. I like it. " * 1000)
negative_document = Document(content="That's bad. I don't like it.")
result = component.run(documents=[positive_document, negative_document])
assert component.pipeline is not None
assert result["documents"][0].to_dict()["classification"]["label"] == "positive"
assert result["documents"][1].to_dict()["classification"]["label"] == "negative"
def test_serialization_and_deserialization_pipeline(self):
pipeline = Pipeline()
document_store = InMemoryDocumentStore()
retriever = InMemoryBM25Retriever(document_store=document_store)
document_classifier = TransformersZeroShotDocumentClassifier(
model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
)
pipeline.add_component(instance=retriever, name="retriever")
pipeline.add_component(instance=document_classifier, name="document_classifier")
pipeline.connect("retriever", "document_classifier")
pipeline_dump = pipeline.dumps()
new_pipeline = Pipeline.loads(pipeline_dump)
assert new_pipeline == pipeline