mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-01 17:47:19 +00:00
100 lines
4.3 KiB
Plaintext
100 lines
4.3 KiB
Plaintext
---
|
|
title: "TransformersZeroShotDocumentClassifier"
|
|
id: transformerszeroshotdocumentclassifier
|
|
slug: "/transformerszeroshotdocumentclassifier"
|
|
description: "Classifies the documents based on the provided labels and adds them to their metadata."
|
|
---
|
|
|
|
# TransformersZeroShotDocumentClassifier
|
|
|
|
Classifies the documents based on the provided labels and adds them to their metadata.
|
|
|
|
<div className="key-value-table">
|
|
|
|
| | |
|
|
| --- | --- |
|
|
| **Most common position in a pipeline** | Before a [MetadataRouter](../routers/metadatarouter.mdx) |
|
|
| **Mandatory init variables** | `model`: The name or path of a Hugging Face model for zero shot document classification <br /> <br />`labels`: The set of possible class labels to classify each document into, for example, [`positive`, `negative`]. The labels depend on the selected model. |
|
|
| **Mandatory run variables** | `documents`: A list of documents to classify |
|
|
| **Output variables** | `documents`: A list of processed documents with an added `classification` metadata field |
|
|
| **API reference** | [Classifiers](/reference/classifiers-api) |
|
|
| **GitHub link** | https://github.com/deepset-ai/haystack/blob/main/haystack/components/classifiers/zero_shot_document_classifier.py |
|
|
|
|
</div>
|
|
|
|
## Overview
|
|
|
|
The `TransformersZeroShotDocumentClassifier` component performs zero-shot classification of documents based on the labels that you set and adds the predicted label to their metadata.
|
|
|
|
The component uses a Hugging Face pipeline for zero-shot classification.
|
|
To initialize the component, provide the model and the set of labels to be used for categorization.
|
|
You can additionally configure the component to allow multiple labels to be true with the `multi_label` boolean set to True.
|
|
|
|
Classification is run on the document's content field by default. If you want it to run on another field, set the`classification_field` to one of the document's metadata fields.
|
|
|
|
The classification results are stored in the `classification` dictionary within each document's metadata. If `multi_label` is set to `True`, you will find the scores for each label under the `details` key within the `classification` dictionary.
|
|
|
|
Available models for the task of zero-shot-classification are:
|
|
- `valhalla/distilbart-mnli-12-3`
|
|
- `cross-encoder/nli-distilroberta-base`
|
|
- `cross-encoder/nli-deberta-v3-xsmall`
|
|
|
|
## Usage
|
|
|
|
### On its own
|
|
|
|
```python
|
|
from haystack import Document
|
|
from haystack.components.classifiers import TransformersZeroShotDocumentClassifier
|
|
|
|
documents = [Document(id="0", content="Cats don't get teeth cavities."),
|
|
Document(id="1", content="Cucumbers can be grown in water.")]
|
|
|
|
document_classifier = TransformersZeroShotDocumentClassifier(
|
|
model="cross-encoder/nli-deberta-v3-xsmall",
|
|
labels=["animals", "food"],
|
|
)
|
|
|
|
document_classifier.warm_up()
|
|
document_classifier.run(documents = documents)
|
|
```
|
|
|
|
### In a pipeline
|
|
|
|
The following is a pipeline that classifies documents based on predefined classification labels
|
|
retrieved from a search pipeline:
|
|
|
|
```python
|
|
from haystack import Document
|
|
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
|
|
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
|
from haystack.core.pipeline import Pipeline
|
|
from haystack.components.classifiers import TransformersZeroShotDocumentClassifier
|
|
|
|
documents = [Document(id="0", content="Today was a nice day!"),
|
|
Document(id="1", content="Yesterday was a bad day!")]
|
|
|
|
document_store = InMemoryDocumentStore()
|
|
retriever = InMemoryBM25Retriever(document_store=document_store)
|
|
document_classifier = TransformersZeroShotDocumentClassifier(
|
|
model="cross-encoder/nli-deberta-v3-xsmall",
|
|
labels=["positive", "negative"],
|
|
)
|
|
|
|
document_store.write_documents(documents)
|
|
|
|
pipeline = Pipeline()
|
|
pipeline.add_component(instance=retriever, name="retriever")
|
|
pipeline.add_component(instance=document_classifier, name="document_classifier")
|
|
pipeline.connect("retriever", "document_classifier")
|
|
|
|
queries = ["How was your day today?", "How was your day yesterday?"]
|
|
expected_predictions = ["positive", "negative"]
|
|
|
|
for idx, query in enumerate(queries):
|
|
result = pipeline.run({"retriever": {"query": query, "top_k": 1}})
|
|
assert result["document_classifier"]["documents"][0].to_dict()["id"] == str(idx)
|
|
assert (result["document_classifier"]["documents"][0].to_dict()["classification"]["label"]
|
|
== expected_predictions[idx])
|
|
```
|