chore: make warm_up() usage consistent (#7752)

* make  usage consistent

* fix error type

* release notes

* pylint fix

* change of plan

* revert

* fix test

* revert

* fix HF tests

* Apply suggestions from code review

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>

* fix formatting

* reformat

* fix regex match with the new error message

* fix integration test

---------

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
This commit is contained in:
Massimiliano Pippi 2024-05-29 10:54:21 +02:00 committed by GitHub
parent 15aa4217bd
commit 0ceeb733ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 89 additions and 43 deletions

View File

@ -6,7 +6,7 @@ import tempfile
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union, get_args
from haystack import ComponentError, Document, component, default_from_dict, default_to_dict, logging
from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ByteStream
from haystack.lazy_imports import LazyImport
from haystack.utils import ComponentDevice
@ -113,7 +113,9 @@ class LocalWhisperTranscriber:
alignment data and the path to the audio file used for the transcription.
"""
if self._model is None:
raise ComponentError("The component was not warmed up. Run 'warm_up()' before calling 'run()'.")
raise RuntimeError(
"The component LocalWhisperTranscriber was not warmed up. Run 'warm_up()' before calling 'run()'."
)
if whisper_params is None:
whisper_params = self.whisper_params
@ -156,7 +158,7 @@ class LocalWhisperTranscriber:
A dictionary mapping 'file_path' to 'transcription'.
"""
if self._model is None:
raise ComponentError("Model is not loaded, please run 'warm_up()' before calling 'run()'")
raise RuntimeError("Model is not loaded, please run 'warm_up()' before calling 'run()'")
return_segments = kwargs.pop("return_segments", False)
transcriptions: Dict[Path, Any] = {}

View File

@ -116,6 +116,9 @@ class SASEvaluator:
"""
Initializes the component.
"""
if self._similarity_model:
return
token = self._token.resolve_value() if self._token else None
config = AutoConfig.from_pretrained(self._model, use_auth_token=token)
cross_encoder_used = False

View File

@ -134,6 +134,7 @@ class NamedEntityExtractor:
backend = NamedEntityExtractorBackend.from_str(backend)
self._backend: _NerBackend
self._warmed_up: bool = False
device = ComponentDevice.resolve_device(device)
if backend == NamedEntityExtractorBackend.HUGGING_FACE:
@ -150,8 +151,12 @@ class NamedEntityExtractor:
:raises ComponentError:
If the backend fails to initialize successfully.
"""
if self._warmed_up:
return
try:
self._backend.initialize()
self._warmed_up = True
except Exception as e:
raise ComponentError(
f"Named entity extractor with backend '{self._backend.type} failed to initialize."
@ -171,6 +176,10 @@ class NamedEntityExtractor:
:raises ComponentError:
If the backend fails to process a document.
"""
if not self._warmed_up:
msg = "The component NamedEntityExtractor was not warmed up. Call warm_up() before running the component."
raise RuntimeError(msg)
texts = [doc.content if doc.content is not None else "" for doc in documents]
annotations = self._backend.annotate(texts, batch_size=batch_size)

View File

@ -265,12 +265,12 @@ class HuggingFaceLocalChatGenerator:
if self.streaming_callback:
num_responses = generation_kwargs.get("num_return_sequences", 1)
if num_responses > 1:
logger.warning(
"Streaming is enabled, but the number of responses is set to %d. "
msg = (
"Streaming is enabled, but the number of responses is set to {num_responses}. "
"Streaming is only supported for single response generation. "
"Setting the number of responses to 1.",
num_responses,
"Setting the number of responses to 1."
)
logger.warning(msg, num_responses=num_responses)
generation_kwargs["num_return_sequences"] = 1
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, self.streaming_callback, stop_words)

View File

@ -150,6 +150,7 @@ class HuggingFaceTGIChatGenerator:
self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
self.streaming_callback = streaming_callback
self.tokenizer = None
self._warmed_up: bool = False
def warm_up(self) -> None:
"""
@ -158,6 +159,8 @@ class HuggingFaceTGIChatGenerator:
If the url is not provided, check if the model is deployed on the free tier of the HF inference API.
Load the tokenizer
"""
if self._warmed_up:
return
# is this user using HF free tier inference API?
if self.model and not self.url:
@ -184,6 +187,8 @@ class HuggingFaceTGIChatGenerator:
model=self.model,
)
self._warmed_up = True
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
@ -229,6 +234,10 @@ class HuggingFaceTGIChatGenerator:
:param generation_kwargs: Additional keyword arguments for text generation.
:return: A list containing the generated responses as ChatMessage instances.
"""
if not self._warmed_up:
raise RuntimeError(
"The component HuggingFaceTGIChatGenerator was not warmed up. Please call warm_up() before running."
)
# check generation kwargs given as parameters to override the default ones
additional_params = ["n", "stop_words"]

View File

@ -140,10 +140,19 @@ class HuggingFaceLocalGenerator:
return {"model": self.huggingface_pipeline_kwargs["model"]}
return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
@property
def _warmed_up(self) -> bool:
if self.stop_words:
return (self.pipeline is not None) and (self.stopping_criteria_list is not None)
return self.pipeline is not None
def warm_up(self):
"""
Initializes the component.
"""
if self._warmed_up:
return
if self.pipeline is None:
self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)
@ -209,8 +218,10 @@ class HuggingFaceLocalGenerator:
A dictionary containing the generated replies.
- replies: A list of strings representing the generated replies.
"""
if self.pipeline is None:
raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.")
if not self._warmed_up:
raise RuntimeError(
"The component HuggingFaceLocalGenerator was not warmed up. Please call warm_up() before running."
)
if not prompt:
return {"replies": []}
@ -221,19 +232,19 @@ class HuggingFaceLocalGenerator:
if self.streaming_callback:
num_responses = updated_generation_kwargs.get("num_return_sequences", 1)
if num_responses > 1:
logger.warning(
"Streaming is enabled, but the number of responses is set to %d. "
msg = (
"Streaming is enabled, but the number of responses is set to {num_responses}. "
"Streaming is only supported for single response generation. "
"Setting the number of responses to 1.",
num_responses,
"Setting the number of responses to 1."
)
logger.warning(msg, num_responses=num_responses)
updated_generation_kwargs["num_return_sequences"] = 1
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
updated_generation_kwargs["streamer"] = HFTokenStreamingHandler(
self.pipeline.tokenizer, self.streaming_callback, self.stop_words
self.pipeline.tokenizer, self.streaming_callback, self.stop_words # type: ignore
)
output = self.pipeline(prompt, stopping_criteria=self.stopping_criteria_list, **updated_generation_kwargs)
output = self.pipeline(prompt, stopping_criteria=self.stopping_criteria_list, **updated_generation_kwargs) # type: ignore
replies = [o["generated_text"] for o in output if "generated_text" in o]
if self.stop_words:

View File

@ -140,6 +140,8 @@ class HuggingFaceTGIGenerator:
"""
Initializes the component.
"""
if self.tokenizer is not None:
return
# is this user using HF free tier inference API?
if self.model and not self.url:
@ -205,6 +207,11 @@ class HuggingFaceTGIGenerator:
A dictionary containing the generated replies and metadata. Both are lists of length n.
- replies: A list of strings representing the generated replies.
"""
if not self.tokenizer:
raise RuntimeError(
"The component HuggingFaceTGIGenerator was not warmed up. Please call warm_up() before running LLM inference."
)
# check generation kwargs given as parameters to override the default ones
additional_params = ["n", "stop_words"]
check_generation_params(generation_kwargs, additional_params)
@ -214,9 +221,6 @@ class HuggingFaceTGIGenerator:
num_responses = generation_kwargs.pop("n", 1)
generation_kwargs.setdefault("stop_sequences", []).extend(generation_kwargs.pop("stop_words", []))
if self.tokenizer is None:
raise RuntimeError("Please call warm_up() before running LLM inference.")
prompt_token_count = len(self.tokenizer.encode(prompt, add_special_tokens=False))
if self.streaming_callback:

View File

@ -4,7 +4,7 @@
from typing import Any, Dict, List, Literal, Optional
from haystack import ComponentError, Document, component, default_from_dict, default_to_dict, logging
from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.lazy_imports import LazyImport
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
@ -231,7 +231,15 @@ class SentenceTransformersDiversityRanker:
- `documents`: List of Document objects that have been selected based on the diversity ranking.
:raises ValueError: If the top_k value is less than or equal to 0.
:raises RuntimeError: If the component has not been warmed up.
"""
if self.model is None:
error_msg = (
"The component SentenceTransformersDiversityRanker wasn't warmed up. "
"Run 'warm_up()' before calling 'run()'."
)
raise RuntimeError(error_msg)
if not documents:
return {"documents": []}
@ -240,13 +248,6 @@ class SentenceTransformersDiversityRanker:
elif top_k <= 0:
raise ValueError(f"top_k must be > 0, but got {top_k}")
if self.model is None:
error_msg = (
"The component SentenceTransformersDiversityRanker wasn't warmed up. "
"Run 'warm_up()' before calling 'run()'."
)
raise ComponentError(error_msg)
diversity_sorted = self._greedy_diversity_order(query=query, documents=documents)
return {"documents": diversity_sorted[:top_k]}

View File

@ -5,7 +5,7 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from haystack import ComponentError, Document, component, default_from_dict, default_to_dict, logging
from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.lazy_imports import LazyImport
from haystack.utils import ComponentDevice, DeviceMap, Secret, deserialize_secrets_inplace
from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_device_map, serialize_hf_model_kwargs
@ -218,9 +218,15 @@ class TransformersSimilarityRanker:
:raises ValueError:
If `top_k` is not > 0.
If `scale_score` is True and `calibration_factor` is not provided.
:raises ComponentError:
:raises RuntimeError:
If the model is not loaded because `warm_up()` was not called before.
"""
# If a model path is provided but the model isn't loaded
if self.model is None:
raise RuntimeError(
"The component TransformersSimilarityRanker wasn't warmed up. Run 'warm_up()' before calling 'run()'."
)
if not documents:
return {"documents": []}
@ -237,12 +243,6 @@ class TransformersSimilarityRanker:
f"scale_score is True so calibration_factor must be provided, but got {calibration_factor}"
)
# If a model path is provided but the model isn't loaded
if self.model is None:
raise ComponentError(
f"The component {self.__class__.__name__} wasn't warmed up. Run 'warm_up()' before calling 'run()'."
)
query_doc_pairs = []
for doc in documents:
meta_values_to_embed = [

View File

@ -7,7 +7,7 @@ import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from haystack import ComponentError, Document, ExtractedAnswer, component, default_from_dict, default_to_dict, logging
from haystack import Document, ExtractedAnswer, component, default_from_dict, default_to_dict, logging
from haystack.lazy_imports import LazyImport
from haystack.utils import ComponentDevice, DeviceMap, Secret, deserialize_secrets_inplace
from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_device_map, serialize_hf_model_kwargs
@ -571,17 +571,19 @@ class ExtractiveReader:
:returns:
List of answers sorted by (desc.) answer score.
:raises ComponentError:
:raises RuntimeError:
If the component was not warmed up by calling 'warm_up()' before.
"""
if self.model is None:
raise RuntimeError(
"The component ExtractiveReader was not warmed up. Run 'warm_up()' before calling 'run()'."
)
if not documents:
return {"answers": []}
queries = [query] # Temporary solution until we have decided what batching should look like in v2
nested_documents = [documents]
if self.model is None:
raise ComponentError("The component was not warmed up. Run 'warm_up()' before calling 'run()'.")
top_k = top_k or self.top_k
score_threshold = score_threshold or self.score_threshold
max_seq_length = max_seq_length or self.max_seq_length

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
Make `warm_up()` usage consistent across the codebase.

View File

@ -385,7 +385,7 @@ class TestHuggingFaceLocalGenerator:
model="google/flan-t5-base", task="text2text-generation", generation_kwargs={"max_new_tokens": 100}
)
with pytest.raises(RuntimeError, match="The generation model has not been loaded."):
with pytest.raises(RuntimeError, match="The component HuggingFaceLocalGenerator was not warmed up"):
generator.run(prompt="irrelevant")
def test_stop_words_criteria_with_a_mocked_tokenizer(self):
@ -424,6 +424,7 @@ class TestHuggingFaceLocalGenerator:
model="google/flan-t5-small", task="text2text-generation", stop_words=["world"]
)
generator.pipeline = Mock(return_value=[{"generated_text": "Hello world"}])
generator.stopping_criteria_list = Mock()
results = generator.run(prompt="irrelevant")
assert results == {"replies": ["Hello"]}

View File

@ -6,7 +6,7 @@ from unittest.mock import MagicMock, call, patch
import pytest
import torch
from haystack import ComponentError, Document
from haystack import Document
from haystack.components.rankers import SentenceTransformersDiversityRanker
from haystack.utils import ComponentDevice
from haystack.utils.auth import Secret
@ -228,7 +228,7 @@ class TestSentenceTransformersDiversityRanker:
documents = [Document(content="doc1"), Document(content="doc2")]
error_msg = "The component SentenceTransformersDiversityRanker wasn't warmed up."
with pytest.raises(ComponentError, match=error_msg):
with pytest.raises(RuntimeError, match=error_msg):
ranker.run(query="test query", documents=documents)
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])

View File

@ -343,7 +343,7 @@ class TestSimilarityRanker:
@pytest.mark.integration
def test_raises_component_error_if_model_not_warmed_up(self):
sampler = TransformersSimilarityRanker()
with pytest.raises(ComponentError):
with pytest.raises(RuntimeError):
sampler.run(query="query", documents=[Document(content="document")])
@pytest.mark.integration