Add switch for BiAdaptive and TriAdaptiveModel in Evaluator (#2908)

* Add switch for BiAdaptive and Triadaptive Model

* fix import

* black

* padding -> attention
This commit is contained in:
Sara Zan 2022-07-29 11:31:52 +02:00 committed by GitHub
parent b78db1cbaf
commit 284c759346
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -8,6 +8,7 @@ from tqdm import tqdm
from haystack.modeling.evaluation.metrics import compute_metrics, compute_report_metrics
from haystack.modeling.model.adaptive_model import AdaptiveModel
from haystack.modeling.model.biadaptive_model import BiAdaptiveModel
from haystack.utils.experiment_tracking import Tracker as tracker
from haystack.modeling.visual import BUSH_SEP
@ -69,13 +70,26 @@ class Evaluator:
with torch.no_grad():
logits = model.forward(
input_ids=batch.get("input_ids", None),
segment_ids=batch.get("segment_ids", None),
padding_mask=batch.get("padding_mask", None),
output_hidden_states=batch.get("output_hidden_states", False),
output_attentions=batch.get("output_attentions", False),
)
if isinstance(model, AdaptiveModel):
logits = model.forward(
input_ids=batch.get("input_ids", None),
segment_ids=batch.get("segment_ids", None),
padding_mask=batch.get("padding_mask", None),
output_hidden_states=batch.get("output_hidden_states", False),
output_attentions=batch.get("output_attentions", False),
)
elif isinstance(model, BiAdaptiveModel):
logits = model.forward(
query_input_ids=batch.get("query_input_ids", None),
query_segment_ids=batch.get("query_segment_ids", None),
query_attention_mask=batch.get("query_attention_mask", None),
passage_input_ids=batch.get("passage_input_ids", None),
passage_segment_ids=batch.get("passage_segment_ids", None),
passage_attention_mask=batch.get("passage_attention_mask", None),
)
else:
logits = model.forward(**batch)
losses_per_head = model.logits_to_loss_per_head(logits=logits, **batch)
preds = model.logits_to_preds(logits=logits, **batch)
labels = model.prepare_labels(**batch)