chore: merge hf utils modules into one (#6921)

* merge hf utils modules

* relnotes

* lint

* Update releasenotes/notes/merge-hf-utils-modules-5c16e04025123568.yaml

Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>

---------

Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
This commit is contained in:
Massimiliano Pippi 2024-02-06 09:59:25 +01:00 committed by GitHub
parent b9d7a98359
commit 7d29ddba42
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 161 additions and 179 deletions

View File

@ -1,32 +0,0 @@
from typing import Optional
from haystack.lazy_imports import LazyImport
from haystack.utils.auth import Secret
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError
def check_valid_model(model_id: str, token: Optional[Secret]) -> None:
"""
Check if the provided model ID corresponds to a valid model on HuggingFace Hub.
Also check if the model is a embedding model.
:param model_id: A string representing the HuggingFace model ID.
:param token: The optional authentication token.
:raises ValueError: If the model is not found or is not a embedding model.
"""
transformers_import.check()
api = HfApi()
try:
model_info = api.model_info(model_id, token=token.resolve_value() if token else None)
except RepositoryNotFoundError as e:
raise ValueError(
f"Model {model_id} not found on HuggingFace Hub. Please provide a valid HuggingFace model_id."
) from e
allowed_model = model_info.pipeline_tag in ["sentence-similarity", "feature-extraction"]
if not allowed_model:
raise ValueError(f"Model {model_id} is not a embedding model. Please provide a embedding model.")

View File

@ -4,10 +4,10 @@ from urllib.parse import urlparse
from tqdm import tqdm
from haystack.components.embedders.hf_utils import check_valid_model
from haystack.dataclasses import Document
from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.utils.hf import check_valid_model, HFModelType
from haystack import component, default_to_dict, default_from_dict
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
@ -116,7 +116,7 @@ class HuggingFaceTEIDocumentEmbedder:
if not is_valid_url:
raise ValueError(f"Invalid TEI endpoint URL provided: {url}")
check_valid_model(model, token)
check_valid_model(model, HFModelType.EMBEDDING, token)
self.model = model
self.url = url

View File

@ -3,9 +3,9 @@ from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
from haystack import component, default_to_dict, default_from_dict
from haystack.components.embedders.hf_utils import check_valid_model
from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.utils.hf import check_valid_model, HFModelType
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub import InferenceClient
@ -98,7 +98,7 @@ class HuggingFaceTEITextEmbedder:
if not is_valid_url:
raise ValueError(f"Invalid TEI endpoint URL provided: {url}")
check_valid_model(model, token)
check_valid_model(model, HFModelType.EMBEDDING, token)
self.model = model
self.url = url

View File

@ -2,8 +2,6 @@ import logging
import sys
from typing import Any, Dict, List, Literal, Optional, Union, Callable
from haystack.components.generators.hf_utils import PIPELINE_SUPPORTED_TASKS
from haystack import component, default_to_dict, default_from_dict
from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
from haystack.dataclasses import ChatMessage, StreamingChunk
@ -16,11 +14,15 @@ logger = logging.getLogger(__name__)
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_transformers_import:
from huggingface_hub import model_info
from transformers import StoppingCriteriaList, pipeline, PreTrainedTokenizer, PreTrainedTokenizerFast
from haystack.components.generators.hf_utils import ( # pylint: disable=ungrouped-imports
from haystack.utils.hf import ( # pylint: disable=ungrouped-imports
StopWordsCriteria,
HFTokenStreamingHandler,
serialize_hf_model_kwargs,
deserialize_hf_model_kwargs,
)
from haystack.utils.hf import serialize_hf_model_kwargs, deserialize_hf_model_kwargs
PIPELINE_SUPPORTED_TASKS = ["text-generation", "text2text-generation"]
@component

View File

@ -6,9 +6,9 @@ from urllib.parse import urlparse
from haystack import component, default_to_dict, default_from_dict
from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.components.generators.hf_utils import check_valid_model, check_generation_params
from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.utils.hf import check_valid_model, HFModelType, check_generation_params
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub import InferenceClient
@ -120,7 +120,7 @@ class HuggingFaceTGIChatGenerator:
if not is_valid_url:
raise ValueError(f"Invalid TGI endpoint URL provided: {url}")
check_valid_model(model, token)
check_valid_model(model, HFModelType.GENERATION, token)
# handle generation kwargs setup
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}

View File

@ -1,131 +0,0 @@
import inspect
from typing import Any, Dict, List, Optional, Union, Callable
from haystack.dataclasses import StreamingChunk
from haystack.lazy_imports import LazyImport
from haystack.utils import Secret
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub import InferenceClient, HfApi
from huggingface_hub.utils import RepositoryNotFoundError
PIPELINE_SUPPORTED_TASKS = ["text-generation", "text2text-generation"]
def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepted_params: Optional[List[str]] = None):
"""
Check the provided generation parameters for validity.
:param kwargs: A dictionary containing the generation parameters.
:param additional_accepted_params: An optional list of strings representing additional accepted parameters.
:raises ValueError: If any unknown text generation parameters are provided.
"""
transformers_import.check()
if kwargs:
accepted_params = {
param
for param in inspect.signature(InferenceClient.text_generation).parameters.keys()
if param not in ["self", "prompt"]
}
if additional_accepted_params:
accepted_params.update(additional_accepted_params)
unknown_params = set(kwargs.keys()) - accepted_params
if unknown_params:
raise ValueError(
f"Unknown text generation parameters: {unknown_params}. The valid parameters are: {accepted_params}."
)
def check_valid_model(model_id: str, token: Optional[Secret]) -> None:
"""
Check if the provided model ID corresponds to a valid model on HuggingFace Hub.
Also check if the model is a text generation model.
:param model_id: A string representing the HuggingFace model ID.
:param token: An optional authentication token.
:raises ValueError: If the model is not found or is not a text generation model.
"""
transformers_import.check()
api = HfApi()
try:
model_info = api.model_info(model_id, token=token.resolve_value() if token else None)
except RepositoryNotFoundError as e:
raise ValueError(
f"Model {model_id} not found on HuggingFace Hub. Please provide a valid HuggingFace model_id."
) from e
allowed_model = model_info.pipeline_tag in ["text-generation", "text2text-generation"]
if not allowed_model:
raise ValueError(f"Model {model_id} is not a text generation model. Please provide a text generation model.")
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_transformers_import:
import torch
from transformers import StoppingCriteria, PreTrainedTokenizer, PreTrainedTokenizerFast, TextStreamer
transformers_import.check()
class StopWordsCriteria(StoppingCriteria):
"""
Stops text generation if any one of the stop words is generated.
Note: When a stop word is encountered, the generation of new text is stopped.
However, if the stop word is in the prompt itself, it can stop generating new text
prematurely after the first token. This is particularly important for LLMs designed
for dialogue generation. For these models, like for example mosaicml/mpt-7b-chat,
the output includes both the new text and the original prompt. Therefore, it's important
to make sure your prompt has no stop words.
"""
def __init__(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
stop_words: List[str],
device: Union[str, torch.device] = "cpu",
):
super().__init__()
# check if tokenizer is a valid tokenizer
if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
raise ValueError(
f"Invalid tokenizer provided for StopWordsCriteria - {tokenizer}. "
f"Please provide a valid tokenizer from the HuggingFace Transformers library."
)
if not tokenizer.pad_token:
if tokenizer.eos_token:
tokenizer.pad_token = tokenizer.eos_token
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
encoded_stop_words = tokenizer(stop_words, add_special_tokens=False, padding=True, return_tensors="pt")
self.stop_ids = encoded_stop_words.input_ids.to(device)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_id in self.stop_ids:
found_stop_word = self.is_stop_word_found(input_ids, stop_id)
if found_stop_word:
return True
return False
def is_stop_word_found(self, generated_text_ids: torch.Tensor, stop_id: torch.Tensor) -> bool:
generated_text_ids = generated_text_ids[-1]
len_generated_text_ids = generated_text_ids.size(0)
len_stop_id = stop_id.size(0)
result = all(generated_text_ids[len_generated_text_ids - len_stop_id :].eq(stop_id))
return result
class HFTokenStreamingHandler(TextStreamer):
def __init__(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
stream_handler: Callable[[StreamingChunk], None],
stop_words: Optional[List[str]] = None,
):
super().__init__(tokenizer=tokenizer, skip_prompt=True) # type: ignore
self.token_handler = stream_handler
self.stop_words = stop_words or []
def on_finalized_text(self, word: str, stream_end: bool = False):
word_to_send = word + "\n" if stream_end else word
if word_to_send.strip() not in self.stop_words:
self.token_handler(StreamingChunk(content=word_to_send))

View File

@ -15,7 +15,7 @@ SUPPORTED_TASKS = ["text-generation", "text2text-generation"]
with LazyImport(message="Run 'pip install transformers[torch]'") as transformers_import:
from huggingface_hub import model_info
from transformers import StoppingCriteriaList, pipeline
from haystack.components.generators.hf_utils import StopWordsCriteria # pylint: disable=ungrouped-imports
from haystack.utils.hf import StopWordsCriteria # pylint: disable=ungrouped-imports
@component

View File

@ -6,9 +6,9 @@ from urllib.parse import urlparse
from haystack import component, default_to_dict, default_from_dict
from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
from haystack.dataclasses import StreamingChunk
from haystack.components.generators.hf_utils import check_generation_params, check_valid_model
from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.utils.hf import check_valid_model, HFModelType, check_generation_params
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub import InferenceClient
@ -104,7 +104,7 @@ class HuggingFaceTGIGenerator:
if not is_valid_url:
raise ValueError(f"Invalid TGI endpoint URL provided: {url}")
check_valid_model(model, token)
check_valid_model(model, HFModelType.GENERATION, token)
# handle generation kwargs setup
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}

View File

@ -1,17 +1,29 @@
import copy
import inspect
import logging
from typing import Any, Dict, Optional
from enum import Enum
from typing import Any, Dict, Optional, List, Union, Callable
from haystack.dataclasses import StreamingChunk
from haystack.lazy_imports import LazyImport
from haystack.utils.device import ComponentDevice
from haystack.utils.auth import Secret
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_import:
import torch
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub.utils import RepositoryNotFoundError
from huggingface_hub import InferenceClient, HfApi
logger = logging.getLogger(__name__)
class HFModelType(Enum):
EMBEDDING = 1
GENERATION = 2
def serialize_hf_model_kwargs(kwargs: Dict[str, Any]):
"""
Recursively serialize HuggingFace specific model keyword arguments
@ -78,3 +90,129 @@ def resolve_hf_device_map(device: Optional[ComponentDevice], model_kwargs: Optio
model_kwargs["device_map"] = device_map
return model_kwargs
def check_valid_model(model_id: str, model_type: HFModelType, token: Optional[Secret]) -> None:
"""
Check if the provided model ID corresponds to a valid model on HuggingFace Hub.
Also check if the model is an embedding or generation model.
:param model_id: A string representing the HuggingFace model ID.
:param model_type: the model type, HFModelType.EMBEDDING or HFModelType.GENERATION
:param token: The optional authentication token.
:raises ValueError: If the model is not found or is not a embedding model.
"""
transformers_import.check()
api = HfApi()
try:
model_info = api.model_info(model_id, token=token.resolve_value() if token else None)
except RepositoryNotFoundError as e:
raise ValueError(
f"Model {model_id} not found on HuggingFace Hub. Please provide a valid HuggingFace model_id."
) from e
if model_type == HFModelType.EMBEDDING:
allowed_model = model_info.pipeline_tag in ["sentence-similarity", "feature-extraction"]
error_msg = f"Model {model_id} is not a embedding model. Please provide a embedding model."
elif model_type == HFModelType.GENERATION:
allowed_model = model_info.pipeline_tag in ["text-generation", "text2text-generation"]
error_msg = f"Model {model_id} is not a text generation model. Please provide a text generation model."
if not allowed_model:
raise ValueError(error_msg)
def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepted_params: Optional[List[str]] = None):
"""
Check the provided generation parameters for validity.
:param kwargs: A dictionary containing the generation parameters.
:param additional_accepted_params: An optional list of strings representing additional accepted parameters.
:raises ValueError: If any unknown text generation parameters are provided.
"""
transformers_import.check()
if kwargs:
accepted_params = {
param
for param in inspect.signature(InferenceClient.text_generation).parameters.keys()
if param not in ["self", "prompt"]
}
if additional_accepted_params:
accepted_params.update(additional_accepted_params)
unknown_params = set(kwargs.keys()) - accepted_params
if unknown_params:
raise ValueError(
f"Unknown text generation parameters: {unknown_params}. The valid parameters are: {accepted_params}."
)
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_transformers_import:
from transformers import StoppingCriteria, PreTrainedTokenizer, PreTrainedTokenizerFast, TextStreamer
transformers_import.check()
torch_import.check()
class StopWordsCriteria(StoppingCriteria):
"""
Stops text generation if any one of the stop words is generated.
Note: When a stop word is encountered, the generation of new text is stopped.
However, if the stop word is in the prompt itself, it can stop generating new text
prematurely after the first token. This is particularly important for LLMs designed
for dialogue generation. For these models, like for example mosaicml/mpt-7b-chat,
the output includes both the new text and the original prompt. Therefore, it's important
to make sure your prompt has no stop words.
"""
def __init__(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
stop_words: List[str],
device: Union[str, torch.device] = "cpu",
):
super().__init__()
# check if tokenizer is a valid tokenizer
if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
raise ValueError(
f"Invalid tokenizer provided for StopWordsCriteria - {tokenizer}. "
f"Please provide a valid tokenizer from the HuggingFace Transformers library."
)
if not tokenizer.pad_token:
if tokenizer.eos_token:
tokenizer.pad_token = tokenizer.eos_token
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
encoded_stop_words = tokenizer(stop_words, add_special_tokens=False, padding=True, return_tensors="pt")
self.stop_ids = encoded_stop_words.input_ids.to(device)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_id in self.stop_ids:
found_stop_word = self.is_stop_word_found(input_ids, stop_id)
if found_stop_word:
return True
return False
def is_stop_word_found(self, generated_text_ids: torch.Tensor, stop_id: torch.Tensor) -> bool:
generated_text_ids = generated_text_ids[-1]
len_generated_text_ids = generated_text_ids.size(0)
len_stop_id = stop_id.size(0)
result = all(generated_text_ids[len_generated_text_ids - len_stop_id :].eq(stop_id))
return result
class HFTokenStreamingHandler(TextStreamer):
def __init__(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
stream_handler: Callable[[StreamingChunk], None],
stop_words: Optional[List[str]] = None,
):
super().__init__(tokenizer=tokenizer, skip_prompt=True) # type: ignore
self.token_handler = stream_handler
self.stop_words = stop_words or []
def on_finalized_text(self, word: str, stream_end: bool = False):
word_to_send = word + "\n" if stream_end else word
if word_to_send.strip() not in self.stop_words:
self.token_handler(StreamingChunk(content=word_to_send))

View File

@ -0,0 +1,5 @@
---
enhancements:
- |
Code from different "hf_utils.py" modules spread across different packages was
merged into `haystack.utils.hf`.

View File

@ -1,6 +1,6 @@
import pytest
from haystack.components.generators.hf_utils import check_generation_params
from haystack.utils.hf import check_generation_params
def test_empty_dictionary():