mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-07 13:24:16 +00:00
Update zero_shot_text_router.py (#8231)
This commit is contained in:
parent
b5d0bfa9df
commit
b51bb6e5a9
@ -24,13 +24,11 @@ with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") a
|
||||
@component
|
||||
class TransformersZeroShotTextRouter:
|
||||
"""
|
||||
Routes a text input onto different output connections depending on which label it has been categorized into.
|
||||
Routes the text strings to different connections based on a category label.
|
||||
|
||||
This is useful for routing queries to different models in a pipeline depending on their categorization.
|
||||
The set of labels to be used for categorization can be specified.
|
||||
Specify the set of labels for categorization when initializing the component.
|
||||
|
||||
Example usage in a retrieval pipeline that passes question-like queries to a text embedder optimized for
|
||||
query-passage retrieval and passage-like queries to a text embedder optimized for passage-passage retrieval.
|
||||
### Usage example
|
||||
|
||||
```python
|
||||
from haystack import Document
|
||||
@ -107,22 +105,23 @@ class TransformersZeroShotTextRouter:
|
||||
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Initializes the TransformersZeroShotTextRouter.
|
||||
Initializes the TransformersZeroShotTextRouter component.
|
||||
|
||||
:param labels: The set of possible class labels to classify each sequence into. Can be a single label,
|
||||
:param labels: The set of labels to use for classification. Can be a single label,
|
||||
a string of comma-separated labels, or a list of labels.
|
||||
:param multi_label: Whether or not multiple candidate labels can be true.
|
||||
If False, the scores are normalized such that the sum of the label likelihoods for each sequence is 1.
|
||||
If True, the labels are considered independent and probabilities are normalized for each candidate by
|
||||
:param multi_label:
|
||||
Indicates if multiple labels can be true.
|
||||
If `False`, label scores are normalized so their sum equals 1 for each sequence.
|
||||
If `True`, the labels are considered independent and probabilities are normalized for each candidate by
|
||||
doing a softmax of the entailment score vs. the contradiction score.
|
||||
:param model: The name or path of a Hugging Face model for zero-shot text classification.
|
||||
:param device: The device on which the model is loaded. If `None`, the default device is automatically
|
||||
selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
|
||||
:param device: The device for loading the model. If `None`, automatically selects the default device.
|
||||
If a device or device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
|
||||
:param token: The API token used to download private models from Hugging Face.
|
||||
If `token` is set to `True`, the token generated when running
|
||||
`transformers-cli login` (stored in ~/.huggingface) is used.
|
||||
:param huggingface_pipeline_kwargs: Dictionary containing keyword arguments used to initialize the
|
||||
Hugging Face pipeline for zero shot text classification.
|
||||
If `True`, uses either `HF_API_TOKEN` or `HF_TOKEN` environment variables.
|
||||
To generate these tokens, run `transformers-cli login`.
|
||||
:param huggingface_pipeline_kwargs: A dictionary of keyword arguments for initializing the Hugging Face
|
||||
zero shot text classification.
|
||||
"""
|
||||
torch_and_transformers_import.check()
|
||||
|
||||
@ -195,11 +194,9 @@ class TransformersZeroShotTextRouter:
|
||||
@component.output_types(documents=Dict[str, str])
|
||||
def run(self, text: str):
|
||||
"""
|
||||
Run the TransformersZeroShotTextRouter.
|
||||
Routes the text strings to different connections based on a category label.
|
||||
|
||||
This method routes the text to one of the different edges based on which label it has been categorized into.
|
||||
|
||||
:param text: A str to route to one of the different edges.
|
||||
:param text: A string of text to route.
|
||||
:returns:
|
||||
A dictionary with the label as key and the text as value.
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user