mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-26 16:46:58 +00:00
chore: make warm_up()
usage consistent (#7752)
* make usage consistent * fix error type * release notes * pylint fix * change of plan * revert * fix test * revert * fix HF tests * Apply suggestions from code review Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> * fix formatting * reformat * fix regex match with the new error message * fix integration test --------- Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
This commit is contained in:
parent
15aa4217bd
commit
0ceeb733ba
@ -6,7 +6,7 @@ import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Union, get_args
|
||||
|
||||
from haystack import ComponentError, Document, component, default_from_dict, default_to_dict, logging
|
||||
from haystack import Document, component, default_from_dict, default_to_dict, logging
|
||||
from haystack.dataclasses import ByteStream
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import ComponentDevice
|
||||
@ -113,7 +113,9 @@ class LocalWhisperTranscriber:
|
||||
alignment data and the path to the audio file used for the transcription.
|
||||
"""
|
||||
if self._model is None:
|
||||
raise ComponentError("The component was not warmed up. Run 'warm_up()' before calling 'run()'.")
|
||||
raise RuntimeError(
|
||||
"The component LocalWhisperTranscriber was not warmed up. Run 'warm_up()' before calling 'run()'."
|
||||
)
|
||||
|
||||
if whisper_params is None:
|
||||
whisper_params = self.whisper_params
|
||||
@ -156,7 +158,7 @@ class LocalWhisperTranscriber:
|
||||
A dictionary mapping 'file_path' to 'transcription'.
|
||||
"""
|
||||
if self._model is None:
|
||||
raise ComponentError("Model is not loaded, please run 'warm_up()' before calling 'run()'")
|
||||
raise RuntimeError("Model is not loaded, please run 'warm_up()' before calling 'run()'")
|
||||
|
||||
return_segments = kwargs.pop("return_segments", False)
|
||||
transcriptions: Dict[Path, Any] = {}
|
||||
|
@ -116,6 +116,9 @@ class SASEvaluator:
|
||||
"""
|
||||
Initializes the component.
|
||||
"""
|
||||
if self._similarity_model:
|
||||
return
|
||||
|
||||
token = self._token.resolve_value() if self._token else None
|
||||
config = AutoConfig.from_pretrained(self._model, use_auth_token=token)
|
||||
cross_encoder_used = False
|
||||
|
@ -134,6 +134,7 @@ class NamedEntityExtractor:
|
||||
backend = NamedEntityExtractorBackend.from_str(backend)
|
||||
|
||||
self._backend: _NerBackend
|
||||
self._warmed_up: bool = False
|
||||
device = ComponentDevice.resolve_device(device)
|
||||
|
||||
if backend == NamedEntityExtractorBackend.HUGGING_FACE:
|
||||
@ -150,8 +151,12 @@ class NamedEntityExtractor:
|
||||
:raises ComponentError:
|
||||
If the backend fails to initialize successfully.
|
||||
"""
|
||||
if self._warmed_up:
|
||||
return
|
||||
|
||||
try:
|
||||
self._backend.initialize()
|
||||
self._warmed_up = True
|
||||
except Exception as e:
|
||||
raise ComponentError(
|
||||
f"Named entity extractor with backend '{self._backend.type} failed to initialize."
|
||||
@ -171,6 +176,10 @@ class NamedEntityExtractor:
|
||||
:raises ComponentError:
|
||||
If the backend fails to process a document.
|
||||
"""
|
||||
if not self._warmed_up:
|
||||
msg = "The component NamedEntityExtractor was not warmed up. Call warm_up() before running the component."
|
||||
raise RuntimeError(msg)
|
||||
|
||||
texts = [doc.content if doc.content is not None else "" for doc in documents]
|
||||
annotations = self._backend.annotate(texts, batch_size=batch_size)
|
||||
|
||||
|
@ -265,12 +265,12 @@ class HuggingFaceLocalChatGenerator:
|
||||
if self.streaming_callback:
|
||||
num_responses = generation_kwargs.get("num_return_sequences", 1)
|
||||
if num_responses > 1:
|
||||
logger.warning(
|
||||
"Streaming is enabled, but the number of responses is set to %d. "
|
||||
msg = (
|
||||
"Streaming is enabled, but the number of responses is set to {num_responses}. "
|
||||
"Streaming is only supported for single response generation. "
|
||||
"Setting the number of responses to 1.",
|
||||
num_responses,
|
||||
"Setting the number of responses to 1."
|
||||
)
|
||||
logger.warning(msg, num_responses=num_responses)
|
||||
generation_kwargs["num_return_sequences"] = 1
|
||||
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
|
||||
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, self.streaming_callback, stop_words)
|
||||
|
@ -150,6 +150,7 @@ class HuggingFaceTGIChatGenerator:
|
||||
self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
|
||||
self.streaming_callback = streaming_callback
|
||||
self.tokenizer = None
|
||||
self._warmed_up: bool = False
|
||||
|
||||
def warm_up(self) -> None:
|
||||
"""
|
||||
@ -158,6 +159,8 @@ class HuggingFaceTGIChatGenerator:
|
||||
If the url is not provided, check if the model is deployed on the free tier of the HF inference API.
|
||||
Load the tokenizer
|
||||
"""
|
||||
if self._warmed_up:
|
||||
return
|
||||
|
||||
# is this user using HF free tier inference API?
|
||||
if self.model and not self.url:
|
||||
@ -184,6 +187,8 @@ class HuggingFaceTGIChatGenerator:
|
||||
model=self.model,
|
||||
)
|
||||
|
||||
self._warmed_up = True
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
@ -229,6 +234,10 @@ class HuggingFaceTGIChatGenerator:
|
||||
:param generation_kwargs: Additional keyword arguments for text generation.
|
||||
:return: A list containing the generated responses as ChatMessage instances.
|
||||
"""
|
||||
if not self._warmed_up:
|
||||
raise RuntimeError(
|
||||
"The component HuggingFaceTGIChatGenerator was not warmed up. Please call warm_up() before running."
|
||||
)
|
||||
|
||||
# check generation kwargs given as parameters to override the default ones
|
||||
additional_params = ["n", "stop_words"]
|
||||
|
@ -140,10 +140,19 @@ class HuggingFaceLocalGenerator:
|
||||
return {"model": self.huggingface_pipeline_kwargs["model"]}
|
||||
return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
|
||||
|
||||
@property
|
||||
def _warmed_up(self) -> bool:
|
||||
if self.stop_words:
|
||||
return (self.pipeline is not None) and (self.stopping_criteria_list is not None)
|
||||
return self.pipeline is not None
|
||||
|
||||
def warm_up(self):
|
||||
"""
|
||||
Initializes the component.
|
||||
"""
|
||||
if self._warmed_up:
|
||||
return
|
||||
|
||||
if self.pipeline is None:
|
||||
self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)
|
||||
|
||||
@ -209,8 +218,10 @@ class HuggingFaceLocalGenerator:
|
||||
A dictionary containing the generated replies.
|
||||
- replies: A list of strings representing the generated replies.
|
||||
"""
|
||||
if self.pipeline is None:
|
||||
raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.")
|
||||
if not self._warmed_up:
|
||||
raise RuntimeError(
|
||||
"The component HuggingFaceLocalGenerator was not warmed up. Please call warm_up() before running."
|
||||
)
|
||||
|
||||
if not prompt:
|
||||
return {"replies": []}
|
||||
@ -221,19 +232,19 @@ class HuggingFaceLocalGenerator:
|
||||
if self.streaming_callback:
|
||||
num_responses = updated_generation_kwargs.get("num_return_sequences", 1)
|
||||
if num_responses > 1:
|
||||
logger.warning(
|
||||
"Streaming is enabled, but the number of responses is set to %d. "
|
||||
msg = (
|
||||
"Streaming is enabled, but the number of responses is set to {num_responses}. "
|
||||
"Streaming is only supported for single response generation. "
|
||||
"Setting the number of responses to 1.",
|
||||
num_responses,
|
||||
"Setting the number of responses to 1."
|
||||
)
|
||||
logger.warning(msg, num_responses=num_responses)
|
||||
updated_generation_kwargs["num_return_sequences"] = 1
|
||||
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
|
||||
updated_generation_kwargs["streamer"] = HFTokenStreamingHandler(
|
||||
self.pipeline.tokenizer, self.streaming_callback, self.stop_words
|
||||
self.pipeline.tokenizer, self.streaming_callback, self.stop_words # type: ignore
|
||||
)
|
||||
|
||||
output = self.pipeline(prompt, stopping_criteria=self.stopping_criteria_list, **updated_generation_kwargs)
|
||||
output = self.pipeline(prompt, stopping_criteria=self.stopping_criteria_list, **updated_generation_kwargs) # type: ignore
|
||||
replies = [o["generated_text"] for o in output if "generated_text" in o]
|
||||
|
||||
if self.stop_words:
|
||||
|
@ -140,6 +140,8 @@ class HuggingFaceTGIGenerator:
|
||||
"""
|
||||
Initializes the component.
|
||||
"""
|
||||
if self.tokenizer is not None:
|
||||
return
|
||||
|
||||
# is this user using HF free tier inference API?
|
||||
if self.model and not self.url:
|
||||
@ -205,6 +207,11 @@ class HuggingFaceTGIGenerator:
|
||||
A dictionary containing the generated replies and metadata. Both are lists of length n.
|
||||
- replies: A list of strings representing the generated replies.
|
||||
"""
|
||||
if not self.tokenizer:
|
||||
raise RuntimeError(
|
||||
"The component HuggingFaceTGIGenerator was not warmed up. Please call warm_up() before running LLM inference."
|
||||
)
|
||||
|
||||
# check generation kwargs given as parameters to override the default ones
|
||||
additional_params = ["n", "stop_words"]
|
||||
check_generation_params(generation_kwargs, additional_params)
|
||||
@ -214,9 +221,6 @@ class HuggingFaceTGIGenerator:
|
||||
num_responses = generation_kwargs.pop("n", 1)
|
||||
generation_kwargs.setdefault("stop_sequences", []).extend(generation_kwargs.pop("stop_words", []))
|
||||
|
||||
if self.tokenizer is None:
|
||||
raise RuntimeError("Please call warm_up() before running LLM inference.")
|
||||
|
||||
prompt_token_count = len(self.tokenizer.encode(prompt, add_special_tokens=False))
|
||||
|
||||
if self.streaming_callback:
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from haystack import ComponentError, Document, component, default_from_dict, default_to_dict, logging
|
||||
from haystack import Document, component, default_from_dict, default_to_dict, logging
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
|
||||
|
||||
@ -231,7 +231,15 @@ class SentenceTransformersDiversityRanker:
|
||||
- `documents`: List of Document objects that have been selected based on the diversity ranking.
|
||||
|
||||
:raises ValueError: If the top_k value is less than or equal to 0.
|
||||
:raises RuntimeError: If the component has not been warmed up.
|
||||
"""
|
||||
if self.model is None:
|
||||
error_msg = (
|
||||
"The component SentenceTransformersDiversityRanker wasn't warmed up. "
|
||||
"Run 'warm_up()' before calling 'run()'."
|
||||
)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
if not documents:
|
||||
return {"documents": []}
|
||||
|
||||
@ -240,13 +248,6 @@ class SentenceTransformersDiversityRanker:
|
||||
elif top_k <= 0:
|
||||
raise ValueError(f"top_k must be > 0, but got {top_k}")
|
||||
|
||||
if self.model is None:
|
||||
error_msg = (
|
||||
"The component SentenceTransformersDiversityRanker wasn't warmed up. "
|
||||
"Run 'warm_up()' before calling 'run()'."
|
||||
)
|
||||
raise ComponentError(error_msg)
|
||||
|
||||
diversity_sorted = self._greedy_diversity_order(query=query, documents=documents)
|
||||
|
||||
return {"documents": diversity_sorted[:top_k]}
|
||||
|
@ -5,7 +5,7 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from haystack import ComponentError, Document, component, default_from_dict, default_to_dict, logging
|
||||
from haystack import Document, component, default_from_dict, default_to_dict, logging
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import ComponentDevice, DeviceMap, Secret, deserialize_secrets_inplace
|
||||
from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_device_map, serialize_hf_model_kwargs
|
||||
@ -218,9 +218,15 @@ class TransformersSimilarityRanker:
|
||||
:raises ValueError:
|
||||
If `top_k` is not > 0.
|
||||
If `scale_score` is True and `calibration_factor` is not provided.
|
||||
:raises ComponentError:
|
||||
:raises RuntimeError:
|
||||
If the model is not loaded because `warm_up()` was not called before.
|
||||
"""
|
||||
# If a model path is provided but the model isn't loaded
|
||||
if self.model is None:
|
||||
raise RuntimeError(
|
||||
"The component TransformersSimilarityRanker wasn't warmed up. Run 'warm_up()' before calling 'run()'."
|
||||
)
|
||||
|
||||
if not documents:
|
||||
return {"documents": []}
|
||||
|
||||
@ -237,12 +243,6 @@ class TransformersSimilarityRanker:
|
||||
f"scale_score is True so calibration_factor must be provided, but got {calibration_factor}"
|
||||
)
|
||||
|
||||
# If a model path is provided but the model isn't loaded
|
||||
if self.model is None:
|
||||
raise ComponentError(
|
||||
f"The component {self.__class__.__name__} wasn't warmed up. Run 'warm_up()' before calling 'run()'."
|
||||
)
|
||||
|
||||
query_doc_pairs = []
|
||||
for doc in documents:
|
||||
meta_values_to_embed = [
|
||||
|
@ -7,7 +7,7 @@ import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from haystack import ComponentError, Document, ExtractedAnswer, component, default_from_dict, default_to_dict, logging
|
||||
from haystack import Document, ExtractedAnswer, component, default_from_dict, default_to_dict, logging
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import ComponentDevice, DeviceMap, Secret, deserialize_secrets_inplace
|
||||
from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_device_map, serialize_hf_model_kwargs
|
||||
@ -571,17 +571,19 @@ class ExtractiveReader:
|
||||
:returns:
|
||||
List of answers sorted by (desc.) answer score.
|
||||
|
||||
:raises ComponentError:
|
||||
:raises RuntimeError:
|
||||
If the component was not warmed up by calling 'warm_up()' before.
|
||||
"""
|
||||
if self.model is None:
|
||||
raise RuntimeError(
|
||||
"The component ExtractiveReader was not warmed up. Run 'warm_up()' before calling 'run()'."
|
||||
)
|
||||
|
||||
if not documents:
|
||||
return {"answers": []}
|
||||
|
||||
queries = [query] # Temporary solution until we have decided what batching should look like in v2
|
||||
nested_documents = [documents]
|
||||
if self.model is None:
|
||||
raise ComponentError("The component was not warmed up. Run 'warm_up()' before calling 'run()'.")
|
||||
|
||||
top_k = top_k or self.top_k
|
||||
score_threshold = score_threshold or self.score_threshold
|
||||
max_seq_length = max_seq_length or self.max_seq_length
|
||||
|
@ -0,0 +1,4 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Make `warm_up()` usage consistent across the codebase.
|
@ -385,7 +385,7 @@ class TestHuggingFaceLocalGenerator:
|
||||
model="google/flan-t5-base", task="text2text-generation", generation_kwargs={"max_new_tokens": 100}
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="The generation model has not been loaded."):
|
||||
with pytest.raises(RuntimeError, match="The component HuggingFaceLocalGenerator was not warmed up"):
|
||||
generator.run(prompt="irrelevant")
|
||||
|
||||
def test_stop_words_criteria_with_a_mocked_tokenizer(self):
|
||||
@ -424,6 +424,7 @@ class TestHuggingFaceLocalGenerator:
|
||||
model="google/flan-t5-small", task="text2text-generation", stop_words=["world"]
|
||||
)
|
||||
generator.pipeline = Mock(return_value=[{"generated_text": "Hello world"}])
|
||||
generator.stopping_criteria_list = Mock()
|
||||
results = generator.run(prompt="irrelevant")
|
||||
assert results == {"replies": ["Hello"]}
|
||||
|
||||
|
@ -6,7 +6,7 @@ from unittest.mock import MagicMock, call, patch
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from haystack import ComponentError, Document
|
||||
from haystack import Document
|
||||
from haystack.components.rankers import SentenceTransformersDiversityRanker
|
||||
from haystack.utils import ComponentDevice
|
||||
from haystack.utils.auth import Secret
|
||||
@ -228,7 +228,7 @@ class TestSentenceTransformersDiversityRanker:
|
||||
documents = [Document(content="doc1"), Document(content="doc2")]
|
||||
|
||||
error_msg = "The component SentenceTransformersDiversityRanker wasn't warmed up."
|
||||
with pytest.raises(ComponentError, match=error_msg):
|
||||
with pytest.raises(RuntimeError, match=error_msg):
|
||||
ranker.run(query="test query", documents=documents)
|
||||
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
|
@ -343,7 +343,7 @@ class TestSimilarityRanker:
|
||||
@pytest.mark.integration
|
||||
def test_raises_component_error_if_model_not_warmed_up(self):
|
||||
sampler = TransformersSimilarityRanker()
|
||||
with pytest.raises(ComponentError):
|
||||
with pytest.raises(RuntimeError):
|
||||
sampler.run(query="query", documents=[Document(content="document")])
|
||||
|
||||
@pytest.mark.integration
|
||||
|
Loading…
x
Reference in New Issue
Block a user