Update zero_shot_text_router.py (#8231)

This commit is contained in:
Daria Fokina 2024-08-16 12:43:13 +02:00 committed by GitHub
parent b5d0bfa9df
commit b51bb6e5a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -24,13 +24,11 @@ with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") a
@component
class TransformersZeroShotTextRouter:
"""
Routes a text input onto different output connections depending on which label it has been categorized into.
Routes the text strings to different connections based on a category label.
This is useful for routing queries to different models in a pipeline depending on their categorization.
The set of labels to be used for categorization can be specified.
Specify the set of labels for categorization when initializing the component.
Example usage in a retrieval pipeline that passes question-like queries to a text embedder optimized for
query-passage retrieval and passage-like queries to a text embedder optimized for passage-passage retrieval.
### Usage example
```python
from haystack import Document
@ -107,22 +105,23 @@ class TransformersZeroShotTextRouter:
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Initializes the TransformersZeroShotTextRouter.
Initializes the TransformersZeroShotTextRouter component.
:param labels: The set of possible class labels to classify each sequence into. Can be a single label,
:param labels: The set of labels to use for classification. Can be a single label,
a string of comma-separated labels, or a list of labels.
:param multi_label: Whether or not multiple candidate labels can be true.
If False, the scores are normalized such that the sum of the label likelihoods for each sequence is 1.
If True, the labels are considered independent and probabilities are normalized for each candidate by
:param multi_label:
Indicates if multiple labels can be true.
If `False`, label scores are normalized so their sum equals 1 for each sequence.
If `True`, the labels are considered independent and probabilities are normalized for each candidate by
doing a softmax of the entailment score vs. the contradiction score.
:param model: The name or path of a Hugging Face model for zero-shot text classification.
:param device: The device on which the model is loaded. If `None`, the default device is automatically
selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
:param device: The device for loading the model. If `None`, automatically selects the default device.
If a device or device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
:param token: The API token used to download private models from Hugging Face.
If `token` is set to `True`, the token generated when running
`transformers-cli login` (stored in ~/.huggingface) is used.
:param huggingface_pipeline_kwargs: Dictionary containing keyword arguments used to initialize the
Hugging Face pipeline for zero shot text classification.
If `True`, uses either `HF_API_TOKEN` or `HF_TOKEN` environment variables.
To generate these tokens, run `transformers-cli login`.
:param huggingface_pipeline_kwargs: A dictionary of keyword arguments for initializing the Hugging Face
zero shot text classification.
"""
torch_and_transformers_import.check()
@ -195,11 +194,9 @@ class TransformersZeroShotTextRouter:
@component.output_types(documents=Dict[str, str])
def run(self, text: str):
"""
Run the TransformersZeroShotTextRouter.
Routes the text strings to different connections based on a category label.
This method routes the text to one of the different edges based on which label it has been categorized into.
:param text: A str to route to one of the different edges.
:param text: A string of text to route.
:returns:
A dictionary with the label as key and the text as value.