fix: Use AutoTokenizer instead of DPR specific tokenizer (#4898)

* Use AutoTokenizer instead of DPR specific tokenizer

* Adapt TableTextRetriever

* Adapt tests

* Adapt tests
This commit is contained in:
bogdankostic 2023-05-17 18:54:34 +02:00 committed by GitHub
parent 34b7d1edb0
commit df46e7fadd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 22 deletions

View File

@ -19,13 +19,7 @@ from torch.nn import DataParallel
from torch.utils.data.sampler import SequentialSampler
import pandas as pd
from huggingface_hub import hf_hub_download
from transformers import (
AutoConfig,
DPRContextEncoderTokenizerFast,
DPRQuestionEncoderTokenizerFast,
DPRContextEncoderTokenizer,
DPRQuestionEncoderTokenizer,
)
from transformers import AutoConfig, AutoTokenizer
from haystack.errors import HaystackError
from haystack.schema import Document, FilterType
@ -191,7 +185,7 @@ class DensePassageRetriever(DenseRetriever):
)
# Init & Load Encoders
self.query_tokenizer = DPRQuestionEncoderTokenizerFast.from_pretrained(
self.query_tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=query_embedding_model,
revision=model_version,
do_lower_case=True,
@ -203,7 +197,7 @@ class DensePassageRetriever(DenseRetriever):
model_type="DPRQuestionEncoder",
use_auth_token=use_auth_token,
)
self.passage_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(
self.passage_tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=passage_embedding_model,
revision=model_version,
do_lower_case=True,
@ -873,12 +867,8 @@ class TableTextRetriever(DenseRetriever):
self.embed_meta_fields = embed_meta_fields
self.scale_score = scale_score
query_tokenizer_class = DPRQuestionEncoderTokenizerFast if use_fast else DPRQuestionEncoderTokenizer
passage_tokenizer_class = DPRContextEncoderTokenizerFast if use_fast else DPRContextEncoderTokenizer
table_tokenizer_class = DPRContextEncoderTokenizerFast if use_fast else DPRContextEncoderTokenizer
# Init & Load Encoders
self.query_tokenizer = query_tokenizer_class.from_pretrained(
self.query_tokenizer = AutoTokenizer.from_pretrained(
query_embedding_model,
revision=model_version,
do_lower_case=True,
@ -888,7 +878,7 @@ class TableTextRetriever(DenseRetriever):
self.query_encoder = get_language_model(
pretrained_model_name_or_path=query_embedding_model, revision=model_version, use_auth_token=use_auth_token
)
self.passage_tokenizer = passage_tokenizer_class.from_pretrained(
self.passage_tokenizer = AutoTokenizer.from_pretrained(
passage_embedding_model,
revision=model_version,
do_lower_case=True,
@ -898,7 +888,7 @@ class TableTextRetriever(DenseRetriever):
self.passage_encoder = get_language_model(
pretrained_model_name_or_path=passage_embedding_model, revision=model_version, use_auth_token=use_auth_token
)
self.table_tokenizer = table_tokenizer_class.from_pretrained(
self.table_tokenizer = AutoTokenizer.from_pretrained(
table_embedding_model,
revision=model_version,
do_lower_case=True,

View File

@ -10,7 +10,7 @@ import pandas as pd
import requests
from boilerpy3.extractors import ArticleExtractor
from pandas.testing import assert_frame_equal
from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast
from transformers import PreTrainedTokenizerFast
try:
@ -578,8 +578,8 @@ def test_dpr_saving_and_loading(tmp_path, retriever, document_store):
assert loaded_retriever.processor.max_seq_len_query == 64
# Tokenizer
assert isinstance(loaded_retriever.passage_tokenizer, DPRContextEncoderTokenizerFast)
assert isinstance(loaded_retriever.query_tokenizer, DPRQuestionEncoderTokenizerFast)
assert isinstance(loaded_retriever.passage_tokenizer, PreTrainedTokenizerFast)
assert isinstance(loaded_retriever.query_tokenizer, PreTrainedTokenizerFast)
assert loaded_retriever.passage_tokenizer.do_lower_case == True
assert loaded_retriever.query_tokenizer.do_lower_case == True
assert loaded_retriever.passage_tokenizer.vocab_size == 30522
@ -621,9 +621,9 @@ def test_table_text_retriever_saving_and_loading(tmp_path, retriever, document_s
assert loaded_retriever.processor.max_seq_len_query == 64
# Tokenizer
assert isinstance(loaded_retriever.passage_tokenizer, DPRContextEncoderTokenizerFast)
assert isinstance(loaded_retriever.table_tokenizer, DPRContextEncoderTokenizerFast)
assert isinstance(loaded_retriever.query_tokenizer, DPRQuestionEncoderTokenizerFast)
assert isinstance(loaded_retriever.passage_tokenizer, PreTrainedTokenizerFast)
assert isinstance(loaded_retriever.table_tokenizer, PreTrainedTokenizerFast)
assert isinstance(loaded_retriever.query_tokenizer, PreTrainedTokenizerFast)
assert loaded_retriever.passage_tokenizer.do_lower_case == True
assert loaded_retriever.table_tokenizer.do_lower_case == True
assert loaded_retriever.query_tokenizer.do_lower_case == True