mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-28 17:45:32 +00:00
reorganize imports in hf utils (#7414)
This commit is contained in:
parent
bfd0d3eacd
commit
41aa6f2b58
@ -14,7 +14,7 @@ from haystack.utils.device import ComponentDevice
|
|||||||
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_import:
|
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_import:
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
with LazyImport(message="Run 'pip install huggingface_hub'") as huggingface_hub_import:
|
||||||
from huggingface_hub import HfApi, InferenceClient, model_info
|
from huggingface_hub import HfApi, InferenceClient, model_info
|
||||||
from huggingface_hub.utils import RepositoryNotFoundError
|
from huggingface_hub.utils import RepositoryNotFoundError
|
||||||
|
|
||||||
@ -120,7 +120,7 @@ def resolve_hf_pipeline_kwargs(
|
|||||||
:param token: The token to use as HTTP bearer authorization for remote files.
|
:param token: The token to use as HTTP bearer authorization for remote files.
|
||||||
If the token is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
|
If the token is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
|
||||||
"""
|
"""
|
||||||
transformers_import.check()
|
huggingface_hub_import.check()
|
||||||
|
|
||||||
token = token.resolve_value() if token else None
|
token = token.resolve_value() if token else None
|
||||||
# check if the huggingface_pipeline_kwargs contain the essential parameters
|
# check if the huggingface_pipeline_kwargs contain the essential parameters
|
||||||
@ -173,7 +173,7 @@ def check_valid_model(model_id: str, model_type: HFModelType, token: Optional[Se
|
|||||||
:param token: The optional authentication token.
|
:param token: The optional authentication token.
|
||||||
:raises ValueError: If the model is not found or is not a embedding model.
|
:raises ValueError: If the model is not found or is not a embedding model.
|
||||||
"""
|
"""
|
||||||
transformers_import.check()
|
huggingface_hub_import.check()
|
||||||
|
|
||||||
api = HfApi()
|
api = HfApi()
|
||||||
try:
|
try:
|
||||||
@ -202,7 +202,7 @@ def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepte
|
|||||||
:param additional_accepted_params: An optional list of strings representing additional accepted parameters.
|
:param additional_accepted_params: An optional list of strings representing additional accepted parameters.
|
||||||
:raises ValueError: If any unknown text generation parameters are provided.
|
:raises ValueError: If any unknown text generation parameters are provided.
|
||||||
"""
|
"""
|
||||||
transformers_import.check()
|
huggingface_hub_import.check()
|
||||||
|
|
||||||
if kwargs:
|
if kwargs:
|
||||||
accepted_params = {
|
accepted_params = {
|
||||||
@ -219,10 +219,11 @@ def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepte
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_transformers_import:
|
with LazyImport(message="Run 'pip install transformers[torch]'") as transformers_import:
|
||||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, StoppingCriteria, TextStreamer
|
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, StoppingCriteria, TextStreamer
|
||||||
|
|
||||||
torch_and_transformers_import.check()
|
torch_import.check()
|
||||||
|
transformers_import.check()
|
||||||
|
|
||||||
class StopWordsCriteria(StoppingCriteria):
|
class StopWordsCriteria(StoppingCriteria):
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user