diff --git a/haystack/document_stores/memory.py b/haystack/document_stores/memory.py index 5d893cd00..f8eb1ab5a 100644 --- a/haystack/document_stores/memory.py +++ b/haystack/document_stores/memory.py @@ -305,7 +305,7 @@ class InMemoryDocumentStore(KeywordDocumentStore): while curr_pos < len(doc_embeds): doc_embeds_slice = doc_embeds[curr_pos : curr_pos + self.scoring_batch_size] doc_embeds_slice = doc_embeds_slice.to(self.main_device) - with torch.no_grad(): + with torch.inference_mode(): slice_scores = torch.matmul(doc_embeds_slice, query_emb.T).cpu() slice_scores = slice_scores.squeeze(dim=1) slice_scores = slice_scores.numpy().tolist() diff --git a/haystack/modeling/data_handler/data_silo.py b/haystack/modeling/data_handler/data_silo.py index cdb50a8ce..14b0a86de 100644 --- a/haystack/modeling/data_handler/data_silo.py +++ b/haystack/modeling/data_handler/data_silo.py @@ -771,7 +771,7 @@ class DistillationDataSilo(DataSilo): teacher_outputs: List[List[Tuple[torch.Tensor, ...]]], tensor_names: List[str], ): - with torch.no_grad(): + with torch.inference_mode(): batch_transposed = zip(*batch) # transpose dimensions (from batch, features, ... to features, batch, ...) batch_transposed_list = [torch.stack(b) for b in batch_transposed] # create tensors for each feature batch_dict = { diff --git a/haystack/modeling/evaluation/eval.py b/haystack/modeling/evaluation/eval.py index b12593dd6..f09bb88e8 100644 --- a/haystack/modeling/evaluation/eval.py +++ b/haystack/modeling/evaluation/eval.py @@ -77,7 +77,7 @@ class Evaluator: else: module = model - with torch.no_grad(): + with torch.inference_mode(): if isinstance(module, AdaptiveModel): logits = model.forward( input_ids=batch.get("input_ids", None), diff --git a/haystack/modeling/infer.py b/haystack/modeling/infer.py index dfc44aa8d..4abe42a4c 100644 --- a/haystack/modeling/infer.py +++ b/haystack/modeling/infer.py @@ -365,7 +365,7 @@ class Inferencer: batch_samples = samples[i * self.batch_size : (i + 1) * self.batch_size] # get logits - with torch.no_grad(): + with torch.inference_mode(): logits = self.model.forward(**batch) preds = self.model.formatted_preds( logits=logits, samples=batch_samples, padding_mask=batch.get("padding_mask", None) @@ -402,7 +402,7 @@ class Inferencer: batch = {key: batch[key].to(self.devices[0]) for key in batch} # get logits - with torch.no_grad(): + with torch.inference_mode(): # Aggregation works on preds, not logits. We want as much processing happening in one batch + on GPU # So we transform logits to preds here as well logits = self.model.forward( diff --git a/haystack/modeling/model/adaptive_model.py b/haystack/modeling/model/adaptive_model.py index f6a962f71..3e564a8ae 100644 --- a/haystack/modeling/model/adaptive_model.py +++ b/haystack/modeling/model/adaptive_model.py @@ -778,7 +778,7 @@ class ONNXAdaptiveModel(BaseAdaptiveModel): :param kwargs: All arguments that need to be passed on to the model. :return: All logits as torch.tensor or multiple tensors. """ - with torch.no_grad(): + with torch.inference_mode(): if self.language_model_class == "Bert": input_to_onnx = { "input_ids": numpy.ascontiguousarray(kwargs["input_ids"].cpu().numpy()), diff --git a/haystack/nodes/ranker/sentence_transformers.py b/haystack/nodes/ranker/sentence_transformers.py index f86cd93b5..bc3402675 100644 --- a/haystack/nodes/ranker/sentence_transformers.py +++ b/haystack/nodes/ranker/sentence_transformers.py @@ -129,7 +129,7 @@ class SentenceTransformersRanker(BaseRanker): # 1. the logit as similarity score/answerable classification # 2. the logits as answerable classification (no_answer / has_answer) # https://www.sbert.net/docs/pretrained-models/ce-msmarco.html#usage-with-transformers - with torch.no_grad(): + with torch.inference_mode(): similarity_scores = self.transformer_model(**features).logits logits_dim = similarity_scores.shape[1] # [batch_size, logits_dim] @@ -216,7 +216,7 @@ class SentenceTransformersRanker(BaseRanker): cur_queries, [doc.content for doc in cur_docs], padding=True, truncation=True, return_tensors="pt" ).to(self.devices[0]) - with torch.no_grad(): + with torch.inference_mode(): similarity_scores = self.transformer_model(**features).logits preds.extend(similarity_scores) pb.update(len(cur_docs)) diff --git a/haystack/nodes/retriever/_embedding_encoder.py b/haystack/nodes/retriever/_embedding_encoder.py index 9e1bfe1d4..4ed663c4a 100644 --- a/haystack/nodes/retriever/_embedding_encoder.py +++ b/haystack/nodes/retriever/_embedding_encoder.py @@ -304,7 +304,7 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder): for i, batch in enumerate(tqdm(dataloader, desc=f"Creating Embeddings", unit=" Batches", disable=disable_tqdm)): batch = {key: batch[key].to(self.embedding_model.device) for key in batch} - with torch.no_grad(): + with torch.inference_mode(): q_reps = ( self.embedding_model.embed_questions( input_ids=batch["input_ids"], attention_mask=batch["padding_mask"] @@ -331,7 +331,7 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder): for i, batch in enumerate(tqdm(dataloader, desc=f"Creating Embeddings", unit=" Batches", disable=disable_tqdm)): batch = {key: batch[key].to(self.embedding_model.device) for key in batch} - with torch.no_grad(): + with torch.inference_mode(): q_reps = ( self.embedding_model.embed_answers( input_ids=batch["input_ids"], attention_mask=batch["padding_mask"] diff --git a/haystack/nodes/retriever/dense.py b/haystack/nodes/retriever/dense.py index d5ca6ef3e..c89834668 100644 --- a/haystack/nodes/retriever/dense.py +++ b/haystack/nodes/retriever/dense.py @@ -528,7 +528,7 @@ class DensePassageRetriever(DenseRetriever): batch = {key: raw_batch[key].to(self.devices[0]) for key in raw_batch} # get logits - with torch.no_grad(): + with torch.inference_mode(): query_embeddings, passage_embeddings = self.model.forward( query_input_ids=batch.get("query_input_ids", None), query_segment_ids=batch.get("query_segment_ids", None), @@ -1171,7 +1171,7 @@ class TableTextRetriever(DenseRetriever): batch = {key: batch[key].to(self.devices[0]) for key in batch} # get logits - with torch.no_grad(): + with torch.inference_mode(): query_embeddings, passage_embeddings = self.model.forward(**batch)[0] if query_embeddings is not None: query_embeddings_batched.append(query_embeddings.cpu().numpy()) diff --git a/haystack/utils/augment_squad.py b/haystack/utils/augment_squad.py index 7635b23df..d8bbf81a9 100644 --- a/haystack/utils/augment_squad.py +++ b/haystack/utils/augment_squad.py @@ -77,7 +77,7 @@ def load_glove( id_word_mapping[i] = split[0] vector_list.append(torch.tensor([float(x) for x in split[1:]])) vectors = torch.stack(vector_list) - with torch.no_grad(): + with torch.inference_mode(): vectors = vectors.to(device) vectors = F.normalize(vectors, dim=1) return word_id_mapping, id_word_mapping, vectors @@ -132,7 +132,7 @@ def get_replacements( inputs.append((input_ids_, subword_index)) # doing batched forward pass - with torch.no_grad(): + with torch.inference_mode(): prediction_list = [] while len(inputs) != 0: batch_list, token_indices = tuple(zip(*inputs[:batch_size])) @@ -165,7 +165,7 @@ def get_replacements( elif word in glove_word_id_mapping: # word was split into subwords so we use glove instead word_id = glove_word_id_mapping[word] glove_vector = glove_vectors[word_id] - with torch.no_grad(): + with torch.inference_mode(): word_similarities = torch.mm(glove_vectors, glove_vector.unsqueeze(1)).squeeze(1) ranking = torch.argsort(word_similarities, descending=True)[: word_possibilities + 1] possible_words.append([glove_id_word_mapping[int(id_)] for id_ in ranking.cpu()]) diff --git a/test/modeling/test_dpr.py b/test/modeling/test_dpr.py index 405b979ee..b3ea5518b 100644 --- a/test/modeling/test_dpr.py +++ b/test/modeling/test_dpr.py @@ -829,7 +829,7 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path: Path, query_and_pa batch = {key: batch[key].to(device) for key in batch} # get logits - with torch.no_grad(): + with torch.inference_mode(): query_embeddings, passage_embeddings = model.forward( query_input_ids=batch.get("query_input_ids", None), query_segment_ids=batch.get("query_segment_ids", None), @@ -863,7 +863,7 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path: Path, query_and_pa batch = {key: batch[key].to(device) for key in batch} # get logits - with torch.no_grad(): + with torch.inference_mode(): query_embeddings, passage_embeddings = loaded_model.forward( query_input_ids=batch.get("query_input_ids", None), query_segment_ids=batch.get("query_segment_ids", None), @@ -952,7 +952,7 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path: Path, query_and_pa batch = {key: batch[key].to(device) for key in batch} # get logits - with torch.no_grad(): + with torch.inference_mode(): query_embeddings, passage_embeddings = loaded_model.forward( query_input_ids=batch.get("query_input_ids", None), query_segment_ids=batch.get("query_segment_ids", None),