mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-23 07:58:36 +00:00
Add save and load method for DPR (#550)
* Add save and load method for DPR * lower memory footprint for test. change names to load() and save() * add test cases Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
This commit is contained in:
parent
46530e86f8
commit
53be92c155
@ -35,8 +35,8 @@ class DensePassageRetriever(BaseRetriever):
|
||||
|
||||
def __init__(self,
|
||||
document_store: BaseDocumentStore,
|
||||
query_embedding_model: str = "facebook/dpr-question_encoder-single-nq-base",
|
||||
passage_embedding_model: str = "facebook/dpr-ctx_encoder-single-nq-base",
|
||||
query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base",
|
||||
passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base",
|
||||
max_seq_len_query: int = 64,
|
||||
max_seq_len_passage: int = 256,
|
||||
use_gpu: bool = True,
|
||||
@ -288,6 +288,42 @@ class DensePassageRetriever(BaseRetriever):
|
||||
self.model.save(Path(save_dir), lm1_name=query_encoder_save_dir, lm2_name=passage_encoder_save_dir)
|
||||
self.processor.save(Path(save_dir))
|
||||
|
||||
def save(self, save_dir: Union[Path, str]):
|
||||
save_dir = Path(save_dir)
|
||||
self.model.save(save_dir, lm1_name="query_encoder", lm2_name="passage_encoder")
|
||||
save_dir = str(save_dir)
|
||||
self.query_tokenizer.save_pretrained(save_dir + "/query_encoder")
|
||||
self.passage_tokenizer.save_pretrained(save_dir + "/passage_encoder")
|
||||
|
||||
@classmethod
|
||||
def load(cls,
|
||||
load_dir: Union[Path, str],
|
||||
document_store: BaseDocumentStore,
|
||||
max_seq_len_query: int = 64,
|
||||
max_seq_len_passage: int = 256,
|
||||
use_gpu: bool = True,
|
||||
batch_size: int = 16,
|
||||
embed_title: bool = True,
|
||||
use_fast_tokenizers: bool = True,
|
||||
similarity_function: str = "dot_product",
|
||||
):
|
||||
|
||||
load_dir = Path(load_dir)
|
||||
dpr = cls(
|
||||
document_store=document_store,
|
||||
query_embedding_model=Path(load_dir) / "query_encoder",
|
||||
passage_embedding_model=Path(load_dir) / "passage_encoder",
|
||||
max_seq_len_query=max_seq_len_query,
|
||||
max_seq_len_passage=max_seq_len_passage,
|
||||
use_gpu=use_gpu,
|
||||
batch_size=batch_size,
|
||||
embed_title=embed_title,
|
||||
use_fast_tokenizers=use_fast_tokenizers,
|
||||
similarity_function=similarity_function
|
||||
)
|
||||
|
||||
return dpr
|
||||
|
||||
|
||||
class EmbeddingRetriever(BaseRetriever):
|
||||
def __init__(
|
||||
|
@ -1,9 +1,12 @@
|
||||
import pytest
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from haystack import Document
|
||||
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.retriever.dense import DensePassageRetriever
|
||||
|
||||
from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.elasticsearch
|
||||
@ -60,3 +63,50 @@ def test_dpr_retrieval(document_store, retriever, return_embedding):
|
||||
assert res[0].embedding is not None
|
||||
else:
|
||||
assert res[0].embedding is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||
def test_dpr_saving_and_loading(retriever, document_store):
|
||||
retriever.save("test_dpr_save")
|
||||
def sum_params(model):
|
||||
s = []
|
||||
for p in model.parameters():
|
||||
n = p.cpu().data.numpy()
|
||||
s.append(np.sum(n))
|
||||
return sum(s)
|
||||
original_sum_query = sum_params(retriever.query_encoder)
|
||||
original_sum_passage = sum_params(retriever.passage_encoder)
|
||||
del retriever
|
||||
|
||||
loaded_retriever = DensePassageRetriever.load("test_dpr_save", document_store)
|
||||
|
||||
loaded_sum_query = sum_params(loaded_retriever.query_encoder)
|
||||
loaded_sum_passage = sum_params(loaded_retriever.passage_encoder)
|
||||
|
||||
assert abs(original_sum_query - loaded_sum_query) < 0.1
|
||||
assert abs(original_sum_passage - loaded_sum_passage) < 0.1
|
||||
|
||||
# comparison of weights (RAM intense!)
|
||||
# for p1, p2 in zip(retriever.query_encoder.parameters(), loaded_retriever.query_encoder.parameters()):
|
||||
# assert (p1.data.ne(p2.data).sum() == 0)
|
||||
#
|
||||
# for p1, p2 in zip(retriever.passage_encoder.parameters(), loaded_retriever.passage_encoder.parameters()):
|
||||
# assert (p1.data.ne(p2.data).sum() == 0)
|
||||
|
||||
# attributes
|
||||
assert loaded_retriever.embed_title == True
|
||||
assert loaded_retriever.batch_size == 16
|
||||
assert loaded_retriever.max_seq_len_passage == 256
|
||||
assert loaded_retriever.max_seq_len_query == 64
|
||||
|
||||
# Tokenizer
|
||||
assert isinstance(loaded_retriever.passage_tokenizer, DPRContextEncoderTokenizerFast)
|
||||
assert isinstance(loaded_retriever.query_tokenizer, DPRQuestionEncoderTokenizerFast)
|
||||
assert loaded_retriever.passage_tokenizer.do_lower_case == True
|
||||
assert loaded_retriever.query_tokenizer.do_lower_case == True
|
||||
assert loaded_retriever.passage_tokenizer.vocab_size == 30522
|
||||
assert loaded_retriever.query_tokenizer.vocab_size == 30522
|
||||
assert loaded_retriever.passage_tokenizer.max_len == 512
|
||||
assert loaded_retriever.query_tokenizer.max_len == 512
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user