Refactor DensePassageRetriever._get_predictions (#642)

This commit is contained in:
Malte Pietsch 2020-12-01 09:22:15 +01:00 committed by GitHub
parent 5e62e54875
commit a9107d29eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -142,7 +142,7 @@ class DensePassageRetriever(BaseRetriever):
documents = self.document_store.query_by_embedding(query_emb=query_emb[0], top_k=top_k, filters=filters, index=index)
return documents
def _get_predictions(self, dicts, tokenizer):
def _get_predictions(self, dicts):
"""
Feed a preprocessed dataset to the model and get the actual predictions (forward pass + formatting).
@ -170,7 +170,7 @@ class DensePassageRetriever(BaseRetriever):
)
all_embeddings = {"query": [], "passages": []}
self.model.eval()
for i, batch in enumerate(tqdm(data_loader, desc=f"Inferencing Samples", unit=" Batches", disable=False)):
for i, batch in enumerate(tqdm(data_loader, desc=f"Creating Embeddings", unit=" Batches", disable=False)):
batch = {key: batch[key].to(self.device) for key in batch}
# get logits
@ -195,7 +195,7 @@ class DensePassageRetriever(BaseRetriever):
:return: Embeddings, one per input queries
"""
queries = [{'query': q} for q in texts]
result = self._get_predictions(queries, self.query_tokenizer)["query"]
result = self._get_predictions(queries)["query"]
return result
def embed_passages(self, docs: List[Document]) -> List[np.array]:
@ -211,7 +211,7 @@ class DensePassageRetriever(BaseRetriever):
"label": d.meta["label"] if d.meta and "label" in d.meta else "positive",
"external_id": d.id}]
} for d in docs]
embeddings = self._get_predictions(passages, self.passage_tokenizer)["passages"]
embeddings = self._get_predictions(passages)["passages"]
return embeddings