mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-04 05:43:29 +00:00
Fix saving tokenizers in DPR training + unify save and load dirs (#682)
This commit is contained in:
parent
4c2804e38e
commit
a9bcabc42d
@ -240,9 +240,9 @@ class DensePassageRetriever(BaseRetriever):
|
|||||||
grad_acc_steps: int = 1,
|
grad_acc_steps: int = 1,
|
||||||
optimizer_name: str = "TransformersAdamW",
|
optimizer_name: str = "TransformersAdamW",
|
||||||
optimizer_correct_bias: bool = True,
|
optimizer_correct_bias: bool = True,
|
||||||
save_dir: str = "../saved_models/dpr-tutorial",
|
save_dir: str = "../saved_models/dpr",
|
||||||
query_encoder_save_dir: str = "lm1",
|
query_encoder_save_dir: str = "query_encoder",
|
||||||
passage_encoder_save_dir: str = "lm2"
|
passage_encoder_save_dir: str = "passage_encoder"
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
train a DensePassageRetrieval model
|
train a DensePassageRetrieval model
|
||||||
@ -317,20 +317,24 @@ class DensePassageRetriever(BaseRetriever):
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
self.model.save(Path(save_dir), lm1_name=query_encoder_save_dir, lm2_name=passage_encoder_save_dir)
|
self.model.save(Path(save_dir), lm1_name=query_encoder_save_dir, lm2_name=passage_encoder_save_dir)
|
||||||
self.processor.save(Path(save_dir))
|
self.query_tokenizer.save_pretrained(f"{save_dir}/{query_encoder_save_dir}")
|
||||||
|
self.passage_tokenizer.save_pretrained(f"{save_dir}/{passage_encoder_save_dir}")
|
||||||
|
|
||||||
def save(self, save_dir: Union[Path, str]):
|
def save(self, save_dir: Union[Path, str], query_encoder_dir: str = "query_encoder",
|
||||||
|
passage_encoder_dir: str = "passage_encoder"):
|
||||||
"""
|
"""
|
||||||
Save DensePassageRetriever to the specified directory.
|
Save DensePassageRetriever to the specified directory.
|
||||||
|
|
||||||
:param save_dir: Directory to save to.
|
:param save_dir: Directory to save to.
|
||||||
|
:param query_encoder_dir: Directory in save_dir that contains query encoder model.
|
||||||
|
:param passage_encoder_dir: Directory in save_dir that contains passage encoder model.
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
save_dir = Path(save_dir)
|
save_dir = Path(save_dir)
|
||||||
self.model.save(save_dir, lm1_name="query_encoder", lm2_name="passage_encoder")
|
self.model.save(save_dir, lm1_name=query_encoder_dir, lm2_name=passage_encoder_dir)
|
||||||
save_dir = str(save_dir)
|
save_dir = str(save_dir)
|
||||||
self.query_tokenizer.save_pretrained(save_dir + "/query_encoder")
|
self.query_tokenizer.save_pretrained(save_dir + f"/{query_encoder_dir}")
|
||||||
self.passage_tokenizer.save_pretrained(save_dir + "/passage_encoder")
|
self.passage_tokenizer.save_pretrained(save_dir + f"/{passage_encoder_dir}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls,
|
def load(cls,
|
||||||
@ -343,6 +347,8 @@ class DensePassageRetriever(BaseRetriever):
|
|||||||
embed_title: bool = True,
|
embed_title: bool = True,
|
||||||
use_fast_tokenizers: bool = True,
|
use_fast_tokenizers: bool = True,
|
||||||
similarity_function: str = "dot_product",
|
similarity_function: str = "dot_product",
|
||||||
|
query_encoder_dir: str = "query_encoder",
|
||||||
|
passage_encoder_dir: str = "passage_encoder"
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Load DensePassageRetriever from the specified directory.
|
Load DensePassageRetriever from the specified directory.
|
||||||
@ -351,8 +357,8 @@ class DensePassageRetriever(BaseRetriever):
|
|||||||
load_dir = Path(load_dir)
|
load_dir = Path(load_dir)
|
||||||
dpr = cls(
|
dpr = cls(
|
||||||
document_store=document_store,
|
document_store=document_store,
|
||||||
query_embedding_model=Path(load_dir) / "query_encoder",
|
query_embedding_model=Path(load_dir) / query_encoder_dir,
|
||||||
passage_embedding_model=Path(load_dir) / "passage_encoder",
|
passage_embedding_model=Path(load_dir) / passage_encoder_dir,
|
||||||
max_seq_len_query=max_seq_len_query,
|
max_seq_len_query=max_seq_len_query,
|
||||||
max_seq_len_passage=max_seq_len_passage,
|
max_seq_len_passage=max_seq_len_passage,
|
||||||
use_gpu=use_gpu,
|
use_gpu=use_gpu,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user