mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-25 06:26:05 +00:00
Refactor DensePassageRetriever._get_predictions (#642)
This commit is contained in:
parent
5e62e54875
commit
a9107d29eb
@ -142,7 +142,7 @@ class DensePassageRetriever(BaseRetriever):
|
|||||||
documents = self.document_store.query_by_embedding(query_emb=query_emb[0], top_k=top_k, filters=filters, index=index)
|
documents = self.document_store.query_by_embedding(query_emb=query_emb[0], top_k=top_k, filters=filters, index=index)
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
def _get_predictions(self, dicts, tokenizer):
|
def _get_predictions(self, dicts):
|
||||||
"""
|
"""
|
||||||
Feed a preprocessed dataset to the model and get the actual predictions (forward pass + formatting).
|
Feed a preprocessed dataset to the model and get the actual predictions (forward pass + formatting).
|
||||||
|
|
||||||
@ -170,7 +170,7 @@ class DensePassageRetriever(BaseRetriever):
|
|||||||
)
|
)
|
||||||
all_embeddings = {"query": [], "passages": []}
|
all_embeddings = {"query": [], "passages": []}
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
for i, batch in enumerate(tqdm(data_loader, desc=f"Inferencing Samples", unit=" Batches", disable=False)):
|
for i, batch in enumerate(tqdm(data_loader, desc=f"Creating Embeddings", unit=" Batches", disable=False)):
|
||||||
batch = {key: batch[key].to(self.device) for key in batch}
|
batch = {key: batch[key].to(self.device) for key in batch}
|
||||||
|
|
||||||
# get logits
|
# get logits
|
||||||
@ -195,7 +195,7 @@ class DensePassageRetriever(BaseRetriever):
|
|||||||
:return: Embeddings, one per input queries
|
:return: Embeddings, one per input queries
|
||||||
"""
|
"""
|
||||||
queries = [{'query': q} for q in texts]
|
queries = [{'query': q} for q in texts]
|
||||||
result = self._get_predictions(queries, self.query_tokenizer)["query"]
|
result = self._get_predictions(queries)["query"]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def embed_passages(self, docs: List[Document]) -> List[np.array]:
|
def embed_passages(self, docs: List[Document]) -> List[np.array]:
|
||||||
@ -211,7 +211,7 @@ class DensePassageRetriever(BaseRetriever):
|
|||||||
"label": d.meta["label"] if d.meta and "label" in d.meta else "positive",
|
"label": d.meta["label"] if d.meta and "label" in d.meta else "positive",
|
||||||
"external_id": d.id}]
|
"external_id": d.id}]
|
||||||
} for d in docs]
|
} for d in docs]
|
||||||
embeddings = self._get_predictions(passages, self.passage_tokenizer)["passages"]
|
embeddings = self._get_predictions(passages)["passages"]
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user