mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 04:56:45 +00:00
chore: merge hf utils modules into one (#6921)
* merge hf utils modules * relnotes * lint * Update releasenotes/notes/merge-hf-utils-modules-5c16e04025123568.yaml Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com> --------- Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
This commit is contained in:
parent
b9d7a98359
commit
7d29ddba42
@ -1,32 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
|
||||
def check_valid_model(model_id: str, token: Optional[Secret]) -> None:
|
||||
"""
|
||||
Check if the provided model ID corresponds to a valid model on HuggingFace Hub.
|
||||
Also check if the model is a embedding model.
|
||||
|
||||
:param model_id: A string representing the HuggingFace model ID.
|
||||
:param token: The optional authentication token.
|
||||
:raises ValueError: If the model is not found or is not a embedding model.
|
||||
"""
|
||||
transformers_import.check()
|
||||
|
||||
api = HfApi()
|
||||
try:
|
||||
model_info = api.model_info(model_id, token=token.resolve_value() if token else None)
|
||||
except RepositoryNotFoundError as e:
|
||||
raise ValueError(
|
||||
f"Model {model_id} not found on HuggingFace Hub. Please provide a valid HuggingFace model_id."
|
||||
) from e
|
||||
|
||||
allowed_model = model_info.pipeline_tag in ["sentence-similarity", "feature-extraction"]
|
||||
if not allowed_model:
|
||||
raise ValueError(f"Model {model_id} is not a embedding model. Please provide a embedding model.")
|
||||
@ -4,10 +4,10 @@ from urllib.parse import urlparse
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.components.embedders.hf_utils import check_valid_model
|
||||
from haystack.dataclasses import Document
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
from haystack.utils.hf import check_valid_model, HFModelType
|
||||
from haystack import component, default_to_dict, default_from_dict
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
||||
@ -116,7 +116,7 @@ class HuggingFaceTEIDocumentEmbedder:
|
||||
if not is_valid_url:
|
||||
raise ValueError(f"Invalid TEI endpoint URL provided: {url}")
|
||||
|
||||
check_valid_model(model, token)
|
||||
check_valid_model(model, HFModelType.EMBEDDING, token)
|
||||
|
||||
self.model = model
|
||||
self.url = url
|
||||
|
||||
@ -3,9 +3,9 @@ from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from haystack import component, default_to_dict, default_from_dict
|
||||
from haystack.components.embedders.hf_utils import check_valid_model
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
from haystack.utils.hf import check_valid_model, HFModelType
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
||||
from huggingface_hub import InferenceClient
|
||||
@ -98,7 +98,7 @@ class HuggingFaceTEITextEmbedder:
|
||||
if not is_valid_url:
|
||||
raise ValueError(f"Invalid TEI endpoint URL provided: {url}")
|
||||
|
||||
check_valid_model(model, token)
|
||||
check_valid_model(model, HFModelType.EMBEDDING, token)
|
||||
|
||||
self.model = model
|
||||
self.url = url
|
||||
|
||||
@ -2,8 +2,6 @@ import logging
|
||||
import sys
|
||||
from typing import Any, Dict, List, Literal, Optional, Union, Callable
|
||||
|
||||
from haystack.components.generators.hf_utils import PIPELINE_SUPPORTED_TASKS
|
||||
|
||||
from haystack import component, default_to_dict, default_from_dict
|
||||
from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
|
||||
from haystack.dataclasses import ChatMessage, StreamingChunk
|
||||
@ -16,11 +14,15 @@ logger = logging.getLogger(__name__)
|
||||
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_transformers_import:
|
||||
from huggingface_hub import model_info
|
||||
from transformers import StoppingCriteriaList, pipeline, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from haystack.components.generators.hf_utils import ( # pylint: disable=ungrouped-imports
|
||||
from haystack.utils.hf import ( # pylint: disable=ungrouped-imports
|
||||
StopWordsCriteria,
|
||||
HFTokenStreamingHandler,
|
||||
serialize_hf_model_kwargs,
|
||||
deserialize_hf_model_kwargs,
|
||||
)
|
||||
from haystack.utils.hf import serialize_hf_model_kwargs, deserialize_hf_model_kwargs
|
||||
|
||||
|
||||
PIPELINE_SUPPORTED_TASKS = ["text-generation", "text2text-generation"]
|
||||
|
||||
|
||||
@component
|
||||
|
||||
@ -6,9 +6,9 @@ from urllib.parse import urlparse
|
||||
from haystack import component, default_to_dict, default_from_dict
|
||||
from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
|
||||
from haystack.dataclasses import ChatMessage, StreamingChunk
|
||||
from haystack.components.generators.hf_utils import check_valid_model, check_generation_params
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
from haystack.utils.hf import check_valid_model, HFModelType, check_generation_params
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
||||
from huggingface_hub import InferenceClient
|
||||
@ -120,7 +120,7 @@ class HuggingFaceTGIChatGenerator:
|
||||
if not is_valid_url:
|
||||
raise ValueError(f"Invalid TGI endpoint URL provided: {url}")
|
||||
|
||||
check_valid_model(model, token)
|
||||
check_valid_model(model, HFModelType.GENERATION, token)
|
||||
|
||||
# handle generation kwargs setup
|
||||
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
|
||||
|
||||
@ -1,131 +0,0 @@
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional, Union, Callable
|
||||
|
||||
from haystack.dataclasses import StreamingChunk
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import Secret
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
||||
from huggingface_hub import InferenceClient, HfApi
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
PIPELINE_SUPPORTED_TASKS = ["text-generation", "text2text-generation"]
|
||||
|
||||
|
||||
def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepted_params: Optional[List[str]] = None):
|
||||
"""
|
||||
Check the provided generation parameters for validity.
|
||||
|
||||
:param kwargs: A dictionary containing the generation parameters.
|
||||
:param additional_accepted_params: An optional list of strings representing additional accepted parameters.
|
||||
:raises ValueError: If any unknown text generation parameters are provided.
|
||||
"""
|
||||
transformers_import.check()
|
||||
|
||||
if kwargs:
|
||||
accepted_params = {
|
||||
param
|
||||
for param in inspect.signature(InferenceClient.text_generation).parameters.keys()
|
||||
if param not in ["self", "prompt"]
|
||||
}
|
||||
if additional_accepted_params:
|
||||
accepted_params.update(additional_accepted_params)
|
||||
unknown_params = set(kwargs.keys()) - accepted_params
|
||||
if unknown_params:
|
||||
raise ValueError(
|
||||
f"Unknown text generation parameters: {unknown_params}. The valid parameters are: {accepted_params}."
|
||||
)
|
||||
|
||||
|
||||
def check_valid_model(model_id: str, token: Optional[Secret]) -> None:
|
||||
"""
|
||||
Check if the provided model ID corresponds to a valid model on HuggingFace Hub.
|
||||
Also check if the model is a text generation model.
|
||||
|
||||
:param model_id: A string representing the HuggingFace model ID.
|
||||
:param token: An optional authentication token.
|
||||
:raises ValueError: If the model is not found or is not a text generation model.
|
||||
"""
|
||||
transformers_import.check()
|
||||
|
||||
api = HfApi()
|
||||
try:
|
||||
model_info = api.model_info(model_id, token=token.resolve_value() if token else None)
|
||||
except RepositoryNotFoundError as e:
|
||||
raise ValueError(
|
||||
f"Model {model_id} not found on HuggingFace Hub. Please provide a valid HuggingFace model_id."
|
||||
) from e
|
||||
|
||||
allowed_model = model_info.pipeline_tag in ["text-generation", "text2text-generation"]
|
||||
if not allowed_model:
|
||||
raise ValueError(f"Model {model_id} is not a text generation model. Please provide a text generation model.")
|
||||
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_transformers_import:
|
||||
import torch
|
||||
from transformers import StoppingCriteria, PreTrainedTokenizer, PreTrainedTokenizerFast, TextStreamer
|
||||
|
||||
transformers_import.check()
|
||||
|
||||
class StopWordsCriteria(StoppingCriteria):
|
||||
"""
|
||||
Stops text generation if any one of the stop words is generated.
|
||||
|
||||
Note: When a stop word is encountered, the generation of new text is stopped.
|
||||
However, if the stop word is in the prompt itself, it can stop generating new text
|
||||
prematurely after the first token. This is particularly important for LLMs designed
|
||||
for dialogue generation. For these models, like for example mosaicml/mpt-7b-chat,
|
||||
the output includes both the new text and the original prompt. Therefore, it's important
|
||||
to make sure your prompt has no stop words.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
stop_words: List[str],
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
):
|
||||
super().__init__()
|
||||
# check if tokenizer is a valid tokenizer
|
||||
if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
|
||||
raise ValueError(
|
||||
f"Invalid tokenizer provided for StopWordsCriteria - {tokenizer}. "
|
||||
f"Please provide a valid tokenizer from the HuggingFace Transformers library."
|
||||
)
|
||||
if not tokenizer.pad_token:
|
||||
if tokenizer.eos_token:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
else:
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
encoded_stop_words = tokenizer(stop_words, add_special_tokens=False, padding=True, return_tensors="pt")
|
||||
self.stop_ids = encoded_stop_words.input_ids.to(device)
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
for stop_id in self.stop_ids:
|
||||
found_stop_word = self.is_stop_word_found(input_ids, stop_id)
|
||||
if found_stop_word:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_stop_word_found(self, generated_text_ids: torch.Tensor, stop_id: torch.Tensor) -> bool:
|
||||
generated_text_ids = generated_text_ids[-1]
|
||||
len_generated_text_ids = generated_text_ids.size(0)
|
||||
len_stop_id = stop_id.size(0)
|
||||
result = all(generated_text_ids[len_generated_text_ids - len_stop_id :].eq(stop_id))
|
||||
return result
|
||||
|
||||
class HFTokenStreamingHandler(TextStreamer):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
stream_handler: Callable[[StreamingChunk], None],
|
||||
stop_words: Optional[List[str]] = None,
|
||||
):
|
||||
super().__init__(tokenizer=tokenizer, skip_prompt=True) # type: ignore
|
||||
self.token_handler = stream_handler
|
||||
self.stop_words = stop_words or []
|
||||
|
||||
def on_finalized_text(self, word: str, stream_end: bool = False):
|
||||
word_to_send = word + "\n" if stream_end else word
|
||||
if word_to_send.strip() not in self.stop_words:
|
||||
self.token_handler(StreamingChunk(content=word_to_send))
|
||||
@ -15,7 +15,7 @@ SUPPORTED_TASKS = ["text-generation", "text2text-generation"]
|
||||
with LazyImport(message="Run 'pip install transformers[torch]'") as transformers_import:
|
||||
from huggingface_hub import model_info
|
||||
from transformers import StoppingCriteriaList, pipeline
|
||||
from haystack.components.generators.hf_utils import StopWordsCriteria # pylint: disable=ungrouped-imports
|
||||
from haystack.utils.hf import StopWordsCriteria # pylint: disable=ungrouped-imports
|
||||
|
||||
|
||||
@component
|
||||
|
||||
@ -6,9 +6,9 @@ from urllib.parse import urlparse
|
||||
from haystack import component, default_to_dict, default_from_dict
|
||||
from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
|
||||
from haystack.dataclasses import StreamingChunk
|
||||
from haystack.components.generators.hf_utils import check_generation_params, check_valid_model
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
from haystack.utils.hf import check_valid_model, HFModelType, check_generation_params
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
||||
from huggingface_hub import InferenceClient
|
||||
@ -104,7 +104,7 @@ class HuggingFaceTGIGenerator:
|
||||
if not is_valid_url:
|
||||
raise ValueError(f"Invalid TGI endpoint URL provided: {url}")
|
||||
|
||||
check_valid_model(model, token)
|
||||
check_valid_model(model, HFModelType.GENERATION, token)
|
||||
|
||||
# handle generation kwargs setup
|
||||
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
|
||||
|
||||
@ -1,17 +1,29 @@
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional, List, Union, Callable
|
||||
|
||||
from haystack.dataclasses import StreamingChunk
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils.device import ComponentDevice
|
||||
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_import:
|
||||
import torch
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
from huggingface_hub import InferenceClient, HfApi
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HFModelType(Enum):
|
||||
EMBEDDING = 1
|
||||
GENERATION = 2
|
||||
|
||||
|
||||
def serialize_hf_model_kwargs(kwargs: Dict[str, Any]):
|
||||
"""
|
||||
Recursively serialize HuggingFace specific model keyword arguments
|
||||
@ -78,3 +90,129 @@ def resolve_hf_device_map(device: Optional[ComponentDevice], model_kwargs: Optio
|
||||
model_kwargs["device_map"] = device_map
|
||||
|
||||
return model_kwargs
|
||||
|
||||
|
||||
def check_valid_model(model_id: str, model_type: HFModelType, token: Optional[Secret]) -> None:
|
||||
"""
|
||||
Check if the provided model ID corresponds to a valid model on HuggingFace Hub.
|
||||
Also check if the model is an embedding or generation model.
|
||||
|
||||
:param model_id: A string representing the HuggingFace model ID.
|
||||
:param model_type: the model type, HFModelType.EMBEDDING or HFModelType.GENERATION
|
||||
:param token: The optional authentication token.
|
||||
:raises ValueError: If the model is not found or is not a embedding model.
|
||||
"""
|
||||
transformers_import.check()
|
||||
|
||||
api = HfApi()
|
||||
try:
|
||||
model_info = api.model_info(model_id, token=token.resolve_value() if token else None)
|
||||
except RepositoryNotFoundError as e:
|
||||
raise ValueError(
|
||||
f"Model {model_id} not found on HuggingFace Hub. Please provide a valid HuggingFace model_id."
|
||||
) from e
|
||||
|
||||
if model_type == HFModelType.EMBEDDING:
|
||||
allowed_model = model_info.pipeline_tag in ["sentence-similarity", "feature-extraction"]
|
||||
error_msg = f"Model {model_id} is not a embedding model. Please provide a embedding model."
|
||||
elif model_type == HFModelType.GENERATION:
|
||||
allowed_model = model_info.pipeline_tag in ["text-generation", "text2text-generation"]
|
||||
error_msg = f"Model {model_id} is not a text generation model. Please provide a text generation model."
|
||||
|
||||
if not allowed_model:
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepted_params: Optional[List[str]] = None):
|
||||
"""
|
||||
Check the provided generation parameters for validity.
|
||||
|
||||
:param kwargs: A dictionary containing the generation parameters.
|
||||
:param additional_accepted_params: An optional list of strings representing additional accepted parameters.
|
||||
:raises ValueError: If any unknown text generation parameters are provided.
|
||||
"""
|
||||
transformers_import.check()
|
||||
|
||||
if kwargs:
|
||||
accepted_params = {
|
||||
param
|
||||
for param in inspect.signature(InferenceClient.text_generation).parameters.keys()
|
||||
if param not in ["self", "prompt"]
|
||||
}
|
||||
if additional_accepted_params:
|
||||
accepted_params.update(additional_accepted_params)
|
||||
unknown_params = set(kwargs.keys()) - accepted_params
|
||||
if unknown_params:
|
||||
raise ValueError(
|
||||
f"Unknown text generation parameters: {unknown_params}. The valid parameters are: {accepted_params}."
|
||||
)
|
||||
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_transformers_import:
|
||||
from transformers import StoppingCriteria, PreTrainedTokenizer, PreTrainedTokenizerFast, TextStreamer
|
||||
|
||||
transformers_import.check()
|
||||
torch_import.check()
|
||||
|
||||
class StopWordsCriteria(StoppingCriteria):
|
||||
"""
|
||||
Stops text generation if any one of the stop words is generated.
|
||||
|
||||
Note: When a stop word is encountered, the generation of new text is stopped.
|
||||
However, if the stop word is in the prompt itself, it can stop generating new text
|
||||
prematurely after the first token. This is particularly important for LLMs designed
|
||||
for dialogue generation. For these models, like for example mosaicml/mpt-7b-chat,
|
||||
the output includes both the new text and the original prompt. Therefore, it's important
|
||||
to make sure your prompt has no stop words.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
stop_words: List[str],
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
):
|
||||
super().__init__()
|
||||
# check if tokenizer is a valid tokenizer
|
||||
if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
|
||||
raise ValueError(
|
||||
f"Invalid tokenizer provided for StopWordsCriteria - {tokenizer}. "
|
||||
f"Please provide a valid tokenizer from the HuggingFace Transformers library."
|
||||
)
|
||||
if not tokenizer.pad_token:
|
||||
if tokenizer.eos_token:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
else:
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
encoded_stop_words = tokenizer(stop_words, add_special_tokens=False, padding=True, return_tensors="pt")
|
||||
self.stop_ids = encoded_stop_words.input_ids.to(device)
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
for stop_id in self.stop_ids:
|
||||
found_stop_word = self.is_stop_word_found(input_ids, stop_id)
|
||||
if found_stop_word:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_stop_word_found(self, generated_text_ids: torch.Tensor, stop_id: torch.Tensor) -> bool:
|
||||
generated_text_ids = generated_text_ids[-1]
|
||||
len_generated_text_ids = generated_text_ids.size(0)
|
||||
len_stop_id = stop_id.size(0)
|
||||
result = all(generated_text_ids[len_generated_text_ids - len_stop_id :].eq(stop_id))
|
||||
return result
|
||||
|
||||
class HFTokenStreamingHandler(TextStreamer):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
stream_handler: Callable[[StreamingChunk], None],
|
||||
stop_words: Optional[List[str]] = None,
|
||||
):
|
||||
super().__init__(tokenizer=tokenizer, skip_prompt=True) # type: ignore
|
||||
self.token_handler = stream_handler
|
||||
self.stop_words = stop_words or []
|
||||
|
||||
def on_finalized_text(self, word: str, stream_end: bool = False):
|
||||
word_to_send = word + "\n" if stream_end else word
|
||||
if word_to_send.strip() not in self.stop_words:
|
||||
self.token_handler(StreamingChunk(content=word_to_send))
|
||||
|
||||
@ -0,0 +1,5 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Code from different "hf_utils.py" modules spread across different packages was
|
||||
merged into `haystack.utils.hf`.
|
||||
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from haystack.components.generators.hf_utils import check_generation_params
|
||||
from haystack.utils.hf import check_generation_params
|
||||
|
||||
|
||||
def test_empty_dictionary():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user