Refactor DensePassageRetriever._get_predictions (#642)

This commit is contained in:
Malte Pietsch 2020-12-01 09:22:15 +01:00 committed by GitHub
parent 5e62e54875
commit a9107d29eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -142,7 +142,7 @@ class DensePassageRetriever(BaseRetriever):
documents = self.document_store.query_by_embedding(query_emb=query_emb[0], top_k=top_k, filters=filters, index=index) documents = self.document_store.query_by_embedding(query_emb=query_emb[0], top_k=top_k, filters=filters, index=index)
return documents return documents
def _get_predictions(self, dicts, tokenizer): def _get_predictions(self, dicts):
""" """
Feed a preprocessed dataset to the model and get the actual predictions (forward pass + formatting). Feed a preprocessed dataset to the model and get the actual predictions (forward pass + formatting).
@ -170,7 +170,7 @@ class DensePassageRetriever(BaseRetriever):
) )
all_embeddings = {"query": [], "passages": []} all_embeddings = {"query": [], "passages": []}
self.model.eval() self.model.eval()
for i, batch in enumerate(tqdm(data_loader, desc=f"Inferencing Samples", unit=" Batches", disable=False)): for i, batch in enumerate(tqdm(data_loader, desc=f"Creating Embeddings", unit=" Batches", disable=False)):
batch = {key: batch[key].to(self.device) for key in batch} batch = {key: batch[key].to(self.device) for key in batch}
# get logits # get logits
@ -195,7 +195,7 @@ class DensePassageRetriever(BaseRetriever):
:return: Embeddings, one per input queries :return: Embeddings, one per input queries
""" """
queries = [{'query': q} for q in texts] queries = [{'query': q} for q in texts]
result = self._get_predictions(queries, self.query_tokenizer)["query"] result = self._get_predictions(queries)["query"]
return result return result
def embed_passages(self, docs: List[Document]) -> List[np.array]: def embed_passages(self, docs: List[Document]) -> List[np.array]:
@ -211,7 +211,7 @@ class DensePassageRetriever(BaseRetriever):
"label": d.meta["label"] if d.meta and "label" in d.meta else "positive", "label": d.meta["label"] if d.meta and "label" in d.meta else "positive",
"external_id": d.id}] "external_id": d.id}]
} for d in docs] } for d in docs]
embeddings = self._get_predictions(passages, self.passage_tokenizer)["passages"] embeddings = self._get_predictions(passages)["passages"]
return embeddings return embeddings