diff --git a/haystack/retriever/dense.py b/haystack/retriever/dense.py index 81354ed3c..289bb7de1 100644 --- a/haystack/retriever/dense.py +++ b/haystack/retriever/dense.py @@ -35,8 +35,8 @@ class DensePassageRetriever(BaseRetriever): def __init__(self, document_store: BaseDocumentStore, - query_embedding_model: str = "facebook/dpr-question_encoder-single-nq-base", - passage_embedding_model: str = "facebook/dpr-ctx_encoder-single-nq-base", + query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base", + passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base", max_seq_len_query: int = 64, max_seq_len_passage: int = 256, use_gpu: bool = True, @@ -288,6 +288,42 @@ class DensePassageRetriever(BaseRetriever): self.model.save(Path(save_dir), lm1_name=query_encoder_save_dir, lm2_name=passage_encoder_save_dir) self.processor.save(Path(save_dir)) + def save(self, save_dir: Union[Path, str]): + save_dir = Path(save_dir) + self.model.save(save_dir, lm1_name="query_encoder", lm2_name="passage_encoder") + save_dir = str(save_dir) + self.query_tokenizer.save_pretrained(save_dir + "/query_encoder") + self.passage_tokenizer.save_pretrained(save_dir + "/passage_encoder") + + @classmethod + def load(cls, + load_dir: Union[Path, str], + document_store: BaseDocumentStore, + max_seq_len_query: int = 64, + max_seq_len_passage: int = 256, + use_gpu: bool = True, + batch_size: int = 16, + embed_title: bool = True, + use_fast_tokenizers: bool = True, + similarity_function: str = "dot_product", + ): + + load_dir = Path(load_dir) + dpr = cls( + document_store=document_store, + query_embedding_model=Path(load_dir) / "query_encoder", + passage_embedding_model=Path(load_dir) / "passage_encoder", + max_seq_len_query=max_seq_len_query, + max_seq_len_passage=max_seq_len_passage, + use_gpu=use_gpu, + batch_size=batch_size, + embed_title=embed_title, + use_fast_tokenizers=use_fast_tokenizers, + similarity_function=similarity_function + ) + + return dpr + class EmbeddingRetriever(BaseRetriever): def __init__( diff --git a/test/test_dpr_retriever.py b/test/test_dpr_retriever.py index f1e892066..1cc7c07e9 100644 --- a/test/test_dpr_retriever.py +++ b/test/test_dpr_retriever.py @@ -1,9 +1,12 @@ import pytest import time +import numpy as np from haystack import Document from haystack.document_store.elasticsearch import ElasticsearchDocumentStore +from haystack.retriever.dense import DensePassageRetriever +from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast @pytest.mark.slow @pytest.mark.elasticsearch @@ -60,3 +63,50 @@ def test_dpr_retrieval(document_store, retriever, return_embedding): assert res[0].embedding is not None else: assert res[0].embedding is None + + +@pytest.mark.parametrize("retriever", ["dpr"], indirect=True) +@pytest.mark.parametrize("document_store", ["memory"], indirect=True) +def test_dpr_saving_and_loading(retriever, document_store): + retriever.save("test_dpr_save") + def sum_params(model): + s = [] + for p in model.parameters(): + n = p.cpu().data.numpy() + s.append(np.sum(n)) + return sum(s) + original_sum_query = sum_params(retriever.query_encoder) + original_sum_passage = sum_params(retriever.passage_encoder) + del retriever + + loaded_retriever = DensePassageRetriever.load("test_dpr_save", document_store) + + loaded_sum_query = sum_params(loaded_retriever.query_encoder) + loaded_sum_passage = sum_params(loaded_retriever.passage_encoder) + + assert abs(original_sum_query - loaded_sum_query) < 0.1 + assert abs(original_sum_passage - loaded_sum_passage) < 0.1 + + # comparison of weights (RAM intense!) + # for p1, p2 in zip(retriever.query_encoder.parameters(), loaded_retriever.query_encoder.parameters()): + # assert (p1.data.ne(p2.data).sum() == 0) + # + # for p1, p2 in zip(retriever.passage_encoder.parameters(), loaded_retriever.passage_encoder.parameters()): + # assert (p1.data.ne(p2.data).sum() == 0) + + # attributes + assert loaded_retriever.embed_title == True + assert loaded_retriever.batch_size == 16 + assert loaded_retriever.max_seq_len_passage == 256 + assert loaded_retriever.max_seq_len_query == 64 + + # Tokenizer + assert isinstance(loaded_retriever.passage_tokenizer, DPRContextEncoderTokenizerFast) + assert isinstance(loaded_retriever.query_tokenizer, DPRQuestionEncoderTokenizerFast) + assert loaded_retriever.passage_tokenizer.do_lower_case == True + assert loaded_retriever.query_tokenizer.do_lower_case == True + assert loaded_retriever.passage_tokenizer.vocab_size == 30522 + assert loaded_retriever.query_tokenizer.vocab_size == 30522 + assert loaded_retriever.passage_tokenizer.max_len == 512 + assert loaded_retriever.query_tokenizer.max_len == 512 +