diff --git a/haystack/components/classifiers/__init__.py b/haystack/components/classifiers/__init__.py index 77090c79b..662df14d8 100644 --- a/haystack/components/classifiers/__init__.py +++ b/haystack/components/classifiers/__init__.py @@ -3,5 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 from haystack.components.classifiers.document_language_classifier import DocumentLanguageClassifier +from haystack.components.classifiers.zero_shot_document_classifier import TransformersZeroShotDocumentClassifier -__all__ = ["DocumentLanguageClassifier"] +__all__ = ["DocumentLanguageClassifier", "TransformersZeroShotDocumentClassifier"] diff --git a/haystack/components/classifiers/zero_shot_document_classifier.py b/haystack/components/classifiers/zero_shot_document_classifier.py new file mode 100644 index 000000000..cff245b35 --- /dev/null +++ b/haystack/components/classifiers/zero_shot_document_classifier.py @@ -0,0 +1,246 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional + +from haystack import Document, component, default_from_dict, default_to_dict, logging +from haystack.lazy_imports import LazyImport +from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace +from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_pipeline_kwargs, serialize_hf_model_kwargs + +logger = logging.getLogger(__name__) + + +with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import: + from transformers import pipeline + + +@component +class TransformersZeroShotDocumentClassifier: + """ + Performs zero-shot classification of documents based on given labels and adds the predicted label to their metadata. + + The component uses a Hugging Face pipeline for zero-shot classification. + Provide the model and the set of labels to be used for categorization during initialization. + Additionally, you can configure the component to allow multiple labels to be true. + + Classification is run on the document's content field by default. If you want it to run on another field, set the + `classification_field` to one of the document's metadata fields. + + Available models for the task of zero-shot-classification include: + - `valhalla/distilbart-mnli-12-3` + - `cross-encoder/nli-distilroberta-base` + - `cross-encoder/nli-deberta-v3-xsmall` + + ### Usage example + + The following is a pipeline that classifies documents based on predefined classification labels + retrieved from a search pipeline: + + ```python + from haystack import Document + from haystack.components.retrievers.in_memory import InMemoryBM25Retriever + from haystack.document_stores.in_memory import InMemoryDocumentStore + from haystack.core.pipeline import Pipeline + from haystack.components.classifiers import TransformersZeroShotDocumentClassifier + + documents = [Document(id="0", content="Today was a nice day!"), + Document(id="1", content="Yesterday was a bad day!")] + + document_store = InMemoryDocumentStore() + retriever = InMemoryBM25Retriever(document_store=document_store) + document_classifier = TransformersZeroShotDocumentClassifier( + model="cross-encoder/nli-deberta-v3-xsmall", + labels=["positive", "negative"], + ) + + document_store.write_documents(documents) + + pipeline = Pipeline() + pipeline.add_component(instance=retriever, name="retriever") + pipeline.add_component(instance=document_classifier, name="document_classifier") + pipeline.connect("retriever", "document_classifier") + + queries = ["How was your day today?", "How was your day yesterday?"] + expected_predictions = ["positive", "negative"] + + for idx, query in enumerate(queries): + result = pipeline.run({"retriever": {"query": query, "top_k": 1}}) + assert result["document_classifier"]["documents"][0].to_dict()["id"] == str(idx) + assert (result["document_classifier"]["documents"][0].to_dict()["classification"]["label"] + == expected_predictions[idx]) + ``` + """ + + def __init__( + self, + model: str, + labels: List[str], + multi_label: bool = False, + classification_field: Optional[str] = None, + device: Optional[ComponentDevice] = None, + token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False), + huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Initializes the TransformersZeroShotDocumentClassifier. + + See the Hugging Face [website](https://huggingface.co/models?pipeline_tag=zero-shot-classification&sort=downloads&search=nli) + for the full list of zero-shot classification models (NLI) models. + + :param model: + The name or path of a Hugging Face model for zero shot document classification. + :param labels: + The set of possible class labels to classify each document into, for example, + ["positive", "negative"]. The labels depend on the selected model. + :param multi_label: + Whether or not multiple candidate labels can be true. + If `False`, the scores are normalized such that + the sum of the label likelihoods for each sequence is 1. If `True`, the labels are considered + independent and probabilities are normalized for each candidate by doing a softmax of the entailment + score vs. the contradiction score. + :param classification_field: + Name of document's meta field to be used for classification. + If not set, `Document.content` is used by default. + :param device: + The device on which the model is loaded. If `None`, the default device is automatically + selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter. + :param token: + The Hugging Face token to use as HTTP bearer authorization. + Check your HF token in your [account settings](https://huggingface.co/settings/tokens). + :param huggingface_pipeline_kwargs: + Dictionary containing keyword arguments used to initialize the + Hugging Face pipeline for text classification. + """ + + torch_and_transformers_import.check() + + self.classification_field = classification_field + + self.token = token + self.labels = labels + self.multi_label = multi_label + component.set_output_types(self, **{label: str for label in labels}) + + huggingface_pipeline_kwargs = resolve_hf_pipeline_kwargs( + huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {}, + model=model, + task="zero-shot-classification", + supported_tasks=["zero-shot-classification"], + device=device, + token=token, + ) + + self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs + self.pipeline = None + + def _get_telemetry_data(self) -> Dict[str, Any]: + """ + Data that is sent to Posthog for usage analytics. + """ + if isinstance(self.huggingface_pipeline_kwargs["model"], str): + return {"model": self.huggingface_pipeline_kwargs["model"]} + return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"} + + def warm_up(self): + """ + Initializes the component. + """ + if self.pipeline is None: + self.pipeline = pipeline(**self.huggingface_pipeline_kwargs) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + serialization_dict = default_to_dict( + self, + labels=self.labels, + model=self.huggingface_pipeline_kwargs["model"], + huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs, + token=self.token.to_dict() if self.token else None, + ) + + huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"] + huggingface_pipeline_kwargs.pop("token", None) + + serialize_hf_model_kwargs(huggingface_pipeline_kwargs) + return serialization_dict + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "TransformersZeroShotDocumentClassifier": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) + if data["init_parameters"].get("huggingface_pipeline_kwargs") is not None: + deserialize_hf_model_kwargs(data["init_parameters"]["huggingface_pipeline_kwargs"]) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run(self, documents: List[Document], batch_size: int = 1): + """ + Classifies the documents based on the provided labels and adds them to their metadata. + + The classification results are stored in the `classification` dict within + each document's metadata. If `multi_label` is set to `True`, the scores for each label are available under + the `details` key within the `classification` dictionary. + + :param documents: + Documents to process. + :param batch_size: + Batch size used for processing the content in each document. + :returns: + A dictionary with the following key: + - `documents`: A list of documents with an added metadata field called `classification`. + """ + + if self.pipeline is None: + raise RuntimeError( + "The component TransformerZeroShotDocumentClassifier wasn't warmed up. " + "Run 'warm_up()' before calling 'run()'." + ) + + if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + raise TypeError( + "DocumentLanguageClassifier expects a list of documents as input. " + "In case you want to classify a text, please use the TextLanguageClassifier." + ) + + invalid_doc_ids = [] + + for doc in documents: + if self.classification_field is not None and self.classification_field not in doc.meta: + invalid_doc_ids.append(doc.id) + + if invalid_doc_ids: + raise ValueError( + f"The following documents do not have the classification field '{self.classification_field}': " + f"{', '.join(invalid_doc_ids)}" + ) + + texts = [ + doc.content if self.classification_field is None else doc.meta[self.classification_field] + for doc in documents + ] + + predictions = self.pipeline(texts, self.labels, multi_label=self.multi_label, batch_size=batch_size) + + for prediction, document in zip(predictions, documents): + formatted_prediction = { + "label": prediction["labels"][0], + "score": prediction["scores"][0], + "details": dict(zip(prediction["labels"], prediction["scores"])), + } + document.meta["classification"] = formatted_prediction + + return {"documents": documents} diff --git a/releasenotes/notes/add-zero-shot-document-classifier-3ab1d7bbdc04db05.yaml b/releasenotes/notes/add-zero-shot-document-classifier-3ab1d7bbdc04db05.yaml new file mode 100644 index 000000000..82f35c467 --- /dev/null +++ b/releasenotes/notes/add-zero-shot-document-classifier-3ab1d7bbdc04db05.yaml @@ -0,0 +1,7 @@ +--- +highlights: > + Adds support for zero shot document classification +features: + - | + Adds support for zero shot document classification. This allows you to classify documents into user-defined + classes (binary and multi-label classification) using pre-trained models from huggingface. diff --git a/test/components/classifiers/test_zero_shot_document_classifier.py b/test/components/classifiers/test_zero_shot_document_classifier.py new file mode 100644 index 000000000..7d679e3d2 --- /dev/null +++ b/test/components/classifiers/test_zero_shot_document_classifier.py @@ -0,0 +1,165 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from unittest.mock import patch + +from haystack import Document, Pipeline +from haystack.components.classifiers import TransformersZeroShotDocumentClassifier +from haystack.components.retrievers import InMemoryBM25Retriever +from haystack.document_stores.in_memory import InMemoryDocumentStore +from haystack.utils import ComponentDevice, Secret + + +class TestTransformersZeroShotDocumentClassifier: + def test_init(self): + component = TransformersZeroShotDocumentClassifier( + model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"] + ) + assert component.labels == ["positive", "negative"] + assert component.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False) + assert component.multi_label is False + assert component.pipeline is None + assert component.classification_field is None + + def test_to_dict(self): + component = TransformersZeroShotDocumentClassifier( + model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"] + ) + component_dict = component.to_dict() + assert component_dict == { + "type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier", + "init_parameters": { + "model": "cross-encoder/nli-deberta-v3-xsmall", + "labels": ["positive", "negative"], + "token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"}, + "huggingface_pipeline_kwargs": { + "model": "cross-encoder/nli-deberta-v3-xsmall", + "device": ComponentDevice.resolve_device(None).to_hf(), + "task": "zero-shot-classification", + }, + }, + } + + def test_from_dict(self, monkeypatch): + monkeypatch.delenv("HF_API_TOKEN", raising=False) + data = { + "type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier", + "init_parameters": { + "model": "cross-encoder/nli-deberta-v3-xsmall", + "labels": ["positive", "negative"], + "token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"}, + "huggingface_pipeline_kwargs": { + "model": "cross-encoder/nli-deberta-v3-xsmall", + "device": ComponentDevice.resolve_device(None).to_hf(), + "task": "zero-shot-classification", + }, + }, + } + component = TransformersZeroShotDocumentClassifier.from_dict(data) + assert component.labels == ["positive", "negative"] + assert component.pipeline is None + assert component.token == Secret.from_dict( + {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"} + ) + assert component.huggingface_pipeline_kwargs == { + "model": "cross-encoder/nli-deberta-v3-xsmall", + "device": ComponentDevice.resolve_device(None).to_hf(), + "task": "zero-shot-classification", + "token": None, + } + + def test_from_dict_no_default_parameters(self, monkeypatch): + monkeypatch.delenv("HF_API_TOKEN", raising=False) + data = { + "type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier", + "init_parameters": {"model": "cross-encoder/nli-deberta-v3-xsmall", "labels": ["positive", "negative"]}, + } + component = TransformersZeroShotDocumentClassifier.from_dict(data) + assert component.labels == ["positive", "negative"] + assert component.pipeline is None + assert component.token == Secret.from_dict( + {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"} + ) + assert component.huggingface_pipeline_kwargs == { + "model": "cross-encoder/nli-deberta-v3-xsmall", + "device": ComponentDevice.resolve_device(None).to_hf(), + "task": "zero-shot-classification", + "token": None, + } + + @patch("haystack.components.classifiers.zero_shot_document_classifier.pipeline") + def test_warm_up(self, hf_pipeline_mock): + component = TransformersZeroShotDocumentClassifier( + model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"] + ) + component.warm_up() + assert component.pipeline is not None + + def test_run_fails_without_warm_up(self): + component = TransformersZeroShotDocumentClassifier( + model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"] + ) + positive_documents = [Document(content="That's good. I like it.")] + with pytest.raises(RuntimeError): + component.run(documents=positive_documents) + + @patch("haystack.components.classifiers.zero_shot_document_classifier.pipeline") + def test_run_fails_with_non_document_input(self, hf_pipeline_mock): + hf_pipeline_mock.return_value = " " + component = TransformersZeroShotDocumentClassifier( + model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"] + ) + component.warm_up() + text_list = ["That's good. I like it.", "That's bad. I don't like it."] + with pytest.raises(TypeError): + component.run(documents=text_list) + + @patch("haystack.components.classifiers.zero_shot_document_classifier.pipeline") + def test_run_unit(self, hf_pipeline_mock): + hf_pipeline_mock.return_value = [ + {"sequence": "That's good. I like it.", "labels": ["positive", "negative"], "scores": [0.99, 0.01]}, + {"sequence": "That's bad. I don't like it.", "labels": ["negative", "positive"], "scores": [0.99, 0.01]}, + ] + component = TransformersZeroShotDocumentClassifier( + model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"] + ) + component.pipeline = hf_pipeline_mock + positive_document = Document(content="That's good. I like it.") + negative_document = Document(content="That's bad. I don't like it.") + result = component.run(documents=[positive_document, negative_document]) + assert component.pipeline is not None + assert result["documents"][0].to_dict()["classification"]["label"] == "positive" + assert result["documents"][1].to_dict()["classification"]["label"] == "negative" + + @pytest.mark.integration + def test_run(self): + component = TransformersZeroShotDocumentClassifier( + model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"] + ) + component.warm_up() + positive_document = Document(content="That's good. I like it. " * 1000) + negative_document = Document(content="That's bad. I don't like it.") + result = component.run(documents=[positive_document, negative_document]) + assert component.pipeline is not None + assert result["documents"][0].to_dict()["classification"]["label"] == "positive" + assert result["documents"][1].to_dict()["classification"]["label"] == "negative" + + def test_serialization_and_deserialization_pipeline(self): + pipeline = Pipeline() + document_store = InMemoryDocumentStore() + retriever = InMemoryBM25Retriever(document_store=document_store) + document_classifier = TransformersZeroShotDocumentClassifier( + model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"] + ) + + pipeline.add_component(instance=retriever, name="retriever") + pipeline.add_component(instance=document_classifier, name="document_classifier") + pipeline.connect("retriever", "document_classifier") + pipeline_dump = pipeline.dumps() + + new_pipeline = Pipeline.loads(pipeline_dump) + + assert new_pipeline == pipeline