2022-10-17 18:58:35 +02:00
|
|
|
import pytest
|
2023-04-28 17:08:41 +02:00
|
|
|
from unittest.mock import MagicMock
|
2022-10-17 18:58:35 +02:00
|
|
|
|
|
|
|
import haystack
|
2023-04-28 17:08:41 +02:00
|
|
|
from haystack.modeling.model.feature_extraction import FeatureExtractor
|
2022-10-17 18:58:35 +02:00
|
|
|
|
|
|
|
|
|
|
|
class MockedAutoTokenizer:
|
|
|
|
mocker: MagicMock = MagicMock()
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_pretrained(cls, *args, **kwargs):
|
|
|
|
cls.mocker.from_pretrained(*args, **kwargs)
|
|
|
|
return cls()
|
|
|
|
|
|
|
|
|
|
|
|
class MockedAutoConfig:
|
|
|
|
mocker: MagicMock = MagicMock()
|
|
|
|
model_type: str = "mocked"
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_pretrained(cls, *args, **kwargs):
|
|
|
|
cls.mocker.from_pretrained(*args, **kwargs)
|
|
|
|
return cls()
|
|
|
|
|
|
|
|
|
2023-04-28 17:08:41 +02:00
|
|
|
@pytest.fixture()
|
|
|
|
def mock_autotokenizer(monkeypatch):
|
2022-10-17 18:58:35 +02:00
|
|
|
monkeypatch.setattr(
|
|
|
|
haystack.modeling.model.feature_extraction, "FEATURE_EXTRACTORS", {"mocked": MockedAutoTokenizer}
|
|
|
|
)
|
|
|
|
monkeypatch.setattr(haystack.modeling.model.feature_extraction, "AutoConfig", MockedAutoConfig)
|
|
|
|
monkeypatch.setattr(haystack.modeling.model.feature_extraction, "AutoTokenizer", MockedAutoTokenizer)
|
|
|
|
|
|
|
|
|
2023-04-28 17:08:41 +02:00
|
|
|
@pytest.mark.unit
|
|
|
|
def test_get_tokenizer_str(mock_autotokenizer):
|
2022-10-17 18:58:35 +02:00
|
|
|
tokenizer = FeatureExtractor(pretrained_model_name_or_path="test-model-name")
|
2023-04-28 17:08:41 +02:00
|
|
|
tokenizer.mocker.from_pretrained.assert_called_with(
|
2022-10-17 18:58:35 +02:00
|
|
|
pretrained_model_name_or_path="test-model-name", revision=None, use_fast=True, use_auth_token=None
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2023-04-28 17:08:41 +02:00
|
|
|
@pytest.mark.unit
|
|
|
|
def test_get_tokenizer_path(mock_autotokenizer, tmp_path):
|
2022-10-17 18:58:35 +02:00
|
|
|
tokenizer = FeatureExtractor(pretrained_model_name_or_path=tmp_path / "test-path")
|
2023-04-28 17:08:41 +02:00
|
|
|
tokenizer.mocker.from_pretrained.assert_called_with(
|
2022-10-17 18:58:35 +02:00
|
|
|
pretrained_model_name_or_path=str(tmp_path / "test-path"), revision=None, use_fast=True, use_auth_token=None
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2023-04-28 17:08:41 +02:00
|
|
|
@pytest.mark.unit
|
|
|
|
def test_get_tokenizer_keep_accents(mock_autotokenizer):
|
|
|
|
tokenizer = FeatureExtractor(pretrained_model_name_or_path="test-model-name-albert")
|
|
|
|
tokenizer.mocker.from_pretrained.assert_called_with(
|
|
|
|
pretrained_model_name_or_path="test-model-name-albert",
|
|
|
|
revision=None,
|
|
|
|
use_fast=True,
|
|
|
|
use_auth_token=None,
|
|
|
|
keep_accents=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_get_tokenizer_mlm_warning(mock_autotokenizer, caplog):
|
|
|
|
tokenizer = FeatureExtractor(pretrained_model_name_or_path="test-model-name-mlm")
|
|
|
|
tokenizer.mocker.from_pretrained.assert_called_with(
|
|
|
|
pretrained_model_name_or_path="test-model-name-mlm", revision=None, use_fast=True, use_auth_token=None
|
|
|
|
)
|
|
|
|
assert "MLM part of codebert is currently not supported in Haystack".lower() in caplog.text.lower()
|
|
|
|
|
2022-10-17 18:58:35 +02:00
|
|
|
|
|
|
|
FEATURE_EXTRACTORS_TO_TEST = ["bert-base-cased"]
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.integration
|
|
|
|
@pytest.mark.parametrize("model_name", FEATURE_EXTRACTORS_TO_TEST)
|
|
|
|
def test_load_modify_save_load(tmp_path, model_name: str):
|
|
|
|
# Load base tokenizer
|
|
|
|
feature_extractor = FeatureExtractor(pretrained_model_name_or_path=model_name, do_lower_case=False)
|
|
|
|
|
|
|
|
# Add new tokens
|
|
|
|
feature_extractor.feature_extractor.add_tokens(new_tokens=["neverseentokens"])
|
|
|
|
|
|
|
|
# Save modified tokenizer
|
|
|
|
save_dir = tmp_path / "saved_tokenizer"
|
|
|
|
feature_extractor.feature_extractor.save_pretrained(save_dir)
|
|
|
|
|
|
|
|
# Load modified tokenizer
|
|
|
|
new_feature_extractor = FeatureExtractor(pretrained_model_name_or_path=save_dir)
|
|
|
|
|
|
|
|
# Assert the new tokenizer still has the added tokens
|
|
|
|
assert len(new_feature_extractor.feature_extractor) == len(feature_extractor.feature_extractor)
|