From 465ccbc12e4a04e679ec8fca59db16c0499f95ce Mon Sep 17 00:00:00 2001 From: Lalit Pagaria Date: Mon, 5 Oct 2020 12:01:20 +0200 Subject: [PATCH] Allow multiple write calls to existing FAISS index. (#422) - Fixing issue when update_embeddings always create new FAISS index instead of clearing existing one. New index creation may not free existing used memory and cause memory leak. Co-authored-by: Malte Pietsch --- haystack/document_store/faiss.py | 23 ++++----- test/test_faiss.py | 88 +++++++++++++++++++++----------- 2 files changed, 68 insertions(+), 43 deletions(-) diff --git a/haystack/document_store/faiss.py b/haystack/document_store/faiss.py index d1e4b5c20..ac8e90fb9 100644 --- a/haystack/document_store/faiss.py +++ b/haystack/document_store/faiss.py @@ -50,10 +50,8 @@ class FAISSDocumentStore(SQLDocumentStore): return index def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None): - if self.faiss_index is not None: - raise Exception("Addition of more data in an existing index is not supported.") - faiss_index = self._create_new_index(vector_size=self.vector_size) + self.faiss_index = self.faiss_index or self._create_new_index(vector_size=self.vector_size) index = index or self.index document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents] @@ -63,12 +61,12 @@ class FAISSDocumentStore(SQLDocumentStore): phi = self._get_phi(document_objects) for i in range(0, len(document_objects), self.index_buffer_size): - vector_id = faiss_index.ntotal + vector_id = self.faiss_index.ntotal if add_vectors: embeddings = [doc.embedding for doc in document_objects[i: i + self.index_buffer_size]] hnsw_vectors = self._get_hnsw_vectors(embeddings=embeddings, phi=phi) hnsw_vectors = hnsw_vectors.astype(np.float32) - faiss_index.add(hnsw_vectors) + self.faiss_index.add(hnsw_vectors) docs_to_write_in_sql = [] for doc in document_objects[i : i + self.index_buffer_size]: @@ -80,8 +78,6 @@ class FAISSDocumentStore(SQLDocumentStore): super(FAISSDocumentStore, self).write_documents(docs_to_write_in_sql, index=index) - self.faiss_index = faiss_index - def _get_hnsw_vectors(self, embeddings: List[np.array], phi: int) -> np.array: """ HNSW indices in FAISS only support L2 distance. This transformation adds an additional dimension to obtain @@ -113,8 +109,10 @@ class FAISSDocumentStore(SQLDocumentStore): :param index: Index name to update :return: None """ - # Some FAISS indexes(like the default HNSWx) do not support removing vectors, so a new index is created. - faiss_index = self._create_new_index(vector_size=self.vector_size) + self.faiss_index = self.faiss_index or self._create_new_index(vector_size=self.vector_size) + # To clear out the FAISS index contents and frees all memory immediately that is in use by the index + self.faiss_index.reset() + index = index or self.index documents = self.get_all_documents(index=index) @@ -128,18 +126,17 @@ class FAISSDocumentStore(SQLDocumentStore): vector_id_map = {} for i in range(0, len(documents), self.index_buffer_size): - vector_id = faiss_index.ntotal + vector_id = self.faiss_index.ntotal embeddings = [doc.embedding for doc in documents[i: i + self.index_buffer_size]] hnsw_vectors = self._get_hnsw_vectors(embeddings=embeddings, phi=phi) hnsw_vectors = hnsw_vectors.astype(np.float32) - faiss_index.add(hnsw_vectors) + self.faiss_index.add(hnsw_vectors) for doc in documents[i: i + self.index_buffer_size]: vector_id_map[doc.id] = vector_id vector_id += 1 - + self.update_vector_ids(vector_id_map, index=index) - self.faiss_index = faiss_index def query_by_embedding( self, query_emb: np.array, filters: Optional[dict] = None, top_k: int = 10, index: Optional[str] = None diff --git a/test/test_faiss.py b/test/test_faiss.py index 8771656d5..4aaaf1101 100644 --- a/test/test_faiss.py +++ b/test/test_faiss.py @@ -10,19 +10,59 @@ DOCUMENTS = [ {"name": "name_1", "text": "text_1", "embedding": np.random.rand(768).astype(np.float32)}, {"name": "name_2", "text": "text_2", "embedding": np.random.rand(768).astype(np.float32)}, {"name": "name_3", "text": "text_3", "embedding": np.random.rand(768).astype(np.float64)}, + {"name": "name_4", "text": "text_4", "embedding": np.random.rand(768).astype(np.float32)}, + {"name": "name_5", "text": "text_5", "embedding": np.random.rand(768).astype(np.float32)}, + {"name": "name_6", "text": "text_6", "embedding": np.random.rand(768).astype(np.float64)}, ] -@pytest.mark.parametrize("document_store", ["faiss"], indirect=True) -@pytest.mark.parametrize("index_buffer_size", [10_000, 2]) -def test_faiss_write_docs(document_store, index_buffer_size): - document_store.index_buffer_size = index_buffer_size - document_store.write_documents(DOCUMENTS) - documents_indexed = document_store.get_all_documents() +def check_data_correctness(documents_indexed, documents_inserted): # test if correct vector_ids are assigned for i, doc in enumerate(documents_indexed): assert doc.meta["vector_id"] == str(i) + # test if number of documents is correct + assert len(documents_indexed) == len(documents_inserted) + + # test if two docs have same vector_is assigned + vector_ids = set() + for i, doc in enumerate(documents_indexed): + vector_ids.add(doc.meta["vector_id"]) + assert len(vector_ids) == len(documents_inserted) + + +@pytest.mark.parametrize("document_store", ["faiss"], indirect=True) +def test_faiss_index_save_and_load(document_store): + document_store.write_documents(DOCUMENTS) + + # test saving the index + document_store.save("haystack_test_faiss") + + # clear existing faiss_index + document_store.faiss_index.reset() + + # test faiss index is cleared + assert document_store.faiss_index.ntotal == 0 + + # test loading the index + new_document_store = document_store.load(sql_url="sqlite:///haystack_test.db", faiss_file_path="haystack_test_faiss") + + # check faiss index is restored + assert new_document_store.faiss_index.ntotal == len(DOCUMENTS) + + +@pytest.mark.parametrize("document_store", ["faiss"], indirect=True) +@pytest.mark.parametrize("index_buffer_size", [10_000, 2]) +@pytest.mark.parametrize("batch_size", [2]) +def test_faiss_write_docs(document_store, index_buffer_size, batch_size): + document_store.index_buffer_size = index_buffer_size + + # Write in small batches + for i in range(0, len(DOCUMENTS), batch_size): + document_store.write_documents(DOCUMENTS[i: i + batch_size]) + + documents_indexed = document_store.get_all_documents() + # test if correct vectors are associated with docs for i, doc in enumerate(documents_indexed): # we currently don't get the embeddings back when we call document_store.get_all_documents() @@ -31,20 +71,13 @@ def test_faiss_write_docs(document_store, index_buffer_size): # compare original input vec with stored one (ignore extra dim added by hnsw) assert np.allclose(original_doc["embedding"], stored_emb[:-1], rtol=0.01) - # test insertion of documents in an existing index fails - with pytest.raises(Exception): - document_store.write_documents(DOCUMENTS) + # test document correctness + check_data_correctness(documents_indexed, DOCUMENTS) - # test saving the index - document_store.save("haystack_test_faiss") - - # test loading the index - document_store.load(sql_url="sqlite:///haystack_test.db", faiss_file_path="haystack_test_faiss") @pytest.mark.parametrize("document_store", ["faiss"], indirect=True) @pytest.mark.parametrize("index_buffer_size", [10_000, 2]) def test_faiss_update_docs(document_store, index_buffer_size): - # adjust buffer size document_store.index_buffer_size = index_buffer_size @@ -61,15 +94,6 @@ def test_faiss_update_docs(document_store, index_buffer_size): document_store.update_embeddings(retriever=retriever) documents_indexed = document_store.get_all_documents() - # test if number of documents is correct - assert len(documents_indexed) == len(DOCUMENTS) - - # test if two docs have same vector_is assigned - vector_ids = set() - for i, doc in enumerate(documents_indexed): - vector_ids.add(doc.meta["vector_id"]) - assert len(vector_ids) == len(DOCUMENTS) - # test if correct vectors are associated with docs for i, doc in enumerate(documents_indexed): original_doc = [d for d in DOCUMENTS if d["text"] == doc.text][0] @@ -78,25 +102,29 @@ def test_faiss_update_docs(document_store, index_buffer_size): # compare original input vec with stored one (ignore extra dim added by hnsw) assert np.allclose(updated_embedding, stored_emb[:-1], rtol=0.01) + # test document correctness + check_data_correctness(documents_indexed, DOCUMENTS) + + @pytest.mark.parametrize("document_store", ["faiss"], indirect=True) def test_faiss_retrieving(document_store): - document_store.write_documents(DOCUMENTS) - retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert", use_gpu=False) + retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert", + use_gpu=False) result = retriever.retrieve(query="How to test this?") - assert len(result) == 3 + assert len(result) == len(DOCUMENTS) assert type(result[0]) == Document + @pytest.mark.parametrize("document_store", ["faiss"], indirect=True) def test_faiss_finding(document_store): - document_store.write_documents(DOCUMENTS) - retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert", use_gpu=False) + retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert", + use_gpu=False) finder = Finder(reader=None, retriever=retriever) prediction = finder.get_answers_via_similar_questions(question="How to test this?", top_k_retriever=1) assert len(prediction.get('answers', [])) == 1 -