Add FAISS query scores (#368)

This commit is contained in:
Tanay Soni 2020-09-11 13:59:38 +02:00 committed by GitHub
parent 9d93ffbe54
commit c0c2865e58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,6 @@
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Union, List, Optional from typing import Union, List, Optional, Dict
import faiss import faiss
import numpy as np import numpy as np
@ -149,13 +149,18 @@ class FAISSDocumentStore(SQLDocumentStore):
aux_dim = np.zeros(len(query_emb), dtype="float32") aux_dim = np.zeros(len(query_emb), dtype="float32")
hnsw_vectors = np.hstack((query_emb, aux_dim.reshape(-1, 1))) hnsw_vectors = np.hstack((query_emb, aux_dim.reshape(-1, 1)))
_, vector_id_matrix = self.faiss_index.search(hnsw_vectors, top_k) score_matrix, vector_id_matrix = self.faiss_index.search(hnsw_vectors, top_k)
vector_ids_for_query = [str(vector_id) for vector_id in vector_id_matrix[0] if vector_id != -1] vector_ids_for_query = [str(vector_id) for vector_id in vector_id_matrix[0] if vector_id != -1]
documents = self.get_all_documents(filters={"vector_id": vector_ids_for_query}, index=index) documents = self.get_all_documents(filters={"vector_id": vector_ids_for_query}, index=index)
# sort the documents as per query results # sort the documents as per query results
documents = sorted(documents, key=lambda doc: vector_ids_for_query.index(doc.meta["vector_id"])) # type: ignore documents = sorted(documents, key=lambda doc: vector_ids_for_query.index(doc.meta["vector_id"])) # type: ignore
# assign query score to each document
scores_for_vector_ids: Dict[str, float] = {str(v_id): s for v_id, s in zip(vector_id_matrix[0], score_matrix[0])}
for doc in documents:
doc.query_score = scores_for_vector_ids[doc.meta["vector_id"]] # type: ignore
return documents return documents
def save(self, file_path: Union[str, Path]): def save(self, file_path: Union[str, Path]):