mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-02 02:39:51 +00:00
Refactor DensePassageRetriever._get_predictions (#642)
This commit is contained in:
parent
5e62e54875
commit
a9107d29eb
@ -142,7 +142,7 @@ class DensePassageRetriever(BaseRetriever):
|
||||
documents = self.document_store.query_by_embedding(query_emb=query_emb[0], top_k=top_k, filters=filters, index=index)
|
||||
return documents
|
||||
|
||||
def _get_predictions(self, dicts, tokenizer):
|
||||
def _get_predictions(self, dicts):
|
||||
"""
|
||||
Feed a preprocessed dataset to the model and get the actual predictions (forward pass + formatting).
|
||||
|
||||
@ -170,7 +170,7 @@ class DensePassageRetriever(BaseRetriever):
|
||||
)
|
||||
all_embeddings = {"query": [], "passages": []}
|
||||
self.model.eval()
|
||||
for i, batch in enumerate(tqdm(data_loader, desc=f"Inferencing Samples", unit=" Batches", disable=False)):
|
||||
for i, batch in enumerate(tqdm(data_loader, desc=f"Creating Embeddings", unit=" Batches", disable=False)):
|
||||
batch = {key: batch[key].to(self.device) for key in batch}
|
||||
|
||||
# get logits
|
||||
@ -195,7 +195,7 @@ class DensePassageRetriever(BaseRetriever):
|
||||
:return: Embeddings, one per input queries
|
||||
"""
|
||||
queries = [{'query': q} for q in texts]
|
||||
result = self._get_predictions(queries, self.query_tokenizer)["query"]
|
||||
result = self._get_predictions(queries)["query"]
|
||||
return result
|
||||
|
||||
def embed_passages(self, docs: List[Document]) -> List[np.array]:
|
||||
@ -211,7 +211,7 @@ class DensePassageRetriever(BaseRetriever):
|
||||
"label": d.meta["label"] if d.meta and "label" in d.meta else "positive",
|
||||
"external_id": d.id}]
|
||||
} for d in docs]
|
||||
embeddings = self._get_predictions(passages, self.passage_tokenizer)["passages"]
|
||||
embeddings = self._get_predictions(passages)["passages"]
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user