fix: allow Biadaptive & Triadaptive to work with EarlyStopping (#4033)

* fix: allow str when saving tri/bi-adaptive models

* fix: make trainer model loading class-agnostic

* test: add test for DPR with EarlyStopping

* refactor: simplify model reloading via classmethod

---------

Co-authored-by: Julian Risch <julian.risch@deepset.ai>
This commit is contained in:
Jack Butler 2023-02-03 10:13:18 +00:00 committed by GitHub
parent a092eac2c7
commit f006eded7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 46 additions and 20 deletions

View File

@ -92,19 +92,18 @@ class BiAdaptiveModel(nn.Module):
loss_aggregation_fn = loss_per_head_sum
self.loss_aggregation_fn = loss_aggregation_fn
def save(self, save_dir: Path, lm1_name: str = "lm1", lm2_name: str = "lm2"):
def save(self, save_dir: Union[str, Path], lm1_name: str = "lm1", lm2_name: str = "lm2"):
"""
Saves the 2 language model weights and respective config_files in directories lm1 and lm2 within save_dir.
:param save_dir: Path to save the BiAdaptiveModel to.
:param save_dir: Path | str to save the BiAdaptiveModel to.
"""
os.makedirs(save_dir, exist_ok=True)
if not os.path.exists(Path.joinpath(save_dir, Path(lm1_name))):
os.makedirs(Path.joinpath(save_dir, Path(lm1_name)))
if not os.path.exists(Path.joinpath(save_dir, Path(lm2_name))):
os.makedirs(Path.joinpath(save_dir, Path(lm2_name)))
self.language_model1.save(Path.joinpath(save_dir, Path(lm1_name)))
self.language_model2.save(Path.joinpath(save_dir, Path(lm2_name)))
for name, model in zip([lm1_name, lm2_name], [self.language_model1, self.language_model2]):
model_save_dir = Path.joinpath(Path(save_dir), Path(name))
os.makedirs(model_save_dir, exist_ok=True)
model.save(model_save_dir)
for i, ph in enumerate(self.prediction_heads):
logger.info("prediction_head saving")
ph.save(save_dir, i)

View File

@ -105,22 +105,20 @@ class TriAdaptiveModel(nn.Module):
loss_aggregation_fn = loss_per_head_sum
self.loss_aggregation_fn = loss_aggregation_fn
def save(self, save_dir: Path, lm1_name: str = "lm1", lm2_name: str = "lm2", lm3_name: str = "lm3"):
def save(self, save_dir: Union[str, Path], lm1_name: str = "lm1", lm2_name: str = "lm2", lm3_name: str = "lm3"):
"""
Saves the 3 language model weights and respective config_files in directories lm1 and lm2 within save_dir.
:param save_dir: Path to save the TriAdaptiveModel to.
:param save_dir: Path | str to save the TriAdaptiveModel to.
"""
os.makedirs(save_dir, exist_ok=True)
if not os.path.exists(Path.joinpath(save_dir, Path(lm1_name))):
os.makedirs(Path.joinpath(save_dir, Path(lm1_name)))
if not os.path.exists(Path.joinpath(save_dir, Path(lm2_name))):
os.makedirs(Path.joinpath(save_dir, Path(lm2_name)))
if not os.path.exists(Path.joinpath(save_dir, Path(lm3_name))):
os.makedirs(Path.joinpath(save_dir, Path(lm3_name)))
self.language_model1.save(Path.joinpath(save_dir, Path(lm1_name)))
self.language_model2.save(Path.joinpath(save_dir, Path(lm2_name)))
self.language_model3.save(Path.joinpath(save_dir, Path(lm3_name)))
for name, model in zip(
[lm1_name, lm2_name, lm3_name], [self.language_model1, self.language_model2, self.language_model3]
):
model_save_dir = Path.joinpath(Path(save_dir), Path(name))
os.makedirs(model_save_dir, exist_ok=True)
model.save(model_save_dir)
for i, ph in enumerate(self.prediction_heads):
logger.info("prediction_head saving")
ph.save(save_dir, i)

View File

@ -281,7 +281,7 @@ class Trainer:
# With early stopping we want to restore the best model
if self.early_stopping and self.early_stopping.save_dir:
logger.info("Restoring best model so far from {}".format(self.early_stopping.save_dir))
self.model = AdaptiveModel.load(self.early_stopping.save_dir, self.device)
self.model = self.model.load(self.early_stopping.save_dir, self.device)
self.model.connect_heads_with_processor(self.data_silo.processor.tasks, require_labels=True)
# Eval on test set

View File

@ -20,6 +20,7 @@ from haystack.modeling.model.prediction_head import TextSimilarityHead
from haystack.nodes.retriever.dense import DensePassageRetriever
from haystack.modeling.utils import set_all_seeds, initialize_device_settings
from haystack.utils.early_stopping import EarlyStopping
from ..conftest import SAMPLES_PATH
@ -1003,6 +1004,34 @@ def test_dpr_training(document_store, tmp_path):
)
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
def test_dpr_training_with_earlystopping(document_store, tmp_path):
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
max_seq_len_query=8,
max_seq_len_passage=8,
)
save_dir = f"{tmp_path}/test_dpr_training"
retriever.train(
data_dir=str(SAMPLES_PATH / "dpr"),
train_filename="sample.json",
dev_filename="sample.json",
test_filename="sample.json",
n_epochs=1,
batch_size=1,
grad_acc_steps=1,
save_dir=save_dir,
evaluate_every=1,
embed_title=True,
num_positives=1,
num_hard_negatives=1,
early_stopping=EarlyStopping(save_dir=save_dir),
)
# TODO fix CI errors (test pass locally or on AWS, next steps: isolate PyTorch versions once FARM dependency is removed)
# def test_dpr_training():
# batch_size = 1