diff --git a/haystack/components/readers/extractive.py b/haystack/components/readers/extractive.py index 573aa5303..62e1ce002 100644 --- a/haystack/components/readers/extractive.py +++ b/haystack/components/readers/extractive.py @@ -602,7 +602,8 @@ class ExtractiveReader: cur_input_ids = input_ids[start_index:end_index] cur_attention_mask = attention_mask[start_index:end_index] - output = self.model(input_ids=cur_input_ids, attention_mask=cur_attention_mask) + with torch.inference_mode(): + output = self.model(input_ids=cur_input_ids, attention_mask=cur_attention_mask) cur_start_logits = output.start_logits cur_end_logits = output.end_logits if num_batches != 1: diff --git a/releasenotes/notes/add-inf-mode-reader-e6eb79920e73c956.yaml b/releasenotes/notes/add-inf-mode-reader-e6eb79920e73c956.yaml new file mode 100644 index 000000000..4dbe549f1 --- /dev/null +++ b/releasenotes/notes/add-inf-mode-reader-e6eb79920e73c956.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Adds inference mode to model call of the ExtractiveReader. This prevents gradients from being calculated during inference time in pytorch.