mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-31 20:03:38 +00:00
fix: Ensure eval mode for farm and transformer models for predictions (#3791)
Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
This commit is contained in:
parent
97d5db3b9c
commit
1777b22fcb
@ -290,6 +290,7 @@ class Trainer:
|
|||||||
)
|
)
|
||||||
self.test_result = evaluator_test.eval(self.model)
|
self.test_result = evaluator_test.eval(self.model)
|
||||||
evaluator_test.log_results(self.test_result, "Test", self.global_step)
|
evaluator_test.log_results(self.test_result, "Test", self.global_step)
|
||||||
|
self.model.eval()
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def compute_loss(self, batch: dict, step: int) -> torch.Tensor:
|
def compute_loss(self, batch: dict, step: int) -> torch.Tensor:
|
||||||
|
@ -431,6 +431,36 @@ def test_no_answer_reader_skips_empty_documents(no_answer_reader):
|
|||||||
assert predictions["answers"][1][1].answer == "Carla" # answer given for 2nd query as usual
|
assert predictions["answers"][1][1].answer == "Carla" # answer given for 2nd query as usual
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_reader_training_returns_eval(tmp_path, samples_path):
|
||||||
|
max_seq_len = 16
|
||||||
|
max_query_length = 8
|
||||||
|
reader = FARMReader(
|
||||||
|
model_name_or_path="deepset/tinyroberta-squad2",
|
||||||
|
use_gpu=False,
|
||||||
|
num_processes=0,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
doc_stride=2,
|
||||||
|
max_query_length=max_query_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_dir = f"{tmp_path}/test_dpr_training"
|
||||||
|
reader.train(
|
||||||
|
data_dir=str(samples_path / "squad"),
|
||||||
|
train_filename="tiny.json",
|
||||||
|
dev_filename="tiny.json",
|
||||||
|
n_epochs=1,
|
||||||
|
batch_size=1,
|
||||||
|
grad_acc_steps=1,
|
||||||
|
evaluate_every=0,
|
||||||
|
save_dir=save_dir,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
max_query_length=max_query_length,
|
||||||
|
)
|
||||||
|
assert reader.inferencer.model.training is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
def test_reader_training(tmp_path, samples_path):
|
def test_reader_training(tmp_path, samples_path):
|
||||||
max_seq_len = 16
|
max_seq_len = 16
|
||||||
max_query_length = 8
|
max_query_length = 8
|
||||||
|
Loading…
x
Reference in New Issue
Block a user