reorganize imports in hf utils (#7414)

This commit is contained in:
Stefano Fiorucci 2024-03-25 11:41:16 +01:00 committed by GitHub
parent bfd0d3eacd
commit 41aa6f2b58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -14,7 +14,7 @@ from haystack.utils.device import ComponentDevice
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_import:
import torch
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
with LazyImport(message="Run 'pip install huggingface_hub'") as huggingface_hub_import:
from huggingface_hub import HfApi, InferenceClient, model_info
from huggingface_hub.utils import RepositoryNotFoundError
@ -120,7 +120,7 @@ def resolve_hf_pipeline_kwargs(
:param token: The token to use as HTTP bearer authorization for remote files.
If the token is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
"""
transformers_import.check()
huggingface_hub_import.check()
token = token.resolve_value() if token else None
# check if the huggingface_pipeline_kwargs contain the essential parameters
@ -173,7 +173,7 @@ def check_valid_model(model_id: str, model_type: HFModelType, token: Optional[Se
:param token: The optional authentication token.
:raises ValueError: If the model is not found or is not a embedding model.
"""
transformers_import.check()
huggingface_hub_import.check()
api = HfApi()
try:
@ -202,7 +202,7 @@ def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepte
:param additional_accepted_params: An optional list of strings representing additional accepted parameters.
:raises ValueError: If any unknown text generation parameters are provided.
"""
transformers_import.check()
huggingface_hub_import.check()
if kwargs:
accepted_params = {
@ -219,10 +219,11 @@ def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepte
)
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_transformers_import:
with LazyImport(message="Run 'pip install transformers[torch]'") as transformers_import:
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, StoppingCriteria, TextStreamer
torch_and_transformers_import.check()
torch_import.check()
transformers_import.check()
class StopWordsCriteria(StoppingCriteria):
"""