mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-02 18:59:28 +00:00
Move DPR embeddings from GPU to CPU straight away (#618)
* Start * Move embeddings from gpu to cpu
This commit is contained in:
parent
ae530c3a41
commit
09690b84b4
@ -159,7 +159,7 @@ class DensePassageRetriever(BaseRetriever):
|
||||
data_loader = NamedDataLoader(
|
||||
dataset=dataset, sampler=SequentialSampler(dataset), batch_size=self.batch_size, tensor_names=tensor_names
|
||||
)
|
||||
all_embeddings = {"query": torch.tensor([]).to(self.device), "passages": torch.tensor([]).to(self.device)}
|
||||
all_embeddings = {"query": [], "passages": []}
|
||||
self.model.eval()
|
||||
for i, batch in enumerate(tqdm(data_loader, desc=f"Inferencing Samples", unit=" Batches", disable=False)):
|
||||
batch = {key: batch[key].to(self.device) for key in batch}
|
||||
@ -167,14 +167,15 @@ class DensePassageRetriever(BaseRetriever):
|
||||
# get logits
|
||||
with torch.no_grad():
|
||||
query_embeddings, passage_embeddings = self.model.forward(**batch)[0]
|
||||
all_embeddings["query"] = torch.cat((all_embeddings["query"], query_embeddings), dim=0) \
|
||||
if isinstance(query_embeddings, torch.Tensor) else None
|
||||
all_embeddings["passages"] = torch.cat((all_embeddings["passages"], passage_embeddings), dim=0) \
|
||||
if isinstance(passage_embeddings, torch.Tensor) else None
|
||||
if query_embeddings is not None:
|
||||
all_embeddings["query"].append(query_embeddings.cpu().numpy())
|
||||
if passage_embeddings is not None:
|
||||
all_embeddings["passages"].append(passage_embeddings.cpu().numpy())
|
||||
|
||||
# convert embeddings to numpy array
|
||||
for k, v in all_embeddings.items():
|
||||
all_embeddings[k] = v.cpu().numpy() if v!=None else None
|
||||
if all_embeddings["passages"]:
|
||||
all_embeddings["passages"] = np.concatenate(all_embeddings["passages"])
|
||||
if all_embeddings["query"]:
|
||||
all_embeddings["query"] = np.concatenate(all_embeddings["query"])
|
||||
return all_embeddings
|
||||
|
||||
def embed_queries(self, texts: List[str]) -> List[np.array]:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user