mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-27 06:58:35 +00:00
fix: allow Biadaptive & Triadaptive to work with EarlyStopping (#4033)
* fix: allow str when saving tri/bi-adaptive models * fix: make trainer model loading class-agnostic * test: add test for DPR with EarlyStopping * refactor: simplify model reloading via classmethod --------- Co-authored-by: Julian Risch <julian.risch@deepset.ai>
This commit is contained in:
parent
a092eac2c7
commit
f006eded7d
@ -92,19 +92,18 @@ class BiAdaptiveModel(nn.Module):
|
||||
loss_aggregation_fn = loss_per_head_sum
|
||||
self.loss_aggregation_fn = loss_aggregation_fn
|
||||
|
||||
def save(self, save_dir: Path, lm1_name: str = "lm1", lm2_name: str = "lm2"):
|
||||
def save(self, save_dir: Union[str, Path], lm1_name: str = "lm1", lm2_name: str = "lm2"):
|
||||
"""
|
||||
Saves the 2 language model weights and respective config_files in directories lm1 and lm2 within save_dir.
|
||||
|
||||
:param save_dir: Path to save the BiAdaptiveModel to.
|
||||
:param save_dir: Path | str to save the BiAdaptiveModel to.
|
||||
"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
if not os.path.exists(Path.joinpath(save_dir, Path(lm1_name))):
|
||||
os.makedirs(Path.joinpath(save_dir, Path(lm1_name)))
|
||||
if not os.path.exists(Path.joinpath(save_dir, Path(lm2_name))):
|
||||
os.makedirs(Path.joinpath(save_dir, Path(lm2_name)))
|
||||
self.language_model1.save(Path.joinpath(save_dir, Path(lm1_name)))
|
||||
self.language_model2.save(Path.joinpath(save_dir, Path(lm2_name)))
|
||||
for name, model in zip([lm1_name, lm2_name], [self.language_model1, self.language_model2]):
|
||||
model_save_dir = Path.joinpath(Path(save_dir), Path(name))
|
||||
os.makedirs(model_save_dir, exist_ok=True)
|
||||
model.save(model_save_dir)
|
||||
|
||||
for i, ph in enumerate(self.prediction_heads):
|
||||
logger.info("prediction_head saving")
|
||||
ph.save(save_dir, i)
|
||||
|
||||
@ -105,22 +105,20 @@ class TriAdaptiveModel(nn.Module):
|
||||
loss_aggregation_fn = loss_per_head_sum
|
||||
self.loss_aggregation_fn = loss_aggregation_fn
|
||||
|
||||
def save(self, save_dir: Path, lm1_name: str = "lm1", lm2_name: str = "lm2", lm3_name: str = "lm3"):
|
||||
def save(self, save_dir: Union[str, Path], lm1_name: str = "lm1", lm2_name: str = "lm2", lm3_name: str = "lm3"):
|
||||
"""
|
||||
Saves the 3 language model weights and respective config_files in directories lm1 and lm2 within save_dir.
|
||||
|
||||
:param save_dir: Path to save the TriAdaptiveModel to.
|
||||
:param save_dir: Path | str to save the TriAdaptiveModel to.
|
||||
"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
if not os.path.exists(Path.joinpath(save_dir, Path(lm1_name))):
|
||||
os.makedirs(Path.joinpath(save_dir, Path(lm1_name)))
|
||||
if not os.path.exists(Path.joinpath(save_dir, Path(lm2_name))):
|
||||
os.makedirs(Path.joinpath(save_dir, Path(lm2_name)))
|
||||
if not os.path.exists(Path.joinpath(save_dir, Path(lm3_name))):
|
||||
os.makedirs(Path.joinpath(save_dir, Path(lm3_name)))
|
||||
self.language_model1.save(Path.joinpath(save_dir, Path(lm1_name)))
|
||||
self.language_model2.save(Path.joinpath(save_dir, Path(lm2_name)))
|
||||
self.language_model3.save(Path.joinpath(save_dir, Path(lm3_name)))
|
||||
for name, model in zip(
|
||||
[lm1_name, lm2_name, lm3_name], [self.language_model1, self.language_model2, self.language_model3]
|
||||
):
|
||||
model_save_dir = Path.joinpath(Path(save_dir), Path(name))
|
||||
os.makedirs(model_save_dir, exist_ok=True)
|
||||
model.save(model_save_dir)
|
||||
|
||||
for i, ph in enumerate(self.prediction_heads):
|
||||
logger.info("prediction_head saving")
|
||||
ph.save(save_dir, i)
|
||||
|
||||
@ -281,7 +281,7 @@ class Trainer:
|
||||
# With early stopping we want to restore the best model
|
||||
if self.early_stopping and self.early_stopping.save_dir:
|
||||
logger.info("Restoring best model so far from {}".format(self.early_stopping.save_dir))
|
||||
self.model = AdaptiveModel.load(self.early_stopping.save_dir, self.device)
|
||||
self.model = self.model.load(self.early_stopping.save_dir, self.device)
|
||||
self.model.connect_heads_with_processor(self.data_silo.processor.tasks, require_labels=True)
|
||||
|
||||
# Eval on test set
|
||||
|
||||
@ -20,6 +20,7 @@ from haystack.modeling.model.prediction_head import TextSimilarityHead
|
||||
from haystack.nodes.retriever.dense import DensePassageRetriever
|
||||
|
||||
from haystack.modeling.utils import set_all_seeds, initialize_device_settings
|
||||
from haystack.utils.early_stopping import EarlyStopping
|
||||
|
||||
from ..conftest import SAMPLES_PATH
|
||||
|
||||
@ -1003,6 +1004,34 @@ def test_dpr_training(document_store, tmp_path):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||
def test_dpr_training_with_earlystopping(document_store, tmp_path):
|
||||
retriever = DensePassageRetriever(
|
||||
document_store=document_store,
|
||||
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
||||
max_seq_len_query=8,
|
||||
max_seq_len_passage=8,
|
||||
)
|
||||
|
||||
save_dir = f"{tmp_path}/test_dpr_training"
|
||||
retriever.train(
|
||||
data_dir=str(SAMPLES_PATH / "dpr"),
|
||||
train_filename="sample.json",
|
||||
dev_filename="sample.json",
|
||||
test_filename="sample.json",
|
||||
n_epochs=1,
|
||||
batch_size=1,
|
||||
grad_acc_steps=1,
|
||||
save_dir=save_dir,
|
||||
evaluate_every=1,
|
||||
embed_title=True,
|
||||
num_positives=1,
|
||||
num_hard_negatives=1,
|
||||
early_stopping=EarlyStopping(save_dir=save_dir),
|
||||
)
|
||||
|
||||
|
||||
# TODO fix CI errors (test pass locally or on AWS, next steps: isolate PyTorch versions once FARM dependency is removed)
|
||||
# def test_dpr_training():
|
||||
# batch_size = 1
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user