Move DPR embeddings from GPU to CPU straight away (#618)

* Start

* Move embeddings from gpu to cpu
This commit is contained in:
Branden Chan 2020-11-25 14:22:43 +01:00 committed by GitHub
parent ae530c3a41
commit 09690b84b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -159,7 +159,7 @@ class DensePassageRetriever(BaseRetriever):
data_loader = NamedDataLoader( data_loader = NamedDataLoader(
dataset=dataset, sampler=SequentialSampler(dataset), batch_size=self.batch_size, tensor_names=tensor_names dataset=dataset, sampler=SequentialSampler(dataset), batch_size=self.batch_size, tensor_names=tensor_names
) )
all_embeddings = {"query": torch.tensor([]).to(self.device), "passages": torch.tensor([]).to(self.device)} all_embeddings = {"query": [], "passages": []}
self.model.eval() self.model.eval()
for i, batch in enumerate(tqdm(data_loader, desc=f"Inferencing Samples", unit=" Batches", disable=False)): for i, batch in enumerate(tqdm(data_loader, desc=f"Inferencing Samples", unit=" Batches", disable=False)):
batch = {key: batch[key].to(self.device) for key in batch} batch = {key: batch[key].to(self.device) for key in batch}
@ -167,14 +167,15 @@ class DensePassageRetriever(BaseRetriever):
# get logits # get logits
with torch.no_grad(): with torch.no_grad():
query_embeddings, passage_embeddings = self.model.forward(**batch)[0] query_embeddings, passage_embeddings = self.model.forward(**batch)[0]
all_embeddings["query"] = torch.cat((all_embeddings["query"], query_embeddings), dim=0) \ if query_embeddings is not None:
if isinstance(query_embeddings, torch.Tensor) else None all_embeddings["query"].append(query_embeddings.cpu().numpy())
all_embeddings["passages"] = torch.cat((all_embeddings["passages"], passage_embeddings), dim=0) \ if passage_embeddings is not None:
if isinstance(passage_embeddings, torch.Tensor) else None all_embeddings["passages"].append(passage_embeddings.cpu().numpy())
# convert embeddings to numpy array if all_embeddings["passages"]:
for k, v in all_embeddings.items(): all_embeddings["passages"] = np.concatenate(all_embeddings["passages"])
all_embeddings[k] = v.cpu().numpy() if v!=None else None if all_embeddings["query"]:
all_embeddings["query"] = np.concatenate(all_embeddings["query"])
return all_embeddings return all_embeddings
def embed_queries(self, texts: List[str]) -> List[np.array]: def embed_queries(self, texts: List[str]) -> List[np.array]: