Add FAISS query scores (#368)

This commit is contained in:
Tanay Soni 2020-09-11 13:59:38 +02:00 committed by GitHub
parent 9d93ffbe54
commit c0c2865e58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,6 @@
import logging
from pathlib import Path
from typing import Union, List, Optional
from typing import Union, List, Optional, Dict
import faiss
import numpy as np
@ -149,13 +149,18 @@ class FAISSDocumentStore(SQLDocumentStore):
aux_dim = np.zeros(len(query_emb), dtype="float32")
hnsw_vectors = np.hstack((query_emb, aux_dim.reshape(-1, 1)))
_, vector_id_matrix = self.faiss_index.search(hnsw_vectors, top_k)
score_matrix, vector_id_matrix = self.faiss_index.search(hnsw_vectors, top_k)
vector_ids_for_query = [str(vector_id) for vector_id in vector_id_matrix[0] if vector_id != -1]
documents = self.get_all_documents(filters={"vector_id": vector_ids_for_query}, index=index)
# sort the documents as per query results
documents = sorted(documents, key=lambda doc: vector_ids_for_query.index(doc.meta["vector_id"])) # type: ignore
# assign query score to each document
scores_for_vector_ids: Dict[str, float] = {str(v_id): s for v_id, s in zip(vector_id_matrix[0], score_matrix[0])}
for doc in documents:
doc.query_score = scores_for_vector_ids[doc.meta["vector_id"]] # type: ignore
return documents
def save(self, file_path: Union[str, Path]):