mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-27 15:08:43 +00:00
* feat: adds support for zero short document classification (#7669) Also, supports multi-label classification * pytests for zero shot document classification * release note * added licence info to py scripts * updated the format of licence info * Added doc string and example code * added review points highlighted in the PR * feat: adds support for zero short document classification (#7669) Also, supports multi-label classification * pytests for zero shot document classification * release note * added licence info to py scripts * updated the format of licence info * Added doc string and example code * added review points highlighted in the PR * Applied suggestions from doc string review Co-authored-by: Daria Fokina <daria.f93@gmail.com> * fixed pytest for init * added output type * added test for pipeline (de-) serialization --------- Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> Co-authored-by: Daria Fokina <daria.f93@gmail.com>
This commit is contained in:
parent
da49e782e2
commit
b126c14e51
@ -3,5 +3,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from haystack.components.classifiers.document_language_classifier import DocumentLanguageClassifier
|
||||
from haystack.components.classifiers.zero_shot_document_classifier import TransformersZeroShotDocumentClassifier
|
||||
|
||||
__all__ = ["DocumentLanguageClassifier"]
|
||||
__all__ = ["DocumentLanguageClassifier", "TransformersZeroShotDocumentClassifier"]
|
||||
|
||||
246
haystack/components/classifiers/zero_shot_document_classifier.py
Normal file
246
haystack/components/classifiers/zero_shot_document_classifier.py
Normal file
@ -0,0 +1,246 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from haystack import Document, component, default_from_dict, default_to_dict, logging
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
|
||||
from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_pipeline_kwargs, serialize_hf_model_kwargs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import:
|
||||
from transformers import pipeline
|
||||
|
||||
|
||||
@component
|
||||
class TransformersZeroShotDocumentClassifier:
|
||||
"""
|
||||
Performs zero-shot classification of documents based on given labels and adds the predicted label to their metadata.
|
||||
|
||||
The component uses a Hugging Face pipeline for zero-shot classification.
|
||||
Provide the model and the set of labels to be used for categorization during initialization.
|
||||
Additionally, you can configure the component to allow multiple labels to be true.
|
||||
|
||||
Classification is run on the document's content field by default. If you want it to run on another field, set the
|
||||
`classification_field` to one of the document's metadata fields.
|
||||
|
||||
Available models for the task of zero-shot-classification include:
|
||||
- `valhalla/distilbart-mnli-12-3`
|
||||
- `cross-encoder/nli-distilroberta-base`
|
||||
- `cross-encoder/nli-deberta-v3-xsmall`
|
||||
|
||||
### Usage example
|
||||
|
||||
The following is a pipeline that classifies documents based on predefined classification labels
|
||||
retrieved from a search pipeline:
|
||||
|
||||
```python
|
||||
from haystack import Document
|
||||
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
|
||||
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.components.classifiers import TransformersZeroShotDocumentClassifier
|
||||
|
||||
documents = [Document(id="0", content="Today was a nice day!"),
|
||||
Document(id="1", content="Yesterday was a bad day!")]
|
||||
|
||||
document_store = InMemoryDocumentStore()
|
||||
retriever = InMemoryBM25Retriever(document_store=document_store)
|
||||
document_classifier = TransformersZeroShotDocumentClassifier(
|
||||
model="cross-encoder/nli-deberta-v3-xsmall",
|
||||
labels=["positive", "negative"],
|
||||
)
|
||||
|
||||
document_store.write_documents(documents)
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_component(instance=retriever, name="retriever")
|
||||
pipeline.add_component(instance=document_classifier, name="document_classifier")
|
||||
pipeline.connect("retriever", "document_classifier")
|
||||
|
||||
queries = ["How was your day today?", "How was your day yesterday?"]
|
||||
expected_predictions = ["positive", "negative"]
|
||||
|
||||
for idx, query in enumerate(queries):
|
||||
result = pipeline.run({"retriever": {"query": query, "top_k": 1}})
|
||||
assert result["document_classifier"]["documents"][0].to_dict()["id"] == str(idx)
|
||||
assert (result["document_classifier"]["documents"][0].to_dict()["classification"]["label"]
|
||||
== expected_predictions[idx])
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
labels: List[str],
|
||||
multi_label: bool = False,
|
||||
classification_field: Optional[str] = None,
|
||||
device: Optional[ComponentDevice] = None,
|
||||
token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
|
||||
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Initializes the TransformersZeroShotDocumentClassifier.
|
||||
|
||||
See the Hugging Face [website](https://huggingface.co/models?pipeline_tag=zero-shot-classification&sort=downloads&search=nli)
|
||||
for the full list of zero-shot classification models (NLI) models.
|
||||
|
||||
:param model:
|
||||
The name or path of a Hugging Face model for zero shot document classification.
|
||||
:param labels:
|
||||
The set of possible class labels to classify each document into, for example,
|
||||
["positive", "negative"]. The labels depend on the selected model.
|
||||
:param multi_label:
|
||||
Whether or not multiple candidate labels can be true.
|
||||
If `False`, the scores are normalized such that
|
||||
the sum of the label likelihoods for each sequence is 1. If `True`, the labels are considered
|
||||
independent and probabilities are normalized for each candidate by doing a softmax of the entailment
|
||||
score vs. the contradiction score.
|
||||
:param classification_field:
|
||||
Name of document's meta field to be used for classification.
|
||||
If not set, `Document.content` is used by default.
|
||||
:param device:
|
||||
The device on which the model is loaded. If `None`, the default device is automatically
|
||||
selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
|
||||
:param token:
|
||||
The Hugging Face token to use as HTTP bearer authorization.
|
||||
Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
|
||||
:param huggingface_pipeline_kwargs:
|
||||
Dictionary containing keyword arguments used to initialize the
|
||||
Hugging Face pipeline for text classification.
|
||||
"""
|
||||
|
||||
torch_and_transformers_import.check()
|
||||
|
||||
self.classification_field = classification_field
|
||||
|
||||
self.token = token
|
||||
self.labels = labels
|
||||
self.multi_label = multi_label
|
||||
component.set_output_types(self, **{label: str for label in labels})
|
||||
|
||||
huggingface_pipeline_kwargs = resolve_hf_pipeline_kwargs(
|
||||
huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {},
|
||||
model=model,
|
||||
task="zero-shot-classification",
|
||||
supported_tasks=["zero-shot-classification"],
|
||||
device=device,
|
||||
token=token,
|
||||
)
|
||||
|
||||
self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
|
||||
self.pipeline = None
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Data that is sent to Posthog for usage analytics.
|
||||
"""
|
||||
if isinstance(self.huggingface_pipeline_kwargs["model"], str):
|
||||
return {"model": self.huggingface_pipeline_kwargs["model"]}
|
||||
return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
|
||||
|
||||
def warm_up(self):
|
||||
"""
|
||||
Initializes the component.
|
||||
"""
|
||||
if self.pipeline is None:
|
||||
self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes the component to a dictionary.
|
||||
|
||||
:returns:
|
||||
Dictionary with serialized data.
|
||||
"""
|
||||
serialization_dict = default_to_dict(
|
||||
self,
|
||||
labels=self.labels,
|
||||
model=self.huggingface_pipeline_kwargs["model"],
|
||||
huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
|
||||
token=self.token.to_dict() if self.token else None,
|
||||
)
|
||||
|
||||
huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
|
||||
huggingface_pipeline_kwargs.pop("token", None)
|
||||
|
||||
serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
|
||||
return serialization_dict
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "TransformersZeroShotDocumentClassifier":
|
||||
"""
|
||||
Deserializes the component from a dictionary.
|
||||
|
||||
:param data:
|
||||
Dictionary to deserialize from.
|
||||
:returns:
|
||||
Deserialized component.
|
||||
"""
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||
if data["init_parameters"].get("huggingface_pipeline_kwargs") is not None:
|
||||
deserialize_hf_model_kwargs(data["init_parameters"]["huggingface_pipeline_kwargs"])
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(self, documents: List[Document], batch_size: int = 1):
|
||||
"""
|
||||
Classifies the documents based on the provided labels and adds them to their metadata.
|
||||
|
||||
The classification results are stored in the `classification` dict within
|
||||
each document's metadata. If `multi_label` is set to `True`, the scores for each label are available under
|
||||
the `details` key within the `classification` dictionary.
|
||||
|
||||
:param documents:
|
||||
Documents to process.
|
||||
:param batch_size:
|
||||
Batch size used for processing the content in each document.
|
||||
:returns:
|
||||
A dictionary with the following key:
|
||||
- `documents`: A list of documents with an added metadata field called `classification`.
|
||||
"""
|
||||
|
||||
if self.pipeline is None:
|
||||
raise RuntimeError(
|
||||
"The component TransformerZeroShotDocumentClassifier wasn't warmed up. "
|
||||
"Run 'warm_up()' before calling 'run()'."
|
||||
)
|
||||
|
||||
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
|
||||
raise TypeError(
|
||||
"DocumentLanguageClassifier expects a list of documents as input. "
|
||||
"In case you want to classify a text, please use the TextLanguageClassifier."
|
||||
)
|
||||
|
||||
invalid_doc_ids = []
|
||||
|
||||
for doc in documents:
|
||||
if self.classification_field is not None and self.classification_field not in doc.meta:
|
||||
invalid_doc_ids.append(doc.id)
|
||||
|
||||
if invalid_doc_ids:
|
||||
raise ValueError(
|
||||
f"The following documents do not have the classification field '{self.classification_field}': "
|
||||
f"{', '.join(invalid_doc_ids)}"
|
||||
)
|
||||
|
||||
texts = [
|
||||
doc.content if self.classification_field is None else doc.meta[self.classification_field]
|
||||
for doc in documents
|
||||
]
|
||||
|
||||
predictions = self.pipeline(texts, self.labels, multi_label=self.multi_label, batch_size=batch_size)
|
||||
|
||||
for prediction, document in zip(predictions, documents):
|
||||
formatted_prediction = {
|
||||
"label": prediction["labels"][0],
|
||||
"score": prediction["scores"][0],
|
||||
"details": dict(zip(prediction["labels"], prediction["scores"])),
|
||||
}
|
||||
document.meta["classification"] = formatted_prediction
|
||||
|
||||
return {"documents": documents}
|
||||
@ -0,0 +1,7 @@
|
||||
---
|
||||
highlights: >
|
||||
Adds support for zero shot document classification
|
||||
features:
|
||||
- |
|
||||
Adds support for zero shot document classification. This allows you to classify documents into user-defined
|
||||
classes (binary and multi-label classification) using pre-trained models from huggingface.
|
||||
@ -0,0 +1,165 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from haystack import Document, Pipeline
|
||||
from haystack.components.classifiers import TransformersZeroShotDocumentClassifier
|
||||
from haystack.components.retrievers import InMemoryBM25Retriever
|
||||
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||
from haystack.utils import ComponentDevice, Secret
|
||||
|
||||
|
||||
class TestTransformersZeroShotDocumentClassifier:
|
||||
def test_init(self):
|
||||
component = TransformersZeroShotDocumentClassifier(
|
||||
model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
|
||||
)
|
||||
assert component.labels == ["positive", "negative"]
|
||||
assert component.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
|
||||
assert component.multi_label is False
|
||||
assert component.pipeline is None
|
||||
assert component.classification_field is None
|
||||
|
||||
def test_to_dict(self):
|
||||
component = TransformersZeroShotDocumentClassifier(
|
||||
model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
|
||||
)
|
||||
component_dict = component.to_dict()
|
||||
assert component_dict == {
|
||||
"type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier",
|
||||
"init_parameters": {
|
||||
"model": "cross-encoder/nli-deberta-v3-xsmall",
|
||||
"labels": ["positive", "negative"],
|
||||
"token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"huggingface_pipeline_kwargs": {
|
||||
"model": "cross-encoder/nli-deberta-v3-xsmall",
|
||||
"device": ComponentDevice.resolve_device(None).to_hf(),
|
||||
"task": "zero-shot-classification",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def test_from_dict(self, monkeypatch):
|
||||
monkeypatch.delenv("HF_API_TOKEN", raising=False)
|
||||
data = {
|
||||
"type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier",
|
||||
"init_parameters": {
|
||||
"model": "cross-encoder/nli-deberta-v3-xsmall",
|
||||
"labels": ["positive", "negative"],
|
||||
"token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"huggingface_pipeline_kwargs": {
|
||||
"model": "cross-encoder/nli-deberta-v3-xsmall",
|
||||
"device": ComponentDevice.resolve_device(None).to_hf(),
|
||||
"task": "zero-shot-classification",
|
||||
},
|
||||
},
|
||||
}
|
||||
component = TransformersZeroShotDocumentClassifier.from_dict(data)
|
||||
assert component.labels == ["positive", "negative"]
|
||||
assert component.pipeline is None
|
||||
assert component.token == Secret.from_dict(
|
||||
{"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"}
|
||||
)
|
||||
assert component.huggingface_pipeline_kwargs == {
|
||||
"model": "cross-encoder/nli-deberta-v3-xsmall",
|
||||
"device": ComponentDevice.resolve_device(None).to_hf(),
|
||||
"task": "zero-shot-classification",
|
||||
"token": None,
|
||||
}
|
||||
|
||||
def test_from_dict_no_default_parameters(self, monkeypatch):
|
||||
monkeypatch.delenv("HF_API_TOKEN", raising=False)
|
||||
data = {
|
||||
"type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier",
|
||||
"init_parameters": {"model": "cross-encoder/nli-deberta-v3-xsmall", "labels": ["positive", "negative"]},
|
||||
}
|
||||
component = TransformersZeroShotDocumentClassifier.from_dict(data)
|
||||
assert component.labels == ["positive", "negative"]
|
||||
assert component.pipeline is None
|
||||
assert component.token == Secret.from_dict(
|
||||
{"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"}
|
||||
)
|
||||
assert component.huggingface_pipeline_kwargs == {
|
||||
"model": "cross-encoder/nli-deberta-v3-xsmall",
|
||||
"device": ComponentDevice.resolve_device(None).to_hf(),
|
||||
"task": "zero-shot-classification",
|
||||
"token": None,
|
||||
}
|
||||
|
||||
@patch("haystack.components.classifiers.zero_shot_document_classifier.pipeline")
|
||||
def test_warm_up(self, hf_pipeline_mock):
|
||||
component = TransformersZeroShotDocumentClassifier(
|
||||
model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
|
||||
)
|
||||
component.warm_up()
|
||||
assert component.pipeline is not None
|
||||
|
||||
def test_run_fails_without_warm_up(self):
|
||||
component = TransformersZeroShotDocumentClassifier(
|
||||
model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
|
||||
)
|
||||
positive_documents = [Document(content="That's good. I like it.")]
|
||||
with pytest.raises(RuntimeError):
|
||||
component.run(documents=positive_documents)
|
||||
|
||||
@patch("haystack.components.classifiers.zero_shot_document_classifier.pipeline")
|
||||
def test_run_fails_with_non_document_input(self, hf_pipeline_mock):
|
||||
hf_pipeline_mock.return_value = " "
|
||||
component = TransformersZeroShotDocumentClassifier(
|
||||
model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
|
||||
)
|
||||
component.warm_up()
|
||||
text_list = ["That's good. I like it.", "That's bad. I don't like it."]
|
||||
with pytest.raises(TypeError):
|
||||
component.run(documents=text_list)
|
||||
|
||||
@patch("haystack.components.classifiers.zero_shot_document_classifier.pipeline")
|
||||
def test_run_unit(self, hf_pipeline_mock):
|
||||
hf_pipeline_mock.return_value = [
|
||||
{"sequence": "That's good. I like it.", "labels": ["positive", "negative"], "scores": [0.99, 0.01]},
|
||||
{"sequence": "That's bad. I don't like it.", "labels": ["negative", "positive"], "scores": [0.99, 0.01]},
|
||||
]
|
||||
component = TransformersZeroShotDocumentClassifier(
|
||||
model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
|
||||
)
|
||||
component.pipeline = hf_pipeline_mock
|
||||
positive_document = Document(content="That's good. I like it.")
|
||||
negative_document = Document(content="That's bad. I don't like it.")
|
||||
result = component.run(documents=[positive_document, negative_document])
|
||||
assert component.pipeline is not None
|
||||
assert result["documents"][0].to_dict()["classification"]["label"] == "positive"
|
||||
assert result["documents"][1].to_dict()["classification"]["label"] == "negative"
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_run(self):
|
||||
component = TransformersZeroShotDocumentClassifier(
|
||||
model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
|
||||
)
|
||||
component.warm_up()
|
||||
positive_document = Document(content="That's good. I like it. " * 1000)
|
||||
negative_document = Document(content="That's bad. I don't like it.")
|
||||
result = component.run(documents=[positive_document, negative_document])
|
||||
assert component.pipeline is not None
|
||||
assert result["documents"][0].to_dict()["classification"]["label"] == "positive"
|
||||
assert result["documents"][1].to_dict()["classification"]["label"] == "negative"
|
||||
|
||||
def test_serialization_and_deserialization_pipeline(self):
|
||||
pipeline = Pipeline()
|
||||
document_store = InMemoryDocumentStore()
|
||||
retriever = InMemoryBM25Retriever(document_store=document_store)
|
||||
document_classifier = TransformersZeroShotDocumentClassifier(
|
||||
model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
|
||||
)
|
||||
|
||||
pipeline.add_component(instance=retriever, name="retriever")
|
||||
pipeline.add_component(instance=document_classifier, name="document_classifier")
|
||||
pipeline.connect("retriever", "document_classifier")
|
||||
pipeline_dump = pipeline.dumps()
|
||||
|
||||
new_pipeline = Pipeline.loads(pipeline_dump)
|
||||
|
||||
assert new_pipeline == pipeline
|
||||
Loading…
x
Reference in New Issue
Block a user