mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-27 02:40:41 +00:00

* Simplification of language_model.py and tokenization.py to remove code duplication Co-authored-by: vblagoje <dovlex@gmail.com>
326 lines
11 KiB
Python
326 lines
11 KiB
Python
from typing import Tuple
|
|
|
|
import re
|
|
|
|
import pytest
|
|
import numpy as np
|
|
from unittest.mock import MagicMock
|
|
|
|
from tokenizers.pre_tokenizers import WhitespaceSplit
|
|
|
|
import haystack
|
|
from haystack.modeling.model.tokenization import get_tokenizer
|
|
|
|
|
|
BERT = "bert-base-cased"
|
|
ROBERTA = "roberta-base"
|
|
XLNET = "xlnet-base-cased"
|
|
|
|
TOKENIZERS_TO_TEST = [BERT, ROBERTA, XLNET]
|
|
TOKENIZERS_TO_TEST_WITH_TOKEN_MARKER = [(BERT, "##"), (ROBERTA, "Ġ"), (XLNET, "▁")]
|
|
|
|
|
|
REGULAR_SENTENCE = "This is a sentence"
|
|
GERMAN_SENTENCE = "Der entscheidende Pass"
|
|
OTHER_ALPHABETS = "力加勝北区ᴵᴺᵀᵃছজটডণত"
|
|
GIBBERISH_SENTENCE = "Thiso text is included tolod makelio sure Unicodeel is handled properly:"
|
|
SENTENCE_WITH_ELLIPSIS = "This is a sentence..."
|
|
SENTENCE_WITH_LINEBREAK_1 = "and another one\n\n\nwithout space"
|
|
SENTENCE_WITH_LINEBREAK_2 = """This is a sentence.
|
|
With linebreak"""
|
|
SENTENCE_WITH_LINEBREAKS = """Sentence
|
|
with
|
|
multiple
|
|
newlines
|
|
"""
|
|
SENTENCE_WITH_EXCESS_WHITESPACE = "This is a sentence with multiple spaces"
|
|
SENTENCE_WITH_TABS = "This is a sentence with multiple tabs"
|
|
SENTENCE_WITH_CUSTOM_TOKEN = "Let's see all on this text and. !23# neverseenwordspossible"
|
|
|
|
|
|
class AutoTokenizer:
|
|
mocker: MagicMock = MagicMock()
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, *args, **kwargs):
|
|
cls.mocker.from_pretrained(*args, **kwargs)
|
|
return cls()
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def mock_autotokenizer(request, monkeypatch):
|
|
# Do not patch integration tests
|
|
if "integration" in request.keywords:
|
|
return
|
|
monkeypatch.setattr(haystack.modeling.model.tokenization, "AutoTokenizer", AutoTokenizer)
|
|
|
|
|
|
def convert_offset_from_word_reference_to_text_reference(offsets, words, word_spans):
|
|
"""
|
|
Token offsets are originally relative to the beginning of the word
|
|
We make them relative to the beginning of the sentence.
|
|
|
|
Not a fixture, just a utility.
|
|
"""
|
|
token_offsets = []
|
|
for ((start, end), word_index) in zip(offsets, words):
|
|
word_start = word_spans[word_index][0]
|
|
token_offsets.append((start + word_start, end + word_start))
|
|
return token_offsets
|
|
|
|
|
|
#
|
|
# Unit tests
|
|
#
|
|
|
|
|
|
def test_get_tokenizer_str():
|
|
tokenizer = get_tokenizer(pretrained_model_name_or_path="test-model-name")
|
|
tokenizer.mocker.from_pretrained.assert_called_with(
|
|
pretrained_model_name_or_path="test-model-name", revision=None, use_fast=True, use_auth_token=None
|
|
)
|
|
|
|
|
|
def test_get_tokenizer_path(tmp_path):
|
|
tokenizer = get_tokenizer(pretrained_model_name_or_path=tmp_path / "test-path")
|
|
tokenizer.mocker.from_pretrained.assert_called_with(
|
|
pretrained_model_name_or_path=str(tmp_path / "test-path"), revision=None, use_fast=True, use_auth_token=None
|
|
)
|
|
|
|
|
|
def test_get_tokenizer_keep_accents():
|
|
tokenizer = get_tokenizer(pretrained_model_name_or_path="test-model-name-albert")
|
|
tokenizer.mocker.from_pretrained.assert_called_with(
|
|
pretrained_model_name_or_path="test-model-name-albert",
|
|
revision=None,
|
|
use_fast=True,
|
|
use_auth_token=None,
|
|
keep_accents=True,
|
|
)
|
|
|
|
|
|
def test_get_tokenizer_mlm_warning(caplog):
|
|
tokenizer = get_tokenizer(pretrained_model_name_or_path="test-model-name-mlm")
|
|
tokenizer.mocker.from_pretrained.assert_called_with(
|
|
pretrained_model_name_or_path="test-model-name-mlm", revision=None, use_fast=True, use_auth_token=None
|
|
)
|
|
assert "MLM part of codebert is currently not supported in Haystack".lower() in caplog.text.lower()
|
|
|
|
|
|
#
|
|
# Integration tests
|
|
#
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.parametrize("model_name", TOKENIZERS_TO_TEST)
|
|
def test_save_load(tmp_path, model_name: str):
|
|
tokenizer = get_tokenizer(pretrained_model_name_or_path=model_name, do_lower_case=False)
|
|
text = "Some Text with neverseentokens plus !215?#. and a combined-token_with/chars"
|
|
|
|
tokenizer.add_tokens(new_tokens=["neverseentokens"])
|
|
original_encoding = tokenizer.encode_plus(text)
|
|
|
|
save_dir = tmp_path / "saved_tokenizer"
|
|
tokenizer.save_pretrained(save_dir)
|
|
|
|
tokenizer_loaded = get_tokenizer(pretrained_model_name_or_path=save_dir)
|
|
new_encoding = tokenizer_loaded.encode_plus(text)
|
|
|
|
assert original_encoding == new_encoding
|
|
|
|
|
|
@pytest.mark.integration
|
|
def test_tokenize_custom_vocab_bert():
|
|
tokenizer = get_tokenizer(pretrained_model_name_or_path=BERT, do_lower_case=False)
|
|
tokenizer.add_tokens(new_tokens=["neverseentokens"])
|
|
text = "Some Text with neverseentokens plus !215?#. and a combined-token_with/chars"
|
|
|
|
tokenized = tokenizer.tokenize(text)
|
|
assert (
|
|
tokenized == f"Some Text with neverseentokens plus ! 215 ? # . and a combined - token _ with / ch ##ars".split()
|
|
)
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.parametrize(
|
|
"edge_case",
|
|
[
|
|
REGULAR_SENTENCE,
|
|
OTHER_ALPHABETS,
|
|
GIBBERISH_SENTENCE,
|
|
SENTENCE_WITH_ELLIPSIS,
|
|
SENTENCE_WITH_LINEBREAK_1,
|
|
SENTENCE_WITH_LINEBREAK_2,
|
|
SENTENCE_WITH_LINEBREAKS,
|
|
SENTENCE_WITH_EXCESS_WHITESPACE,
|
|
SENTENCE_WITH_TABS,
|
|
],
|
|
)
|
|
@pytest.mark.parametrize("model_name", TOKENIZERS_TO_TEST)
|
|
def test_tokenization_on_edge_cases_full_sequence_tokenization(model_name: str, edge_case: str):
|
|
"""
|
|
Verify that tokenization on full sequence is the same as the one on "whitespace tokenized words"
|
|
"""
|
|
tokenizer = get_tokenizer(pretrained_model_name_or_path=model_name, do_lower_case=False, add_prefix_space=True)
|
|
|
|
pre_tokenizer = WhitespaceSplit()
|
|
words_and_spans = pre_tokenizer.pre_tokenize_str(edge_case)
|
|
words = [x[0] for x in words_and_spans]
|
|
|
|
encoded = tokenizer.encode_plus(words, is_split_into_words=True, add_special_tokens=False).encodings[0]
|
|
expected_tokenization = tokenizer.tokenize(" ".join(edge_case.split())) # remove multiple whitespaces
|
|
|
|
assert encoded.tokens == expected_tokenization
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.parametrize("edge_case", [SENTENCE_WITH_CUSTOM_TOKEN, GERMAN_SENTENCE])
|
|
@pytest.mark.parametrize("model_name", [t for t in TOKENIZERS_TO_TEST if t != ROBERTA])
|
|
def test_tokenization_on_edge_cases_full_sequence_tokenization_roberta_exceptions(model_name: str, edge_case: str):
|
|
"""
|
|
Verify that tokenization on full sequence is the same as the one on "whitespace tokenized words".
|
|
These test cases work for all tokenizers under test except for RoBERTa.
|
|
"""
|
|
tokenizer = get_tokenizer(pretrained_model_name_or_path=model_name, do_lower_case=False, add_prefix_space=True)
|
|
|
|
pre_tokenizer = WhitespaceSplit()
|
|
words_and_spans = pre_tokenizer.pre_tokenize_str(edge_case)
|
|
words = [x[0] for x in words_and_spans]
|
|
|
|
encoded = tokenizer.encode_plus(words, is_split_into_words=True, add_special_tokens=False).encodings[0]
|
|
expected_tokenization = tokenizer.tokenize(" ".join(edge_case.split())) # remove multiple whitespaces
|
|
|
|
assert encoded.tokens == expected_tokenization
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.parametrize(
|
|
"edge_case",
|
|
[
|
|
REGULAR_SENTENCE,
|
|
# OTHER_ALPHABETS, # contains [UNK] that are impossible to match back to original text space
|
|
GIBBERISH_SENTENCE,
|
|
SENTENCE_WITH_ELLIPSIS,
|
|
SENTENCE_WITH_LINEBREAK_1,
|
|
SENTENCE_WITH_LINEBREAK_2,
|
|
SENTENCE_WITH_LINEBREAKS,
|
|
SENTENCE_WITH_EXCESS_WHITESPACE,
|
|
SENTENCE_WITH_TABS,
|
|
],
|
|
)
|
|
@pytest.mark.parametrize("model_name,marker", TOKENIZERS_TO_TEST_WITH_TOKEN_MARKER)
|
|
def test_tokenization_on_edge_cases_full_sequence_verify_spans(model_name: str, marker: str, edge_case: str):
|
|
tokenizer = get_tokenizer(pretrained_model_name_or_path=model_name, do_lower_case=False, add_prefix_space=True)
|
|
|
|
pre_tokenizer = WhitespaceSplit()
|
|
words_and_spans = pre_tokenizer.pre_tokenize_str(edge_case)
|
|
words = [x[0] for x in words_and_spans]
|
|
word_spans = [x[1] for x in words_and_spans]
|
|
|
|
encoded = tokenizer.encode_plus(words, is_split_into_words=True, add_special_tokens=False).encodings[0]
|
|
|
|
# subword-tokens have special chars depending on model type. To align with original text we get rid of them
|
|
tokens = [token.replace(marker, "") for token in encoded.tokens]
|
|
token_offsets = convert_offset_from_word_reference_to_text_reference(encoded.offsets, encoded.words, word_spans)
|
|
|
|
for token, (start, end) in zip(tokens, token_offsets):
|
|
assert token == edge_case[start:end]
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.parametrize(
|
|
"edge_case",
|
|
[
|
|
REGULAR_SENTENCE,
|
|
GERMAN_SENTENCE,
|
|
SENTENCE_WITH_EXCESS_WHITESPACE,
|
|
OTHER_ALPHABETS,
|
|
GIBBERISH_SENTENCE,
|
|
SENTENCE_WITH_ELLIPSIS,
|
|
SENTENCE_WITH_CUSTOM_TOKEN,
|
|
SENTENCE_WITH_LINEBREAK_1,
|
|
SENTENCE_WITH_LINEBREAK_2,
|
|
SENTENCE_WITH_LINEBREAKS,
|
|
SENTENCE_WITH_TABS,
|
|
],
|
|
)
|
|
def test_detokenization_for_bert(edge_case):
|
|
tokenizer = get_tokenizer(pretrained_model_name_or_path=BERT, do_lower_case=False)
|
|
|
|
encoded = tokenizer.encode_plus(edge_case, add_special_tokens=False).encodings[0]
|
|
|
|
detokenized = " ".join(encoded.tokens)
|
|
detokenized = re.sub(r"(^|\s+)(##)", "", detokenized)
|
|
|
|
detokenized_ids = tokenizer(detokenized, add_special_tokens=False)["input_ids"]
|
|
detokenized_tokens = [tokenizer.decode([tok_id]).strip() for tok_id in detokenized_ids]
|
|
|
|
assert encoded.tokens == detokenized_tokens
|
|
|
|
|
|
@pytest.mark.integration
|
|
def test_encode_plus_for_bert():
|
|
tokenizer = get_tokenizer(pretrained_model_name_or_path=BERT, do_lower_case=False)
|
|
text = "Some Text with neverseentokens plus !215?#. and a combined-token_with/chars"
|
|
|
|
encoded_batch = tokenizer.encode_plus(text)
|
|
encoded = encoded_batch.encodings[0]
|
|
|
|
words = np.array(encoded.words)
|
|
words[0] = -1
|
|
words[-1] = -1
|
|
|
|
print(words.tolist())
|
|
|
|
tokens = encoded.tokens
|
|
offsets = [x[0] for x in encoded.offsets]
|
|
start_of_word = [False] + list(np.ediff1d(words) > 0)
|
|
|
|
assert list(zip(tokens, offsets, start_of_word)) == [
|
|
("[CLS]", 0, False),
|
|
("Some", 0, True),
|
|
("Text", 5, True),
|
|
("with", 10, True),
|
|
("never", 15, True),
|
|
("##see", 20, False),
|
|
("##nto", 23, False),
|
|
("##ken", 26, False),
|
|
("##s", 29, False),
|
|
("plus", 31, True),
|
|
("!", 36, True),
|
|
("215", 37, True),
|
|
("?", 40, True),
|
|
("#", 41, True),
|
|
(".", 42, True),
|
|
("and", 44, True),
|
|
("a", 48, True),
|
|
("combined", 50, True),
|
|
("-", 58, True),
|
|
("token", 59, True),
|
|
("_", 64, True),
|
|
("with", 65, True),
|
|
("/", 69, True),
|
|
("ch", 70, True),
|
|
("##ars", 72, False),
|
|
("[SEP]", 0, False),
|
|
]
|
|
|
|
|
|
@pytest.mark.integration
|
|
def test_tokenize_custom_vocab_bert():
|
|
tokenizer = get_tokenizer(pretrained_model_name_or_path=BERT, do_lower_case=False)
|
|
|
|
tokenizer.add_tokens(new_tokens=["neverseentokens"])
|
|
text = "Some Text with neverseentokens plus !215?#. and a combined-token_with/chars"
|
|
|
|
tokenized = tokenizer.tokenize(text)
|
|
|
|
encoded = tokenizer.encode_plus(text, add_special_tokens=False).encodings[0]
|
|
offsets = [x[0] for x in encoded.offsets]
|
|
start_of_word_single = [True] + list(np.ediff1d(encoded.words) > 0)
|
|
|
|
assert encoded.tokens == tokenized
|
|
assert offsets == [0, 5, 10, 15, 31, 36, 37, 40, 41, 42, 44, 48, 50, 58, 59, 64, 65, 69, 70, 72]
|
|
assert start_of_word_single == [True] * 19 + [False]
|