From c0c2865e58b09570d525e5a778cf0b39e7cc25b9 Mon Sep 17 00:00:00 2001 From: Tanay Soni Date: Fri, 11 Sep 2020 13:59:38 +0200 Subject: [PATCH] Add FAISS query scores (#368) --- haystack/database/faiss.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/haystack/database/faiss.py b/haystack/database/faiss.py index 2376b4484..d54e3b897 100644 --- a/haystack/database/faiss.py +++ b/haystack/database/faiss.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Union, List, Optional +from typing import Union, List, Optional, Dict import faiss import numpy as np @@ -149,13 +149,18 @@ class FAISSDocumentStore(SQLDocumentStore): aux_dim = np.zeros(len(query_emb), dtype="float32") hnsw_vectors = np.hstack((query_emb, aux_dim.reshape(-1, 1))) - _, vector_id_matrix = self.faiss_index.search(hnsw_vectors, top_k) + score_matrix, vector_id_matrix = self.faiss_index.search(hnsw_vectors, top_k) vector_ids_for_query = [str(vector_id) for vector_id in vector_id_matrix[0] if vector_id != -1] documents = self.get_all_documents(filters={"vector_id": vector_ids_for_query}, index=index) # sort the documents as per query results documents = sorted(documents, key=lambda doc: vector_ids_for_query.index(doc.meta["vector_id"])) # type: ignore + # assign query score to each document + scores_for_vector_ids: Dict[str, float] = {str(v_id): s for v_id, s in zip(vector_id_matrix[0], score_matrix[0])} + for doc in documents: + doc.query_score = scores_for_vector_ids[doc.meta["vector_id"]] # type: ignore + return documents def save(self, file_path: Union[str, Path]):