mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-23 16:08:19 +00:00
Add save and load method for DPR (#550)
* Add save and load method for DPR * lower memory footprint for test. change names to load() and save() * add test cases Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
This commit is contained in:
parent
46530e86f8
commit
53be92c155
@ -35,8 +35,8 @@ class DensePassageRetriever(BaseRetriever):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
document_store: BaseDocumentStore,
|
document_store: BaseDocumentStore,
|
||||||
query_embedding_model: str = "facebook/dpr-question_encoder-single-nq-base",
|
query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base",
|
||||||
passage_embedding_model: str = "facebook/dpr-ctx_encoder-single-nq-base",
|
passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base",
|
||||||
max_seq_len_query: int = 64,
|
max_seq_len_query: int = 64,
|
||||||
max_seq_len_passage: int = 256,
|
max_seq_len_passage: int = 256,
|
||||||
use_gpu: bool = True,
|
use_gpu: bool = True,
|
||||||
@ -288,6 +288,42 @@ class DensePassageRetriever(BaseRetriever):
|
|||||||
self.model.save(Path(save_dir), lm1_name=query_encoder_save_dir, lm2_name=passage_encoder_save_dir)
|
self.model.save(Path(save_dir), lm1_name=query_encoder_save_dir, lm2_name=passage_encoder_save_dir)
|
||||||
self.processor.save(Path(save_dir))
|
self.processor.save(Path(save_dir))
|
||||||
|
|
||||||
|
def save(self, save_dir: Union[Path, str]):
|
||||||
|
save_dir = Path(save_dir)
|
||||||
|
self.model.save(save_dir, lm1_name="query_encoder", lm2_name="passage_encoder")
|
||||||
|
save_dir = str(save_dir)
|
||||||
|
self.query_tokenizer.save_pretrained(save_dir + "/query_encoder")
|
||||||
|
self.passage_tokenizer.save_pretrained(save_dir + "/passage_encoder")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls,
|
||||||
|
load_dir: Union[Path, str],
|
||||||
|
document_store: BaseDocumentStore,
|
||||||
|
max_seq_len_query: int = 64,
|
||||||
|
max_seq_len_passage: int = 256,
|
||||||
|
use_gpu: bool = True,
|
||||||
|
batch_size: int = 16,
|
||||||
|
embed_title: bool = True,
|
||||||
|
use_fast_tokenizers: bool = True,
|
||||||
|
similarity_function: str = "dot_product",
|
||||||
|
):
|
||||||
|
|
||||||
|
load_dir = Path(load_dir)
|
||||||
|
dpr = cls(
|
||||||
|
document_store=document_store,
|
||||||
|
query_embedding_model=Path(load_dir) / "query_encoder",
|
||||||
|
passage_embedding_model=Path(load_dir) / "passage_encoder",
|
||||||
|
max_seq_len_query=max_seq_len_query,
|
||||||
|
max_seq_len_passage=max_seq_len_passage,
|
||||||
|
use_gpu=use_gpu,
|
||||||
|
batch_size=batch_size,
|
||||||
|
embed_title=embed_title,
|
||||||
|
use_fast_tokenizers=use_fast_tokenizers,
|
||||||
|
similarity_function=similarity_function
|
||||||
|
)
|
||||||
|
|
||||||
|
return dpr
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingRetriever(BaseRetriever):
|
class EmbeddingRetriever(BaseRetriever):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import time
|
import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from haystack import Document
|
from haystack import Document
|
||||||
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
||||||
|
from haystack.retriever.dense import DensePassageRetriever
|
||||||
|
|
||||||
|
from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.elasticsearch
|
@pytest.mark.elasticsearch
|
||||||
@ -60,3 +63,50 @@ def test_dpr_retrieval(document_store, retriever, return_embedding):
|
|||||||
assert res[0].embedding is not None
|
assert res[0].embedding is not None
|
||||||
else:
|
else:
|
||||||
assert res[0].embedding is None
|
assert res[0].embedding is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
||||||
|
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||||
|
def test_dpr_saving_and_loading(retriever, document_store):
|
||||||
|
retriever.save("test_dpr_save")
|
||||||
|
def sum_params(model):
|
||||||
|
s = []
|
||||||
|
for p in model.parameters():
|
||||||
|
n = p.cpu().data.numpy()
|
||||||
|
s.append(np.sum(n))
|
||||||
|
return sum(s)
|
||||||
|
original_sum_query = sum_params(retriever.query_encoder)
|
||||||
|
original_sum_passage = sum_params(retriever.passage_encoder)
|
||||||
|
del retriever
|
||||||
|
|
||||||
|
loaded_retriever = DensePassageRetriever.load("test_dpr_save", document_store)
|
||||||
|
|
||||||
|
loaded_sum_query = sum_params(loaded_retriever.query_encoder)
|
||||||
|
loaded_sum_passage = sum_params(loaded_retriever.passage_encoder)
|
||||||
|
|
||||||
|
assert abs(original_sum_query - loaded_sum_query) < 0.1
|
||||||
|
assert abs(original_sum_passage - loaded_sum_passage) < 0.1
|
||||||
|
|
||||||
|
# comparison of weights (RAM intense!)
|
||||||
|
# for p1, p2 in zip(retriever.query_encoder.parameters(), loaded_retriever.query_encoder.parameters()):
|
||||||
|
# assert (p1.data.ne(p2.data).sum() == 0)
|
||||||
|
#
|
||||||
|
# for p1, p2 in zip(retriever.passage_encoder.parameters(), loaded_retriever.passage_encoder.parameters()):
|
||||||
|
# assert (p1.data.ne(p2.data).sum() == 0)
|
||||||
|
|
||||||
|
# attributes
|
||||||
|
assert loaded_retriever.embed_title == True
|
||||||
|
assert loaded_retriever.batch_size == 16
|
||||||
|
assert loaded_retriever.max_seq_len_passage == 256
|
||||||
|
assert loaded_retriever.max_seq_len_query == 64
|
||||||
|
|
||||||
|
# Tokenizer
|
||||||
|
assert isinstance(loaded_retriever.passage_tokenizer, DPRContextEncoderTokenizerFast)
|
||||||
|
assert isinstance(loaded_retriever.query_tokenizer, DPRQuestionEncoderTokenizerFast)
|
||||||
|
assert loaded_retriever.passage_tokenizer.do_lower_case == True
|
||||||
|
assert loaded_retriever.query_tokenizer.do_lower_case == True
|
||||||
|
assert loaded_retriever.passage_tokenizer.vocab_size == 30522
|
||||||
|
assert loaded_retriever.query_tokenizer.vocab_size == 30522
|
||||||
|
assert loaded_retriever.passage_tokenizer.max_len == 512
|
||||||
|
assert loaded_retriever.query_tokenizer.max_len == 512
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user