feat: Add inference mode to ExtractiveReader (#7699)

* Add inference mode to ExtractiveReader

* Add release notes
This commit is contained in:
Sebastian Husch Lee 2024-05-15 21:33:57 +02:00 committed by GitHub
parent c8d53b3ebf
commit af53e8430d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 1 deletions

View File

@ -602,7 +602,8 @@ class ExtractiveReader:
cur_input_ids = input_ids[start_index:end_index]
cur_attention_mask = attention_mask[start_index:end_index]
output = self.model(input_ids=cur_input_ids, attention_mask=cur_attention_mask)
with torch.inference_mode():
output = self.model(input_ids=cur_input_ids, attention_mask=cur_attention_mask)
cur_start_logits = output.start_logits
cur_end_logits = output.end_logits
if num_batches != 1:

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
Adds inference mode to model call of the ExtractiveReader. This prevents gradients from being calculated during inference time in pytorch.