refactor: replace torch.no_grad with torch.inference_mode (where possible) (#3601)

* try to replace torch.no_grad

* revert erroneous change

* revert other module breaking

* revert training/base
This commit is contained in:
Stefano Fiorucci 2022-11-23 09:26:11 +01:00 committed by GitHub
parent 3040e59c63
commit f43bc562d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 18 additions and 18 deletions

View File

@ -305,7 +305,7 @@ class InMemoryDocumentStore(KeywordDocumentStore):
while curr_pos < len(doc_embeds):
doc_embeds_slice = doc_embeds[curr_pos : curr_pos + self.scoring_batch_size]
doc_embeds_slice = doc_embeds_slice.to(self.main_device)
with torch.no_grad():
with torch.inference_mode():
slice_scores = torch.matmul(doc_embeds_slice, query_emb.T).cpu()
slice_scores = slice_scores.squeeze(dim=1)
slice_scores = slice_scores.numpy().tolist()

View File

@ -771,7 +771,7 @@ class DistillationDataSilo(DataSilo):
teacher_outputs: List[List[Tuple[torch.Tensor, ...]]],
tensor_names: List[str],
):
with torch.no_grad():
with torch.inference_mode():
batch_transposed = zip(*batch) # transpose dimensions (from batch, features, ... to features, batch, ...)
batch_transposed_list = [torch.stack(b) for b in batch_transposed] # create tensors for each feature
batch_dict = {

View File

@ -77,7 +77,7 @@ class Evaluator:
else:
module = model
with torch.no_grad():
with torch.inference_mode():
if isinstance(module, AdaptiveModel):
logits = model.forward(
input_ids=batch.get("input_ids", None),

View File

@ -365,7 +365,7 @@ class Inferencer:
batch_samples = samples[i * self.batch_size : (i + 1) * self.batch_size]
# get logits
with torch.no_grad():
with torch.inference_mode():
logits = self.model.forward(**batch)
preds = self.model.formatted_preds(
logits=logits, samples=batch_samples, padding_mask=batch.get("padding_mask", None)
@ -402,7 +402,7 @@ class Inferencer:
batch = {key: batch[key].to(self.devices[0]) for key in batch}
# get logits
with torch.no_grad():
with torch.inference_mode():
# Aggregation works on preds, not logits. We want as much processing happening in one batch + on GPU
# So we transform logits to preds here as well
logits = self.model.forward(

View File

@ -778,7 +778,7 @@ class ONNXAdaptiveModel(BaseAdaptiveModel):
:param kwargs: All arguments that need to be passed on to the model.
:return: All logits as torch.tensor or multiple tensors.
"""
with torch.no_grad():
with torch.inference_mode():
if self.language_model_class == "Bert":
input_to_onnx = {
"input_ids": numpy.ascontiguousarray(kwargs["input_ids"].cpu().numpy()),

View File

@ -129,7 +129,7 @@ class SentenceTransformersRanker(BaseRanker):
# 1. the logit as similarity score/answerable classification
# 2. the logits as answerable classification (no_answer / has_answer)
# https://www.sbert.net/docs/pretrained-models/ce-msmarco.html#usage-with-transformers
with torch.no_grad():
with torch.inference_mode():
similarity_scores = self.transformer_model(**features).logits
logits_dim = similarity_scores.shape[1] # [batch_size, logits_dim]
@ -216,7 +216,7 @@ class SentenceTransformersRanker(BaseRanker):
cur_queries, [doc.content for doc in cur_docs], padding=True, truncation=True, return_tensors="pt"
).to(self.devices[0])
with torch.no_grad():
with torch.inference_mode():
similarity_scores = self.transformer_model(**features).logits
preds.extend(similarity_scores)
pb.update(len(cur_docs))

View File

@ -304,7 +304,7 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
for i, batch in enumerate(tqdm(dataloader, desc=f"Creating Embeddings", unit=" Batches", disable=disable_tqdm)):
batch = {key: batch[key].to(self.embedding_model.device) for key in batch}
with torch.no_grad():
with torch.inference_mode():
q_reps = (
self.embedding_model.embed_questions(
input_ids=batch["input_ids"], attention_mask=batch["padding_mask"]
@ -331,7 +331,7 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
for i, batch in enumerate(tqdm(dataloader, desc=f"Creating Embeddings", unit=" Batches", disable=disable_tqdm)):
batch = {key: batch[key].to(self.embedding_model.device) for key in batch}
with torch.no_grad():
with torch.inference_mode():
q_reps = (
self.embedding_model.embed_answers(
input_ids=batch["input_ids"], attention_mask=batch["padding_mask"]

View File

@ -528,7 +528,7 @@ class DensePassageRetriever(DenseRetriever):
batch = {key: raw_batch[key].to(self.devices[0]) for key in raw_batch}
# get logits
with torch.no_grad():
with torch.inference_mode():
query_embeddings, passage_embeddings = self.model.forward(
query_input_ids=batch.get("query_input_ids", None),
query_segment_ids=batch.get("query_segment_ids", None),
@ -1171,7 +1171,7 @@ class TableTextRetriever(DenseRetriever):
batch = {key: batch[key].to(self.devices[0]) for key in batch}
# get logits
with torch.no_grad():
with torch.inference_mode():
query_embeddings, passage_embeddings = self.model.forward(**batch)[0]
if query_embeddings is not None:
query_embeddings_batched.append(query_embeddings.cpu().numpy())

View File

@ -77,7 +77,7 @@ def load_glove(
id_word_mapping[i] = split[0]
vector_list.append(torch.tensor([float(x) for x in split[1:]]))
vectors = torch.stack(vector_list)
with torch.no_grad():
with torch.inference_mode():
vectors = vectors.to(device)
vectors = F.normalize(vectors, dim=1)
return word_id_mapping, id_word_mapping, vectors
@ -132,7 +132,7 @@ def get_replacements(
inputs.append((input_ids_, subword_index))
# doing batched forward pass
with torch.no_grad():
with torch.inference_mode():
prediction_list = []
while len(inputs) != 0:
batch_list, token_indices = tuple(zip(*inputs[:batch_size]))
@ -165,7 +165,7 @@ def get_replacements(
elif word in glove_word_id_mapping: # word was split into subwords so we use glove instead
word_id = glove_word_id_mapping[word]
glove_vector = glove_vectors[word_id]
with torch.no_grad():
with torch.inference_mode():
word_similarities = torch.mm(glove_vectors, glove_vector.unsqueeze(1)).squeeze(1)
ranking = torch.argsort(word_similarities, descending=True)[: word_possibilities + 1]
possible_words.append([glove_id_word_mapping[int(id_)] for id_ in ranking.cpu()])

View File

@ -829,7 +829,7 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path: Path, query_and_pa
batch = {key: batch[key].to(device) for key in batch}
# get logits
with torch.no_grad():
with torch.inference_mode():
query_embeddings, passage_embeddings = model.forward(
query_input_ids=batch.get("query_input_ids", None),
query_segment_ids=batch.get("query_segment_ids", None),
@ -863,7 +863,7 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path: Path, query_and_pa
batch = {key: batch[key].to(device) for key in batch}
# get logits
with torch.no_grad():
with torch.inference_mode():
query_embeddings, passage_embeddings = loaded_model.forward(
query_input_ids=batch.get("query_input_ids", None),
query_segment_ids=batch.get("query_segment_ids", None),
@ -952,7 +952,7 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path: Path, query_and_pa
batch = {key: batch[key].to(device) for key in batch}
# get logits
with torch.no_grad():
with torch.inference_mode():
query_embeddings, passage_embeddings = loaded_model.forward(
query_input_ids=batch.get("query_input_ids", None),
query_segment_ids=batch.get("query_segment_ids", None),