mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-07 20:46:31 +00:00
fix: Use AutoTokenizer instead of DPR specific tokenizer (#4898)
* Use AutoTokenizer instead of DPR specific tokenizer * Adapt TableTextRetriever * Adapt tests * Adapt tests
This commit is contained in:
parent
34b7d1edb0
commit
df46e7fadd
@ -19,13 +19,7 @@ from torch.nn import DataParallel
|
||||
from torch.utils.data.sampler import SequentialSampler
|
||||
import pandas as pd
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
DPRContextEncoderTokenizerFast,
|
||||
DPRQuestionEncoderTokenizerFast,
|
||||
DPRContextEncoderTokenizer,
|
||||
DPRQuestionEncoderTokenizer,
|
||||
)
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
|
||||
from haystack.errors import HaystackError
|
||||
from haystack.schema import Document, FilterType
|
||||
@ -191,7 +185,7 @@ class DensePassageRetriever(DenseRetriever):
|
||||
)
|
||||
|
||||
# Init & Load Encoders
|
||||
self.query_tokenizer = DPRQuestionEncoderTokenizerFast.from_pretrained(
|
||||
self.query_tokenizer = AutoTokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path=query_embedding_model,
|
||||
revision=model_version,
|
||||
do_lower_case=True,
|
||||
@ -203,7 +197,7 @@ class DensePassageRetriever(DenseRetriever):
|
||||
model_type="DPRQuestionEncoder",
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
self.passage_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(
|
||||
self.passage_tokenizer = AutoTokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path=passage_embedding_model,
|
||||
revision=model_version,
|
||||
do_lower_case=True,
|
||||
@ -873,12 +867,8 @@ class TableTextRetriever(DenseRetriever):
|
||||
self.embed_meta_fields = embed_meta_fields
|
||||
self.scale_score = scale_score
|
||||
|
||||
query_tokenizer_class = DPRQuestionEncoderTokenizerFast if use_fast else DPRQuestionEncoderTokenizer
|
||||
passage_tokenizer_class = DPRContextEncoderTokenizerFast if use_fast else DPRContextEncoderTokenizer
|
||||
table_tokenizer_class = DPRContextEncoderTokenizerFast if use_fast else DPRContextEncoderTokenizer
|
||||
|
||||
# Init & Load Encoders
|
||||
self.query_tokenizer = query_tokenizer_class.from_pretrained(
|
||||
self.query_tokenizer = AutoTokenizer.from_pretrained(
|
||||
query_embedding_model,
|
||||
revision=model_version,
|
||||
do_lower_case=True,
|
||||
@ -888,7 +878,7 @@ class TableTextRetriever(DenseRetriever):
|
||||
self.query_encoder = get_language_model(
|
||||
pretrained_model_name_or_path=query_embedding_model, revision=model_version, use_auth_token=use_auth_token
|
||||
)
|
||||
self.passage_tokenizer = passage_tokenizer_class.from_pretrained(
|
||||
self.passage_tokenizer = AutoTokenizer.from_pretrained(
|
||||
passage_embedding_model,
|
||||
revision=model_version,
|
||||
do_lower_case=True,
|
||||
@ -898,7 +888,7 @@ class TableTextRetriever(DenseRetriever):
|
||||
self.passage_encoder = get_language_model(
|
||||
pretrained_model_name_or_path=passage_embedding_model, revision=model_version, use_auth_token=use_auth_token
|
||||
)
|
||||
self.table_tokenizer = table_tokenizer_class.from_pretrained(
|
||||
self.table_tokenizer = AutoTokenizer.from_pretrained(
|
||||
table_embedding_model,
|
||||
revision=model_version,
|
||||
do_lower_case=True,
|
||||
|
||||
@ -10,7 +10,7 @@ import pandas as pd
|
||||
import requests
|
||||
from boilerpy3.extractors import ArticleExtractor
|
||||
from pandas.testing import assert_frame_equal
|
||||
from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
|
||||
try:
|
||||
@ -578,8 +578,8 @@ def test_dpr_saving_and_loading(tmp_path, retriever, document_store):
|
||||
assert loaded_retriever.processor.max_seq_len_query == 64
|
||||
|
||||
# Tokenizer
|
||||
assert isinstance(loaded_retriever.passage_tokenizer, DPRContextEncoderTokenizerFast)
|
||||
assert isinstance(loaded_retriever.query_tokenizer, DPRQuestionEncoderTokenizerFast)
|
||||
assert isinstance(loaded_retriever.passage_tokenizer, PreTrainedTokenizerFast)
|
||||
assert isinstance(loaded_retriever.query_tokenizer, PreTrainedTokenizerFast)
|
||||
assert loaded_retriever.passage_tokenizer.do_lower_case == True
|
||||
assert loaded_retriever.query_tokenizer.do_lower_case == True
|
||||
assert loaded_retriever.passage_tokenizer.vocab_size == 30522
|
||||
@ -621,9 +621,9 @@ def test_table_text_retriever_saving_and_loading(tmp_path, retriever, document_s
|
||||
assert loaded_retriever.processor.max_seq_len_query == 64
|
||||
|
||||
# Tokenizer
|
||||
assert isinstance(loaded_retriever.passage_tokenizer, DPRContextEncoderTokenizerFast)
|
||||
assert isinstance(loaded_retriever.table_tokenizer, DPRContextEncoderTokenizerFast)
|
||||
assert isinstance(loaded_retriever.query_tokenizer, DPRQuestionEncoderTokenizerFast)
|
||||
assert isinstance(loaded_retriever.passage_tokenizer, PreTrainedTokenizerFast)
|
||||
assert isinstance(loaded_retriever.table_tokenizer, PreTrainedTokenizerFast)
|
||||
assert isinstance(loaded_retriever.query_tokenizer, PreTrainedTokenizerFast)
|
||||
assert loaded_retriever.passage_tokenizer.do_lower_case == True
|
||||
assert loaded_retriever.table_tokenizer.do_lower_case == True
|
||||
assert loaded_retriever.query_tokenizer.do_lower_case == True
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user