mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-07 12:37:27 +00:00
feat: enable loading tokenizer for models that are not supported by the transformers library (#5314)
* add tokenizer load * change import order * move imports * refactor code * import lib * remove pretrainedmodel * fix linting * update patch * fix order * remove tokenizer class * use tokenizer class * no copy * add case for model is an instance * fix optional * add ut * set default to None * change models * Update haystack/nodes/prompt/invocation_layer/hugging_face.py Co-authored-by: bogdankostic <bogdankostic@web.de> * Update haystack/nodes/prompt/invocation_layer/hugging_face.py Co-authored-by: bogdankostic <bogdankostic@web.de> * add unit tests * add unit tests * remove lib * formatting * formatting * formatting * add release note * Update releasenotes/notes/load-tokenizer-if-not-load-by-transformers-5841cdc9ff69bcc2.yaml Co-authored-by: bogdankostic <bogdankostic@web.de> --------- Co-authored-by: bogdankostic <bogdankostic@web.de>
This commit is contained in:
parent
97e4522a83
commit
f7fd5eeb4f
@ -7,7 +7,6 @@ from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamin
|
||||
from haystack.nodes.prompt.invocation_layer.utils import get_task
|
||||
from haystack.lazy_imports import LazyImport
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -17,10 +16,14 @@ with LazyImport(message="Run 'pip install farm-haystack[inference]'") as torch_a
|
||||
pipeline,
|
||||
StoppingCriteriaList,
|
||||
StoppingCriteria,
|
||||
GenerationConfig,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
GenerationConfig,
|
||||
PreTrainedModel,
|
||||
Pipeline,
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
TOKENIZER_MAPPING,
|
||||
)
|
||||
from haystack.modeling.utils import initialize_device_settings # pylint: disable=ungrouped-imports
|
||||
from haystack.nodes.prompt.invocation_layer.handlers import HFTokenStreamingHandler
|
||||
@ -167,21 +170,35 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
||||
torch_dtype = self._extract_torch_dtype(**kwargs)
|
||||
# and the model (prefer model instance over model_name_or_path str identifier)
|
||||
model = kwargs.get("model") or kwargs.get("model_name_or_path")
|
||||
trust_remote_code = kwargs.get("trust_remote_code", False)
|
||||
hub_kwargs = {
|
||||
"revision": kwargs.get("revision", None),
|
||||
"use_auth_token": kwargs.get("use_auth_token", None),
|
||||
"trust_remote_code": trust_remote_code,
|
||||
}
|
||||
model_kwargs = kwargs.get("model_kwargs", {})
|
||||
tokenizer = kwargs.get("tokenizer", None)
|
||||
|
||||
if tokenizer is None and trust_remote_code:
|
||||
# For models not yet supported by the transformers library, we must set `trust_remote_code=True` within
|
||||
# the underlying pipeline to ensure the model's successful loading. However, this does not guarantee the
|
||||
# tokenizer will be loaded alongside. Therefore, we need to add additional logic here to manually load the
|
||||
# tokenizer and pass it to transformers' pipleine.
|
||||
# Otherwise, calling `self.pipe.tokenizer.model_max_length` will return an error.
|
||||
tokenizer = self._prepare_tokenizer(model, hub_kwargs, model_kwargs)
|
||||
|
||||
pipeline_kwargs = {
|
||||
"task": kwargs.get("task", None),
|
||||
"model": model,
|
||||
"config": kwargs.get("config", None),
|
||||
"tokenizer": kwargs.get("tokenizer", None),
|
||||
"tokenizer": tokenizer,
|
||||
"feature_extractor": kwargs.get("feature_extractor", None),
|
||||
"revision": kwargs.get("revision", None),
|
||||
"use_auth_token": kwargs.get("use_auth_token", None),
|
||||
"device_map": device_map,
|
||||
"device": device,
|
||||
"torch_dtype": torch_dtype,
|
||||
"trust_remote_code": kwargs.get("trust_remote_code", False),
|
||||
"model_kwargs": kwargs.get("model_kwargs", {}),
|
||||
"model_kwargs": model_kwargs,
|
||||
"pipeline_class": kwargs.get("pipeline_class", None),
|
||||
**hub_kwargs,
|
||||
}
|
||||
return pipeline_kwargs
|
||||
|
||||
@ -313,6 +330,44 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
||||
raise ValueError(f"Invalid torch_dtype value {torch_dtype}")
|
||||
return torch_dtype_resolved
|
||||
|
||||
def _prepare_tokenizer(
|
||||
self, model: Union[str, "PreTrainedModel"], hub_kwargs: Dict, model_kwargs: Optional[Dict] = None
|
||||
) -> Union["PreTrainedTokenizer", "PreTrainedTokenizerFast", None]:
|
||||
"""
|
||||
this method prepares the tokenizer before passing it to transformers' pipeline, so that the instantiated pipeline
|
||||
object has a working tokenizer.
|
||||
|
||||
It basically check whether the pipeline method in the transformers library will load the tokenizer.
|
||||
- If yes, None will be returned, because in this case, the pipeline is intelligent enough to load the tokenizer by itself
|
||||
- If not, we will load the tokenizer and an tokenizer instance is returned
|
||||
|
||||
:param model: the name or path of the underlying model
|
||||
:hub_kwargs: keyword argument related to hugging face hub, including revision, trust_remote_code and use_auth_token
|
||||
:model_kwargs: keyword arguments passed to the underlying model
|
||||
"""
|
||||
|
||||
if isinstance(model, str):
|
||||
model_config = AutoConfig.from_pretrained(model, **hub_kwargs, **model_kwargs)
|
||||
else:
|
||||
model_config = model.config
|
||||
model = model_config._name_or_path
|
||||
# the will_load_tokenizer logic corresponds to this line in transformers library
|
||||
# https://github.com/huggingface/transformers/blob/05cda5df3405e6a2ee4ecf8f7e1b2300ebda472e/src/transformers/pipelines/__init__.py#L805
|
||||
will_load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
|
||||
if not will_load_tokenizer:
|
||||
logger.warning(
|
||||
"The transformers library doesn't know which tokenizer class should be "
|
||||
"loaded for the model %s. Therefore, the tokenizer will be loaded in Haystack's "
|
||||
"invocation layer and then passed to the underlying pipeline. Alternatively, you could "
|
||||
"pass `tokenizer_class` to `model_kwargs` to workaround this, if your tokenizer is supported "
|
||||
"by the transformers library.",
|
||||
model,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model, **hub_kwargs, **model_kwargs)
|
||||
else:
|
||||
tokenizer = None
|
||||
return tokenizer
|
||||
|
||||
@classmethod
|
||||
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
|
||||
task_name: Optional[str] = kwargs.get("task_name", None)
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Allow loading Tokenizers for prompt models not natively supported by transformers by setting `trust_remote_code` to `True`.
|
||||
@ -1,5 +1,6 @@
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock, patch, Mock
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -68,14 +69,14 @@ def test_constructor_with_model_name_only(mock_pipeline, mock_get_task):
|
||||
"config",
|
||||
"tokenizer",
|
||||
"feature_extractor",
|
||||
"revision",
|
||||
"use_auth_token",
|
||||
"device_map",
|
||||
"device",
|
||||
"torch_dtype",
|
||||
"trust_remote_code",
|
||||
"model_kwargs",
|
||||
"pipeline_class",
|
||||
"revision",
|
||||
"use_auth_token",
|
||||
"trust_remote_code",
|
||||
]
|
||||
|
||||
|
||||
@ -552,3 +553,65 @@ def test_ensure_token_limit_negative_mock(mock_pipeline, mock_get_task, mock_aut
|
||||
result = layer._ensure_token_limit("I am a tokenized prompt of length eight")
|
||||
|
||||
assert result == correct_result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.invocation_layer.hugging_face.AutoConfig.from_pretrained")
|
||||
@patch("haystack.nodes.prompt.invocation_layer.hugging_face.AutoTokenizer.from_pretrained")
|
||||
def test_tokenizer_loading_unsupported_model(mock_tokenizer, mock_config, mock_pipeline, mock_get_task, caplog):
|
||||
"""
|
||||
Test loading of tokenizers for models that are not natively supported by the transformers library.
|
||||
"""
|
||||
mock_config.return_value = Mock(tokenizer_class=None)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
invocation_layer = HFLocalInvocationLayer("unsupported_model", trust_remote_code=True)
|
||||
assert (
|
||||
"The transformers library doesn't know which tokenizer class should be "
|
||||
"loaded for the model unsupported_model. Therefore, the tokenizer will be loaded in Haystack's "
|
||||
"invocation layer and then passed to the underlying pipeline. Alternatively, you could "
|
||||
"pass `tokenizer_class` to `model_kwargs` to workaround this, if your tokenizer is supported "
|
||||
"by the transformers library."
|
||||
) in caplog.text
|
||||
assert mock_tokenizer.called
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.invocation_layer.hugging_face.AutoTokenizer.from_pretrained")
|
||||
def test_tokenizer_loading_unsupported_model_with_initialized_model(
|
||||
mock_tokenizer, mock_pipeline, mock_get_task, caplog
|
||||
):
|
||||
"""
|
||||
Test loading of tokenizers for models that are not natively supported by the transformers library. In this case,
|
||||
the model is already initialized and the model config is loaded from the model.
|
||||
"""
|
||||
model = Mock()
|
||||
model.config = Mock(tokenizer_class=None, _name_or_path="unsupported_model")
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
invocation_layer = HFLocalInvocationLayer(model_name_or_path="unsupported", model=model, trust_remote_code=True)
|
||||
assert (
|
||||
"The transformers library doesn't know which tokenizer class should be "
|
||||
"loaded for the model unsupported_model. Therefore, the tokenizer will be loaded in Haystack's "
|
||||
"invocation layer and then passed to the underlying pipeline. Alternatively, you could "
|
||||
"pass `tokenizer_class` to `model_kwargs` to workaround this, if your tokenizer is supported "
|
||||
"by the transformers library."
|
||||
) in caplog.text
|
||||
assert mock_tokenizer.called
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.prompt.invocation_layer.hugging_face.AutoConfig.from_pretrained")
|
||||
@patch("haystack.nodes.prompt.invocation_layer.hugging_face.AutoTokenizer.from_pretrained")
|
||||
def test_tokenizer_loading_unsupported_model_with_tokenizer_class_in_config(
|
||||
mock_tokenizer, mock_config, mock_pipeline, mock_get_task, caplog
|
||||
):
|
||||
"""
|
||||
Test that tokenizer is not loaded if tokenizer_class is set in model config.
|
||||
"""
|
||||
mock_config.return_value = Mock(tokenizer_class="Some-Supported-Tokenizer")
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
invocation_layer = HFLocalInvocationLayer(model_name_or_path="unsupported_model", trust_remote_code=True)
|
||||
assert not mock_tokenizer.called
|
||||
assert not caplog.text
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user