diff --git a/haystack/preview/components/generators/chat/hugging_face_tgi.py b/haystack/preview/components/generators/chat/hugging_face_tgi.py index 0e129915f..a5681e380 100644 --- a/haystack/preview/components/generators/chat/hugging_face_tgi.py +++ b/haystack/preview/components/generators/chat/hugging_face_tgi.py @@ -3,14 +3,16 @@ from dataclasses import asdict from typing import Any, Dict, List, Optional, Iterable, Callable from urllib.parse import urlparse -from huggingface_hub import InferenceClient -from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, TextGenerationResponse, Token -from transformers import AutoTokenizer - from haystack.preview import component, default_to_dict, default_from_dict -from haystack.preview.components.generators.hf_utils import check_valid_model, check_generation_params from haystack.preview.components.generators.utils import serialize_callback_handler, deserialize_callback_handler from haystack.preview.dataclasses import ChatMessage, StreamingChunk +from haystack.preview.components.generators.hf_utils import check_valid_model, check_generation_params +from haystack.lazy_imports import LazyImport + +with LazyImport(message="Run 'pip install transformers'") as transformers_import: + from huggingface_hub import InferenceClient + from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, TextGenerationResponse, Token + from transformers import AutoTokenizer logger = logging.getLogger(__name__) @@ -108,6 +110,8 @@ class HuggingFaceTGIChatGenerator: :param stop_words: An optional list of strings representing the stop words. :param streaming_callback: An optional callable for handling streaming responses. """ + transformers_import.check() + if url: r = urlparse(url) is_valid_url = all([r.scheme in ["http", "https"], r.netloc]) diff --git a/haystack/preview/components/generators/hf_utils.py b/haystack/preview/components/generators/hf_utils.py index 9107b4152..832a99628 100644 --- a/haystack/preview/components/generators/hf_utils.py +++ b/haystack/preview/components/generators/hf_utils.py @@ -1,8 +1,11 @@ import inspect from typing import Any, Dict, List, Optional -from huggingface_hub import InferenceClient, HfApi -from huggingface_hub.utils import RepositoryNotFoundError +from haystack.lazy_imports import LazyImport + +with LazyImport(message="Run 'pip install transformers'") as transformers_import: + from huggingface_hub import InferenceClient, HfApi + from huggingface_hub.utils import RepositoryNotFoundError def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepted_params: Optional[List[str]] = None): @@ -13,6 +16,8 @@ def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepte :param additional_accepted_params: An optional list of strings representing additional accepted parameters. :raises ValueError: If any unknown text generation parameters are provided. """ + transformers_import.check() + if kwargs: accepted_params = { param @@ -37,6 +42,8 @@ def check_valid_model(model_id: str, token: Optional[str]) -> None: :param token: An optional string representing the authentication token. :raises ValueError: If the model is not found or is not a text generation model. """ + transformers_import.check() + api = HfApi() try: model_info = api.model_info(model_id, token=token) diff --git a/haystack/preview/components/generators/hugging_face_tgi.py b/haystack/preview/components/generators/hugging_face_tgi.py index ed0d230ea..5d3526f03 100644 --- a/haystack/preview/components/generators/hugging_face_tgi.py +++ b/haystack/preview/components/generators/hugging_face_tgi.py @@ -3,14 +3,17 @@ from dataclasses import asdict from typing import Any, Dict, List, Optional, Iterable, Callable from urllib.parse import urlparse -from huggingface_hub import InferenceClient -from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, TextGenerationResponse, Token -from transformers import AutoTokenizer - from haystack.preview import component, default_to_dict, default_from_dict -from haystack.preview.components.generators.hf_utils import check_generation_params, check_valid_model from haystack.preview.components.generators.utils import serialize_callback_handler, deserialize_callback_handler from haystack.preview.dataclasses import StreamingChunk +from haystack.lazy_imports import LazyImport +from haystack.preview.components.generators.hf_utils import check_generation_params, check_valid_model + +with LazyImport(message="Run 'pip install transformers'") as transformers_import: + from huggingface_hub import InferenceClient + from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, TextGenerationResponse, Token + from transformers import AutoTokenizer + logger = logging.getLogger(__name__) @@ -90,6 +93,8 @@ class HuggingFaceTGIGenerator: :param stop_words: An optional list of strings representing the stop words. :param streaming_callback: An optional callable for handling streaming responses. """ + transformers_import.check() + if url: r = urlparse(url) is_valid_url = all([r.scheme in ["http", "https"], r.netloc])