fix: Ensure eval mode for farm and transformer models for predictions (#3791)

Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
This commit is contained in:
Sebastian 2023-06-06 13:06:30 +02:00 committed by GitHub
parent 97d5db3b9c
commit 1777b22fcb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 0 deletions

View File

@ -290,6 +290,7 @@ class Trainer:
)
self.test_result = evaluator_test.eval(self.model)
evaluator_test.log_results(self.test_result, "Test", self.global_step)
self.model.eval()
return self.model
def compute_loss(self, batch: dict, step: int) -> torch.Tensor:

View File

@ -431,6 +431,36 @@ def test_no_answer_reader_skips_empty_documents(no_answer_reader):
assert predictions["answers"][1][1].answer == "Carla" # answer given for 2nd query as usual
@pytest.mark.integration
def test_reader_training_returns_eval(tmp_path, samples_path):
max_seq_len = 16
max_query_length = 8
reader = FARMReader(
model_name_or_path="deepset/tinyroberta-squad2",
use_gpu=False,
num_processes=0,
max_seq_len=max_seq_len,
doc_stride=2,
max_query_length=max_query_length,
)
save_dir = f"{tmp_path}/test_dpr_training"
reader.train(
data_dir=str(samples_path / "squad"),
train_filename="tiny.json",
dev_filename="tiny.json",
n_epochs=1,
batch_size=1,
grad_acc_steps=1,
evaluate_every=0,
save_dir=save_dir,
max_seq_len=max_seq_len,
max_query_length=max_query_length,
)
assert reader.inferencer.model.training is False
@pytest.mark.integration
def test_reader_training(tmp_path, samples_path):
max_seq_len = 16
max_query_length = 8