feat: enable loading tokenizer for models that are not supported by the transformers library (#5314)

* add tokenizer load

* change import order

* move imports

* refactor code

* import lib

* remove pretrainedmodel

* fix linting

* update patch

* fix order

* remove tokenizer class

* use tokenizer class

* no copy

* add case for model is an instance

* fix optional

* add ut

* set default to None

* change models

* Update haystack/nodes/prompt/invocation_layer/hugging_face.py

Co-authored-by: bogdankostic <bogdankostic@web.de>

* Update haystack/nodes/prompt/invocation_layer/hugging_face.py

Co-authored-by: bogdankostic <bogdankostic@web.de>

* add unit tests

* add unit tests

* remove lib

* formatting

* formatting

* formatting

* add release note

* Update releasenotes/notes/load-tokenizer-if-not-load-by-transformers-5841cdc9ff69bcc2.yaml

Co-authored-by: bogdankostic <bogdankostic@web.de>

---------

Co-authored-by: bogdankostic <bogdankostic@web.de>
This commit is contained in:
Fanli Lin 2023-08-02 17:42:23 +08:00 committed by GitHub
parent 97e4522a83
commit f7fd5eeb4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 132 additions and 10 deletions

View File

@ -7,7 +7,6 @@ from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamin
from haystack.nodes.prompt.invocation_layer.utils import get_task
from haystack.lazy_imports import LazyImport
logger = logging.getLogger(__name__)
@ -17,10 +16,14 @@ with LazyImport(message="Run 'pip install farm-haystack[inference]'") as torch_a
pipeline,
StoppingCriteriaList,
StoppingCriteria,
GenerationConfig,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
GenerationConfig,
PreTrainedModel,
Pipeline,
AutoTokenizer,
AutoConfig,
TOKENIZER_MAPPING,
)
from haystack.modeling.utils import initialize_device_settings # pylint: disable=ungrouped-imports
from haystack.nodes.prompt.invocation_layer.handlers import HFTokenStreamingHandler
@ -167,21 +170,35 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
torch_dtype = self._extract_torch_dtype(**kwargs)
# and the model (prefer model instance over model_name_or_path str identifier)
model = kwargs.get("model") or kwargs.get("model_name_or_path")
trust_remote_code = kwargs.get("trust_remote_code", False)
hub_kwargs = {
"revision": kwargs.get("revision", None),
"use_auth_token": kwargs.get("use_auth_token", None),
"trust_remote_code": trust_remote_code,
}
model_kwargs = kwargs.get("model_kwargs", {})
tokenizer = kwargs.get("tokenizer", None)
if tokenizer is None and trust_remote_code:
# For models not yet supported by the transformers library, we must set `trust_remote_code=True` within
# the underlying pipeline to ensure the model's successful loading. However, this does not guarantee the
# tokenizer will be loaded alongside. Therefore, we need to add additional logic here to manually load the
# tokenizer and pass it to transformers' pipleine.
# Otherwise, calling `self.pipe.tokenizer.model_max_length` will return an error.
tokenizer = self._prepare_tokenizer(model, hub_kwargs, model_kwargs)
pipeline_kwargs = {
"task": kwargs.get("task", None),
"model": model,
"config": kwargs.get("config", None),
"tokenizer": kwargs.get("tokenizer", None),
"tokenizer": tokenizer,
"feature_extractor": kwargs.get("feature_extractor", None),
"revision": kwargs.get("revision", None),
"use_auth_token": kwargs.get("use_auth_token", None),
"device_map": device_map,
"device": device,
"torch_dtype": torch_dtype,
"trust_remote_code": kwargs.get("trust_remote_code", False),
"model_kwargs": kwargs.get("model_kwargs", {}),
"model_kwargs": model_kwargs,
"pipeline_class": kwargs.get("pipeline_class", None),
**hub_kwargs,
}
return pipeline_kwargs
@ -313,6 +330,44 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
raise ValueError(f"Invalid torch_dtype value {torch_dtype}")
return torch_dtype_resolved
def _prepare_tokenizer(
self, model: Union[str, "PreTrainedModel"], hub_kwargs: Dict, model_kwargs: Optional[Dict] = None
) -> Union["PreTrainedTokenizer", "PreTrainedTokenizerFast", None]:
"""
this method prepares the tokenizer before passing it to transformers' pipeline, so that the instantiated pipeline
object has a working tokenizer.
It basically check whether the pipeline method in the transformers library will load the tokenizer.
- If yes, None will be returned, because in this case, the pipeline is intelligent enough to load the tokenizer by itself
- If not, we will load the tokenizer and an tokenizer instance is returned
:param model: the name or path of the underlying model
:hub_kwargs: keyword argument related to hugging face hub, including revision, trust_remote_code and use_auth_token
:model_kwargs: keyword arguments passed to the underlying model
"""
if isinstance(model, str):
model_config = AutoConfig.from_pretrained(model, **hub_kwargs, **model_kwargs)
else:
model_config = model.config
model = model_config._name_or_path
# the will_load_tokenizer logic corresponds to this line in transformers library
# https://github.com/huggingface/transformers/blob/05cda5df3405e6a2ee4ecf8f7e1b2300ebda472e/src/transformers/pipelines/__init__.py#L805
will_load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
if not will_load_tokenizer:
logger.warning(
"The transformers library doesn't know which tokenizer class should be "
"loaded for the model %s. Therefore, the tokenizer will be loaded in Haystack's "
"invocation layer and then passed to the underlying pipeline. Alternatively, you could "
"pass `tokenizer_class` to `model_kwargs` to workaround this, if your tokenizer is supported "
"by the transformers library.",
model,
)
tokenizer = AutoTokenizer.from_pretrained(model, **hub_kwargs, **model_kwargs)
else:
tokenizer = None
return tokenizer
@classmethod
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
task_name: Optional[str] = kwargs.get("task_name", None)

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
Allow loading Tokenizers for prompt models not natively supported by transformers by setting `trust_remote_code` to `True`.

View File

@ -1,5 +1,6 @@
from typing import List
from unittest.mock import MagicMock, patch, Mock
import logging
import pytest
import torch
@ -68,14 +69,14 @@ def test_constructor_with_model_name_only(mock_pipeline, mock_get_task):
"config",
"tokenizer",
"feature_extractor",
"revision",
"use_auth_token",
"device_map",
"device",
"torch_dtype",
"trust_remote_code",
"model_kwargs",
"pipeline_class",
"revision",
"use_auth_token",
"trust_remote_code",
]
@ -552,3 +553,65 @@ def test_ensure_token_limit_negative_mock(mock_pipeline, mock_get_task, mock_aut
result = layer._ensure_token_limit("I am a tokenized prompt of length eight")
assert result == correct_result
@pytest.mark.unit
@patch("haystack.nodes.prompt.invocation_layer.hugging_face.AutoConfig.from_pretrained")
@patch("haystack.nodes.prompt.invocation_layer.hugging_face.AutoTokenizer.from_pretrained")
def test_tokenizer_loading_unsupported_model(mock_tokenizer, mock_config, mock_pipeline, mock_get_task, caplog):
"""
Test loading of tokenizers for models that are not natively supported by the transformers library.
"""
mock_config.return_value = Mock(tokenizer_class=None)
with caplog.at_level(logging.WARNING):
invocation_layer = HFLocalInvocationLayer("unsupported_model", trust_remote_code=True)
assert (
"The transformers library doesn't know which tokenizer class should be "
"loaded for the model unsupported_model. Therefore, the tokenizer will be loaded in Haystack's "
"invocation layer and then passed to the underlying pipeline. Alternatively, you could "
"pass `tokenizer_class` to `model_kwargs` to workaround this, if your tokenizer is supported "
"by the transformers library."
) in caplog.text
assert mock_tokenizer.called
@pytest.mark.unit
@patch("haystack.nodes.prompt.invocation_layer.hugging_face.AutoTokenizer.from_pretrained")
def test_tokenizer_loading_unsupported_model_with_initialized_model(
mock_tokenizer, mock_pipeline, mock_get_task, caplog
):
"""
Test loading of tokenizers for models that are not natively supported by the transformers library. In this case,
the model is already initialized and the model config is loaded from the model.
"""
model = Mock()
model.config = Mock(tokenizer_class=None, _name_or_path="unsupported_model")
with caplog.at_level(logging.WARNING):
invocation_layer = HFLocalInvocationLayer(model_name_or_path="unsupported", model=model, trust_remote_code=True)
assert (
"The transformers library doesn't know which tokenizer class should be "
"loaded for the model unsupported_model. Therefore, the tokenizer will be loaded in Haystack's "
"invocation layer and then passed to the underlying pipeline. Alternatively, you could "
"pass `tokenizer_class` to `model_kwargs` to workaround this, if your tokenizer is supported "
"by the transformers library."
) in caplog.text
assert mock_tokenizer.called
@pytest.mark.unit
@patch("haystack.nodes.prompt.invocation_layer.hugging_face.AutoConfig.from_pretrained")
@patch("haystack.nodes.prompt.invocation_layer.hugging_face.AutoTokenizer.from_pretrained")
def test_tokenizer_loading_unsupported_model_with_tokenizer_class_in_config(
mock_tokenizer, mock_config, mock_pipeline, mock_get_task, caplog
):
"""
Test that tokenizer is not loaded if tokenizer_class is set in model config.
"""
mock_config.return_value = Mock(tokenizer_class="Some-Supported-Tokenizer")
with caplog.at_level(logging.WARNING):
invocation_layer = HFLocalInvocationLayer(model_name_or_path="unsupported_model", trust_remote_code=True)
assert not mock_tokenizer.called
assert not caplog.text