Enable cosine similarity metric in FAISSDocumentStore (#1352)

* feat: normalize embeddings for cosine sim

* WIP add test case for faiss cosine

* input to faiss normalize needs to be an array of vectors

* fix: test should compare correct result embedding to original embedding

* add sanity check for cosine sim

* fix typo

* normalize cosine score

* Update docstring

Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
This commit is contained in:
mathislucka 2021-09-20 07:54:26 +02:00 committed by GitHub
parent 5b1b875374
commit 9c4e67d9b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 93 additions and 7 deletions

View File

@ -67,8 +67,11 @@ class FAISSDocumentStore(SQLDocumentStore):
or one with docs that you used in Haystack before and want to load again.
:param return_embedding: To return document embedding
:param index: Name of index in document store to use.
:param similarity: The similarity function used to compare document vectors. 'dot_product' is the default sine it is
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence BERT model.
:param similarity: The similarity function used to compare document vectors. 'dot_product' is the default since it is
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence-Transformer model.
In both cases, the returned values in Document.score are normalized to be in range [0,1]:
For `dot_product`: expit(np.asarray(raw_score / 100))
FOr `cosine`: (raw_score + 1) / 2
:param embedding_field: Name of field containing an embedding vector.
:param progress_bar: Whether to show a tqdm progress bar or not.
Can be helpful to disable in production deployments to keep the logs clean.
@ -88,15 +91,15 @@ class FAISSDocumentStore(SQLDocumentStore):
embedding_field=embedding_field, progress_bar=progress_bar
)
if similarity == "dot_product":
if similarity == "dot_product" or similarity == 'cosine':
self.similarity = similarity
self.metric_type = faiss.METRIC_INNER_PRODUCT
elif similarity == "l2":
self.similarity = similarity
self.metric_type = faiss.METRIC_L2
else:
raise ValueError("The FAISS document store can currently only support dot_product similarity. "
"Please set similarity=\"dot_product\"")
raise ValueError("The FAISS document store can currently only support dot_product, cosine and l2 similarity. "
"Please set similarity to one of the above.")
self.vector_dim = vector_dim
self.faiss_index_factory_str = faiss_index_factory_str
@ -184,6 +187,10 @@ class FAISSDocumentStore(SQLDocumentStore):
if add_vectors:
embeddings = [doc.embedding for doc in document_objects[i: i + batch_size]]
embeddings_to_index = np.array(embeddings, dtype="float32")
if self.similarity == 'cosine':
faiss.normalize_L2(embeddings_to_index)
self.faiss_indexes[index].add(embeddings_to_index)
docs_to_write_in_sql = []
@ -261,6 +268,10 @@ class FAISSDocumentStore(SQLDocumentStore):
assert len(document_batch) == len(embeddings)
embeddings_to_index = np.array(embeddings, dtype="float32")
if self.similarity == 'cosine':
faiss.normalize_L2(embeddings_to_index)
self.faiss_indexes[index].add(embeddings_to_index)
vector_id_map = {}
@ -417,6 +428,10 @@ class FAISSDocumentStore(SQLDocumentStore):
return_embedding = self.return_embedding
query_emb = query_emb.reshape(1, -1).astype(np.float32)
if self.similarity == 'cosine':
faiss.normalize_L2(query_emb)
score_matrix, vector_id_matrix = self.faiss_indexes[index].search(query_emb, top_k)
vector_ids_for_query = [str(vector_id) for vector_id in vector_id_matrix[0] if vector_id != -1]
@ -426,7 +441,10 @@ class FAISSDocumentStore(SQLDocumentStore):
scores_for_vector_ids: Dict[str, float] = {str(v_id): s for v_id, s in zip(vector_id_matrix[0], score_matrix[0])}
for doc in documents:
raw_score = scores_for_vector_ids[doc.meta["vector_id"]]
doc.score = float(expit(np.asarray(raw_score / 100)))
if self.similarity == 'cosine':
doc.score = (raw_score + 1) / 2
else:
doc.score = float(expit(np.asarray(raw_score / 100)))
if return_embedding is True:
doc.embedding = self.faiss_indexes[index].reconstruct(int(doc.meta["vector_id"]))

View File

@ -1,4 +1,5 @@
import faiss
import math
import numpy as np
import pytest
from haystack import Document
@ -67,7 +68,6 @@ def test_faiss_write_docs(document_store, index_buffer_size, batch_size):
# compare original input vec with stored one (ignore extra dim added by hnsw)
assert np.allclose(original_doc["embedding"], stored_emb, rtol=0.01)
@pytest.mark.slow
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
@ -205,3 +205,71 @@ def test_faiss_passing_index_from_outside(tmp_path):
assert 0 <= int(doc.meta["vector_id"]) <= 7
def test_faiss_cosine_similarity(tmp_path):
document_store = FAISSDocumentStore(
sql_url=f"sqlite:////{tmp_path/'haystack_test_faiss.db'}", similarity='cosine'
)
# below we will write documents to the store and then query it to see if vectors were normalized
document_store.write_documents(documents=DOCUMENTS)
# note that the same query will be used later when querying after updating the embeddings
query = np.random.rand(768).astype(np.float32)
query_results = document_store.query_by_embedding(query_emb=query, top_k=len(DOCUMENTS), return_embedding=True)
# check if search with cosine similarity returns the correct number of results
assert len(query_results) == len(DOCUMENTS)
indexed_docs = {}
for doc in DOCUMENTS:
indexed_docs[doc["text"]] = doc["embedding"]
for doc in query_results:
result_emb = doc.embedding
original_emb = np.array([indexed_docs[doc.text]], dtype="float32")
faiss.normalize_L2(original_emb)
# check if the stored embedding was normalized
assert np.allclose(original_emb[0], result_emb, rtol=0.01)
# check if the score is plausible for cosine similarity
assert 0 <= doc.score <= 1.0
# now check if vectors are normalized when updating embeddings
class MockRetriever():
def embed_passages(self, docs):
return [np.random.rand(768).astype(np.float32) for doc in docs]
retriever = MockRetriever()
document_store.update_embeddings(retriever=retriever)
query_results = document_store.query_by_embedding(query_emb=query, top_k=len(DOCUMENTS), return_embedding=True)
for doc in query_results:
original_emb = np.array([indexed_docs[doc.text]], dtype="float32")
faiss.normalize_L2(original_emb)
# check if the original embedding has changed after updating the embeddings
assert not np.allclose(original_emb[0], doc.embedding, rtol=0.01)
def test_faiss_cosine_sanity_check(tmp_path):
document_store = FAISSDocumentStore(
sql_url=f"sqlite:////{tmp_path/'haystack_test_faiss.db'}", similarity='cosine',
vector_dim=3
)
VEC_1 = np.array([.1, .2, .3], dtype="float32")
VEC_2 = np.array([.4, .5, .6], dtype="float32")
# This is the cosine similarity of VEC_1 and VEC_2 calculated using sklearn.metrics.pairwise.cosine_similarity
# The score is normalized to yield a value between 0 and 1.
KNOWN_COSINE = (0.9746317 + 1) / 2
docs = [{"name": "vec_1", "text": "vec_1", "embedding": VEC_1}]
document_store.write_documents(documents=docs)
query_results = document_store.query_by_embedding(query_emb=VEC_2, top_k=1, return_embedding=True)
# check if faiss returns the same cosine similarity. Manual testing with faiss yielded 0.9746318
assert math.isclose(query_results[0].score, KNOWN_COSINE, abs_tol=0.000001)