Add save and load method for DPR (#550)

* Add save and load method for DPR

* lower memory footprint for test. change names to load() and save()

* add test cases

Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
This commit is contained in:
bogdankostic 2020-11-05 13:29:23 +01:00 committed by GitHub
parent 46530e86f8
commit 53be92c155
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 88 additions and 2 deletions

View File

@ -35,8 +35,8 @@ class DensePassageRetriever(BaseRetriever):
def __init__(self,
document_store: BaseDocumentStore,
query_embedding_model: str = "facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model: str = "facebook/dpr-ctx_encoder-single-nq-base",
query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base",
max_seq_len_query: int = 64,
max_seq_len_passage: int = 256,
use_gpu: bool = True,
@ -288,6 +288,42 @@ class DensePassageRetriever(BaseRetriever):
self.model.save(Path(save_dir), lm1_name=query_encoder_save_dir, lm2_name=passage_encoder_save_dir)
self.processor.save(Path(save_dir))
def save(self, save_dir: Union[Path, str]):
save_dir = Path(save_dir)
self.model.save(save_dir, lm1_name="query_encoder", lm2_name="passage_encoder")
save_dir = str(save_dir)
self.query_tokenizer.save_pretrained(save_dir + "/query_encoder")
self.passage_tokenizer.save_pretrained(save_dir + "/passage_encoder")
@classmethod
def load(cls,
load_dir: Union[Path, str],
document_store: BaseDocumentStore,
max_seq_len_query: int = 64,
max_seq_len_passage: int = 256,
use_gpu: bool = True,
batch_size: int = 16,
embed_title: bool = True,
use_fast_tokenizers: bool = True,
similarity_function: str = "dot_product",
):
load_dir = Path(load_dir)
dpr = cls(
document_store=document_store,
query_embedding_model=Path(load_dir) / "query_encoder",
passage_embedding_model=Path(load_dir) / "passage_encoder",
max_seq_len_query=max_seq_len_query,
max_seq_len_passage=max_seq_len_passage,
use_gpu=use_gpu,
batch_size=batch_size,
embed_title=embed_title,
use_fast_tokenizers=use_fast_tokenizers,
similarity_function=similarity_function
)
return dpr
class EmbeddingRetriever(BaseRetriever):
def __init__(

View File

@ -1,9 +1,12 @@
import pytest
import time
import numpy as np
from haystack import Document
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
from haystack.retriever.dense import DensePassageRetriever
from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast
@pytest.mark.slow
@pytest.mark.elasticsearch
@ -60,3 +63,50 @@ def test_dpr_retrieval(document_store, retriever, return_embedding):
assert res[0].embedding is not None
else:
assert res[0].embedding is None
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
def test_dpr_saving_and_loading(retriever, document_store):
retriever.save("test_dpr_save")
def sum_params(model):
s = []
for p in model.parameters():
n = p.cpu().data.numpy()
s.append(np.sum(n))
return sum(s)
original_sum_query = sum_params(retriever.query_encoder)
original_sum_passage = sum_params(retriever.passage_encoder)
del retriever
loaded_retriever = DensePassageRetriever.load("test_dpr_save", document_store)
loaded_sum_query = sum_params(loaded_retriever.query_encoder)
loaded_sum_passage = sum_params(loaded_retriever.passage_encoder)
assert abs(original_sum_query - loaded_sum_query) < 0.1
assert abs(original_sum_passage - loaded_sum_passage) < 0.1
# comparison of weights (RAM intense!)
# for p1, p2 in zip(retriever.query_encoder.parameters(), loaded_retriever.query_encoder.parameters()):
# assert (p1.data.ne(p2.data).sum() == 0)
#
# for p1, p2 in zip(retriever.passage_encoder.parameters(), loaded_retriever.passage_encoder.parameters()):
# assert (p1.data.ne(p2.data).sum() == 0)
# attributes
assert loaded_retriever.embed_title == True
assert loaded_retriever.batch_size == 16
assert loaded_retriever.max_seq_len_passage == 256
assert loaded_retriever.max_seq_len_query == 64
# Tokenizer
assert isinstance(loaded_retriever.passage_tokenizer, DPRContextEncoderTokenizerFast)
assert isinstance(loaded_retriever.query_tokenizer, DPRQuestionEncoderTokenizerFast)
assert loaded_retriever.passage_tokenizer.do_lower_case == True
assert loaded_retriever.query_tokenizer.do_lower_case == True
assert loaded_retriever.passage_tokenizer.vocab_size == 30522
assert loaded_retriever.query_tokenizer.vocab_size == 30522
assert loaded_retriever.passage_tokenizer.max_len == 512
assert loaded_retriever.query_tokenizer.max_len == 512