fix: lazy import transformers in TGI Generators (#6252)

* lazy import transformers in tgi

* fix pylint
This commit is contained in:
Stefano Fiorucci 2023-11-08 00:09:42 +01:00 committed by GitHub
parent 5497ca2a45
commit e2881e2ad3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 12 deletions

View File

@ -3,14 +3,16 @@ from dataclasses import asdict
from typing import Any, Dict, List, Optional, Iterable, Callable
from urllib.parse import urlparse
from huggingface_hub import InferenceClient
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, TextGenerationResponse, Token
from transformers import AutoTokenizer
from haystack.preview import component, default_to_dict, default_from_dict
from haystack.preview.components.generators.hf_utils import check_valid_model, check_generation_params
from haystack.preview.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
from haystack.preview.dataclasses import ChatMessage, StreamingChunk
from haystack.preview.components.generators.hf_utils import check_valid_model, check_generation_params
from haystack.lazy_imports import LazyImport
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub import InferenceClient
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, TextGenerationResponse, Token
from transformers import AutoTokenizer
logger = logging.getLogger(__name__)
@ -108,6 +110,8 @@ class HuggingFaceTGIChatGenerator:
:param stop_words: An optional list of strings representing the stop words.
:param streaming_callback: An optional callable for handling streaming responses.
"""
transformers_import.check()
if url:
r = urlparse(url)
is_valid_url = all([r.scheme in ["http", "https"], r.netloc])

View File

@ -1,8 +1,11 @@
import inspect
from typing import Any, Dict, List, Optional
from huggingface_hub import InferenceClient, HfApi
from huggingface_hub.utils import RepositoryNotFoundError
from haystack.lazy_imports import LazyImport
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub import InferenceClient, HfApi
from huggingface_hub.utils import RepositoryNotFoundError
def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepted_params: Optional[List[str]] = None):
@ -13,6 +16,8 @@ def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepte
:param additional_accepted_params: An optional list of strings representing additional accepted parameters.
:raises ValueError: If any unknown text generation parameters are provided.
"""
transformers_import.check()
if kwargs:
accepted_params = {
param
@ -37,6 +42,8 @@ def check_valid_model(model_id: str, token: Optional[str]) -> None:
:param token: An optional string representing the authentication token.
:raises ValueError: If the model is not found or is not a text generation model.
"""
transformers_import.check()
api = HfApi()
try:
model_info = api.model_info(model_id, token=token)

View File

@ -3,14 +3,17 @@ from dataclasses import asdict
from typing import Any, Dict, List, Optional, Iterable, Callable
from urllib.parse import urlparse
from huggingface_hub import InferenceClient
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, TextGenerationResponse, Token
from transformers import AutoTokenizer
from haystack.preview import component, default_to_dict, default_from_dict
from haystack.preview.components.generators.hf_utils import check_generation_params, check_valid_model
from haystack.preview.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
from haystack.preview.dataclasses import StreamingChunk
from haystack.lazy_imports import LazyImport
from haystack.preview.components.generators.hf_utils import check_generation_params, check_valid_model
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub import InferenceClient
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, TextGenerationResponse, Token
from transformers import AutoTokenizer
logger = logging.getLogger(__name__)
@ -90,6 +93,8 @@ class HuggingFaceTGIGenerator:
:param stop_words: An optional list of strings representing the stop words.
:param streaming_callback: An optional callable for handling streaming responses.
"""
transformers_import.check()
if url:
r = urlparse(url)
is_valid_url = all([r.scheme in ["http", "https"], r.netloc])