diff --git a/haystack/retriever/dense.py b/haystack/retriever/dense.py index 1c65854f8..c2b63db0c 100644 --- a/haystack/retriever/dense.py +++ b/haystack/retriever/dense.py @@ -142,7 +142,7 @@ class DensePassageRetriever(BaseRetriever): documents = self.document_store.query_by_embedding(query_emb=query_emb[0], top_k=top_k, filters=filters, index=index) return documents - def _get_predictions(self, dicts, tokenizer): + def _get_predictions(self, dicts): """ Feed a preprocessed dataset to the model and get the actual predictions (forward pass + formatting). @@ -170,7 +170,7 @@ class DensePassageRetriever(BaseRetriever): ) all_embeddings = {"query": [], "passages": []} self.model.eval() - for i, batch in enumerate(tqdm(data_loader, desc=f"Inferencing Samples", unit=" Batches", disable=False)): + for i, batch in enumerate(tqdm(data_loader, desc=f"Creating Embeddings", unit=" Batches", disable=False)): batch = {key: batch[key].to(self.device) for key in batch} # get logits @@ -195,7 +195,7 @@ class DensePassageRetriever(BaseRetriever): :return: Embeddings, one per input queries """ queries = [{'query': q} for q in texts] - result = self._get_predictions(queries, self.query_tokenizer)["query"] + result = self._get_predictions(queries)["query"] return result def embed_passages(self, docs: List[Document]) -> List[np.array]: @@ -211,7 +211,7 @@ class DensePassageRetriever(BaseRetriever): "label": d.meta["label"] if d.meta and "label" in d.meta else "positive", "external_id": d.id}] } for d in docs] - embeddings = self._get_predictions(passages, self.passage_tokenizer)["passages"] + embeddings = self._get_predictions(passages)["passages"] return embeddings