mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-08 13:54:31 +00:00
Move DPR embeddings from GPU to CPU straight away (#618)
* Start * Move embeddings from gpu to cpu
This commit is contained in:
parent
ae530c3a41
commit
09690b84b4
@ -159,7 +159,7 @@ class DensePassageRetriever(BaseRetriever):
|
|||||||
data_loader = NamedDataLoader(
|
data_loader = NamedDataLoader(
|
||||||
dataset=dataset, sampler=SequentialSampler(dataset), batch_size=self.batch_size, tensor_names=tensor_names
|
dataset=dataset, sampler=SequentialSampler(dataset), batch_size=self.batch_size, tensor_names=tensor_names
|
||||||
)
|
)
|
||||||
all_embeddings = {"query": torch.tensor([]).to(self.device), "passages": torch.tensor([]).to(self.device)}
|
all_embeddings = {"query": [], "passages": []}
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
for i, batch in enumerate(tqdm(data_loader, desc=f"Inferencing Samples", unit=" Batches", disable=False)):
|
for i, batch in enumerate(tqdm(data_loader, desc=f"Inferencing Samples", unit=" Batches", disable=False)):
|
||||||
batch = {key: batch[key].to(self.device) for key in batch}
|
batch = {key: batch[key].to(self.device) for key in batch}
|
||||||
@ -167,14 +167,15 @@ class DensePassageRetriever(BaseRetriever):
|
|||||||
# get logits
|
# get logits
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
query_embeddings, passage_embeddings = self.model.forward(**batch)[0]
|
query_embeddings, passage_embeddings = self.model.forward(**batch)[0]
|
||||||
all_embeddings["query"] = torch.cat((all_embeddings["query"], query_embeddings), dim=0) \
|
if query_embeddings is not None:
|
||||||
if isinstance(query_embeddings, torch.Tensor) else None
|
all_embeddings["query"].append(query_embeddings.cpu().numpy())
|
||||||
all_embeddings["passages"] = torch.cat((all_embeddings["passages"], passage_embeddings), dim=0) \
|
if passage_embeddings is not None:
|
||||||
if isinstance(passage_embeddings, torch.Tensor) else None
|
all_embeddings["passages"].append(passage_embeddings.cpu().numpy())
|
||||||
|
|
||||||
# convert embeddings to numpy array
|
if all_embeddings["passages"]:
|
||||||
for k, v in all_embeddings.items():
|
all_embeddings["passages"] = np.concatenate(all_embeddings["passages"])
|
||||||
all_embeddings[k] = v.cpu().numpy() if v!=None else None
|
if all_embeddings["query"]:
|
||||||
|
all_embeddings["query"] = np.concatenate(all_embeddings["query"])
|
||||||
return all_embeddings
|
return all_embeddings
|
||||||
|
|
||||||
def embed_queries(self, texts: List[str]) -> List[np.array]:
|
def embed_queries(self, texts: List[str]) -> List[np.array]:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user