diff --git a/haystack/document_store/faiss.py b/haystack/document_store/faiss.py index 6fcd0d23b..aad082c88 100644 --- a/haystack/document_store/faiss.py +++ b/haystack/document_store/faiss.py @@ -67,8 +67,11 @@ class FAISSDocumentStore(SQLDocumentStore): or one with docs that you used in Haystack before and want to load again. :param return_embedding: To return document embedding :param index: Name of index in document store to use. - :param similarity: The similarity function used to compare document vectors. 'dot_product' is the default sine it is - more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence BERT model. + :param similarity: The similarity function used to compare document vectors. 'dot_product' is the default since it is + more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence-Transformer model. + In both cases, the returned values in Document.score are normalized to be in range [0,1]: + For `dot_product`: expit(np.asarray(raw_score / 100)) + FOr `cosine`: (raw_score + 1) / 2 :param embedding_field: Name of field containing an embedding vector. :param progress_bar: Whether to show a tqdm progress bar or not. Can be helpful to disable in production deployments to keep the logs clean. @@ -88,15 +91,15 @@ class FAISSDocumentStore(SQLDocumentStore): embedding_field=embedding_field, progress_bar=progress_bar ) - if similarity == "dot_product": + if similarity == "dot_product" or similarity == 'cosine': self.similarity = similarity self.metric_type = faiss.METRIC_INNER_PRODUCT elif similarity == "l2": self.similarity = similarity self.metric_type = faiss.METRIC_L2 else: - raise ValueError("The FAISS document store can currently only support dot_product similarity. " - "Please set similarity=\"dot_product\"") + raise ValueError("The FAISS document store can currently only support dot_product, cosine and l2 similarity. " + "Please set similarity to one of the above.") self.vector_dim = vector_dim self.faiss_index_factory_str = faiss_index_factory_str @@ -184,6 +187,10 @@ class FAISSDocumentStore(SQLDocumentStore): if add_vectors: embeddings = [doc.embedding for doc in document_objects[i: i + batch_size]] embeddings_to_index = np.array(embeddings, dtype="float32") + + if self.similarity == 'cosine': + faiss.normalize_L2(embeddings_to_index) + self.faiss_indexes[index].add(embeddings_to_index) docs_to_write_in_sql = [] @@ -261,6 +268,10 @@ class FAISSDocumentStore(SQLDocumentStore): assert len(document_batch) == len(embeddings) embeddings_to_index = np.array(embeddings, dtype="float32") + + if self.similarity == 'cosine': + faiss.normalize_L2(embeddings_to_index) + self.faiss_indexes[index].add(embeddings_to_index) vector_id_map = {} @@ -417,6 +428,10 @@ class FAISSDocumentStore(SQLDocumentStore): return_embedding = self.return_embedding query_emb = query_emb.reshape(1, -1).astype(np.float32) + + if self.similarity == 'cosine': + faiss.normalize_L2(query_emb) + score_matrix, vector_id_matrix = self.faiss_indexes[index].search(query_emb, top_k) vector_ids_for_query = [str(vector_id) for vector_id in vector_id_matrix[0] if vector_id != -1] @@ -426,7 +441,10 @@ class FAISSDocumentStore(SQLDocumentStore): scores_for_vector_ids: Dict[str, float] = {str(v_id): s for v_id, s in zip(vector_id_matrix[0], score_matrix[0])} for doc in documents: raw_score = scores_for_vector_ids[doc.meta["vector_id"]] - doc.score = float(expit(np.asarray(raw_score / 100))) + if self.similarity == 'cosine': + doc.score = (raw_score + 1) / 2 + else: + doc.score = float(expit(np.asarray(raw_score / 100))) if return_embedding is True: doc.embedding = self.faiss_indexes[index].reconstruct(int(doc.meta["vector_id"])) diff --git a/test/test_faiss_and_milvus.py b/test/test_faiss_and_milvus.py index 0b197b868..20c71bb67 100644 --- a/test/test_faiss_and_milvus.py +++ b/test/test_faiss_and_milvus.py @@ -1,4 +1,5 @@ import faiss +import math import numpy as np import pytest from haystack import Document @@ -67,7 +68,6 @@ def test_faiss_write_docs(document_store, index_buffer_size, batch_size): # compare original input vec with stored one (ignore extra dim added by hnsw) assert np.allclose(original_doc["embedding"], stored_emb, rtol=0.01) - @pytest.mark.slow @pytest.mark.parametrize("retriever", ["dpr"], indirect=True) @pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True) @@ -205,3 +205,71 @@ def test_faiss_passing_index_from_outside(tmp_path): assert 0 <= int(doc.meta["vector_id"]) <= 7 +def test_faiss_cosine_similarity(tmp_path): + document_store = FAISSDocumentStore( + sql_url=f"sqlite:////{tmp_path/'haystack_test_faiss.db'}", similarity='cosine' + ) + + # below we will write documents to the store and then query it to see if vectors were normalized + + document_store.write_documents(documents=DOCUMENTS) + + # note that the same query will be used later when querying after updating the embeddings + query = np.random.rand(768).astype(np.float32) + + query_results = document_store.query_by_embedding(query_emb=query, top_k=len(DOCUMENTS), return_embedding=True) + + # check if search with cosine similarity returns the correct number of results + assert len(query_results) == len(DOCUMENTS) + indexed_docs = {} + for doc in DOCUMENTS: + indexed_docs[doc["text"]] = doc["embedding"] + + for doc in query_results: + result_emb = doc.embedding + original_emb = np.array([indexed_docs[doc.text]], dtype="float32") + faiss.normalize_L2(original_emb) + + # check if the stored embedding was normalized + assert np.allclose(original_emb[0], result_emb, rtol=0.01) + + # check if the score is plausible for cosine similarity + assert 0 <= doc.score <= 1.0 + + # now check if vectors are normalized when updating embeddings + class MockRetriever(): + def embed_passages(self, docs): + return [np.random.rand(768).astype(np.float32) for doc in docs] + + retriever = MockRetriever() + document_store.update_embeddings(retriever=retriever) + query_results = document_store.query_by_embedding(query_emb=query, top_k=len(DOCUMENTS), return_embedding=True) + + for doc in query_results: + original_emb = np.array([indexed_docs[doc.text]], dtype="float32") + faiss.normalize_L2(original_emb) + # check if the original embedding has changed after updating the embeddings + assert not np.allclose(original_emb[0], doc.embedding, rtol=0.01) + + + +def test_faiss_cosine_sanity_check(tmp_path): + document_store = FAISSDocumentStore( + sql_url=f"sqlite:////{tmp_path/'haystack_test_faiss.db'}", similarity='cosine', + vector_dim=3 + ) + + VEC_1 = np.array([.1, .2, .3], dtype="float32") + VEC_2 = np.array([.4, .5, .6], dtype="float32") + + # This is the cosine similarity of VEC_1 and VEC_2 calculated using sklearn.metrics.pairwise.cosine_similarity + # The score is normalized to yield a value between 0 and 1. + KNOWN_COSINE = (0.9746317 + 1) / 2 + + docs = [{"name": "vec_1", "text": "vec_1", "embedding": VEC_1}] + document_store.write_documents(documents=docs) + + query_results = document_store.query_by_embedding(query_emb=VEC_2, top_k=1, return_embedding=True) + + # check if faiss returns the same cosine similarity. Manual testing with faiss yielded 0.9746318 + assert math.isclose(query_results[0].score, KNOWN_COSINE, abs_tol=0.000001)