mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-07 04:27:15 +00:00
Enable cosine similarity metric in FAISSDocumentStore (#1352)
* feat: normalize embeddings for cosine sim * WIP add test case for faiss cosine * input to faiss normalize needs to be an array of vectors * fix: test should compare correct result embedding to original embedding * add sanity check for cosine sim * fix typo * normalize cosine score * Update docstring Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
This commit is contained in:
parent
5b1b875374
commit
9c4e67d9b6
@ -67,8 +67,11 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
or one with docs that you used in Haystack before and want to load again.
|
||||
:param return_embedding: To return document embedding
|
||||
:param index: Name of index in document store to use.
|
||||
:param similarity: The similarity function used to compare document vectors. 'dot_product' is the default sine it is
|
||||
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence BERT model.
|
||||
:param similarity: The similarity function used to compare document vectors. 'dot_product' is the default since it is
|
||||
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence-Transformer model.
|
||||
In both cases, the returned values in Document.score are normalized to be in range [0,1]:
|
||||
For `dot_product`: expit(np.asarray(raw_score / 100))
|
||||
FOr `cosine`: (raw_score + 1) / 2
|
||||
:param embedding_field: Name of field containing an embedding vector.
|
||||
:param progress_bar: Whether to show a tqdm progress bar or not.
|
||||
Can be helpful to disable in production deployments to keep the logs clean.
|
||||
@ -88,15 +91,15 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
embedding_field=embedding_field, progress_bar=progress_bar
|
||||
)
|
||||
|
||||
if similarity == "dot_product":
|
||||
if similarity == "dot_product" or similarity == 'cosine':
|
||||
self.similarity = similarity
|
||||
self.metric_type = faiss.METRIC_INNER_PRODUCT
|
||||
elif similarity == "l2":
|
||||
self.similarity = similarity
|
||||
self.metric_type = faiss.METRIC_L2
|
||||
else:
|
||||
raise ValueError("The FAISS document store can currently only support dot_product similarity. "
|
||||
"Please set similarity=\"dot_product\"")
|
||||
raise ValueError("The FAISS document store can currently only support dot_product, cosine and l2 similarity. "
|
||||
"Please set similarity to one of the above.")
|
||||
|
||||
self.vector_dim = vector_dim
|
||||
self.faiss_index_factory_str = faiss_index_factory_str
|
||||
@ -184,6 +187,10 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
if add_vectors:
|
||||
embeddings = [doc.embedding for doc in document_objects[i: i + batch_size]]
|
||||
embeddings_to_index = np.array(embeddings, dtype="float32")
|
||||
|
||||
if self.similarity == 'cosine':
|
||||
faiss.normalize_L2(embeddings_to_index)
|
||||
|
||||
self.faiss_indexes[index].add(embeddings_to_index)
|
||||
|
||||
docs_to_write_in_sql = []
|
||||
@ -261,6 +268,10 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
assert len(document_batch) == len(embeddings)
|
||||
|
||||
embeddings_to_index = np.array(embeddings, dtype="float32")
|
||||
|
||||
if self.similarity == 'cosine':
|
||||
faiss.normalize_L2(embeddings_to_index)
|
||||
|
||||
self.faiss_indexes[index].add(embeddings_to_index)
|
||||
|
||||
vector_id_map = {}
|
||||
@ -417,6 +428,10 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
return_embedding = self.return_embedding
|
||||
|
||||
query_emb = query_emb.reshape(1, -1).astype(np.float32)
|
||||
|
||||
if self.similarity == 'cosine':
|
||||
faiss.normalize_L2(query_emb)
|
||||
|
||||
score_matrix, vector_id_matrix = self.faiss_indexes[index].search(query_emb, top_k)
|
||||
vector_ids_for_query = [str(vector_id) for vector_id in vector_id_matrix[0] if vector_id != -1]
|
||||
|
||||
@ -426,7 +441,10 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
scores_for_vector_ids: Dict[str, float] = {str(v_id): s for v_id, s in zip(vector_id_matrix[0], score_matrix[0])}
|
||||
for doc in documents:
|
||||
raw_score = scores_for_vector_ids[doc.meta["vector_id"]]
|
||||
doc.score = float(expit(np.asarray(raw_score / 100)))
|
||||
if self.similarity == 'cosine':
|
||||
doc.score = (raw_score + 1) / 2
|
||||
else:
|
||||
doc.score = float(expit(np.asarray(raw_score / 100)))
|
||||
if return_embedding is True:
|
||||
doc.embedding = self.faiss_indexes[index].reconstruct(int(doc.meta["vector_id"]))
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import faiss
|
||||
import math
|
||||
import numpy as np
|
||||
import pytest
|
||||
from haystack import Document
|
||||
@ -67,7 +68,6 @@ def test_faiss_write_docs(document_store, index_buffer_size, batch_size):
|
||||
# compare original input vec with stored one (ignore extra dim added by hnsw)
|
||||
assert np.allclose(original_doc["embedding"], stored_emb, rtol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
|
||||
@ -205,3 +205,71 @@ def test_faiss_passing_index_from_outside(tmp_path):
|
||||
assert 0 <= int(doc.meta["vector_id"]) <= 7
|
||||
|
||||
|
||||
def test_faiss_cosine_similarity(tmp_path):
|
||||
document_store = FAISSDocumentStore(
|
||||
sql_url=f"sqlite:////{tmp_path/'haystack_test_faiss.db'}", similarity='cosine'
|
||||
)
|
||||
|
||||
# below we will write documents to the store and then query it to see if vectors were normalized
|
||||
|
||||
document_store.write_documents(documents=DOCUMENTS)
|
||||
|
||||
# note that the same query will be used later when querying after updating the embeddings
|
||||
query = np.random.rand(768).astype(np.float32)
|
||||
|
||||
query_results = document_store.query_by_embedding(query_emb=query, top_k=len(DOCUMENTS), return_embedding=True)
|
||||
|
||||
# check if search with cosine similarity returns the correct number of results
|
||||
assert len(query_results) == len(DOCUMENTS)
|
||||
indexed_docs = {}
|
||||
for doc in DOCUMENTS:
|
||||
indexed_docs[doc["text"]] = doc["embedding"]
|
||||
|
||||
for doc in query_results:
|
||||
result_emb = doc.embedding
|
||||
original_emb = np.array([indexed_docs[doc.text]], dtype="float32")
|
||||
faiss.normalize_L2(original_emb)
|
||||
|
||||
# check if the stored embedding was normalized
|
||||
assert np.allclose(original_emb[0], result_emb, rtol=0.01)
|
||||
|
||||
# check if the score is plausible for cosine similarity
|
||||
assert 0 <= doc.score <= 1.0
|
||||
|
||||
# now check if vectors are normalized when updating embeddings
|
||||
class MockRetriever():
|
||||
def embed_passages(self, docs):
|
||||
return [np.random.rand(768).astype(np.float32) for doc in docs]
|
||||
|
||||
retriever = MockRetriever()
|
||||
document_store.update_embeddings(retriever=retriever)
|
||||
query_results = document_store.query_by_embedding(query_emb=query, top_k=len(DOCUMENTS), return_embedding=True)
|
||||
|
||||
for doc in query_results:
|
||||
original_emb = np.array([indexed_docs[doc.text]], dtype="float32")
|
||||
faiss.normalize_L2(original_emb)
|
||||
# check if the original embedding has changed after updating the embeddings
|
||||
assert not np.allclose(original_emb[0], doc.embedding, rtol=0.01)
|
||||
|
||||
|
||||
|
||||
def test_faiss_cosine_sanity_check(tmp_path):
|
||||
document_store = FAISSDocumentStore(
|
||||
sql_url=f"sqlite:////{tmp_path/'haystack_test_faiss.db'}", similarity='cosine',
|
||||
vector_dim=3
|
||||
)
|
||||
|
||||
VEC_1 = np.array([.1, .2, .3], dtype="float32")
|
||||
VEC_2 = np.array([.4, .5, .6], dtype="float32")
|
||||
|
||||
# This is the cosine similarity of VEC_1 and VEC_2 calculated using sklearn.metrics.pairwise.cosine_similarity
|
||||
# The score is normalized to yield a value between 0 and 1.
|
||||
KNOWN_COSINE = (0.9746317 + 1) / 2
|
||||
|
||||
docs = [{"name": "vec_1", "text": "vec_1", "embedding": VEC_1}]
|
||||
document_store.write_documents(documents=docs)
|
||||
|
||||
query_results = document_store.query_by_embedding(query_emb=VEC_2, top_k=1, return_embedding=True)
|
||||
|
||||
# check if faiss returns the same cosine similarity. Manual testing with faiss yielded 0.9746318
|
||||
assert math.isclose(query_results[0].score, KNOWN_COSINE, abs_tol=0.000001)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user