mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-16 10:03:44 +00:00
refactor: replace torch.no_grad with torch.inference_mode (where possible) (#3601)
* try to replace torch.no_grad * revert erroneous change * revert other module breaking * revert training/base
This commit is contained in:
parent
3040e59c63
commit
f43bc562d3
@ -305,7 +305,7 @@ class InMemoryDocumentStore(KeywordDocumentStore):
|
|||||||
while curr_pos < len(doc_embeds):
|
while curr_pos < len(doc_embeds):
|
||||||
doc_embeds_slice = doc_embeds[curr_pos : curr_pos + self.scoring_batch_size]
|
doc_embeds_slice = doc_embeds[curr_pos : curr_pos + self.scoring_batch_size]
|
||||||
doc_embeds_slice = doc_embeds_slice.to(self.main_device)
|
doc_embeds_slice = doc_embeds_slice.to(self.main_device)
|
||||||
with torch.no_grad():
|
with torch.inference_mode():
|
||||||
slice_scores = torch.matmul(doc_embeds_slice, query_emb.T).cpu()
|
slice_scores = torch.matmul(doc_embeds_slice, query_emb.T).cpu()
|
||||||
slice_scores = slice_scores.squeeze(dim=1)
|
slice_scores = slice_scores.squeeze(dim=1)
|
||||||
slice_scores = slice_scores.numpy().tolist()
|
slice_scores = slice_scores.numpy().tolist()
|
||||||
|
|||||||
@ -771,7 +771,7 @@ class DistillationDataSilo(DataSilo):
|
|||||||
teacher_outputs: List[List[Tuple[torch.Tensor, ...]]],
|
teacher_outputs: List[List[Tuple[torch.Tensor, ...]]],
|
||||||
tensor_names: List[str],
|
tensor_names: List[str],
|
||||||
):
|
):
|
||||||
with torch.no_grad():
|
with torch.inference_mode():
|
||||||
batch_transposed = zip(*batch) # transpose dimensions (from batch, features, ... to features, batch, ...)
|
batch_transposed = zip(*batch) # transpose dimensions (from batch, features, ... to features, batch, ...)
|
||||||
batch_transposed_list = [torch.stack(b) for b in batch_transposed] # create tensors for each feature
|
batch_transposed_list = [torch.stack(b) for b in batch_transposed] # create tensors for each feature
|
||||||
batch_dict = {
|
batch_dict = {
|
||||||
|
|||||||
@ -77,7 +77,7 @@ class Evaluator:
|
|||||||
else:
|
else:
|
||||||
module = model
|
module = model
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.inference_mode():
|
||||||
if isinstance(module, AdaptiveModel):
|
if isinstance(module, AdaptiveModel):
|
||||||
logits = model.forward(
|
logits = model.forward(
|
||||||
input_ids=batch.get("input_ids", None),
|
input_ids=batch.get("input_ids", None),
|
||||||
|
|||||||
@ -365,7 +365,7 @@ class Inferencer:
|
|||||||
batch_samples = samples[i * self.batch_size : (i + 1) * self.batch_size]
|
batch_samples = samples[i * self.batch_size : (i + 1) * self.batch_size]
|
||||||
|
|
||||||
# get logits
|
# get logits
|
||||||
with torch.no_grad():
|
with torch.inference_mode():
|
||||||
logits = self.model.forward(**batch)
|
logits = self.model.forward(**batch)
|
||||||
preds = self.model.formatted_preds(
|
preds = self.model.formatted_preds(
|
||||||
logits=logits, samples=batch_samples, padding_mask=batch.get("padding_mask", None)
|
logits=logits, samples=batch_samples, padding_mask=batch.get("padding_mask", None)
|
||||||
@ -402,7 +402,7 @@ class Inferencer:
|
|||||||
batch = {key: batch[key].to(self.devices[0]) for key in batch}
|
batch = {key: batch[key].to(self.devices[0]) for key in batch}
|
||||||
|
|
||||||
# get logits
|
# get logits
|
||||||
with torch.no_grad():
|
with torch.inference_mode():
|
||||||
# Aggregation works on preds, not logits. We want as much processing happening in one batch + on GPU
|
# Aggregation works on preds, not logits. We want as much processing happening in one batch + on GPU
|
||||||
# So we transform logits to preds here as well
|
# So we transform logits to preds here as well
|
||||||
logits = self.model.forward(
|
logits = self.model.forward(
|
||||||
|
|||||||
@ -778,7 +778,7 @@ class ONNXAdaptiveModel(BaseAdaptiveModel):
|
|||||||
:param kwargs: All arguments that need to be passed on to the model.
|
:param kwargs: All arguments that need to be passed on to the model.
|
||||||
:return: All logits as torch.tensor or multiple tensors.
|
:return: All logits as torch.tensor or multiple tensors.
|
||||||
"""
|
"""
|
||||||
with torch.no_grad():
|
with torch.inference_mode():
|
||||||
if self.language_model_class == "Bert":
|
if self.language_model_class == "Bert":
|
||||||
input_to_onnx = {
|
input_to_onnx = {
|
||||||
"input_ids": numpy.ascontiguousarray(kwargs["input_ids"].cpu().numpy()),
|
"input_ids": numpy.ascontiguousarray(kwargs["input_ids"].cpu().numpy()),
|
||||||
|
|||||||
@ -129,7 +129,7 @@ class SentenceTransformersRanker(BaseRanker):
|
|||||||
# 1. the logit as similarity score/answerable classification
|
# 1. the logit as similarity score/answerable classification
|
||||||
# 2. the logits as answerable classification (no_answer / has_answer)
|
# 2. the logits as answerable classification (no_answer / has_answer)
|
||||||
# https://www.sbert.net/docs/pretrained-models/ce-msmarco.html#usage-with-transformers
|
# https://www.sbert.net/docs/pretrained-models/ce-msmarco.html#usage-with-transformers
|
||||||
with torch.no_grad():
|
with torch.inference_mode():
|
||||||
similarity_scores = self.transformer_model(**features).logits
|
similarity_scores = self.transformer_model(**features).logits
|
||||||
|
|
||||||
logits_dim = similarity_scores.shape[1] # [batch_size, logits_dim]
|
logits_dim = similarity_scores.shape[1] # [batch_size, logits_dim]
|
||||||
@ -216,7 +216,7 @@ class SentenceTransformersRanker(BaseRanker):
|
|||||||
cur_queries, [doc.content for doc in cur_docs], padding=True, truncation=True, return_tensors="pt"
|
cur_queries, [doc.content for doc in cur_docs], padding=True, truncation=True, return_tensors="pt"
|
||||||
).to(self.devices[0])
|
).to(self.devices[0])
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.inference_mode():
|
||||||
similarity_scores = self.transformer_model(**features).logits
|
similarity_scores = self.transformer_model(**features).logits
|
||||||
preds.extend(similarity_scores)
|
preds.extend(similarity_scores)
|
||||||
pb.update(len(cur_docs))
|
pb.update(len(cur_docs))
|
||||||
|
|||||||
@ -304,7 +304,7 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
|
|||||||
|
|
||||||
for i, batch in enumerate(tqdm(dataloader, desc=f"Creating Embeddings", unit=" Batches", disable=disable_tqdm)):
|
for i, batch in enumerate(tqdm(dataloader, desc=f"Creating Embeddings", unit=" Batches", disable=disable_tqdm)):
|
||||||
batch = {key: batch[key].to(self.embedding_model.device) for key in batch}
|
batch = {key: batch[key].to(self.embedding_model.device) for key in batch}
|
||||||
with torch.no_grad():
|
with torch.inference_mode():
|
||||||
q_reps = (
|
q_reps = (
|
||||||
self.embedding_model.embed_questions(
|
self.embedding_model.embed_questions(
|
||||||
input_ids=batch["input_ids"], attention_mask=batch["padding_mask"]
|
input_ids=batch["input_ids"], attention_mask=batch["padding_mask"]
|
||||||
@ -331,7 +331,7 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
|
|||||||
|
|
||||||
for i, batch in enumerate(tqdm(dataloader, desc=f"Creating Embeddings", unit=" Batches", disable=disable_tqdm)):
|
for i, batch in enumerate(tqdm(dataloader, desc=f"Creating Embeddings", unit=" Batches", disable=disable_tqdm)):
|
||||||
batch = {key: batch[key].to(self.embedding_model.device) for key in batch}
|
batch = {key: batch[key].to(self.embedding_model.device) for key in batch}
|
||||||
with torch.no_grad():
|
with torch.inference_mode():
|
||||||
q_reps = (
|
q_reps = (
|
||||||
self.embedding_model.embed_answers(
|
self.embedding_model.embed_answers(
|
||||||
input_ids=batch["input_ids"], attention_mask=batch["padding_mask"]
|
input_ids=batch["input_ids"], attention_mask=batch["padding_mask"]
|
||||||
|
|||||||
@ -528,7 +528,7 @@ class DensePassageRetriever(DenseRetriever):
|
|||||||
batch = {key: raw_batch[key].to(self.devices[0]) for key in raw_batch}
|
batch = {key: raw_batch[key].to(self.devices[0]) for key in raw_batch}
|
||||||
|
|
||||||
# get logits
|
# get logits
|
||||||
with torch.no_grad():
|
with torch.inference_mode():
|
||||||
query_embeddings, passage_embeddings = self.model.forward(
|
query_embeddings, passage_embeddings = self.model.forward(
|
||||||
query_input_ids=batch.get("query_input_ids", None),
|
query_input_ids=batch.get("query_input_ids", None),
|
||||||
query_segment_ids=batch.get("query_segment_ids", None),
|
query_segment_ids=batch.get("query_segment_ids", None),
|
||||||
@ -1171,7 +1171,7 @@ class TableTextRetriever(DenseRetriever):
|
|||||||
batch = {key: batch[key].to(self.devices[0]) for key in batch}
|
batch = {key: batch[key].to(self.devices[0]) for key in batch}
|
||||||
|
|
||||||
# get logits
|
# get logits
|
||||||
with torch.no_grad():
|
with torch.inference_mode():
|
||||||
query_embeddings, passage_embeddings = self.model.forward(**batch)[0]
|
query_embeddings, passage_embeddings = self.model.forward(**batch)[0]
|
||||||
if query_embeddings is not None:
|
if query_embeddings is not None:
|
||||||
query_embeddings_batched.append(query_embeddings.cpu().numpy())
|
query_embeddings_batched.append(query_embeddings.cpu().numpy())
|
||||||
|
|||||||
@ -77,7 +77,7 @@ def load_glove(
|
|||||||
id_word_mapping[i] = split[0]
|
id_word_mapping[i] = split[0]
|
||||||
vector_list.append(torch.tensor([float(x) for x in split[1:]]))
|
vector_list.append(torch.tensor([float(x) for x in split[1:]]))
|
||||||
vectors = torch.stack(vector_list)
|
vectors = torch.stack(vector_list)
|
||||||
with torch.no_grad():
|
with torch.inference_mode():
|
||||||
vectors = vectors.to(device)
|
vectors = vectors.to(device)
|
||||||
vectors = F.normalize(vectors, dim=1)
|
vectors = F.normalize(vectors, dim=1)
|
||||||
return word_id_mapping, id_word_mapping, vectors
|
return word_id_mapping, id_word_mapping, vectors
|
||||||
@ -132,7 +132,7 @@ def get_replacements(
|
|||||||
inputs.append((input_ids_, subword_index))
|
inputs.append((input_ids_, subword_index))
|
||||||
|
|
||||||
# doing batched forward pass
|
# doing batched forward pass
|
||||||
with torch.no_grad():
|
with torch.inference_mode():
|
||||||
prediction_list = []
|
prediction_list = []
|
||||||
while len(inputs) != 0:
|
while len(inputs) != 0:
|
||||||
batch_list, token_indices = tuple(zip(*inputs[:batch_size]))
|
batch_list, token_indices = tuple(zip(*inputs[:batch_size]))
|
||||||
@ -165,7 +165,7 @@ def get_replacements(
|
|||||||
elif word in glove_word_id_mapping: # word was split into subwords so we use glove instead
|
elif word in glove_word_id_mapping: # word was split into subwords so we use glove instead
|
||||||
word_id = glove_word_id_mapping[word]
|
word_id = glove_word_id_mapping[word]
|
||||||
glove_vector = glove_vectors[word_id]
|
glove_vector = glove_vectors[word_id]
|
||||||
with torch.no_grad():
|
with torch.inference_mode():
|
||||||
word_similarities = torch.mm(glove_vectors, glove_vector.unsqueeze(1)).squeeze(1)
|
word_similarities = torch.mm(glove_vectors, glove_vector.unsqueeze(1)).squeeze(1)
|
||||||
ranking = torch.argsort(word_similarities, descending=True)[: word_possibilities + 1]
|
ranking = torch.argsort(word_similarities, descending=True)[: word_possibilities + 1]
|
||||||
possible_words.append([glove_id_word_mapping[int(id_)] for id_ in ranking.cpu()])
|
possible_words.append([glove_id_word_mapping[int(id_)] for id_ in ranking.cpu()])
|
||||||
|
|||||||
@ -829,7 +829,7 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path: Path, query_and_pa
|
|||||||
batch = {key: batch[key].to(device) for key in batch}
|
batch = {key: batch[key].to(device) for key in batch}
|
||||||
|
|
||||||
# get logits
|
# get logits
|
||||||
with torch.no_grad():
|
with torch.inference_mode():
|
||||||
query_embeddings, passage_embeddings = model.forward(
|
query_embeddings, passage_embeddings = model.forward(
|
||||||
query_input_ids=batch.get("query_input_ids", None),
|
query_input_ids=batch.get("query_input_ids", None),
|
||||||
query_segment_ids=batch.get("query_segment_ids", None),
|
query_segment_ids=batch.get("query_segment_ids", None),
|
||||||
@ -863,7 +863,7 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path: Path, query_and_pa
|
|||||||
batch = {key: batch[key].to(device) for key in batch}
|
batch = {key: batch[key].to(device) for key in batch}
|
||||||
|
|
||||||
# get logits
|
# get logits
|
||||||
with torch.no_grad():
|
with torch.inference_mode():
|
||||||
query_embeddings, passage_embeddings = loaded_model.forward(
|
query_embeddings, passage_embeddings = loaded_model.forward(
|
||||||
query_input_ids=batch.get("query_input_ids", None),
|
query_input_ids=batch.get("query_input_ids", None),
|
||||||
query_segment_ids=batch.get("query_segment_ids", None),
|
query_segment_ids=batch.get("query_segment_ids", None),
|
||||||
@ -952,7 +952,7 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path: Path, query_and_pa
|
|||||||
batch = {key: batch[key].to(device) for key in batch}
|
batch = {key: batch[key].to(device) for key in batch}
|
||||||
|
|
||||||
# get logits
|
# get logits
|
||||||
with torch.no_grad():
|
with torch.inference_mode():
|
||||||
query_embeddings, passage_embeddings = loaded_model.forward(
|
query_embeddings, passage_embeddings = loaded_model.forward(
|
||||||
query_input_ids=batch.get("query_input_ids", None),
|
query_input_ids=batch.get("query_input_ids", None),
|
||||||
query_segment_ids=batch.get("query_segment_ids", None),
|
query_segment_ids=batch.get("query_segment_ids", None),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user