diff --git a/haystack/components/routers/zero_shot_text_router.py b/haystack/components/routers/zero_shot_text_router.py index 2112ea6cd..6ac9e8628 100644 --- a/haystack/components/routers/zero_shot_text_router.py +++ b/haystack/components/routers/zero_shot_text_router.py @@ -24,13 +24,11 @@ with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") a @component class TransformersZeroShotTextRouter: """ - Routes a text input onto different output connections depending on which label it has been categorized into. + Routes the text strings to different connections based on a category label. - This is useful for routing queries to different models in a pipeline depending on their categorization. - The set of labels to be used for categorization can be specified. + Specify the set of labels for categorization when initializing the component. - Example usage in a retrieval pipeline that passes question-like queries to a text embedder optimized for - query-passage retrieval and passage-like queries to a text embedder optimized for passage-passage retrieval. + ### Usage example ```python from haystack import Document @@ -107,22 +105,23 @@ class TransformersZeroShotTextRouter: huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None, ): """ - Initializes the TransformersZeroShotTextRouter. + Initializes the TransformersZeroShotTextRouter component. - :param labels: The set of possible class labels to classify each sequence into. Can be a single label, + :param labels: The set of labels to use for classification. Can be a single label, a string of comma-separated labels, or a list of labels. - :param multi_label: Whether or not multiple candidate labels can be true. - If False, the scores are normalized such that the sum of the label likelihoods for each sequence is 1. - If True, the labels are considered independent and probabilities are normalized for each candidate by + :param multi_label: + Indicates if multiple labels can be true. + If `False`, label scores are normalized so their sum equals 1 for each sequence. + If `True`, the labels are considered independent and probabilities are normalized for each candidate by doing a softmax of the entailment score vs. the contradiction score. :param model: The name or path of a Hugging Face model for zero-shot text classification. - :param device: The device on which the model is loaded. If `None`, the default device is automatically - selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter. + :param device: The device for loading the model. If `None`, automatically selects the default device. + If a device or device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter. :param token: The API token used to download private models from Hugging Face. - If `token` is set to `True`, the token generated when running - `transformers-cli login` (stored in ~/.huggingface) is used. - :param huggingface_pipeline_kwargs: Dictionary containing keyword arguments used to initialize the - Hugging Face pipeline for zero shot text classification. + If `True`, uses either `HF_API_TOKEN` or `HF_TOKEN` environment variables. + To generate these tokens, run `transformers-cli login`. + :param huggingface_pipeline_kwargs: A dictionary of keyword arguments for initializing the Hugging Face + zero shot text classification. """ torch_and_transformers_import.check() @@ -195,11 +194,9 @@ class TransformersZeroShotTextRouter: @component.output_types(documents=Dict[str, str]) def run(self, text: str): """ - Run the TransformersZeroShotTextRouter. + Routes the text strings to different connections based on a category label. - This method routes the text to one of the different edges based on which label it has been categorized into. - - :param text: A str to route to one of the different edges. + :param text: A string of text to route. :returns: A dictionary with the label as key and the text as value.