mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-10 06:43:58 +00:00
Add switch for BiAdaptive and TriAdaptiveModel in Evaluator (#2908)
* Add switch for BiAdaptive and Triadaptive Model * fix import * black * padding -> attention
This commit is contained in:
parent
b78db1cbaf
commit
284c759346
@ -8,6 +8,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from haystack.modeling.evaluation.metrics import compute_metrics, compute_report_metrics
|
from haystack.modeling.evaluation.metrics import compute_metrics, compute_report_metrics
|
||||||
from haystack.modeling.model.adaptive_model import AdaptiveModel
|
from haystack.modeling.model.adaptive_model import AdaptiveModel
|
||||||
|
from haystack.modeling.model.biadaptive_model import BiAdaptiveModel
|
||||||
from haystack.utils.experiment_tracking import Tracker as tracker
|
from haystack.utils.experiment_tracking import Tracker as tracker
|
||||||
from haystack.modeling.visual import BUSH_SEP
|
from haystack.modeling.visual import BUSH_SEP
|
||||||
|
|
||||||
@ -69,13 +70,26 @@ class Evaluator:
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
||||||
logits = model.forward(
|
if isinstance(model, AdaptiveModel):
|
||||||
input_ids=batch.get("input_ids", None),
|
logits = model.forward(
|
||||||
segment_ids=batch.get("segment_ids", None),
|
input_ids=batch.get("input_ids", None),
|
||||||
padding_mask=batch.get("padding_mask", None),
|
segment_ids=batch.get("segment_ids", None),
|
||||||
output_hidden_states=batch.get("output_hidden_states", False),
|
padding_mask=batch.get("padding_mask", None),
|
||||||
output_attentions=batch.get("output_attentions", False),
|
output_hidden_states=batch.get("output_hidden_states", False),
|
||||||
)
|
output_attentions=batch.get("output_attentions", False),
|
||||||
|
)
|
||||||
|
elif isinstance(model, BiAdaptiveModel):
|
||||||
|
logits = model.forward(
|
||||||
|
query_input_ids=batch.get("query_input_ids", None),
|
||||||
|
query_segment_ids=batch.get("query_segment_ids", None),
|
||||||
|
query_attention_mask=batch.get("query_attention_mask", None),
|
||||||
|
passage_input_ids=batch.get("passage_input_ids", None),
|
||||||
|
passage_segment_ids=batch.get("passage_segment_ids", None),
|
||||||
|
passage_attention_mask=batch.get("passage_attention_mask", None),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits = model.forward(**batch)
|
||||||
|
|
||||||
losses_per_head = model.logits_to_loss_per_head(logits=logits, **batch)
|
losses_per_head = model.logits_to_loss_per_head(logits=logits, **batch)
|
||||||
preds = model.logits_to_preds(logits=logits, **batch)
|
preds = model.logits_to_preds(logits=logits, **batch)
|
||||||
labels = model.prepare_labels(**batch)
|
labels = model.prepare_labels(**batch)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user