mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-03 21:33:40 +00:00
Add FAISS query scores (#368)
This commit is contained in:
parent
9d93ffbe54
commit
c0c2865e58
@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union, List, Optional
|
from typing import Union, List, Optional, Dict
|
||||||
|
|
||||||
import faiss
|
import faiss
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -149,13 +149,18 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
|
|
||||||
aux_dim = np.zeros(len(query_emb), dtype="float32")
|
aux_dim = np.zeros(len(query_emb), dtype="float32")
|
||||||
hnsw_vectors = np.hstack((query_emb, aux_dim.reshape(-1, 1)))
|
hnsw_vectors = np.hstack((query_emb, aux_dim.reshape(-1, 1)))
|
||||||
_, vector_id_matrix = self.faiss_index.search(hnsw_vectors, top_k)
|
score_matrix, vector_id_matrix = self.faiss_index.search(hnsw_vectors, top_k)
|
||||||
vector_ids_for_query = [str(vector_id) for vector_id in vector_id_matrix[0] if vector_id != -1]
|
vector_ids_for_query = [str(vector_id) for vector_id in vector_id_matrix[0] if vector_id != -1]
|
||||||
|
|
||||||
documents = self.get_all_documents(filters={"vector_id": vector_ids_for_query}, index=index)
|
documents = self.get_all_documents(filters={"vector_id": vector_ids_for_query}, index=index)
|
||||||
# sort the documents as per query results
|
# sort the documents as per query results
|
||||||
documents = sorted(documents, key=lambda doc: vector_ids_for_query.index(doc.meta["vector_id"])) # type: ignore
|
documents = sorted(documents, key=lambda doc: vector_ids_for_query.index(doc.meta["vector_id"])) # type: ignore
|
||||||
|
|
||||||
|
# assign query score to each document
|
||||||
|
scores_for_vector_ids: Dict[str, float] = {str(v_id): s for v_id, s in zip(vector_id_matrix[0], score_matrix[0])}
|
||||||
|
for doc in documents:
|
||||||
|
doc.query_score = scores_for_vector_ids[doc.meta["vector_id"]] # type: ignore
|
||||||
|
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
def save(self, file_path: Union[str, Path]):
|
def save(self, file_path: Union[str, Path]):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user