fix: put HFTokenStreamingHandler in a lazy_import block (#7403)

* put HFTokenStreamingHandler in a lazy_import block

* fix pylint
This commit is contained in:
Stefano Fiorucci 2024-03-22 09:43:45 +01:00 committed by GitHub
parent c789f905bc
commit d59444543a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 4 deletions

View File

@ -10,7 +10,7 @@ from haystack.utils import (
deserialize_secrets_inplace, deserialize_secrets_inplace,
serialize_callable, serialize_callable,
) )
from haystack.utils.hf import HFTokenStreamingHandler, deserialize_hf_model_kwargs, serialize_hf_model_kwargs from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -19,7 +19,11 @@ SUPPORTED_TASKS = ["text-generation", "text2text-generation"]
with LazyImport(message="Run 'pip install transformers[torch]'") as transformers_import: with LazyImport(message="Run 'pip install transformers[torch]'") as transformers_import:
from transformers import StoppingCriteriaList, pipeline from transformers import StoppingCriteriaList, pipeline
from haystack.utils.hf import StopWordsCriteria, resolve_hf_pipeline_kwargs # pylint: disable=ungrouped-imports from haystack.utils.hf import ( # pylint: disable=ungrouped-imports
HFTokenStreamingHandler,
StopWordsCriteria,
resolve_hf_pipeline_kwargs,
)
@component @component

View File

@ -222,8 +222,7 @@ def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepte
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_transformers_import: with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_transformers_import:
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, StoppingCriteria, TextStreamer from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, StoppingCriteria, TextStreamer
transformers_import.check() torch_and_transformers_import.check()
torch_import.check()
class StopWordsCriteria(StoppingCriteria): class StopWordsCriteria(StoppingCriteria):
""" """