mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-03 11:19:57 +00:00
Add switch for BiAdaptive and TriAdaptiveModel in Evaluator (#2908)
* Add switch for BiAdaptive and Triadaptive Model * fix import * black * padding -> attention
This commit is contained in:
parent
b78db1cbaf
commit
284c759346
@ -8,6 +8,7 @@ from tqdm import tqdm
|
||||
|
||||
from haystack.modeling.evaluation.metrics import compute_metrics, compute_report_metrics
|
||||
from haystack.modeling.model.adaptive_model import AdaptiveModel
|
||||
from haystack.modeling.model.biadaptive_model import BiAdaptiveModel
|
||||
from haystack.utils.experiment_tracking import Tracker as tracker
|
||||
from haystack.modeling.visual import BUSH_SEP
|
||||
|
||||
@ -69,13 +70,26 @@ class Evaluator:
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
logits = model.forward(
|
||||
input_ids=batch.get("input_ids", None),
|
||||
segment_ids=batch.get("segment_ids", None),
|
||||
padding_mask=batch.get("padding_mask", None),
|
||||
output_hidden_states=batch.get("output_hidden_states", False),
|
||||
output_attentions=batch.get("output_attentions", False),
|
||||
)
|
||||
if isinstance(model, AdaptiveModel):
|
||||
logits = model.forward(
|
||||
input_ids=batch.get("input_ids", None),
|
||||
segment_ids=batch.get("segment_ids", None),
|
||||
padding_mask=batch.get("padding_mask", None),
|
||||
output_hidden_states=batch.get("output_hidden_states", False),
|
||||
output_attentions=batch.get("output_attentions", False),
|
||||
)
|
||||
elif isinstance(model, BiAdaptiveModel):
|
||||
logits = model.forward(
|
||||
query_input_ids=batch.get("query_input_ids", None),
|
||||
query_segment_ids=batch.get("query_segment_ids", None),
|
||||
query_attention_mask=batch.get("query_attention_mask", None),
|
||||
passage_input_ids=batch.get("passage_input_ids", None),
|
||||
passage_segment_ids=batch.get("passage_segment_ids", None),
|
||||
passage_attention_mask=batch.get("passage_attention_mask", None),
|
||||
)
|
||||
else:
|
||||
logits = model.forward(**batch)
|
||||
|
||||
losses_per_head = model.logits_to_loss_per_head(logits=logits, **batch)
|
||||
preds = model.logits_to_preds(logits=logits, **batch)
|
||||
labels = model.prepare_labels(**batch)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user