From 0ceeb733baabe2b3658ee7065c4441a632ef465d Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Wed, 29 May 2024 10:54:21 +0200 Subject: [PATCH] chore: make `warm_up()` usage consistent (#7752) * make usage consistent * fix error type * release notes * pylint fix * change of plan * revert * fix test * revert * fix HF tests * Apply suggestions from code review Co-authored-by: Madeesh Kannan * fix formatting * reformat * fix regex match with the new error message * fix integration test --------- Co-authored-by: Madeesh Kannan --- haystack/components/audio/whisper_local.py | 8 +++--- .../components/evaluators/sas_evaluator.py | 3 +++ .../extractors/named_entity_extractor.py | 9 +++++++ .../generators/chat/hugging_face_local.py | 8 +++--- .../generators/chat/hugging_face_tgi.py | 9 +++++++ .../generators/hugging_face_local.py | 27 +++++++++++++------ .../components/generators/hugging_face_tgi.py | 10 ++++--- .../sentence_transformers_diversity.py | 17 ++++++------ .../rankers/transformers_similarity.py | 16 +++++------ haystack/components/readers/extractive.py | 12 +++++---- ...e-warm-up-consistent-0247da81b155b136.yaml | 4 +++ .../test_hugging_face_local_generator.py | 3 ++- .../test_sentence_transformers_diversity.py | 4 +-- .../rankers/test_transformers_similarity.py | 2 +- 14 files changed, 89 insertions(+), 43 deletions(-) create mode 100644 releasenotes/notes/make-warm-up-consistent-0247da81b155b136.yaml diff --git a/haystack/components/audio/whisper_local.py b/haystack/components/audio/whisper_local.py index 5a96f40e4..4909a7f9d 100644 --- a/haystack/components/audio/whisper_local.py +++ b/haystack/components/audio/whisper_local.py @@ -6,7 +6,7 @@ import tempfile from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Union, get_args -from haystack import ComponentError, Document, component, default_from_dict, default_to_dict, logging +from haystack import Document, component, default_from_dict, default_to_dict, logging from haystack.dataclasses import ByteStream from haystack.lazy_imports import LazyImport from haystack.utils import ComponentDevice @@ -113,7 +113,9 @@ class LocalWhisperTranscriber: alignment data and the path to the audio file used for the transcription. """ if self._model is None: - raise ComponentError("The component was not warmed up. Run 'warm_up()' before calling 'run()'.") + raise RuntimeError( + "The component LocalWhisperTranscriber was not warmed up. Run 'warm_up()' before calling 'run()'." + ) if whisper_params is None: whisper_params = self.whisper_params @@ -156,7 +158,7 @@ class LocalWhisperTranscriber: A dictionary mapping 'file_path' to 'transcription'. """ if self._model is None: - raise ComponentError("Model is not loaded, please run 'warm_up()' before calling 'run()'") + raise RuntimeError("Model is not loaded, please run 'warm_up()' before calling 'run()'") return_segments = kwargs.pop("return_segments", False) transcriptions: Dict[Path, Any] = {} diff --git a/haystack/components/evaluators/sas_evaluator.py b/haystack/components/evaluators/sas_evaluator.py index 6733b45e0..b67688756 100644 --- a/haystack/components/evaluators/sas_evaluator.py +++ b/haystack/components/evaluators/sas_evaluator.py @@ -116,6 +116,9 @@ class SASEvaluator: """ Initializes the component. """ + if self._similarity_model: + return + token = self._token.resolve_value() if self._token else None config = AutoConfig.from_pretrained(self._model, use_auth_token=token) cross_encoder_used = False diff --git a/haystack/components/extractors/named_entity_extractor.py b/haystack/components/extractors/named_entity_extractor.py index cbb96ab46..b8083742f 100644 --- a/haystack/components/extractors/named_entity_extractor.py +++ b/haystack/components/extractors/named_entity_extractor.py @@ -134,6 +134,7 @@ class NamedEntityExtractor: backend = NamedEntityExtractorBackend.from_str(backend) self._backend: _NerBackend + self._warmed_up: bool = False device = ComponentDevice.resolve_device(device) if backend == NamedEntityExtractorBackend.HUGGING_FACE: @@ -150,8 +151,12 @@ class NamedEntityExtractor: :raises ComponentError: If the backend fails to initialize successfully. """ + if self._warmed_up: + return + try: self._backend.initialize() + self._warmed_up = True except Exception as e: raise ComponentError( f"Named entity extractor with backend '{self._backend.type} failed to initialize." @@ -171,6 +176,10 @@ class NamedEntityExtractor: :raises ComponentError: If the backend fails to process a document. """ + if not self._warmed_up: + msg = "The component NamedEntityExtractor was not warmed up. Call warm_up() before running the component." + raise RuntimeError(msg) + texts = [doc.content if doc.content is not None else "" for doc in documents] annotations = self._backend.annotate(texts, batch_size=batch_size) diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index cfeba779d..4e2bee87c 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -265,12 +265,12 @@ class HuggingFaceLocalChatGenerator: if self.streaming_callback: num_responses = generation_kwargs.get("num_return_sequences", 1) if num_responses > 1: - logger.warning( - "Streaming is enabled, but the number of responses is set to %d. " + msg = ( + "Streaming is enabled, but the number of responses is set to {num_responses}. " "Streaming is only supported for single response generation. " - "Setting the number of responses to 1.", - num_responses, + "Setting the number of responses to 1." ) + logger.warning(msg, num_responses=num_responses) generation_kwargs["num_return_sequences"] = 1 # streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, self.streaming_callback, stop_words) diff --git a/haystack/components/generators/chat/hugging_face_tgi.py b/haystack/components/generators/chat/hugging_face_tgi.py index 046e793f8..645298178 100644 --- a/haystack/components/generators/chat/hugging_face_tgi.py +++ b/haystack/components/generators/chat/hugging_face_tgi.py @@ -150,6 +150,7 @@ class HuggingFaceTGIChatGenerator: self.client = InferenceClient(url or model, token=token.resolve_value() if token else None) self.streaming_callback = streaming_callback self.tokenizer = None + self._warmed_up: bool = False def warm_up(self) -> None: """ @@ -158,6 +159,8 @@ class HuggingFaceTGIChatGenerator: If the url is not provided, check if the model is deployed on the free tier of the HF inference API. Load the tokenizer """ + if self._warmed_up: + return # is this user using HF free tier inference API? if self.model and not self.url: @@ -184,6 +187,8 @@ class HuggingFaceTGIChatGenerator: model=self.model, ) + self._warmed_up = True + def to_dict(self) -> Dict[str, Any]: """ Serialize this component to a dictionary. @@ -229,6 +234,10 @@ class HuggingFaceTGIChatGenerator: :param generation_kwargs: Additional keyword arguments for text generation. :return: A list containing the generated responses as ChatMessage instances. """ + if not self._warmed_up: + raise RuntimeError( + "The component HuggingFaceTGIChatGenerator was not warmed up. Please call warm_up() before running." + ) # check generation kwargs given as parameters to override the default ones additional_params = ["n", "stop_words"] diff --git a/haystack/components/generators/hugging_face_local.py b/haystack/components/generators/hugging_face_local.py index f97951f46..0a3c6df34 100644 --- a/haystack/components/generators/hugging_face_local.py +++ b/haystack/components/generators/hugging_face_local.py @@ -140,10 +140,19 @@ class HuggingFaceLocalGenerator: return {"model": self.huggingface_pipeline_kwargs["model"]} return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"} + @property + def _warmed_up(self) -> bool: + if self.stop_words: + return (self.pipeline is not None) and (self.stopping_criteria_list is not None) + return self.pipeline is not None + def warm_up(self): """ Initializes the component. """ + if self._warmed_up: + return + if self.pipeline is None: self.pipeline = pipeline(**self.huggingface_pipeline_kwargs) @@ -209,8 +218,10 @@ class HuggingFaceLocalGenerator: A dictionary containing the generated replies. - replies: A list of strings representing the generated replies. """ - if self.pipeline is None: - raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.") + if not self._warmed_up: + raise RuntimeError( + "The component HuggingFaceLocalGenerator was not warmed up. Please call warm_up() before running." + ) if not prompt: return {"replies": []} @@ -221,19 +232,19 @@ class HuggingFaceLocalGenerator: if self.streaming_callback: num_responses = updated_generation_kwargs.get("num_return_sequences", 1) if num_responses > 1: - logger.warning( - "Streaming is enabled, but the number of responses is set to %d. " + msg = ( + "Streaming is enabled, but the number of responses is set to {num_responses}. " "Streaming is only supported for single response generation. " - "Setting the number of responses to 1.", - num_responses, + "Setting the number of responses to 1." ) + logger.warning(msg, num_responses=num_responses) updated_generation_kwargs["num_return_sequences"] = 1 # streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming updated_generation_kwargs["streamer"] = HFTokenStreamingHandler( - self.pipeline.tokenizer, self.streaming_callback, self.stop_words + self.pipeline.tokenizer, self.streaming_callback, self.stop_words # type: ignore ) - output = self.pipeline(prompt, stopping_criteria=self.stopping_criteria_list, **updated_generation_kwargs) + output = self.pipeline(prompt, stopping_criteria=self.stopping_criteria_list, **updated_generation_kwargs) # type: ignore replies = [o["generated_text"] for o in output if "generated_text" in o] if self.stop_words: diff --git a/haystack/components/generators/hugging_face_tgi.py b/haystack/components/generators/hugging_face_tgi.py index f8fd23f94..fa3eeb855 100644 --- a/haystack/components/generators/hugging_face_tgi.py +++ b/haystack/components/generators/hugging_face_tgi.py @@ -140,6 +140,8 @@ class HuggingFaceTGIGenerator: """ Initializes the component. """ + if self.tokenizer is not None: + return # is this user using HF free tier inference API? if self.model and not self.url: @@ -205,6 +207,11 @@ class HuggingFaceTGIGenerator: A dictionary containing the generated replies and metadata. Both are lists of length n. - replies: A list of strings representing the generated replies. """ + if not self.tokenizer: + raise RuntimeError( + "The component HuggingFaceTGIGenerator was not warmed up. Please call warm_up() before running LLM inference." + ) + # check generation kwargs given as parameters to override the default ones additional_params = ["n", "stop_words"] check_generation_params(generation_kwargs, additional_params) @@ -214,9 +221,6 @@ class HuggingFaceTGIGenerator: num_responses = generation_kwargs.pop("n", 1) generation_kwargs.setdefault("stop_sequences", []).extend(generation_kwargs.pop("stop_words", [])) - if self.tokenizer is None: - raise RuntimeError("Please call warm_up() before running LLM inference.") - prompt_token_count = len(self.tokenizer.encode(prompt, add_special_tokens=False)) if self.streaming_callback: diff --git a/haystack/components/rankers/sentence_transformers_diversity.py b/haystack/components/rankers/sentence_transformers_diversity.py index c1a216533..9ce0c0efb 100644 --- a/haystack/components/rankers/sentence_transformers_diversity.py +++ b/haystack/components/rankers/sentence_transformers_diversity.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Literal, Optional -from haystack import ComponentError, Document, component, default_from_dict, default_to_dict, logging +from haystack import Document, component, default_from_dict, default_to_dict, logging from haystack.lazy_imports import LazyImport from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace @@ -231,7 +231,15 @@ class SentenceTransformersDiversityRanker: - `documents`: List of Document objects that have been selected based on the diversity ranking. :raises ValueError: If the top_k value is less than or equal to 0. + :raises RuntimeError: If the component has not been warmed up. """ + if self.model is None: + error_msg = ( + "The component SentenceTransformersDiversityRanker wasn't warmed up. " + "Run 'warm_up()' before calling 'run()'." + ) + raise RuntimeError(error_msg) + if not documents: return {"documents": []} @@ -240,13 +248,6 @@ class SentenceTransformersDiversityRanker: elif top_k <= 0: raise ValueError(f"top_k must be > 0, but got {top_k}") - if self.model is None: - error_msg = ( - "The component SentenceTransformersDiversityRanker wasn't warmed up. " - "Run 'warm_up()' before calling 'run()'." - ) - raise ComponentError(error_msg) - diversity_sorted = self._greedy_diversity_order(query=query, documents=documents) return {"documents": diversity_sorted[:top_k]} diff --git a/haystack/components/rankers/transformers_similarity.py b/haystack/components/rankers/transformers_similarity.py index 4d0a892bb..3b7db8ca6 100644 --- a/haystack/components/rankers/transformers_similarity.py +++ b/haystack/components/rankers/transformers_similarity.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union -from haystack import ComponentError, Document, component, default_from_dict, default_to_dict, logging +from haystack import Document, component, default_from_dict, default_to_dict, logging from haystack.lazy_imports import LazyImport from haystack.utils import ComponentDevice, DeviceMap, Secret, deserialize_secrets_inplace from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_device_map, serialize_hf_model_kwargs @@ -218,9 +218,15 @@ class TransformersSimilarityRanker: :raises ValueError: If `top_k` is not > 0. If `scale_score` is True and `calibration_factor` is not provided. - :raises ComponentError: + :raises RuntimeError: If the model is not loaded because `warm_up()` was not called before. """ + # If a model path is provided but the model isn't loaded + if self.model is None: + raise RuntimeError( + "The component TransformersSimilarityRanker wasn't warmed up. Run 'warm_up()' before calling 'run()'." + ) + if not documents: return {"documents": []} @@ -237,12 +243,6 @@ class TransformersSimilarityRanker: f"scale_score is True so calibration_factor must be provided, but got {calibration_factor}" ) - # If a model path is provided but the model isn't loaded - if self.model is None: - raise ComponentError( - f"The component {self.__class__.__name__} wasn't warmed up. Run 'warm_up()' before calling 'run()'." - ) - query_doc_pairs = [] for doc in documents: meta_values_to_embed = [ diff --git a/haystack/components/readers/extractive.py b/haystack/components/readers/extractive.py index 29bee01d3..edf7864e6 100644 --- a/haystack/components/readers/extractive.py +++ b/haystack/components/readers/extractive.py @@ -7,7 +7,7 @@ import warnings from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union -from haystack import ComponentError, Document, ExtractedAnswer, component, default_from_dict, default_to_dict, logging +from haystack import Document, ExtractedAnswer, component, default_from_dict, default_to_dict, logging from haystack.lazy_imports import LazyImport from haystack.utils import ComponentDevice, DeviceMap, Secret, deserialize_secrets_inplace from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_device_map, serialize_hf_model_kwargs @@ -571,17 +571,19 @@ class ExtractiveReader: :returns: List of answers sorted by (desc.) answer score. - :raises ComponentError: + :raises RuntimeError: If the component was not warmed up by calling 'warm_up()' before. """ + if self.model is None: + raise RuntimeError( + "The component ExtractiveReader was not warmed up. Run 'warm_up()' before calling 'run()'." + ) + if not documents: return {"answers": []} queries = [query] # Temporary solution until we have decided what batching should look like in v2 nested_documents = [documents] - if self.model is None: - raise ComponentError("The component was not warmed up. Run 'warm_up()' before calling 'run()'.") - top_k = top_k or self.top_k score_threshold = score_threshold or self.score_threshold max_seq_length = max_seq_length or self.max_seq_length diff --git a/releasenotes/notes/make-warm-up-consistent-0247da81b155b136.yaml b/releasenotes/notes/make-warm-up-consistent-0247da81b155b136.yaml new file mode 100644 index 000000000..49cc2a91a --- /dev/null +++ b/releasenotes/notes/make-warm-up-consistent-0247da81b155b136.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Make `warm_up()` usage consistent across the codebase. diff --git a/test/components/generators/test_hugging_face_local_generator.py b/test/components/generators/test_hugging_face_local_generator.py index ff9569ebf..e5eebbd73 100644 --- a/test/components/generators/test_hugging_face_local_generator.py +++ b/test/components/generators/test_hugging_face_local_generator.py @@ -385,7 +385,7 @@ class TestHuggingFaceLocalGenerator: model="google/flan-t5-base", task="text2text-generation", generation_kwargs={"max_new_tokens": 100} ) - with pytest.raises(RuntimeError, match="The generation model has not been loaded."): + with pytest.raises(RuntimeError, match="The component HuggingFaceLocalGenerator was not warmed up"): generator.run(prompt="irrelevant") def test_stop_words_criteria_with_a_mocked_tokenizer(self): @@ -424,6 +424,7 @@ class TestHuggingFaceLocalGenerator: model="google/flan-t5-small", task="text2text-generation", stop_words=["world"] ) generator.pipeline = Mock(return_value=[{"generated_text": "Hello world"}]) + generator.stopping_criteria_list = Mock() results = generator.run(prompt="irrelevant") assert results == {"replies": ["Hello"]} diff --git a/test/components/rankers/test_sentence_transformers_diversity.py b/test/components/rankers/test_sentence_transformers_diversity.py index b4885d327..ab794ffd3 100644 --- a/test/components/rankers/test_sentence_transformers_diversity.py +++ b/test/components/rankers/test_sentence_transformers_diversity.py @@ -6,7 +6,7 @@ from unittest.mock import MagicMock, call, patch import pytest import torch -from haystack import ComponentError, Document +from haystack import Document from haystack.components.rankers import SentenceTransformersDiversityRanker from haystack.utils import ComponentDevice from haystack.utils.auth import Secret @@ -228,7 +228,7 @@ class TestSentenceTransformersDiversityRanker: documents = [Document(content="doc1"), Document(content="doc2")] error_msg = "The component SentenceTransformersDiversityRanker wasn't warmed up." - with pytest.raises(ComponentError, match=error_msg): + with pytest.raises(RuntimeError, match=error_msg): ranker.run(query="test query", documents=documents) @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) diff --git a/test/components/rankers/test_transformers_similarity.py b/test/components/rankers/test_transformers_similarity.py index 1d7315139..cec755ee2 100644 --- a/test/components/rankers/test_transformers_similarity.py +++ b/test/components/rankers/test_transformers_similarity.py @@ -343,7 +343,7 @@ class TestSimilarityRanker: @pytest.mark.integration def test_raises_component_error_if_model_not_warmed_up(self): sampler = TransformersSimilarityRanker() - with pytest.raises(ComponentError): + with pytest.raises(RuntimeError): sampler.run(query="query", documents=[Document(content="document")]) @pytest.mark.integration