mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-31 20:58:29 +00:00

* Adding Data2VecVision and Data2VecText to the supported models and adapt Tokenizers accordingly * content_types * Splitting classes into respective folders * small changes * Fix EOF * eof * black * API * EOF * whitespace * api * improve multimodal similarity processor * tokenizer -> feature extractor * Making feature vectors come out of the feature extractor in the similarity head * embed_queries is now self-sufficient * couple trivial errors * Implemented separate language model classes for multimodal inference * Document embedding seems to work * removing batch_encode_plus, is deprecated anyway * Realized the base Data2Vec models are not trained on retrieval tasks * Issue with the generated embeddings * Add batching * Try to fit CLIP in * Stub of CLIP integration * Retrieval goes through but returns noise only * Still working on the scores * Introduce temporary adapter for CLIP models * Image retrieval now works with sentence-transformers * Tidying up the code * Refactoring is now functional * Add MPNet to the supported sentence transformers models * Remove unused classes * pylint * docs * docs * Remove the method renaming * mpyp first pass * docs * tutorial * schema * mypy * Move devices setup into get_model * more mypy * mypy * pylint * Move a few params in HaystackModel's init * make feature extractor work with squadprocessor * fix feature_extractor_kwargs forwarding * Forgotten part of the fix * Revert unrelated ES change * Revert unrelated memdocstore changes * comment * Small corrections * mypy and pylint * mypy * typo * mypy * Refactor the call * mypy * Do not make FARMReader use the new FeatureExtractor * mypy * Detach DPR tests from FeatureExtractor too * Detach processor tests too * Add end2end marker * extract end2end feature extractor tests * temporary disable feature extraction tests * Introduce end2end tests for tokenizer tests * pylint * Fix model loading from folder in FeatureExtractor * working o n end2end * end2end keeps failing * Restructuring retriever tests * Restructuring retriever tests * remove covert_dataset_to_dataloader * remove comment * Better check sentence-transformers models * Use embed_meta_fields properly * rename passage into document * Embedding dims can't be found * Add check for models that support it * pylint * Split all retriever tests into suites, running mostly on InMemory only * fix mypy * fix tfidf test * fix weaviate tests * Parallelize on every docstore * Fix schema and specify modality in base retriever suite * tests * Add first image tests * remove comment * Revert to simpler tests * Update docs/_src/api/api/primitives.md Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * Update haystack/modeling/model/multimodal/__init__.py Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * get_args * mypy * Update haystack/modeling/model/multimodal/__init__.py * Update haystack/modeling/model/multimodal/base.py * Update haystack/modeling/model/multimodal/base.py Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * Update haystack/modeling/model/multimodal/sentence_transformers.py * Update haystack/modeling/model/multimodal/sentence_transformers.py Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * Update haystack/modeling/model/multimodal/transformers.py * Update haystack/modeling/model/multimodal/transformers.py Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * Update haystack/modeling/model/multimodal/transformers.py Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * Update haystack/nodes/retriever/multimodal/retriever.py Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * mypy * mypy * removing more ContentTypes * more contentypes * pylint * add to __init__ * revert end2end workflow for now * missing integration markers * Update haystack/nodes/retriever/multimodal/embedder.py Co-authored-by: bogdankostic <bogdankostic@web.de> * review feedback, removing HaystackImageTransformerModel * review feedback part 2 * mypy & pylint * mypy * mypy * fix multimodal docs also for Pinecone * add note on internal constants * Fix pinecone write_documents * schemas * keep support for sentence-transformers only * fix pinecone test * schemas * fix pinecone again * temporarily disable some tests, need to understand if they're still relevant Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> Co-authored-by: bogdankostic <bogdankostic@web.de>
326 lines
11 KiB
Python
326 lines
11 KiB
Python
from typing import Tuple
|
|
|
|
import re
|
|
|
|
import pytest
|
|
import numpy as np
|
|
from unittest.mock import MagicMock
|
|
|
|
from tokenizers.pre_tokenizers import WhitespaceSplit
|
|
|
|
import haystack
|
|
from haystack.modeling.model.feature_extraction import FeatureExtractor
|
|
|
|
|
|
BERT = "bert-base-cased"
|
|
ROBERTA = "roberta-base"
|
|
XLNET = "xlnet-base-cased"
|
|
|
|
TOKENIZERS_TO_TEST = [BERT, ROBERTA, XLNET]
|
|
TOKENIZERS_TO_TEST_WITH_TOKEN_MARKER = [(BERT, "##"), (ROBERTA, "Ġ"), (XLNET, "▁")]
|
|
|
|
|
|
REGULAR_SENTENCE = "This is a sentence"
|
|
GERMAN_SENTENCE = "Der entscheidende Pass"
|
|
OTHER_ALPHABETS = "力加勝北区ᴵᴺᵀᵃছজটডণত"
|
|
GIBBERISH_SENTENCE = "Thiso text is included tolod makelio sure Unicodeel is handled properly:"
|
|
SENTENCE_WITH_ELLIPSIS = "This is a sentence..."
|
|
SENTENCE_WITH_LINEBREAK_1 = "and another one\n\n\nwithout space"
|
|
SENTENCE_WITH_LINEBREAK_2 = """This is a sentence.
|
|
With linebreak"""
|
|
SENTENCE_WITH_LINEBREAKS = """Sentence
|
|
with
|
|
multiple
|
|
newlines
|
|
"""
|
|
SENTENCE_WITH_EXCESS_WHITESPACE = "This is a sentence with multiple spaces"
|
|
SENTENCE_WITH_TABS = "This is a sentence with multiple tabs"
|
|
SENTENCE_WITH_CUSTOM_TOKEN = "Let's see all on this text and. !23# neverseenwordspossible"
|
|
|
|
|
|
class AutoTokenizer:
|
|
mocker: MagicMock = MagicMock()
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, *args, **kwargs):
|
|
cls.mocker.from_pretrained(*args, **kwargs)
|
|
return cls()
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def mock_autotokenizer(request, monkeypatch):
|
|
# Do not patch integration tests
|
|
if "integration" in request.keywords:
|
|
return
|
|
monkeypatch.setattr(haystack.modeling.model.tokenization, "AutoTokenizer", AutoTokenizer)
|
|
|
|
|
|
def convert_offset_from_word_reference_to_text_reference(offsets, words, word_spans):
|
|
"""
|
|
Token offsets are originally relative to the beginning of the word
|
|
We make them relative to the beginning of the sentence.
|
|
|
|
Not a fixture, just a utility.
|
|
"""
|
|
token_offsets = []
|
|
for ((start, end), word_index) in zip(offsets, words):
|
|
word_start = word_spans[word_index][0]
|
|
token_offsets.append((start + word_start, end + word_start))
|
|
return token_offsets
|
|
|
|
|
|
#
|
|
# Unit tests
|
|
#
|
|
|
|
|
|
def test_get_tokenizer_str():
|
|
tokenizer = FeatureExtractor(pretrained_model_name_or_path="test-model-name")
|
|
tokenizer.mocker.from_pretrained.assert_called_with(
|
|
pretrained_model_name_or_path="test-model-name", revision=None, use_fast=True, use_auth_token=None
|
|
)
|
|
|
|
|
|
def test_get_tokenizer_path(tmp_path):
|
|
tokenizer = FeatureExtractor(pretrained_model_name_or_path=tmp_path / "test-path")
|
|
tokenizer.mocker.from_pretrained.assert_called_with(
|
|
pretrained_model_name_or_path=str(tmp_path / "test-path"), revision=None, use_fast=True, use_auth_token=None
|
|
)
|
|
|
|
|
|
def test_get_tokenizer_keep_accents():
|
|
tokenizer = FeatureExtractor(pretrained_model_name_or_path="test-model-name-albert")
|
|
tokenizer.mocker.from_pretrained.assert_called_with(
|
|
pretrained_model_name_or_path="test-model-name-albert",
|
|
revision=None,
|
|
use_fast=True,
|
|
use_auth_token=None,
|
|
keep_accents=True,
|
|
)
|
|
|
|
|
|
def test_get_tokenizer_mlm_warning(caplog):
|
|
tokenizer = FeatureExtractor(pretrained_model_name_or_path="test-model-name-mlm")
|
|
tokenizer.mocker.from_pretrained.assert_called_with(
|
|
pretrained_model_name_or_path="test-model-name-mlm", revision=None, use_fast=True, use_auth_token=None
|
|
)
|
|
assert "MLM part of codebert is currently not supported in Haystack".lower() in caplog.text.lower()
|
|
|
|
|
|
#
|
|
# Integration tests
|
|
#
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.parametrize("model_name", TOKENIZERS_TO_TEST)
|
|
def test_save_load(tmp_path, model_name: str):
|
|
tokenizer = FeatureExtractor(pretrained_model_name_or_path=model_name, do_lower_case=False)
|
|
text = "Some Text with neverseentokens plus !215?#. and a combined-token_with/chars"
|
|
|
|
tokenizer.add_tokens(new_tokens=["neverseentokens"])
|
|
original_encoding = tokenizer(text)
|
|
|
|
save_dir = tmp_path / "saved_tokenizer"
|
|
tokenizer.save_pretrained(save_dir)
|
|
|
|
tokenizer_loaded = FeatureExtractor(pretrained_model_name_or_path=save_dir)
|
|
new_encoding = tokenizer_loaded(text)
|
|
|
|
assert original_encoding == new_encoding
|
|
|
|
|
|
@pytest.mark.integration
|
|
def test_tokenize_custom_vocab_bert():
|
|
tokenizer = FeatureExtractor(pretrained_model_name_or_path=BERT, do_lower_case=False)
|
|
tokenizer.add_tokens(new_tokens=["neverseentokens"])
|
|
text = "Some Text with neverseentokens plus !215?#. and a combined-token_with/chars"
|
|
|
|
tokenized = tokenizer.tokenize(text)
|
|
assert (
|
|
tokenized == f"Some Text with neverseentokens plus ! 215 ? # . and a combined - token _ with / ch ##ars".split()
|
|
)
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.parametrize(
|
|
"edge_case",
|
|
[
|
|
REGULAR_SENTENCE,
|
|
OTHER_ALPHABETS,
|
|
GIBBERISH_SENTENCE,
|
|
SENTENCE_WITH_ELLIPSIS,
|
|
SENTENCE_WITH_LINEBREAK_1,
|
|
SENTENCE_WITH_LINEBREAK_2,
|
|
SENTENCE_WITH_LINEBREAKS,
|
|
SENTENCE_WITH_EXCESS_WHITESPACE,
|
|
SENTENCE_WITH_TABS,
|
|
],
|
|
)
|
|
@pytest.mark.parametrize("model_name", TOKENIZERS_TO_TEST)
|
|
def test_tokenization_on_edge_cases_full_sequence_tokenization(model_name: str, edge_case: str):
|
|
"""
|
|
Verify that tokenization on full sequence is the same as the one on "whitespace tokenized words"
|
|
"""
|
|
tokenizer = FeatureExtractor(pretrained_model_name_or_path=model_name, do_lower_case=False, add_prefix_space=True)
|
|
|
|
pre_tokenizer = WhitespaceSplit()
|
|
words_and_spans = pre_tokenizer.pre_tokenize_str(edge_case)
|
|
words = [x[0] for x in words_and_spans]
|
|
|
|
encoded = tokenizer(words, is_split_into_words=True, add_special_tokens=False).encodings[0]
|
|
expected_tokenization = tokenizer.tokenize(" ".join(edge_case.split())) # remove multiple whitespaces
|
|
|
|
assert encoded.tokens == expected_tokenization
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.parametrize("edge_case", [SENTENCE_WITH_CUSTOM_TOKEN, GERMAN_SENTENCE])
|
|
@pytest.mark.parametrize("model_name", [t for t in TOKENIZERS_TO_TEST if t != ROBERTA])
|
|
def test_tokenization_on_edge_cases_full_sequence_tokenization_roberta_exceptions(model_name: str, edge_case: str):
|
|
"""
|
|
Verify that tokenization on full sequence is the same as the one on "whitespace tokenized words".
|
|
These test cases work for all tokenizers under test except for RoBERTa.
|
|
"""
|
|
tokenizer = FeatureExtractor(pretrained_model_name_or_path=model_name, do_lower_case=False, add_prefix_space=True)
|
|
|
|
pre_tokenizer = WhitespaceSplit()
|
|
words_and_spans = pre_tokenizer.pre_tokenize_str(edge_case)
|
|
words = [x[0] for x in words_and_spans]
|
|
|
|
encoded = tokenizer(words, is_split_into_words=True, add_special_tokens=False).encodings[0]
|
|
expected_tokenization = tokenizer.tokenize(" ".join(edge_case.split())) # remove multiple whitespaces
|
|
|
|
assert encoded.tokens == expected_tokenization
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.parametrize(
|
|
"edge_case",
|
|
[
|
|
REGULAR_SENTENCE,
|
|
# OTHER_ALPHABETS, # contains [UNK] that are impossible to match back to original text space
|
|
GIBBERISH_SENTENCE,
|
|
SENTENCE_WITH_ELLIPSIS,
|
|
SENTENCE_WITH_LINEBREAK_1,
|
|
SENTENCE_WITH_LINEBREAK_2,
|
|
SENTENCE_WITH_LINEBREAKS,
|
|
SENTENCE_WITH_EXCESS_WHITESPACE,
|
|
SENTENCE_WITH_TABS,
|
|
],
|
|
)
|
|
@pytest.mark.parametrize("model_name,marker", TOKENIZERS_TO_TEST_WITH_TOKEN_MARKER)
|
|
def test_tokenization_on_edge_cases_full_sequence_verify_spans(model_name: str, marker: str, edge_case: str):
|
|
tokenizer = FeatureExtractor(pretrained_model_name_or_path=model_name, do_lower_case=False, add_prefix_space=True)
|
|
|
|
pre_tokenizer = WhitespaceSplit()
|
|
words_and_spans = pre_tokenizer.pre_tokenize_str(edge_case)
|
|
words = [x[0] for x in words_and_spans]
|
|
word_spans = [x[1] for x in words_and_spans]
|
|
|
|
encoded = tokenizer(words, is_split_into_words=True, add_special_tokens=False).encodings[0]
|
|
|
|
# subword-tokens have special chars depending on model type. To align with original text we get rid of them
|
|
tokens = [token.replace(marker, "") for token in encoded.tokens]
|
|
token_offsets = convert_offset_from_word_reference_to_text_reference(encoded.offsets, encoded.words, word_spans)
|
|
|
|
for token, (start, end) in zip(tokens, token_offsets):
|
|
assert token == edge_case[start:end]
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.parametrize(
|
|
"edge_case",
|
|
[
|
|
REGULAR_SENTENCE,
|
|
GERMAN_SENTENCE,
|
|
SENTENCE_WITH_EXCESS_WHITESPACE,
|
|
OTHER_ALPHABETS,
|
|
GIBBERISH_SENTENCE,
|
|
SENTENCE_WITH_ELLIPSIS,
|
|
SENTENCE_WITH_CUSTOM_TOKEN,
|
|
SENTENCE_WITH_LINEBREAK_1,
|
|
SENTENCE_WITH_LINEBREAK_2,
|
|
SENTENCE_WITH_LINEBREAKS,
|
|
SENTENCE_WITH_TABS,
|
|
],
|
|
)
|
|
def test_detokenization_for_bert(edge_case):
|
|
tokenizer = FeatureExtractor(pretrained_model_name_or_path=BERT, do_lower_case=False)
|
|
|
|
encoded = tokenizer(edge_case, add_special_tokens=False).encodings[0]
|
|
|
|
detokenized = " ".join(encoded.tokens)
|
|
detokenized = re.sub(r"(^|\s+)(##)", "", detokenized)
|
|
|
|
detokenized_ids = tokenizer(detokenized, add_special_tokens=False)["input_ids"]
|
|
detokenized_tokens = [tokenizer.decode([tok_id]).strip() for tok_id in detokenized_ids]
|
|
|
|
assert encoded.tokens == detokenized_tokens
|
|
|
|
|
|
@pytest.mark.integration
|
|
def test_encode_plus_for_bert():
|
|
tokenizer = FeatureExtractor(pretrained_model_name_or_path=BERT, do_lower_case=False)
|
|
text = "Some Text with neverseentokens plus !215?#. and a combined-token_with/chars"
|
|
|
|
encoded_batch = tokenizer(text)
|
|
encoded = encoded_batch.encodings[0]
|
|
|
|
words = np.array(encoded.words)
|
|
words[0] = -1
|
|
words[-1] = -1
|
|
|
|
print(words.tolist())
|
|
|
|
tokens = encoded.tokens
|
|
offsets = [x[0] for x in encoded.offsets]
|
|
start_of_word = [False] + list(np.ediff1d(words) > 0)
|
|
|
|
assert list(zip(tokens, offsets, start_of_word)) == [
|
|
("[CLS]", 0, False),
|
|
("Some", 0, True),
|
|
("Text", 5, True),
|
|
("with", 10, True),
|
|
("never", 15, True),
|
|
("##see", 20, False),
|
|
("##nto", 23, False),
|
|
("##ken", 26, False),
|
|
("##s", 29, False),
|
|
("plus", 31, True),
|
|
("!", 36, True),
|
|
("215", 37, True),
|
|
("?", 40, True),
|
|
("#", 41, True),
|
|
(".", 42, True),
|
|
("and", 44, True),
|
|
("a", 48, True),
|
|
("combined", 50, True),
|
|
("-", 58, True),
|
|
("token", 59, True),
|
|
("_", 64, True),
|
|
("with", 65, True),
|
|
("/", 69, True),
|
|
("ch", 70, True),
|
|
("##ars", 72, False),
|
|
("[SEP]", 0, False),
|
|
]
|
|
|
|
|
|
@pytest.mark.integration
|
|
def test_tokenize_custom_vocab_bert():
|
|
tokenizer = FeatureExtractor(pretrained_model_name_or_path=BERT, do_lower_case=False)
|
|
|
|
tokenizer.add_tokens(new_tokens=["neverseentokens"])
|
|
text = "Some Text with neverseentokens plus !215?#. and a combined-token_with/chars"
|
|
|
|
tokenized = tokenizer.tokenize(text)
|
|
|
|
encoded = tokenizer(text, add_special_tokens=False).encodings[0]
|
|
offsets = [x[0] for x in encoded.offsets]
|
|
start_of_word_single = [True] + list(np.ediff1d(encoded.words) > 0)
|
|
|
|
assert encoded.tokens == tokenized
|
|
assert offsets == [0, 5, 10, 15, 31, 36, 37, 40, 41, 42, 44, 48, 50, 58, 59, 64, 65, 69, 70, 72]
|
|
assert start_of_word_single == [True] * 19 + [False]
|