Fix DPR tests + add Tokenizer tests (#1429)

* Fix DPR tests

* Add Tokenizer tests
This commit is contained in:
bogdankostic 2021-09-09 12:56:44 +02:00 committed by GitHub
parent 23338f1b74
commit 2626388961
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 408 additions and 24 deletions

View File

@ -411,17 +411,26 @@ class DataSilo:
return train, test
def _calculate_statistics(self):
"""
Calculate and log simple summary statistics of the datasets.
"""
""" Calculate and log simple summary statistics of the datasets """
logger.info("")
logger.info("DATASETS SUMMARY")
logger.info("================")
self.counts = {}
clipped = -1
ave_len = -1
if self.data["train"]:
self.counts["train"] = len(self.data["train"])
if "input_ids" in self.tensor_names:
clipped, ave_len, seq_lens, max_seq_len = self._calc_length_stats_single_encoder()
elif "query_input_ids" in self.tensor_names and "passage_input_ids" in self.tensor_names:
clipped, ave_len, seq_lens, max_seq_len = self._calc_length_stats_biencoder()
else:
logger.warning(
f"Could not compute length statistics because 'input_ids' or 'query_input_ids' and 'passage_input_ids' are missing.")
clipped = -1
ave_len = -1
else:
self.counts["train"] = 0
@ -435,29 +444,31 @@ class DataSilo:
else:
self.counts["test"] = 0
seq_lens = []
if self.data["train"]:
for dataset in self.data["train"].datasets:
train_input_numpy = dataset[:][0].numpy()
seq_lens.extend(np.sum(train_input_numpy != self.processor.tokenizer.pad_token_id, axis=1))
max_seq_len = dataset[:][0].shape[1]
self.clipped = np.mean(np.array(seq_lens) == max_seq_len) if seq_lens else 0
self.ave_len = np.mean(seq_lens) if seq_lens else 0
logger.info("Examples in train: {}".format(self.counts["train"]))
logger.info("Examples in dev : {}".format(self.counts["dev"]))
logger.info("Examples in test : {}".format(self.counts["test"]))
logger.info("Total examples : {}".format(self.counts["train"] + self.counts["dev"] + self.counts["test"]))
logger.info("")
if self.data["train"]:
logger.info("Longest sequence length observed after clipping: {}".format(max(seq_lens)))
logger.info("Average sequence length after clipping: {}".format(self.ave_len))
logger.info("Proportion clipped: {}".format(self.clipped))
if self.clipped > 0.5:
logger.info("[Farmer's Tip] {}% of your samples got cut down to {} tokens. "
"Consider increasing max_seq_len. "
"This will lead to higher memory consumption but is likely to "
"improve your model performance".format(round(self.clipped * 100, 1), max_seq_len))
if "input_ids" in self.tensor_names:
logger.info("Longest sequence length observed after clipping: {}".format(max(seq_lens)))
logger.info("Average sequence length after clipping: {}".format(ave_len))
logger.info("Proportion clipped: {}".format(clipped))
if clipped > 0.5:
logger.info("[Farmer's Tip] {}% of your samples got cut down to {} tokens. "
"Consider increasing max_seq_len. "
"This will lead to higher memory consumption but is likely to "
"improve your model performance".format(round(clipped * 100, 1), max_seq_len))
elif "query_input_ids" in self.tensor_names and "passage_input_ids" in self.tensor_names:
logger.info("Longest query length observed after clipping: {} - for max_query_len: {}".format(
max(seq_lens[0]), max_seq_len[0]))
logger.info("Average query length after clipping: {}".format(ave_len[0]))
logger.info("Proportion queries clipped: {}".format(clipped[0]))
logger.info("")
logger.info("Longest passage length observed after clipping: {} - for max_passage_len: {}".format(
max(seq_lens[1]), max_seq_len[1]))
logger.info("Average passage length after clipping: {}".format(ave_len[1]))
logger.info("Proportion passages clipped: {}".format(clipped[1]))
MlLogger.log_params(
{
@ -465,11 +476,44 @@ class DataSilo:
"n_samples_dev": self.counts["dev"],
"n_samples_test": self.counts["test"],
"batch_size": self.batch_size,
"ave_seq_len": self.ave_len,
"clipped": self.clipped,
"ave_seq_len": ave_len,
"clipped": clipped,
}
)
def _calc_length_stats_single_encoder(self):
seq_lens = []
for dataset in self.data["train"].datasets:
train_input_numpy = dataset[:][self.tensor_names.index("input_ids")].numpy()
seq_lens.extend(np.sum(train_input_numpy != self.processor.tokenizer.pad_token_id, axis=1))
max_seq_len = dataset[:][self.tensor_names.index("input_ids")].shape[1]
clipped = np.mean(np.array(seq_lens) == max_seq_len) if seq_lens else 0
ave_len = np.mean(seq_lens) if seq_lens else 0
return clipped, ave_len, seq_lens, max_seq_len
def _calc_length_stats_biencoder(self):
seq_lens = [[], []]
for dataset in self.data["train"].datasets:
query_input_numpy = dataset[:][self.tensor_names.index("query_input_ids")].numpy()
num_passages = dataset[:][self.tensor_names.index("passage_input_ids")].shape[1]
bs = dataset[:][self.tensor_names.index("passage_input_ids")].shape[0]
passage_input_numpy = dataset[:][self.tensor_names.index("passage_input_ids")].numpy().reshape((bs, -1),
order='C')
qlen = np.sum(query_input_numpy != self.processor.query_tokenizer.pad_token_id, axis=1)
plen = np.sum(passage_input_numpy != self.processor.passage_tokenizer.pad_token_id, axis=1) / num_passages
seq_lens[0].extend(qlen)
seq_lens[1].extend(plen)
q_max_seq_len = dataset[:][self.tensor_names.index("query_input_ids")].shape[1]
p_max_seq_len = dataset[:][self.tensor_names.index("passage_input_ids")].shape[2]
clipped_q = np.mean(np.array(seq_lens[0]) == q_max_seq_len) if seq_lens[0] else 0
ave_len_q = np.mean(seq_lens[0]) if seq_lens[0] else 0
clipped_p = np.mean(np.array(seq_lens[1]) == p_max_seq_len) if seq_lens[1] else 0
ave_len_p = np.mean(seq_lens[1]) if seq_lens[1] else 0
clipped = [clipped_q, clipped_p]
ave_len = [ave_len_q, ave_len_p]
max_seq_len = [q_max_seq_len, p_max_seq_len]
return clipped, ave_len, seq_lens, max_seq_len
def calculate_class_weights(self, task_name: str, source: str = "train"):
"""
For imbalanced datasets, we can calculate class weights that can be used later in the

View File

@ -842,7 +842,7 @@ class TextSimilarityProcessor(Processor):
self.max_seq_len_passage = max_seq_len_passage
super(TextSimilarityProcessor, self).__init__(
tokenizer = None, #type: ignore
tokenizer = None, # type: ignore
max_seq_len = 0,
train_filename=train_filename,
dev_filename=dev_filename,

View File

@ -67,6 +67,7 @@ class BiAdaptiveModel(nn.Module):
Note: The loss at this stage is per sample, i.e one tensor of
shape (batchsize) per prediction head.
"""
super(BiAdaptiveModel, self).__init__()
self.device = device
self.language_model1 = language_model1.to(device)

File diff suppressed because one or more lines are too long

338
test/test_tokenization.py Normal file
View File

@ -0,0 +1,338 @@
import logging
import pytest
import re
from transformers import BertTokenizer, BertTokenizerFast, RobertaTokenizer, RobertaTokenizerFast, \
XLNetTokenizer, XLNetTokenizerFast, ElectraTokenizerFast
from tokenizers.pre_tokenizers import WhitespaceSplit
from haystack.basics.modeling.tokenization import Tokenizer
import numpy as np
TEXTS = [
"This is a sentence",
"Der entscheidende Pass",
"This is a sentence with multiple spaces",
"力加勝北区ᴵᴺᵀᵃছজটডণত",
"Thiso text is included tolod makelio sure Unicodeel is handled properly:",
"This is a sentence...",
"Let's see all on this text and. !23# neverseenwordspossible",
"""This is a sentence.
With linebreak""",
"""Sentence with multiple
newlines
""",
"and another one\n\n\nwithout space",
"This is a sentence with tab",
"This is a sentence with multiple tabs",
]
def test_basic_loading(caplog):
caplog.set_level(logging.CRITICAL)
# slow tokenizers
tokenizer = Tokenizer.load(
pretrained_model_name_or_path="bert-base-cased",
do_lower_case=True,
use_fast=False,
)
assert type(tokenizer) == BertTokenizer
assert tokenizer.basic_tokenizer.do_lower_case == True
tokenizer = Tokenizer.load(
pretrained_model_name_or_path="xlnet-base-cased",
do_lower_case=True,
use_fast=False
)
assert type(tokenizer) == XLNetTokenizer
assert tokenizer.do_lower_case == True
tokenizer = Tokenizer.load(
pretrained_model_name_or_path="roberta-base",
use_fast=False
)
assert type(tokenizer) == RobertaTokenizer
# fast tokenizers
tokenizer = Tokenizer.load(
pretrained_model_name_or_path="bert-base-cased",
do_lower_case=True
)
assert type(tokenizer) == BertTokenizerFast
assert tokenizer.do_lower_case == True
tokenizer = Tokenizer.load(
pretrained_model_name_or_path="xlnet-base-cased",
do_lower_case=True
)
assert type(tokenizer) == XLNetTokenizerFast
assert tokenizer.do_lower_case == True
tokenizer = Tokenizer.load(
pretrained_model_name_or_path="roberta-base"
)
assert type(tokenizer) == RobertaTokenizerFast
def test_bert_tokenizer_all_meta(caplog):
caplog.set_level(logging.CRITICAL)
lang_model = "bert-base-cased"
tokenizer = Tokenizer.load(
pretrained_model_name_or_path=lang_model,
do_lower_case=False
)
basic_text = "Some Text with neverseentokens plus !215?#. and a combined-token_with/chars"
tokenized = tokenizer.tokenize(basic_text)
assert tokenized == ['Some', 'Text', 'with', 'never', '##see', '##nto', '##ken', '##s', 'plus', '!', '215', '?', '#', '.', 'and', 'a', 'combined', '-', 'token', '_', 'with', '/', 'ch', '##ars']
encoded_batch = tokenizer.encode_plus(basic_text)
encoded = encoded_batch.encodings[0]
words = np.array(encoded.words)
words[words == None] = -1
start_of_word_single = [False] + list(np.ediff1d(words) > 0)
assert encoded.tokens == ['[CLS]', 'Some', 'Text', 'with', 'never', '##see', '##nto', '##ken', '##s', 'plus', '!', '215', '?', '#', '.', 'and', 'a', 'combined', '-', 'token', '_', 'with', '/', 'ch', '##ars', '[SEP]']
assert [x[0] for x in encoded.offsets] == [0, 0, 5, 10, 15, 20, 23, 26, 29, 31, 36, 37, 40, 41, 42, 44, 48, 50, 58, 59, 64, 65, 69, 70, 72, 0]
assert start_of_word_single == [False, True, True, True, True, False, False, False, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, False]
def test_save_load(caplog):
caplog.set_level(logging.CRITICAL)
lang_names = ["bert-base-cased", "roberta-base", "xlnet-base-cased"]
tokenizers = []
for lang_name in lang_names:
if "xlnet" in lang_name.lower():
t = Tokenizer.load(lang_name, lower_case=False, use_fast=True, from_slow=True)
else:
t = Tokenizer.load(lang_name, lower_case=False)
t.add_tokens(new_tokens=["neverseentokens"])
tokenizers.append(t)
basic_text = "Some Text with neverseentokens plus !215?#. and a combined-token_with/chars"
for tokenizer in tokenizers:
tokenizer_type = tokenizer.__class__.__name__
save_dir = f"testsave/{tokenizer_type}"
tokenizer.save_pretrained(save_dir)
tokenizer_loaded = Tokenizer.load(save_dir, tokenizer_class=tokenizer_type)
encoded_before = tokenizer.encode_plus(basic_text).encodings[0]
encoded_after = tokenizer_loaded.encode_plus(basic_text).encodings[0]
data_before = {"tokens": encoded_before.tokens,
"offsets": encoded_before.offsets,
"words": encoded_before.words}
data_after = {"tokens": encoded_after.tokens,
"offsets": encoded_after.offsets,
"words": encoded_after.words}
assert data_before == data_after
@pytest.mark.parametrize("model_name", ["bert-base-german-cased",
"google/electra-small-discriminator",
])
def test_fast_tokenizer_with_examples(caplog, model_name):
fast_tokenizer = Tokenizer.load(model_name, lower_case=False, use_fast=True)
tokenizer = Tokenizer.load(model_name, lower_case=False, use_fast=False)
for text in TEXTS:
# plain tokenize function
tokenized = tokenizer.tokenize(text)
fast_tokenized = fast_tokenizer.tokenize(text)
assert tokenized == fast_tokenized
def test_all_tokenizer_on_special_cases(caplog):
caplog.set_level(logging.CRITICAL)
lang_names = ["bert-base-cased", "roberta-base", "xlnet-base-cased"]
tokenizers = []
for lang_name in lang_names:
if "roberta" in lang_name:
add_prefix_space = True
else:
add_prefix_space = False
t = Tokenizer.load(lang_name, lower_case=False, add_prefix_space=add_prefix_space)
tokenizers.append(t)
texts = [
"This is a sentence",
"Der entscheidende Pass",
"力加勝北区ᴵᴺᵀᵃছজটডণত",
"Thiso text is included tolod makelio sure Unicodeel is handled properly:",
"This is a sentence...",
"Let's see all on this text and. !23# neverseenwordspossible"
"This is a sentence with multiple spaces",
"""This is a sentence.
With linebreak""",
"""Sentence with multiple
newlines
""",
"and another one\n\n\nwithout space",
"This is a sentence with multiple tabs",
]
expected_to_fail = [(1, 1), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (2, 1), (2, 5)]
for i_tok, tokenizer in enumerate(tokenizers):
for i_text, text in enumerate(texts):
# Important: we don't assume to preserve whitespaces after tokenization.
# This means: \t, \n " " etc will all resolve to a single " ".
# This doesn't make a difference for BERT + XLNet but it does for roBERTa
test_passed = True
# 1. original tokenize function from transformer repo on full sentence
standardized_whitespace_text = ' '.join(text.split()) # remove multiple whitespaces
tokenized = tokenizer.tokenize(standardized_whitespace_text)
# 2. Our tokenization method using a pretokenizer which can normalize multiple white spaces
# This approach is used in NER
pre_tokenizer = WhitespaceSplit()
words_and_spans = pre_tokenizer.pre_tokenize_str(text)
words = [x[0] for x in words_and_spans]
word_spans = [x[1] for x in words_and_spans]
encoded = tokenizer.encode_plus(words, is_split_into_words=True, add_special_tokens=False).encodings[0]
# verify that tokenization on full sequence is the same as the one on "whitespace tokenized words"
if encoded.tokens != tokenized:
test_passed = False
# token offsets are originally relative to the beginning of the word
# These lines convert them so they are relative to the beginning of the sentence
token_offsets = []
for (start, end), w_index, in zip(encoded.offsets, encoded.words):
word_start_ch = word_spans[w_index][0]
token_offsets.append((start + word_start_ch, end + word_start_ch))
if getattr(tokenizer, "add_prefix_space", None):
token_offsets = [(start-1, end) for start, end in token_offsets]
# verify that offsets align back to original text
if text == "力加勝北区ᴵᴺᵀᵃছজটডণত":
# contains [UNK] that are impossible to match back to original text space
continue
for tok, (start, end) in zip(encoded.tokens, token_offsets):
#subword-tokens have special chars depending on model type. In order to align with original text we need to get rid of them
tok = re.sub(r"^(##|Ġ|▁)", "", tok)
#tok = tokenizer.decode(tokenizer.convert_tokens_to_ids(tok))
original_tok = text[start: end]
if tok != original_tok:
test_passed = False
if (i_tok, i_text) in expected_to_fail:
assert not test_passed, f"Behaviour of {tokenizer.__class__.__name__} has changed on text {text}'"
else:
assert test_passed, f"Behaviour of {tokenizer.__class__.__name__} has changed on text {text}'"
def test_bert_custom_vocab(caplog):
caplog.set_level(logging.CRITICAL)
lang_model = "bert-base-cased"
tokenizer = Tokenizer.load(
pretrained_model_name_or_path=lang_model,
do_lower_case=False
)
#deprecated: tokenizer.add_custom_vocab("samples/tokenizer/custom_vocab.txt")
tokenizer.add_tokens(new_tokens=["neverseentokens"])
basic_text = "Some Text with neverseentokens plus !215?#. and a combined-token_with/chars"
# original tokenizer from transformer repo
tokenized = tokenizer.tokenize(basic_text)
assert tokenized == ['Some', 'Text', 'with', 'neverseentokens', 'plus', '!', '215', '?', '#', '.', 'and', 'a', 'combined', '-', 'token', '_', 'with', '/', 'ch', '##ars']
# ours with metadata
encoded = tokenizer.encode_plus(basic_text, add_special_tokens=False).encodings[0]
offsets = [x[0] for x in encoded.offsets]
start_of_word_single = [True] + list(np.ediff1d(encoded.words) > 0)
assert encoded.tokens == tokenized
assert offsets == [0, 5, 10, 15, 31, 36, 37, 40, 41, 42, 44, 48, 50, 58, 59, 64, 65, 69, 70, 72]
assert start_of_word_single == [True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False]
def test_fast_bert_custom_vocab(caplog):
caplog.set_level(logging.CRITICAL)
lang_model = "bert-base-cased"
tokenizer = Tokenizer.load(
pretrained_model_name_or_path=lang_model,
do_lower_case=False, use_fast=True
)
#deprecated: tokenizer.add_custom_vocab("samples/tokenizer/custom_vocab.txt")
tokenizer.add_tokens(new_tokens=["neverseentokens"])
basic_text = "Some Text with neverseentokens plus !215?#. and a combined-token_with/chars"
# original tokenizer from transformer repo
tokenized = tokenizer.tokenize(basic_text)
assert tokenized == ['Some', 'Text', 'with', 'neverseentokens', 'plus', '!', '215', '?', '#', '.', 'and', 'a', 'combined', '-', 'token', '_', 'with', '/', 'ch', '##ars']
# ours with metadata
encoded = tokenizer.encode_plus(basic_text, add_special_tokens=False).encodings[0]
offsets = [x[0] for x in encoded.offsets]
start_of_word_single = [True] + list(np.ediff1d(encoded.words) > 0)
assert encoded.tokens == tokenized
assert offsets == [0, 5, 10, 15, 31, 36, 37, 40, 41, 42, 44, 48, 50, 58, 59, 64, 65, 69, 70, 72]
assert start_of_word_single == [True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False]
@pytest.mark.parametrize("model_name, tokenizer_type", [
("bert-base-german-cased", BertTokenizerFast),
("google/electra-small-discriminator", ElectraTokenizerFast),
])
def test_fast_tokenizer_type(caplog, model_name, tokenizer_type):
caplog.set_level(logging.CRITICAL)
tokenizer = Tokenizer.load(model_name, use_fast=True)
assert type(tokenizer) is tokenizer_type
# See discussion in https://github.com/deepset-ai/FARM/pull/624 for reason to remove the test
# def test_fast_bert_tokenizer_strip_accents(caplog):
# caplog.set_level(logging.CRITICAL)
#
# tokenizer = Tokenizer.load("dbmdz/bert-base-german-uncased",
# use_fast=True,
# strip_accents=False)
# assert type(tokenizer) is BertTokenizerFast
# assert tokenizer.do_lower_case
# assert tokenizer._tokenizer._parameters['strip_accents'] is False
def test_fast_electra_tokenizer(caplog):
caplog.set_level(logging.CRITICAL)
tokenizer = Tokenizer.load("dbmdz/electra-base-german-europeana-cased-discriminator",
use_fast=True)
assert type(tokenizer) is ElectraTokenizerFast
@pytest.mark.parametrize("model_name", ["bert-base-cased", "distilbert-base-uncased", "deepset/electra-base-squad2"])
def test_detokenization_in_fast_tokenizers(model_name):
tokenizer = Tokenizer.load(
pretrained_model_name_or_path=model_name,
use_fast=True
)
for text in TEXTS:
encoded = tokenizer.encode_plus(text, add_special_tokens=False).encodings[0]
detokenized = " ".join(encoded.tokens)
detokenized = re.sub(r"(^|\s+)(##)", "", detokenized)
detokenized_ids = tokenizer(detokenized, add_special_tokens=False)["input_ids"]
detokenized_tokens = [tokenizer.decode([tok_id]).strip() for tok_id in detokenized_ids]
assert encoded.tokens == detokenized_tokens
if __name__ == "__main__":
test_all_tokenizer_on_special_cases()