Removed single_model_path; added infer_tokenizer to dpr load() (#1060)

This commit is contained in:
Julian Risch 2021-06-14 14:14:46 +02:00 committed by GitHub
parent 1c31589b43
commit f6e70f0f3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -37,7 +37,6 @@ class DensePassageRetriever(BaseRetriever):
document_store: BaseDocumentStore,
query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base",
single_model_path: Optional[Union[Path, str]] = None,
model_version: Optional[str] = None,
max_seq_len_query: int = 64,
max_seq_len_passage: int = 256,
@ -74,9 +73,6 @@ class DensePassageRetriever(BaseRetriever):
:param passage_embedding_model: Local path or remote name of passage encoder checkpoint. The format equals the
one used by hugging-face transformers' modelhub models
Currently available remote names: ``"facebook/dpr-ctx_encoder-single-nq-base"``
:param single_model_path: Local path or remote name of a query and passage embedder in one single model. Those
models are typically trained within FARM.
Currently available remote names: TODO add FARM DPR model to HF modelhub
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
:param max_seq_len_query: Longest length of each query sequence. Maximum number of tokens for the query text. Longer ones will be cut down."
:param max_seq_len_passage: Longest length of each passage/context sequence. Maximum number of tokens for the passage text. Longer ones will be cut down."
@ -101,7 +97,7 @@ class DensePassageRetriever(BaseRetriever):
# save init parameters to enable export of component config as YAML
self.set_config(
document_store=document_store, query_embedding_model=query_embedding_model,
passage_embedding_model=passage_embedding_model, single_model_path=single_model_path,
passage_embedding_model=passage_embedding_model,
model_version=model_version, max_seq_len_query=max_seq_len_query, max_seq_len_passage=max_seq_len_passage,
top_k=top_k, use_gpu=use_gpu, batch_size=batch_size, embed_title=embed_title,
use_fast_tokenizers=use_fast_tokenizers, infer_tokenizer_classes=infer_tokenizer_classes,
@ -137,51 +133,42 @@ class DensePassageRetriever(BaseRetriever):
tokenizers_default_classes["passage"] = None # type: ignore
# Init & Load Encoders
if single_model_path is None:
self.query_tokenizer = Tokenizer.load(pretrained_model_name_or_path=query_embedding_model,
self.query_tokenizer = Tokenizer.load(pretrained_model_name_or_path=query_embedding_model,
revision=model_version,
do_lower_case=True,
use_fast=use_fast_tokenizers,
tokenizer_class=tokenizers_default_classes["query"])
self.query_encoder = LanguageModel.load(pretrained_model_name_or_path=query_embedding_model,
revision=model_version,
language_model_class="DPRQuestionEncoder")
self.passage_tokenizer = Tokenizer.load(pretrained_model_name_or_path=passage_embedding_model,
revision=model_version,
do_lower_case=True,
use_fast=use_fast_tokenizers,
tokenizer_class=tokenizers_default_classes["passage"])
self.passage_encoder = LanguageModel.load(pretrained_model_name_or_path=passage_embedding_model,
revision=model_version,
do_lower_case=True,
use_fast=use_fast_tokenizers,
tokenizer_class=tokenizers_default_classes["query"])
self.query_encoder = LanguageModel.load(pretrained_model_name_or_path=query_embedding_model,
revision=model_version,
language_model_class="DPRQuestionEncoder")
self.passage_tokenizer = Tokenizer.load(pretrained_model_name_or_path=passage_embedding_model,
revision=model_version,
do_lower_case=True,
use_fast=use_fast_tokenizers,
tokenizer_class=tokenizers_default_classes["passage"])
self.passage_encoder = LanguageModel.load(pretrained_model_name_or_path=passage_embedding_model,
revision=model_version,
language_model_class="DPRContextEncoder")
language_model_class="DPRContextEncoder")
self.processor = TextSimilarityProcessor(query_tokenizer=self.query_tokenizer,
passage_tokenizer=self.passage_tokenizer,
max_seq_len_passage=max_seq_len_passage,
max_seq_len_query=max_seq_len_query,
label_list=["hard_negative", "positive"],
metric="text_similarity_metric",
embed_title=embed_title,
num_hard_negatives=0,
num_positives=1)
prediction_head = TextSimilarityHead(similarity_function=similarity_function)
self.model = BiAdaptiveModel(
language_model1=self.query_encoder,
language_model2=self.passage_encoder,
prediction_heads=[prediction_head],
embeds_dropout_prob=0.1,
lm1_output_types=["per_sequence"],
lm2_output_types=["per_sequence"],
device=self.device,
)
else:
self.processor = TextSimilarityProcessor.load_from_dir(single_model_path)
self.processor.max_seq_len_passage = max_seq_len_passage
self.processor.max_seq_len_query = max_seq_len_query
self.processor.embed_title = embed_title
self.processor.num_hard_negatives = 0
self.processor.num_positives = 1 # during indexing of documents only one embedding is created
self.model = BiAdaptiveModel.load(single_model_path, device=self.device)
self.processor = TextSimilarityProcessor(query_tokenizer=self.query_tokenizer,
passage_tokenizer=self.passage_tokenizer,
max_seq_len_passage=max_seq_len_passage,
max_seq_len_query=max_seq_len_query,
label_list=["hard_negative", "positive"],
metric="text_similarity_metric",
embed_title=embed_title,
num_hard_negatives=0,
num_positives=1)
prediction_head = TextSimilarityHead(similarity_function=similarity_function)
self.model = BiAdaptiveModel(
language_model1=self.query_encoder,
language_model2=self.passage_encoder,
prediction_heads=[prediction_head],
embeds_dropout_prob=0.1,
lm1_output_types=["per_sequence"],
lm2_output_types=["per_sequence"],
device=self.device,
)
self.model.connect_heads_with_processor(self.processor.tasks, require_labels=False)
@ -426,7 +413,8 @@ class DensePassageRetriever(BaseRetriever):
use_fast_tokenizers: bool = True,
similarity_function: str = "dot_product",
query_encoder_dir: str = "query_encoder",
passage_encoder_dir: str = "passage_encoder"
passage_encoder_dir: str = "passage_encoder",
infer_tokenizer_classes: bool = False
):
"""
Load DensePassageRetriever from the specified directory.
@ -443,7 +431,8 @@ class DensePassageRetriever(BaseRetriever):
batch_size=batch_size,
embed_title=embed_title,
use_fast_tokenizers=use_fast_tokenizers,
similarity_function=similarity_function
similarity_function=similarity_function,
infer_tokenizer_classes=infer_tokenizer_classes
)
logger.info(f"DPR model loaded from {load_dir}")