diff --git a/haystack/nodes/prompt/invocation_layer/hugging_face.py b/haystack/nodes/prompt/invocation_layer/hugging_face.py index a0a48d902..b96af10ad 100644 --- a/haystack/nodes/prompt/invocation_layer/hugging_face.py +++ b/haystack/nodes/prompt/invocation_layer/hugging_face.py @@ -7,7 +7,6 @@ from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamin from haystack.nodes.prompt.invocation_layer.utils import get_task from haystack.lazy_imports import LazyImport - logger = logging.getLogger(__name__) @@ -17,10 +16,14 @@ with LazyImport(message="Run 'pip install farm-haystack[inference]'") as torch_a pipeline, StoppingCriteriaList, StoppingCriteria, + GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast, - GenerationConfig, + PreTrainedModel, Pipeline, + AutoTokenizer, + AutoConfig, + TOKENIZER_MAPPING, ) from haystack.modeling.utils import initialize_device_settings # pylint: disable=ungrouped-imports from haystack.nodes.prompt.invocation_layer.handlers import HFTokenStreamingHandler @@ -167,21 +170,35 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer): torch_dtype = self._extract_torch_dtype(**kwargs) # and the model (prefer model instance over model_name_or_path str identifier) model = kwargs.get("model") or kwargs.get("model_name_or_path") + trust_remote_code = kwargs.get("trust_remote_code", False) + hub_kwargs = { + "revision": kwargs.get("revision", None), + "use_auth_token": kwargs.get("use_auth_token", None), + "trust_remote_code": trust_remote_code, + } + model_kwargs = kwargs.get("model_kwargs", {}) + tokenizer = kwargs.get("tokenizer", None) + + if tokenizer is None and trust_remote_code: + # For models not yet supported by the transformers library, we must set `trust_remote_code=True` within + # the underlying pipeline to ensure the model's successful loading. However, this does not guarantee the + # tokenizer will be loaded alongside. Therefore, we need to add additional logic here to manually load the + # tokenizer and pass it to transformers' pipleine. + # Otherwise, calling `self.pipe.tokenizer.model_max_length` will return an error. + tokenizer = self._prepare_tokenizer(model, hub_kwargs, model_kwargs) pipeline_kwargs = { "task": kwargs.get("task", None), "model": model, "config": kwargs.get("config", None), - "tokenizer": kwargs.get("tokenizer", None), + "tokenizer": tokenizer, "feature_extractor": kwargs.get("feature_extractor", None), - "revision": kwargs.get("revision", None), - "use_auth_token": kwargs.get("use_auth_token", None), "device_map": device_map, "device": device, "torch_dtype": torch_dtype, - "trust_remote_code": kwargs.get("trust_remote_code", False), - "model_kwargs": kwargs.get("model_kwargs", {}), + "model_kwargs": model_kwargs, "pipeline_class": kwargs.get("pipeline_class", None), + **hub_kwargs, } return pipeline_kwargs @@ -313,6 +330,44 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer): raise ValueError(f"Invalid torch_dtype value {torch_dtype}") return torch_dtype_resolved + def _prepare_tokenizer( + self, model: Union[str, "PreTrainedModel"], hub_kwargs: Dict, model_kwargs: Optional[Dict] = None + ) -> Union["PreTrainedTokenizer", "PreTrainedTokenizerFast", None]: + """ + this method prepares the tokenizer before passing it to transformers' pipeline, so that the instantiated pipeline + object has a working tokenizer. + + It basically check whether the pipeline method in the transformers library will load the tokenizer. + - If yes, None will be returned, because in this case, the pipeline is intelligent enough to load the tokenizer by itself + - If not, we will load the tokenizer and an tokenizer instance is returned + + :param model: the name or path of the underlying model + :hub_kwargs: keyword argument related to hugging face hub, including revision, trust_remote_code and use_auth_token + :model_kwargs: keyword arguments passed to the underlying model + """ + + if isinstance(model, str): + model_config = AutoConfig.from_pretrained(model, **hub_kwargs, **model_kwargs) + else: + model_config = model.config + model = model_config._name_or_path + # the will_load_tokenizer logic corresponds to this line in transformers library + # https://github.com/huggingface/transformers/blob/05cda5df3405e6a2ee4ecf8f7e1b2300ebda472e/src/transformers/pipelines/__init__.py#L805 + will_load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None + if not will_load_tokenizer: + logger.warning( + "The transformers library doesn't know which tokenizer class should be " + "loaded for the model %s. Therefore, the tokenizer will be loaded in Haystack's " + "invocation layer and then passed to the underlying pipeline. Alternatively, you could " + "pass `tokenizer_class` to `model_kwargs` to workaround this, if your tokenizer is supported " + "by the transformers library.", + model, + ) + tokenizer = AutoTokenizer.from_pretrained(model, **hub_kwargs, **model_kwargs) + else: + tokenizer = None + return tokenizer + @classmethod def supports(cls, model_name_or_path: str, **kwargs) -> bool: task_name: Optional[str] = kwargs.get("task_name", None) diff --git a/releasenotes/notes/load-tokenizer-if-not-load-by-transformers-5841cdc9ff69bcc2.yaml b/releasenotes/notes/load-tokenizer-if-not-load-by-transformers-5841cdc9ff69bcc2.yaml new file mode 100644 index 000000000..3904010f9 --- /dev/null +++ b/releasenotes/notes/load-tokenizer-if-not-load-by-transformers-5841cdc9ff69bcc2.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Allow loading Tokenizers for prompt models not natively supported by transformers by setting `trust_remote_code` to `True`. diff --git a/test/prompt/invocation_layer/test_hugging_face.py b/test/prompt/invocation_layer/test_hugging_face.py index e5ca043ef..5e45d39d0 100644 --- a/test/prompt/invocation_layer/test_hugging_face.py +++ b/test/prompt/invocation_layer/test_hugging_face.py @@ -1,5 +1,6 @@ from typing import List from unittest.mock import MagicMock, patch, Mock +import logging import pytest import torch @@ -68,14 +69,14 @@ def test_constructor_with_model_name_only(mock_pipeline, mock_get_task): "config", "tokenizer", "feature_extractor", - "revision", - "use_auth_token", "device_map", "device", "torch_dtype", - "trust_remote_code", "model_kwargs", "pipeline_class", + "revision", + "use_auth_token", + "trust_remote_code", ] @@ -552,3 +553,65 @@ def test_ensure_token_limit_negative_mock(mock_pipeline, mock_get_task, mock_aut result = layer._ensure_token_limit("I am a tokenized prompt of length eight") assert result == correct_result + + +@pytest.mark.unit +@patch("haystack.nodes.prompt.invocation_layer.hugging_face.AutoConfig.from_pretrained") +@patch("haystack.nodes.prompt.invocation_layer.hugging_face.AutoTokenizer.from_pretrained") +def test_tokenizer_loading_unsupported_model(mock_tokenizer, mock_config, mock_pipeline, mock_get_task, caplog): + """ + Test loading of tokenizers for models that are not natively supported by the transformers library. + """ + mock_config.return_value = Mock(tokenizer_class=None) + + with caplog.at_level(logging.WARNING): + invocation_layer = HFLocalInvocationLayer("unsupported_model", trust_remote_code=True) + assert ( + "The transformers library doesn't know which tokenizer class should be " + "loaded for the model unsupported_model. Therefore, the tokenizer will be loaded in Haystack's " + "invocation layer and then passed to the underlying pipeline. Alternatively, you could " + "pass `tokenizer_class` to `model_kwargs` to workaround this, if your tokenizer is supported " + "by the transformers library." + ) in caplog.text + assert mock_tokenizer.called + + +@pytest.mark.unit +@patch("haystack.nodes.prompt.invocation_layer.hugging_face.AutoTokenizer.from_pretrained") +def test_tokenizer_loading_unsupported_model_with_initialized_model( + mock_tokenizer, mock_pipeline, mock_get_task, caplog +): + """ + Test loading of tokenizers for models that are not natively supported by the transformers library. In this case, + the model is already initialized and the model config is loaded from the model. + """ + model = Mock() + model.config = Mock(tokenizer_class=None, _name_or_path="unsupported_model") + + with caplog.at_level(logging.WARNING): + invocation_layer = HFLocalInvocationLayer(model_name_or_path="unsupported", model=model, trust_remote_code=True) + assert ( + "The transformers library doesn't know which tokenizer class should be " + "loaded for the model unsupported_model. Therefore, the tokenizer will be loaded in Haystack's " + "invocation layer and then passed to the underlying pipeline. Alternatively, you could " + "pass `tokenizer_class` to `model_kwargs` to workaround this, if your tokenizer is supported " + "by the transformers library." + ) in caplog.text + assert mock_tokenizer.called + + +@pytest.mark.unit +@patch("haystack.nodes.prompt.invocation_layer.hugging_face.AutoConfig.from_pretrained") +@patch("haystack.nodes.prompt.invocation_layer.hugging_face.AutoTokenizer.from_pretrained") +def test_tokenizer_loading_unsupported_model_with_tokenizer_class_in_config( + mock_tokenizer, mock_config, mock_pipeline, mock_get_task, caplog +): + """ + Test that tokenizer is not loaded if tokenizer_class is set in model config. + """ + mock_config.return_value = Mock(tokenizer_class="Some-Supported-Tokenizer") + + with caplog.at_level(logging.WARNING): + invocation_layer = HFLocalInvocationLayer(model_name_or_path="unsupported_model", trust_remote_code=True) + assert not mock_tokenizer.called + assert not caplog.text