From d59444543a3eaf1a3abebc78b58f95b397bce0cb Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 22 Mar 2024 09:43:45 +0100 Subject: [PATCH] fix: put `HFTokenStreamingHandler` in a lazy_import block (#7403) * put HFTokenStreamingHandler in a lazy_import block * fix pylint --- haystack/components/generators/hugging_face_local.py | 8 ++++++-- haystack/utils/hf.py | 3 +-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/haystack/components/generators/hugging_face_local.py b/haystack/components/generators/hugging_face_local.py index 33cc69063..9e0c607de 100644 --- a/haystack/components/generators/hugging_face_local.py +++ b/haystack/components/generators/hugging_face_local.py @@ -10,7 +10,7 @@ from haystack.utils import ( deserialize_secrets_inplace, serialize_callable, ) -from haystack.utils.hf import HFTokenStreamingHandler, deserialize_hf_model_kwargs, serialize_hf_model_kwargs +from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs logger = logging.getLogger(__name__) @@ -19,7 +19,11 @@ SUPPORTED_TASKS = ["text-generation", "text2text-generation"] with LazyImport(message="Run 'pip install transformers[torch]'") as transformers_import: from transformers import StoppingCriteriaList, pipeline - from haystack.utils.hf import StopWordsCriteria, resolve_hf_pipeline_kwargs # pylint: disable=ungrouped-imports + from haystack.utils.hf import ( # pylint: disable=ungrouped-imports + HFTokenStreamingHandler, + StopWordsCriteria, + resolve_hf_pipeline_kwargs, + ) @component diff --git a/haystack/utils/hf.py b/haystack/utils/hf.py index 6e53e3bfc..b3afe20c1 100644 --- a/haystack/utils/hf.py +++ b/haystack/utils/hf.py @@ -222,8 +222,7 @@ def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepte with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_transformers_import: from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, StoppingCriteria, TextStreamer - transformers_import.check() - torch_import.check() + torch_and_transformers_import.check() class StopWordsCriteria(StoppingCriteria): """