From 284c759346dd6adcdb8d0de3f59b745b04d2ac37 Mon Sep 17 00:00:00 2001 From: Sara Zan Date: Fri, 29 Jul 2022 11:31:52 +0200 Subject: [PATCH] Add switch for `BiAdaptive` and `TriAdaptiveModel` in `Evaluator` (#2908) * Add switch for BiAdaptive and Triadaptive Model * fix import * black * padding -> attention --- haystack/modeling/evaluation/eval.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/haystack/modeling/evaluation/eval.py b/haystack/modeling/evaluation/eval.py index d73d77213..028227462 100644 --- a/haystack/modeling/evaluation/eval.py +++ b/haystack/modeling/evaluation/eval.py @@ -8,6 +8,7 @@ from tqdm import tqdm from haystack.modeling.evaluation.metrics import compute_metrics, compute_report_metrics from haystack.modeling.model.adaptive_model import AdaptiveModel +from haystack.modeling.model.biadaptive_model import BiAdaptiveModel from haystack.utils.experiment_tracking import Tracker as tracker from haystack.modeling.visual import BUSH_SEP @@ -69,13 +70,26 @@ class Evaluator: with torch.no_grad(): - logits = model.forward( - input_ids=batch.get("input_ids", None), - segment_ids=batch.get("segment_ids", None), - padding_mask=batch.get("padding_mask", None), - output_hidden_states=batch.get("output_hidden_states", False), - output_attentions=batch.get("output_attentions", False), - ) + if isinstance(model, AdaptiveModel): + logits = model.forward( + input_ids=batch.get("input_ids", None), + segment_ids=batch.get("segment_ids", None), + padding_mask=batch.get("padding_mask", None), + output_hidden_states=batch.get("output_hidden_states", False), + output_attentions=batch.get("output_attentions", False), + ) + elif isinstance(model, BiAdaptiveModel): + logits = model.forward( + query_input_ids=batch.get("query_input_ids", None), + query_segment_ids=batch.get("query_segment_ids", None), + query_attention_mask=batch.get("query_attention_mask", None), + passage_input_ids=batch.get("passage_input_ids", None), + passage_segment_ids=batch.get("passage_segment_ids", None), + passage_attention_mask=batch.get("passage_attention_mask", None), + ) + else: + logits = model.forward(**batch) + losses_per_head = model.logits_to_loss_per_head(logits=logits, **batch) preds = model.logits_to_preds(logits=logits, **batch) labels = model.prepare_labels(**batch)