mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-11-03 19:29:32 +00:00 
			
		
		
		
	refactor: replace torch.no_grad with torch.inference_mode (where possible) (#3601)
				
					
				
			* try to replace torch.no_grad * revert erroneous change * revert other module breaking * revert training/base
This commit is contained in:
		
							parent
							
								
									3040e59c63
								
							
						
					
					
						commit
						f43bc562d3
					
				@ -305,7 +305,7 @@ class InMemoryDocumentStore(KeywordDocumentStore):
 | 
			
		||||
        while curr_pos < len(doc_embeds):
 | 
			
		||||
            doc_embeds_slice = doc_embeds[curr_pos : curr_pos + self.scoring_batch_size]
 | 
			
		||||
            doc_embeds_slice = doc_embeds_slice.to(self.main_device)
 | 
			
		||||
            with torch.no_grad():
 | 
			
		||||
            with torch.inference_mode():
 | 
			
		||||
                slice_scores = torch.matmul(doc_embeds_slice, query_emb.T).cpu()
 | 
			
		||||
                slice_scores = slice_scores.squeeze(dim=1)
 | 
			
		||||
                slice_scores = slice_scores.numpy().tolist()
 | 
			
		||||
 | 
			
		||||
@ -771,7 +771,7 @@ class DistillationDataSilo(DataSilo):
 | 
			
		||||
        teacher_outputs: List[List[Tuple[torch.Tensor, ...]]],
 | 
			
		||||
        tensor_names: List[str],
 | 
			
		||||
    ):
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
        with torch.inference_mode():
 | 
			
		||||
            batch_transposed = zip(*batch)  # transpose dimensions (from batch, features, ... to features, batch, ...)
 | 
			
		||||
            batch_transposed_list = [torch.stack(b) for b in batch_transposed]  # create tensors for each feature
 | 
			
		||||
            batch_dict = {
 | 
			
		||||
 | 
			
		||||
@ -77,7 +77,7 @@ class Evaluator:
 | 
			
		||||
            else:
 | 
			
		||||
                module = model
 | 
			
		||||
 | 
			
		||||
            with torch.no_grad():
 | 
			
		||||
            with torch.inference_mode():
 | 
			
		||||
                if isinstance(module, AdaptiveModel):
 | 
			
		||||
                    logits = model.forward(
 | 
			
		||||
                        input_ids=batch.get("input_ids", None),
 | 
			
		||||
 | 
			
		||||
@ -365,7 +365,7 @@ class Inferencer:
 | 
			
		||||
            batch_samples = samples[i * self.batch_size : (i + 1) * self.batch_size]
 | 
			
		||||
 | 
			
		||||
            # get logits
 | 
			
		||||
            with torch.no_grad():
 | 
			
		||||
            with torch.inference_mode():
 | 
			
		||||
                logits = self.model.forward(**batch)
 | 
			
		||||
                preds = self.model.formatted_preds(
 | 
			
		||||
                    logits=logits, samples=batch_samples, padding_mask=batch.get("padding_mask", None)
 | 
			
		||||
@ -402,7 +402,7 @@ class Inferencer:
 | 
			
		||||
            batch = {key: batch[key].to(self.devices[0]) for key in batch}
 | 
			
		||||
 | 
			
		||||
            # get logits
 | 
			
		||||
            with torch.no_grad():
 | 
			
		||||
            with torch.inference_mode():
 | 
			
		||||
                # Aggregation works on preds, not logits. We want as much processing happening in one batch + on GPU
 | 
			
		||||
                # So we transform logits to preds here as well
 | 
			
		||||
                logits = self.model.forward(
 | 
			
		||||
 | 
			
		||||
@ -778,7 +778,7 @@ class ONNXAdaptiveModel(BaseAdaptiveModel):
 | 
			
		||||
        :param kwargs: All arguments that need to be passed on to the model.
 | 
			
		||||
        :return: All logits as torch.tensor or multiple tensors.
 | 
			
		||||
        """
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
        with torch.inference_mode():
 | 
			
		||||
            if self.language_model_class == "Bert":
 | 
			
		||||
                input_to_onnx = {
 | 
			
		||||
                    "input_ids": numpy.ascontiguousarray(kwargs["input_ids"].cpu().numpy()),
 | 
			
		||||
 | 
			
		||||
@ -129,7 +129,7 @@ class SentenceTransformersRanker(BaseRanker):
 | 
			
		||||
        # 1. the logit as similarity score/answerable classification
 | 
			
		||||
        # 2. the logits as answerable classification  (no_answer / has_answer)
 | 
			
		||||
        # https://www.sbert.net/docs/pretrained-models/ce-msmarco.html#usage-with-transformers
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
        with torch.inference_mode():
 | 
			
		||||
            similarity_scores = self.transformer_model(**features).logits
 | 
			
		||||
 | 
			
		||||
        logits_dim = similarity_scores.shape[1]  # [batch_size, logits_dim]
 | 
			
		||||
@ -216,7 +216,7 @@ class SentenceTransformersRanker(BaseRanker):
 | 
			
		||||
                cur_queries, [doc.content for doc in cur_docs], padding=True, truncation=True, return_tensors="pt"
 | 
			
		||||
            ).to(self.devices[0])
 | 
			
		||||
 | 
			
		||||
            with torch.no_grad():
 | 
			
		||||
            with torch.inference_mode():
 | 
			
		||||
                similarity_scores = self.transformer_model(**features).logits
 | 
			
		||||
                preds.extend(similarity_scores)
 | 
			
		||||
            pb.update(len(cur_docs))
 | 
			
		||||
 | 
			
		||||
@ -304,7 +304,7 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
 | 
			
		||||
 | 
			
		||||
        for i, batch in enumerate(tqdm(dataloader, desc=f"Creating Embeddings", unit=" Batches", disable=disable_tqdm)):
 | 
			
		||||
            batch = {key: batch[key].to(self.embedding_model.device) for key in batch}
 | 
			
		||||
            with torch.no_grad():
 | 
			
		||||
            with torch.inference_mode():
 | 
			
		||||
                q_reps = (
 | 
			
		||||
                    self.embedding_model.embed_questions(
 | 
			
		||||
                        input_ids=batch["input_ids"], attention_mask=batch["padding_mask"]
 | 
			
		||||
@ -331,7 +331,7 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
 | 
			
		||||
 | 
			
		||||
        for i, batch in enumerate(tqdm(dataloader, desc=f"Creating Embeddings", unit=" Batches", disable=disable_tqdm)):
 | 
			
		||||
            batch = {key: batch[key].to(self.embedding_model.device) for key in batch}
 | 
			
		||||
            with torch.no_grad():
 | 
			
		||||
            with torch.inference_mode():
 | 
			
		||||
                q_reps = (
 | 
			
		||||
                    self.embedding_model.embed_answers(
 | 
			
		||||
                        input_ids=batch["input_ids"], attention_mask=batch["padding_mask"]
 | 
			
		||||
 | 
			
		||||
@ -528,7 +528,7 @@ class DensePassageRetriever(DenseRetriever):
 | 
			
		||||
                batch = {key: raw_batch[key].to(self.devices[0]) for key in raw_batch}
 | 
			
		||||
 | 
			
		||||
                # get logits
 | 
			
		||||
                with torch.no_grad():
 | 
			
		||||
                with torch.inference_mode():
 | 
			
		||||
                    query_embeddings, passage_embeddings = self.model.forward(
 | 
			
		||||
                        query_input_ids=batch.get("query_input_ids", None),
 | 
			
		||||
                        query_segment_ids=batch.get("query_segment_ids", None),
 | 
			
		||||
@ -1171,7 +1171,7 @@ class TableTextRetriever(DenseRetriever):
 | 
			
		||||
                batch = {key: batch[key].to(self.devices[0]) for key in batch}
 | 
			
		||||
 | 
			
		||||
                # get logits
 | 
			
		||||
                with torch.no_grad():
 | 
			
		||||
                with torch.inference_mode():
 | 
			
		||||
                    query_embeddings, passage_embeddings = self.model.forward(**batch)[0]
 | 
			
		||||
                    if query_embeddings is not None:
 | 
			
		||||
                        query_embeddings_batched.append(query_embeddings.cpu().numpy())
 | 
			
		||||
 | 
			
		||||
@ -77,7 +77,7 @@ def load_glove(
 | 
			
		||||
            id_word_mapping[i] = split[0]
 | 
			
		||||
            vector_list.append(torch.tensor([float(x) for x in split[1:]]))
 | 
			
		||||
    vectors = torch.stack(vector_list)
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
    with torch.inference_mode():
 | 
			
		||||
        vectors = vectors.to(device)
 | 
			
		||||
        vectors = F.normalize(vectors, dim=1)
 | 
			
		||||
    return word_id_mapping, id_word_mapping, vectors
 | 
			
		||||
@ -132,7 +132,7 @@ def get_replacements(
 | 
			
		||||
        inputs.append((input_ids_, subword_index))
 | 
			
		||||
 | 
			
		||||
    # doing batched forward pass
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
    with torch.inference_mode():
 | 
			
		||||
        prediction_list = []
 | 
			
		||||
        while len(inputs) != 0:
 | 
			
		||||
            batch_list, token_indices = tuple(zip(*inputs[:batch_size]))
 | 
			
		||||
@ -165,7 +165,7 @@ def get_replacements(
 | 
			
		||||
        elif word in glove_word_id_mapping:  # word was split into subwords so we use glove instead
 | 
			
		||||
            word_id = glove_word_id_mapping[word]
 | 
			
		||||
            glove_vector = glove_vectors[word_id]
 | 
			
		||||
            with torch.no_grad():
 | 
			
		||||
            with torch.inference_mode():
 | 
			
		||||
                word_similarities = torch.mm(glove_vectors, glove_vector.unsqueeze(1)).squeeze(1)
 | 
			
		||||
                ranking = torch.argsort(word_similarities, descending=True)[: word_possibilities + 1]
 | 
			
		||||
                possible_words.append([glove_id_word_mapping[int(id_)] for id_ in ranking.cpu()])
 | 
			
		||||
 | 
			
		||||
@ -829,7 +829,7 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path: Path, query_and_pa
 | 
			
		||||
        batch = {key: batch[key].to(device) for key in batch}
 | 
			
		||||
 | 
			
		||||
        # get logits
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
        with torch.inference_mode():
 | 
			
		||||
            query_embeddings, passage_embeddings = model.forward(
 | 
			
		||||
                query_input_ids=batch.get("query_input_ids", None),
 | 
			
		||||
                query_segment_ids=batch.get("query_segment_ids", None),
 | 
			
		||||
@ -863,7 +863,7 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path: Path, query_and_pa
 | 
			
		||||
        batch = {key: batch[key].to(device) for key in batch}
 | 
			
		||||
 | 
			
		||||
        # get logits
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
        with torch.inference_mode():
 | 
			
		||||
            query_embeddings, passage_embeddings = loaded_model.forward(
 | 
			
		||||
                query_input_ids=batch.get("query_input_ids", None),
 | 
			
		||||
                query_segment_ids=batch.get("query_segment_ids", None),
 | 
			
		||||
@ -952,7 +952,7 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path: Path, query_and_pa
 | 
			
		||||
        batch = {key: batch[key].to(device) for key in batch}
 | 
			
		||||
 | 
			
		||||
        # get logits
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
        with torch.inference_mode():
 | 
			
		||||
            query_embeddings, passage_embeddings = loaded_model.forward(
 | 
			
		||||
                query_input_ids=batch.get("query_input_ids", None),
 | 
			
		||||
                query_segment_ids=batch.get("query_segment_ids", None),
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user