fix: put HFTokenStreamingHandler in a lazy_import block (#7403)

* put HFTokenStreamingHandler in a lazy_import block

* fix pylint
This commit is contained in:
Stefano Fiorucci 2024-03-22 09:43:45 +01:00 committed by GitHub
parent c789f905bc
commit d59444543a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 4 deletions

View File

@ -10,7 +10,7 @@ from haystack.utils import (
deserialize_secrets_inplace,
serialize_callable,
)
from haystack.utils.hf import HFTokenStreamingHandler, deserialize_hf_model_kwargs, serialize_hf_model_kwargs
from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs
logger = logging.getLogger(__name__)
@ -19,7 +19,11 @@ SUPPORTED_TASKS = ["text-generation", "text2text-generation"]
with LazyImport(message="Run 'pip install transformers[torch]'") as transformers_import:
from transformers import StoppingCriteriaList, pipeline
from haystack.utils.hf import StopWordsCriteria, resolve_hf_pipeline_kwargs # pylint: disable=ungrouped-imports
from haystack.utils.hf import ( # pylint: disable=ungrouped-imports
HFTokenStreamingHandler,
StopWordsCriteria,
resolve_hf_pipeline_kwargs,
)
@component

View File

@ -222,8 +222,7 @@ def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepte
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_transformers_import:
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, StoppingCriteria, TextStreamer
transformers_import.check()
torch_import.check()
torch_and_transformers_import.check()
class StopWordsCriteria(StoppingCriteria):
"""