mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-29 16:08:38 +00:00
fix: lazy import transformers in TGI Generators (#6252)
* lazy import transformers in tgi * fix pylint
This commit is contained in:
parent
5497ca2a45
commit
e2881e2ad3
@ -3,14 +3,16 @@ from dataclasses import asdict
|
||||
from typing import Any, Dict, List, Optional, Iterable, Callable
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from huggingface_hub import InferenceClient
|
||||
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, TextGenerationResponse, Token
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from haystack.preview import component, default_to_dict, default_from_dict
|
||||
from haystack.preview.components.generators.hf_utils import check_valid_model, check_generation_params
|
||||
from haystack.preview.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
|
||||
from haystack.preview.dataclasses import ChatMessage, StreamingChunk
|
||||
from haystack.preview.components.generators.hf_utils import check_valid_model, check_generation_params
|
||||
from haystack.lazy_imports import LazyImport
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
||||
from huggingface_hub import InferenceClient
|
||||
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, TextGenerationResponse, Token
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -108,6 +110,8 @@ class HuggingFaceTGIChatGenerator:
|
||||
:param stop_words: An optional list of strings representing the stop words.
|
||||
:param streaming_callback: An optional callable for handling streaming responses.
|
||||
"""
|
||||
transformers_import.check()
|
||||
|
||||
if url:
|
||||
r = urlparse(url)
|
||||
is_valid_url = all([r.scheme in ["http", "https"], r.netloc])
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from huggingface_hub import InferenceClient, HfApi
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
from haystack.lazy_imports import LazyImport
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
||||
from huggingface_hub import InferenceClient, HfApi
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
|
||||
def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepted_params: Optional[List[str]] = None):
|
||||
@ -13,6 +16,8 @@ def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepte
|
||||
:param additional_accepted_params: An optional list of strings representing additional accepted parameters.
|
||||
:raises ValueError: If any unknown text generation parameters are provided.
|
||||
"""
|
||||
transformers_import.check()
|
||||
|
||||
if kwargs:
|
||||
accepted_params = {
|
||||
param
|
||||
@ -37,6 +42,8 @@ def check_valid_model(model_id: str, token: Optional[str]) -> None:
|
||||
:param token: An optional string representing the authentication token.
|
||||
:raises ValueError: If the model is not found or is not a text generation model.
|
||||
"""
|
||||
transformers_import.check()
|
||||
|
||||
api = HfApi()
|
||||
try:
|
||||
model_info = api.model_info(model_id, token=token)
|
||||
|
||||
@ -3,14 +3,17 @@ from dataclasses import asdict
|
||||
from typing import Any, Dict, List, Optional, Iterable, Callable
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from huggingface_hub import InferenceClient
|
||||
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, TextGenerationResponse, Token
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from haystack.preview import component, default_to_dict, default_from_dict
|
||||
from haystack.preview.components.generators.hf_utils import check_generation_params, check_valid_model
|
||||
from haystack.preview.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
|
||||
from haystack.preview.dataclasses import StreamingChunk
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.preview.components.generators.hf_utils import check_generation_params, check_valid_model
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
||||
from huggingface_hub import InferenceClient
|
||||
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, TextGenerationResponse, Token
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -90,6 +93,8 @@ class HuggingFaceTGIGenerator:
|
||||
:param stop_words: An optional list of strings representing the stop words.
|
||||
:param streaming_callback: An optional callable for handling streaming responses.
|
||||
"""
|
||||
transformers_import.check()
|
||||
|
||||
if url:
|
||||
r = urlparse(url)
|
||||
is_valid_url = all([r.scheme in ["http", "https"], r.netloc])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user