Allow multiple write calls to existing FAISS index. (#422)

- Fixing issue when update_embeddings always create new FAISS index instead of clearing existing one. New index creation may not free existing used memory and cause memory leak.

Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
This commit is contained in:
Lalit Pagaria 2020-10-05 12:01:20 +02:00 committed by GitHub
parent 072e32b38a
commit 465ccbc12e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 68 additions and 43 deletions

View File

@ -50,10 +50,8 @@ class FAISSDocumentStore(SQLDocumentStore):
return index return index
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None): def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None):
if self.faiss_index is not None:
raise Exception("Addition of more data in an existing index is not supported.")
faiss_index = self._create_new_index(vector_size=self.vector_size) self.faiss_index = self.faiss_index or self._create_new_index(vector_size=self.vector_size)
index = index or self.index index = index or self.index
document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents] document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
@ -63,12 +61,12 @@ class FAISSDocumentStore(SQLDocumentStore):
phi = self._get_phi(document_objects) phi = self._get_phi(document_objects)
for i in range(0, len(document_objects), self.index_buffer_size): for i in range(0, len(document_objects), self.index_buffer_size):
vector_id = faiss_index.ntotal vector_id = self.faiss_index.ntotal
if add_vectors: if add_vectors:
embeddings = [doc.embedding for doc in document_objects[i: i + self.index_buffer_size]] embeddings = [doc.embedding for doc in document_objects[i: i + self.index_buffer_size]]
hnsw_vectors = self._get_hnsw_vectors(embeddings=embeddings, phi=phi) hnsw_vectors = self._get_hnsw_vectors(embeddings=embeddings, phi=phi)
hnsw_vectors = hnsw_vectors.astype(np.float32) hnsw_vectors = hnsw_vectors.astype(np.float32)
faiss_index.add(hnsw_vectors) self.faiss_index.add(hnsw_vectors)
docs_to_write_in_sql = [] docs_to_write_in_sql = []
for doc in document_objects[i : i + self.index_buffer_size]: for doc in document_objects[i : i + self.index_buffer_size]:
@ -80,8 +78,6 @@ class FAISSDocumentStore(SQLDocumentStore):
super(FAISSDocumentStore, self).write_documents(docs_to_write_in_sql, index=index) super(FAISSDocumentStore, self).write_documents(docs_to_write_in_sql, index=index)
self.faiss_index = faiss_index
def _get_hnsw_vectors(self, embeddings: List[np.array], phi: int) -> np.array: def _get_hnsw_vectors(self, embeddings: List[np.array], phi: int) -> np.array:
""" """
HNSW indices in FAISS only support L2 distance. This transformation adds an additional dimension to obtain HNSW indices in FAISS only support L2 distance. This transformation adds an additional dimension to obtain
@ -113,8 +109,10 @@ class FAISSDocumentStore(SQLDocumentStore):
:param index: Index name to update :param index: Index name to update
:return: None :return: None
""" """
# Some FAISS indexes(like the default HNSWx) do not support removing vectors, so a new index is created. self.faiss_index = self.faiss_index or self._create_new_index(vector_size=self.vector_size)
faiss_index = self._create_new_index(vector_size=self.vector_size) # To clear out the FAISS index contents and frees all memory immediately that is in use by the index
self.faiss_index.reset()
index = index or self.index index = index or self.index
documents = self.get_all_documents(index=index) documents = self.get_all_documents(index=index)
@ -128,18 +126,17 @@ class FAISSDocumentStore(SQLDocumentStore):
vector_id_map = {} vector_id_map = {}
for i in range(0, len(documents), self.index_buffer_size): for i in range(0, len(documents), self.index_buffer_size):
vector_id = faiss_index.ntotal vector_id = self.faiss_index.ntotal
embeddings = [doc.embedding for doc in documents[i: i + self.index_buffer_size]] embeddings = [doc.embedding for doc in documents[i: i + self.index_buffer_size]]
hnsw_vectors = self._get_hnsw_vectors(embeddings=embeddings, phi=phi) hnsw_vectors = self._get_hnsw_vectors(embeddings=embeddings, phi=phi)
hnsw_vectors = hnsw_vectors.astype(np.float32) hnsw_vectors = hnsw_vectors.astype(np.float32)
faiss_index.add(hnsw_vectors) self.faiss_index.add(hnsw_vectors)
for doc in documents[i: i + self.index_buffer_size]: for doc in documents[i: i + self.index_buffer_size]:
vector_id_map[doc.id] = vector_id vector_id_map[doc.id] = vector_id
vector_id += 1 vector_id += 1
self.update_vector_ids(vector_id_map, index=index) self.update_vector_ids(vector_id_map, index=index)
self.faiss_index = faiss_index
def query_by_embedding( def query_by_embedding(
self, query_emb: np.array, filters: Optional[dict] = None, top_k: int = 10, index: Optional[str] = None self, query_emb: np.array, filters: Optional[dict] = None, top_k: int = 10, index: Optional[str] = None

View File

@ -10,19 +10,59 @@ DOCUMENTS = [
{"name": "name_1", "text": "text_1", "embedding": np.random.rand(768).astype(np.float32)}, {"name": "name_1", "text": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
{"name": "name_2", "text": "text_2", "embedding": np.random.rand(768).astype(np.float32)}, {"name": "name_2", "text": "text_2", "embedding": np.random.rand(768).astype(np.float32)},
{"name": "name_3", "text": "text_3", "embedding": np.random.rand(768).astype(np.float64)}, {"name": "name_3", "text": "text_3", "embedding": np.random.rand(768).astype(np.float64)},
{"name": "name_4", "text": "text_4", "embedding": np.random.rand(768).astype(np.float32)},
{"name": "name_5", "text": "text_5", "embedding": np.random.rand(768).astype(np.float32)},
{"name": "name_6", "text": "text_6", "embedding": np.random.rand(768).astype(np.float64)},
] ]
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
@pytest.mark.parametrize("index_buffer_size", [10_000, 2])
def test_faiss_write_docs(document_store, index_buffer_size):
document_store.index_buffer_size = index_buffer_size
document_store.write_documents(DOCUMENTS)
documents_indexed = document_store.get_all_documents()
def check_data_correctness(documents_indexed, documents_inserted):
# test if correct vector_ids are assigned # test if correct vector_ids are assigned
for i, doc in enumerate(documents_indexed): for i, doc in enumerate(documents_indexed):
assert doc.meta["vector_id"] == str(i) assert doc.meta["vector_id"] == str(i)
# test if number of documents is correct
assert len(documents_indexed) == len(documents_inserted)
# test if two docs have same vector_is assigned
vector_ids = set()
for i, doc in enumerate(documents_indexed):
vector_ids.add(doc.meta["vector_id"])
assert len(vector_ids) == len(documents_inserted)
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
def test_faiss_index_save_and_load(document_store):
document_store.write_documents(DOCUMENTS)
# test saving the index
document_store.save("haystack_test_faiss")
# clear existing faiss_index
document_store.faiss_index.reset()
# test faiss index is cleared
assert document_store.faiss_index.ntotal == 0
# test loading the index
new_document_store = document_store.load(sql_url="sqlite:///haystack_test.db", faiss_file_path="haystack_test_faiss")
# check faiss index is restored
assert new_document_store.faiss_index.ntotal == len(DOCUMENTS)
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
@pytest.mark.parametrize("index_buffer_size", [10_000, 2])
@pytest.mark.parametrize("batch_size", [2])
def test_faiss_write_docs(document_store, index_buffer_size, batch_size):
document_store.index_buffer_size = index_buffer_size
# Write in small batches
for i in range(0, len(DOCUMENTS), batch_size):
document_store.write_documents(DOCUMENTS[i: i + batch_size])
documents_indexed = document_store.get_all_documents()
# test if correct vectors are associated with docs # test if correct vectors are associated with docs
for i, doc in enumerate(documents_indexed): for i, doc in enumerate(documents_indexed):
# we currently don't get the embeddings back when we call document_store.get_all_documents() # we currently don't get the embeddings back when we call document_store.get_all_documents()
@ -31,20 +71,13 @@ def test_faiss_write_docs(document_store, index_buffer_size):
# compare original input vec with stored one (ignore extra dim added by hnsw) # compare original input vec with stored one (ignore extra dim added by hnsw)
assert np.allclose(original_doc["embedding"], stored_emb[:-1], rtol=0.01) assert np.allclose(original_doc["embedding"], stored_emb[:-1], rtol=0.01)
# test insertion of documents in an existing index fails # test document correctness
with pytest.raises(Exception): check_data_correctness(documents_indexed, DOCUMENTS)
document_store.write_documents(DOCUMENTS)
# test saving the index
document_store.save("haystack_test_faiss")
# test loading the index
document_store.load(sql_url="sqlite:///haystack_test.db", faiss_file_path="haystack_test_faiss")
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True) @pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
@pytest.mark.parametrize("index_buffer_size", [10_000, 2]) @pytest.mark.parametrize("index_buffer_size", [10_000, 2])
def test_faiss_update_docs(document_store, index_buffer_size): def test_faiss_update_docs(document_store, index_buffer_size):
# adjust buffer size # adjust buffer size
document_store.index_buffer_size = index_buffer_size document_store.index_buffer_size = index_buffer_size
@ -61,15 +94,6 @@ def test_faiss_update_docs(document_store, index_buffer_size):
document_store.update_embeddings(retriever=retriever) document_store.update_embeddings(retriever=retriever)
documents_indexed = document_store.get_all_documents() documents_indexed = document_store.get_all_documents()
# test if number of documents is correct
assert len(documents_indexed) == len(DOCUMENTS)
# test if two docs have same vector_is assigned
vector_ids = set()
for i, doc in enumerate(documents_indexed):
vector_ids.add(doc.meta["vector_id"])
assert len(vector_ids) == len(DOCUMENTS)
# test if correct vectors are associated with docs # test if correct vectors are associated with docs
for i, doc in enumerate(documents_indexed): for i, doc in enumerate(documents_indexed):
original_doc = [d for d in DOCUMENTS if d["text"] == doc.text][0] original_doc = [d for d in DOCUMENTS if d["text"] == doc.text][0]
@ -78,25 +102,29 @@ def test_faiss_update_docs(document_store, index_buffer_size):
# compare original input vec with stored one (ignore extra dim added by hnsw) # compare original input vec with stored one (ignore extra dim added by hnsw)
assert np.allclose(updated_embedding, stored_emb[:-1], rtol=0.01) assert np.allclose(updated_embedding, stored_emb[:-1], rtol=0.01)
# test document correctness
check_data_correctness(documents_indexed, DOCUMENTS)
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True) @pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
def test_faiss_retrieving(document_store): def test_faiss_retrieving(document_store):
document_store.write_documents(DOCUMENTS) document_store.write_documents(DOCUMENTS)
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert", use_gpu=False) retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert",
use_gpu=False)
result = retriever.retrieve(query="How to test this?") result = retriever.retrieve(query="How to test this?")
assert len(result) == 3 assert len(result) == len(DOCUMENTS)
assert type(result[0]) == Document assert type(result[0]) == Document
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True) @pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
def test_faiss_finding(document_store): def test_faiss_finding(document_store):
document_store.write_documents(DOCUMENTS) document_store.write_documents(DOCUMENTS)
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert", use_gpu=False) retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert",
use_gpu=False)
finder = Finder(reader=None, retriever=retriever) finder = Finder(reader=None, retriever=retriever)
prediction = finder.get_answers_via_similar_questions(question="How to test this?", top_k_retriever=1) prediction = finder.get_answers_via_similar_questions(question="How to test this?", top_k_retriever=1)
assert len(prediction.get('answers', [])) == 1 assert len(prediction.get('answers', [])) == 1