mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-26 16:46:58 +00:00
chore: make warm_up()
usage consistent (#7752)
* make usage consistent * fix error type * release notes * pylint fix * change of plan * revert * fix test * revert * fix HF tests * Apply suggestions from code review Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> * fix formatting * reformat * fix regex match with the new error message * fix integration test --------- Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
This commit is contained in:
parent
15aa4217bd
commit
0ceeb733ba
@ -6,7 +6,7 @@ import tempfile
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union, get_args
|
from typing import Any, Dict, List, Literal, Optional, Union, get_args
|
||||||
|
|
||||||
from haystack import ComponentError, Document, component, default_from_dict, default_to_dict, logging
|
from haystack import Document, component, default_from_dict, default_to_dict, logging
|
||||||
from haystack.dataclasses import ByteStream
|
from haystack.dataclasses import ByteStream
|
||||||
from haystack.lazy_imports import LazyImport
|
from haystack.lazy_imports import LazyImport
|
||||||
from haystack.utils import ComponentDevice
|
from haystack.utils import ComponentDevice
|
||||||
@ -113,7 +113,9 @@ class LocalWhisperTranscriber:
|
|||||||
alignment data and the path to the audio file used for the transcription.
|
alignment data and the path to the audio file used for the transcription.
|
||||||
"""
|
"""
|
||||||
if self._model is None:
|
if self._model is None:
|
||||||
raise ComponentError("The component was not warmed up. Run 'warm_up()' before calling 'run()'.")
|
raise RuntimeError(
|
||||||
|
"The component LocalWhisperTranscriber was not warmed up. Run 'warm_up()' before calling 'run()'."
|
||||||
|
)
|
||||||
|
|
||||||
if whisper_params is None:
|
if whisper_params is None:
|
||||||
whisper_params = self.whisper_params
|
whisper_params = self.whisper_params
|
||||||
@ -156,7 +158,7 @@ class LocalWhisperTranscriber:
|
|||||||
A dictionary mapping 'file_path' to 'transcription'.
|
A dictionary mapping 'file_path' to 'transcription'.
|
||||||
"""
|
"""
|
||||||
if self._model is None:
|
if self._model is None:
|
||||||
raise ComponentError("Model is not loaded, please run 'warm_up()' before calling 'run()'")
|
raise RuntimeError("Model is not loaded, please run 'warm_up()' before calling 'run()'")
|
||||||
|
|
||||||
return_segments = kwargs.pop("return_segments", False)
|
return_segments = kwargs.pop("return_segments", False)
|
||||||
transcriptions: Dict[Path, Any] = {}
|
transcriptions: Dict[Path, Any] = {}
|
||||||
|
@ -116,6 +116,9 @@ class SASEvaluator:
|
|||||||
"""
|
"""
|
||||||
Initializes the component.
|
Initializes the component.
|
||||||
"""
|
"""
|
||||||
|
if self._similarity_model:
|
||||||
|
return
|
||||||
|
|
||||||
token = self._token.resolve_value() if self._token else None
|
token = self._token.resolve_value() if self._token else None
|
||||||
config = AutoConfig.from_pretrained(self._model, use_auth_token=token)
|
config = AutoConfig.from_pretrained(self._model, use_auth_token=token)
|
||||||
cross_encoder_used = False
|
cross_encoder_used = False
|
||||||
|
@ -134,6 +134,7 @@ class NamedEntityExtractor:
|
|||||||
backend = NamedEntityExtractorBackend.from_str(backend)
|
backend = NamedEntityExtractorBackend.from_str(backend)
|
||||||
|
|
||||||
self._backend: _NerBackend
|
self._backend: _NerBackend
|
||||||
|
self._warmed_up: bool = False
|
||||||
device = ComponentDevice.resolve_device(device)
|
device = ComponentDevice.resolve_device(device)
|
||||||
|
|
||||||
if backend == NamedEntityExtractorBackend.HUGGING_FACE:
|
if backend == NamedEntityExtractorBackend.HUGGING_FACE:
|
||||||
@ -150,8 +151,12 @@ class NamedEntityExtractor:
|
|||||||
:raises ComponentError:
|
:raises ComponentError:
|
||||||
If the backend fails to initialize successfully.
|
If the backend fails to initialize successfully.
|
||||||
"""
|
"""
|
||||||
|
if self._warmed_up:
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._backend.initialize()
|
self._backend.initialize()
|
||||||
|
self._warmed_up = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ComponentError(
|
raise ComponentError(
|
||||||
f"Named entity extractor with backend '{self._backend.type} failed to initialize."
|
f"Named entity extractor with backend '{self._backend.type} failed to initialize."
|
||||||
@ -171,6 +176,10 @@ class NamedEntityExtractor:
|
|||||||
:raises ComponentError:
|
:raises ComponentError:
|
||||||
If the backend fails to process a document.
|
If the backend fails to process a document.
|
||||||
"""
|
"""
|
||||||
|
if not self._warmed_up:
|
||||||
|
msg = "The component NamedEntityExtractor was not warmed up. Call warm_up() before running the component."
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
texts = [doc.content if doc.content is not None else "" for doc in documents]
|
texts = [doc.content if doc.content is not None else "" for doc in documents]
|
||||||
annotations = self._backend.annotate(texts, batch_size=batch_size)
|
annotations = self._backend.annotate(texts, batch_size=batch_size)
|
||||||
|
|
||||||
|
@ -265,12 +265,12 @@ class HuggingFaceLocalChatGenerator:
|
|||||||
if self.streaming_callback:
|
if self.streaming_callback:
|
||||||
num_responses = generation_kwargs.get("num_return_sequences", 1)
|
num_responses = generation_kwargs.get("num_return_sequences", 1)
|
||||||
if num_responses > 1:
|
if num_responses > 1:
|
||||||
logger.warning(
|
msg = (
|
||||||
"Streaming is enabled, but the number of responses is set to %d. "
|
"Streaming is enabled, but the number of responses is set to {num_responses}. "
|
||||||
"Streaming is only supported for single response generation. "
|
"Streaming is only supported for single response generation. "
|
||||||
"Setting the number of responses to 1.",
|
"Setting the number of responses to 1."
|
||||||
num_responses,
|
|
||||||
)
|
)
|
||||||
|
logger.warning(msg, num_responses=num_responses)
|
||||||
generation_kwargs["num_return_sequences"] = 1
|
generation_kwargs["num_return_sequences"] = 1
|
||||||
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
|
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
|
||||||
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, self.streaming_callback, stop_words)
|
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, self.streaming_callback, stop_words)
|
||||||
|
@ -150,6 +150,7 @@ class HuggingFaceTGIChatGenerator:
|
|||||||
self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
|
self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
|
||||||
self.streaming_callback = streaming_callback
|
self.streaming_callback = streaming_callback
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
|
self._warmed_up: bool = False
|
||||||
|
|
||||||
def warm_up(self) -> None:
|
def warm_up(self) -> None:
|
||||||
"""
|
"""
|
||||||
@ -158,6 +159,8 @@ class HuggingFaceTGIChatGenerator:
|
|||||||
If the url is not provided, check if the model is deployed on the free tier of the HF inference API.
|
If the url is not provided, check if the model is deployed on the free tier of the HF inference API.
|
||||||
Load the tokenizer
|
Load the tokenizer
|
||||||
"""
|
"""
|
||||||
|
if self._warmed_up:
|
||||||
|
return
|
||||||
|
|
||||||
# is this user using HF free tier inference API?
|
# is this user using HF free tier inference API?
|
||||||
if self.model and not self.url:
|
if self.model and not self.url:
|
||||||
@ -184,6 +187,8 @@ class HuggingFaceTGIChatGenerator:
|
|||||||
model=self.model,
|
model=self.model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._warmed_up = True
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Serialize this component to a dictionary.
|
Serialize this component to a dictionary.
|
||||||
@ -229,6 +234,10 @@ class HuggingFaceTGIChatGenerator:
|
|||||||
:param generation_kwargs: Additional keyword arguments for text generation.
|
:param generation_kwargs: Additional keyword arguments for text generation.
|
||||||
:return: A list containing the generated responses as ChatMessage instances.
|
:return: A list containing the generated responses as ChatMessage instances.
|
||||||
"""
|
"""
|
||||||
|
if not self._warmed_up:
|
||||||
|
raise RuntimeError(
|
||||||
|
"The component HuggingFaceTGIChatGenerator was not warmed up. Please call warm_up() before running."
|
||||||
|
)
|
||||||
|
|
||||||
# check generation kwargs given as parameters to override the default ones
|
# check generation kwargs given as parameters to override the default ones
|
||||||
additional_params = ["n", "stop_words"]
|
additional_params = ["n", "stop_words"]
|
||||||
|
@ -140,10 +140,19 @@ class HuggingFaceLocalGenerator:
|
|||||||
return {"model": self.huggingface_pipeline_kwargs["model"]}
|
return {"model": self.huggingface_pipeline_kwargs["model"]}
|
||||||
return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
|
return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _warmed_up(self) -> bool:
|
||||||
|
if self.stop_words:
|
||||||
|
return (self.pipeline is not None) and (self.stopping_criteria_list is not None)
|
||||||
|
return self.pipeline is not None
|
||||||
|
|
||||||
def warm_up(self):
|
def warm_up(self):
|
||||||
"""
|
"""
|
||||||
Initializes the component.
|
Initializes the component.
|
||||||
"""
|
"""
|
||||||
|
if self._warmed_up:
|
||||||
|
return
|
||||||
|
|
||||||
if self.pipeline is None:
|
if self.pipeline is None:
|
||||||
self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)
|
self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)
|
||||||
|
|
||||||
@ -209,8 +218,10 @@ class HuggingFaceLocalGenerator:
|
|||||||
A dictionary containing the generated replies.
|
A dictionary containing the generated replies.
|
||||||
- replies: A list of strings representing the generated replies.
|
- replies: A list of strings representing the generated replies.
|
||||||
"""
|
"""
|
||||||
if self.pipeline is None:
|
if not self._warmed_up:
|
||||||
raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.")
|
raise RuntimeError(
|
||||||
|
"The component HuggingFaceLocalGenerator was not warmed up. Please call warm_up() before running."
|
||||||
|
)
|
||||||
|
|
||||||
if not prompt:
|
if not prompt:
|
||||||
return {"replies": []}
|
return {"replies": []}
|
||||||
@ -221,19 +232,19 @@ class HuggingFaceLocalGenerator:
|
|||||||
if self.streaming_callback:
|
if self.streaming_callback:
|
||||||
num_responses = updated_generation_kwargs.get("num_return_sequences", 1)
|
num_responses = updated_generation_kwargs.get("num_return_sequences", 1)
|
||||||
if num_responses > 1:
|
if num_responses > 1:
|
||||||
logger.warning(
|
msg = (
|
||||||
"Streaming is enabled, but the number of responses is set to %d. "
|
"Streaming is enabled, but the number of responses is set to {num_responses}. "
|
||||||
"Streaming is only supported for single response generation. "
|
"Streaming is only supported for single response generation. "
|
||||||
"Setting the number of responses to 1.",
|
"Setting the number of responses to 1."
|
||||||
num_responses,
|
|
||||||
)
|
)
|
||||||
|
logger.warning(msg, num_responses=num_responses)
|
||||||
updated_generation_kwargs["num_return_sequences"] = 1
|
updated_generation_kwargs["num_return_sequences"] = 1
|
||||||
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
|
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
|
||||||
updated_generation_kwargs["streamer"] = HFTokenStreamingHandler(
|
updated_generation_kwargs["streamer"] = HFTokenStreamingHandler(
|
||||||
self.pipeline.tokenizer, self.streaming_callback, self.stop_words
|
self.pipeline.tokenizer, self.streaming_callback, self.stop_words # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
output = self.pipeline(prompt, stopping_criteria=self.stopping_criteria_list, **updated_generation_kwargs)
|
output = self.pipeline(prompt, stopping_criteria=self.stopping_criteria_list, **updated_generation_kwargs) # type: ignore
|
||||||
replies = [o["generated_text"] for o in output if "generated_text" in o]
|
replies = [o["generated_text"] for o in output if "generated_text" in o]
|
||||||
|
|
||||||
if self.stop_words:
|
if self.stop_words:
|
||||||
|
@ -140,6 +140,8 @@ class HuggingFaceTGIGenerator:
|
|||||||
"""
|
"""
|
||||||
Initializes the component.
|
Initializes the component.
|
||||||
"""
|
"""
|
||||||
|
if self.tokenizer is not None:
|
||||||
|
return
|
||||||
|
|
||||||
# is this user using HF free tier inference API?
|
# is this user using HF free tier inference API?
|
||||||
if self.model and not self.url:
|
if self.model and not self.url:
|
||||||
@ -205,6 +207,11 @@ class HuggingFaceTGIGenerator:
|
|||||||
A dictionary containing the generated replies and metadata. Both are lists of length n.
|
A dictionary containing the generated replies and metadata. Both are lists of length n.
|
||||||
- replies: A list of strings representing the generated replies.
|
- replies: A list of strings representing the generated replies.
|
||||||
"""
|
"""
|
||||||
|
if not self.tokenizer:
|
||||||
|
raise RuntimeError(
|
||||||
|
"The component HuggingFaceTGIGenerator was not warmed up. Please call warm_up() before running LLM inference."
|
||||||
|
)
|
||||||
|
|
||||||
# check generation kwargs given as parameters to override the default ones
|
# check generation kwargs given as parameters to override the default ones
|
||||||
additional_params = ["n", "stop_words"]
|
additional_params = ["n", "stop_words"]
|
||||||
check_generation_params(generation_kwargs, additional_params)
|
check_generation_params(generation_kwargs, additional_params)
|
||||||
@ -214,9 +221,6 @@ class HuggingFaceTGIGenerator:
|
|||||||
num_responses = generation_kwargs.pop("n", 1)
|
num_responses = generation_kwargs.pop("n", 1)
|
||||||
generation_kwargs.setdefault("stop_sequences", []).extend(generation_kwargs.pop("stop_words", []))
|
generation_kwargs.setdefault("stop_sequences", []).extend(generation_kwargs.pop("stop_words", []))
|
||||||
|
|
||||||
if self.tokenizer is None:
|
|
||||||
raise RuntimeError("Please call warm_up() before running LLM inference.")
|
|
||||||
|
|
||||||
prompt_token_count = len(self.tokenizer.encode(prompt, add_special_tokens=False))
|
prompt_token_count = len(self.tokenizer.encode(prompt, add_special_tokens=False))
|
||||||
|
|
||||||
if self.streaming_callback:
|
if self.streaming_callback:
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
from typing import Any, Dict, List, Literal, Optional
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
from haystack import ComponentError, Document, component, default_from_dict, default_to_dict, logging
|
from haystack import Document, component, default_from_dict, default_to_dict, logging
|
||||||
from haystack.lazy_imports import LazyImport
|
from haystack.lazy_imports import LazyImport
|
||||||
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
|
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
|
||||||
|
|
||||||
@ -231,7 +231,15 @@ class SentenceTransformersDiversityRanker:
|
|||||||
- `documents`: List of Document objects that have been selected based on the diversity ranking.
|
- `documents`: List of Document objects that have been selected based on the diversity ranking.
|
||||||
|
|
||||||
:raises ValueError: If the top_k value is less than or equal to 0.
|
:raises ValueError: If the top_k value is less than or equal to 0.
|
||||||
|
:raises RuntimeError: If the component has not been warmed up.
|
||||||
"""
|
"""
|
||||||
|
if self.model is None:
|
||||||
|
error_msg = (
|
||||||
|
"The component SentenceTransformersDiversityRanker wasn't warmed up. "
|
||||||
|
"Run 'warm_up()' before calling 'run()'."
|
||||||
|
)
|
||||||
|
raise RuntimeError(error_msg)
|
||||||
|
|
||||||
if not documents:
|
if not documents:
|
||||||
return {"documents": []}
|
return {"documents": []}
|
||||||
|
|
||||||
@ -240,13 +248,6 @@ class SentenceTransformersDiversityRanker:
|
|||||||
elif top_k <= 0:
|
elif top_k <= 0:
|
||||||
raise ValueError(f"top_k must be > 0, but got {top_k}")
|
raise ValueError(f"top_k must be > 0, but got {top_k}")
|
||||||
|
|
||||||
if self.model is None:
|
|
||||||
error_msg = (
|
|
||||||
"The component SentenceTransformersDiversityRanker wasn't warmed up. "
|
|
||||||
"Run 'warm_up()' before calling 'run()'."
|
|
||||||
)
|
|
||||||
raise ComponentError(error_msg)
|
|
||||||
|
|
||||||
diversity_sorted = self._greedy_diversity_order(query=query, documents=documents)
|
diversity_sorted = self._greedy_diversity_order(query=query, documents=documents)
|
||||||
|
|
||||||
return {"documents": diversity_sorted[:top_k]}
|
return {"documents": diversity_sorted[:top_k]}
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from haystack import ComponentError, Document, component, default_from_dict, default_to_dict, logging
|
from haystack import Document, component, default_from_dict, default_to_dict, logging
|
||||||
from haystack.lazy_imports import LazyImport
|
from haystack.lazy_imports import LazyImport
|
||||||
from haystack.utils import ComponentDevice, DeviceMap, Secret, deserialize_secrets_inplace
|
from haystack.utils import ComponentDevice, DeviceMap, Secret, deserialize_secrets_inplace
|
||||||
from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_device_map, serialize_hf_model_kwargs
|
from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_device_map, serialize_hf_model_kwargs
|
||||||
@ -218,9 +218,15 @@ class TransformersSimilarityRanker:
|
|||||||
:raises ValueError:
|
:raises ValueError:
|
||||||
If `top_k` is not > 0.
|
If `top_k` is not > 0.
|
||||||
If `scale_score` is True and `calibration_factor` is not provided.
|
If `scale_score` is True and `calibration_factor` is not provided.
|
||||||
:raises ComponentError:
|
:raises RuntimeError:
|
||||||
If the model is not loaded because `warm_up()` was not called before.
|
If the model is not loaded because `warm_up()` was not called before.
|
||||||
"""
|
"""
|
||||||
|
# If a model path is provided but the model isn't loaded
|
||||||
|
if self.model is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"The component TransformersSimilarityRanker wasn't warmed up. Run 'warm_up()' before calling 'run()'."
|
||||||
|
)
|
||||||
|
|
||||||
if not documents:
|
if not documents:
|
||||||
return {"documents": []}
|
return {"documents": []}
|
||||||
|
|
||||||
@ -237,12 +243,6 @@ class TransformersSimilarityRanker:
|
|||||||
f"scale_score is True so calibration_factor must be provided, but got {calibration_factor}"
|
f"scale_score is True so calibration_factor must be provided, but got {calibration_factor}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# If a model path is provided but the model isn't loaded
|
|
||||||
if self.model is None:
|
|
||||||
raise ComponentError(
|
|
||||||
f"The component {self.__class__.__name__} wasn't warmed up. Run 'warm_up()' before calling 'run()'."
|
|
||||||
)
|
|
||||||
|
|
||||||
query_doc_pairs = []
|
query_doc_pairs = []
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
meta_values_to_embed = [
|
meta_values_to_embed = [
|
||||||
|
@ -7,7 +7,7 @@ import warnings
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from haystack import ComponentError, Document, ExtractedAnswer, component, default_from_dict, default_to_dict, logging
|
from haystack import Document, ExtractedAnswer, component, default_from_dict, default_to_dict, logging
|
||||||
from haystack.lazy_imports import LazyImport
|
from haystack.lazy_imports import LazyImport
|
||||||
from haystack.utils import ComponentDevice, DeviceMap, Secret, deserialize_secrets_inplace
|
from haystack.utils import ComponentDevice, DeviceMap, Secret, deserialize_secrets_inplace
|
||||||
from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_device_map, serialize_hf_model_kwargs
|
from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_device_map, serialize_hf_model_kwargs
|
||||||
@ -571,17 +571,19 @@ class ExtractiveReader:
|
|||||||
:returns:
|
:returns:
|
||||||
List of answers sorted by (desc.) answer score.
|
List of answers sorted by (desc.) answer score.
|
||||||
|
|
||||||
:raises ComponentError:
|
:raises RuntimeError:
|
||||||
If the component was not warmed up by calling 'warm_up()' before.
|
If the component was not warmed up by calling 'warm_up()' before.
|
||||||
"""
|
"""
|
||||||
|
if self.model is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"The component ExtractiveReader was not warmed up. Run 'warm_up()' before calling 'run()'."
|
||||||
|
)
|
||||||
|
|
||||||
if not documents:
|
if not documents:
|
||||||
return {"answers": []}
|
return {"answers": []}
|
||||||
|
|
||||||
queries = [query] # Temporary solution until we have decided what batching should look like in v2
|
queries = [query] # Temporary solution until we have decided what batching should look like in v2
|
||||||
nested_documents = [documents]
|
nested_documents = [documents]
|
||||||
if self.model is None:
|
|
||||||
raise ComponentError("The component was not warmed up. Run 'warm_up()' before calling 'run()'.")
|
|
||||||
|
|
||||||
top_k = top_k or self.top_k
|
top_k = top_k or self.top_k
|
||||||
score_threshold = score_threshold or self.score_threshold
|
score_threshold = score_threshold or self.score_threshold
|
||||||
max_seq_length = max_seq_length or self.max_seq_length
|
max_seq_length = max_seq_length or self.max_seq_length
|
||||||
|
@ -0,0 +1,4 @@
|
|||||||
|
---
|
||||||
|
enhancements:
|
||||||
|
- |
|
||||||
|
Make `warm_up()` usage consistent across the codebase.
|
@ -385,7 +385,7 @@ class TestHuggingFaceLocalGenerator:
|
|||||||
model="google/flan-t5-base", task="text2text-generation", generation_kwargs={"max_new_tokens": 100}
|
model="google/flan-t5-base", task="text2text-generation", generation_kwargs={"max_new_tokens": 100}
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(RuntimeError, match="The generation model has not been loaded."):
|
with pytest.raises(RuntimeError, match="The component HuggingFaceLocalGenerator was not warmed up"):
|
||||||
generator.run(prompt="irrelevant")
|
generator.run(prompt="irrelevant")
|
||||||
|
|
||||||
def test_stop_words_criteria_with_a_mocked_tokenizer(self):
|
def test_stop_words_criteria_with_a_mocked_tokenizer(self):
|
||||||
@ -424,6 +424,7 @@ class TestHuggingFaceLocalGenerator:
|
|||||||
model="google/flan-t5-small", task="text2text-generation", stop_words=["world"]
|
model="google/flan-t5-small", task="text2text-generation", stop_words=["world"]
|
||||||
)
|
)
|
||||||
generator.pipeline = Mock(return_value=[{"generated_text": "Hello world"}])
|
generator.pipeline = Mock(return_value=[{"generated_text": "Hello world"}])
|
||||||
|
generator.stopping_criteria_list = Mock()
|
||||||
results = generator.run(prompt="irrelevant")
|
results = generator.run(prompt="irrelevant")
|
||||||
assert results == {"replies": ["Hello"]}
|
assert results == {"replies": ["Hello"]}
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from unittest.mock import MagicMock, call, patch
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from haystack import ComponentError, Document
|
from haystack import Document
|
||||||
from haystack.components.rankers import SentenceTransformersDiversityRanker
|
from haystack.components.rankers import SentenceTransformersDiversityRanker
|
||||||
from haystack.utils import ComponentDevice
|
from haystack.utils import ComponentDevice
|
||||||
from haystack.utils.auth import Secret
|
from haystack.utils.auth import Secret
|
||||||
@ -228,7 +228,7 @@ class TestSentenceTransformersDiversityRanker:
|
|||||||
documents = [Document(content="doc1"), Document(content="doc2")]
|
documents = [Document(content="doc1"), Document(content="doc2")]
|
||||||
|
|
||||||
error_msg = "The component SentenceTransformersDiversityRanker wasn't warmed up."
|
error_msg = "The component SentenceTransformersDiversityRanker wasn't warmed up."
|
||||||
with pytest.raises(ComponentError, match=error_msg):
|
with pytest.raises(RuntimeError, match=error_msg):
|
||||||
ranker.run(query="test query", documents=documents)
|
ranker.run(query="test query", documents=documents)
|
||||||
|
|
||||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||||
|
@ -343,7 +343,7 @@ class TestSimilarityRanker:
|
|||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
def test_raises_component_error_if_model_not_warmed_up(self):
|
def test_raises_component_error_if_model_not_warmed_up(self):
|
||||||
sampler = TransformersSimilarityRanker()
|
sampler = TransformersSimilarityRanker()
|
||||||
with pytest.raises(ComponentError):
|
with pytest.raises(RuntimeError):
|
||||||
sampler.run(query="query", documents=[Document(content="document")])
|
sampler.run(query="query", documents=[Document(content="document")])
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
|
Loading…
x
Reference in New Issue
Block a user