Allow multiple write calls to existing FAISS index. (#422)

- Fixing issue when update_embeddings always create new FAISS index instead of clearing existing one. New index creation may not free existing used memory and cause memory leak.

Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
This commit is contained in:
Lalit Pagaria 2020-10-05 12:01:20 +02:00 committed by GitHub
parent 072e32b38a
commit 465ccbc12e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 68 additions and 43 deletions

View File

@ -50,10 +50,8 @@ class FAISSDocumentStore(SQLDocumentStore):
return index
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None):
if self.faiss_index is not None:
raise Exception("Addition of more data in an existing index is not supported.")
faiss_index = self._create_new_index(vector_size=self.vector_size)
self.faiss_index = self.faiss_index or self._create_new_index(vector_size=self.vector_size)
index = index or self.index
document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
@ -63,12 +61,12 @@ class FAISSDocumentStore(SQLDocumentStore):
phi = self._get_phi(document_objects)
for i in range(0, len(document_objects), self.index_buffer_size):
vector_id = faiss_index.ntotal
vector_id = self.faiss_index.ntotal
if add_vectors:
embeddings = [doc.embedding for doc in document_objects[i: i + self.index_buffer_size]]
hnsw_vectors = self._get_hnsw_vectors(embeddings=embeddings, phi=phi)
hnsw_vectors = hnsw_vectors.astype(np.float32)
faiss_index.add(hnsw_vectors)
self.faiss_index.add(hnsw_vectors)
docs_to_write_in_sql = []
for doc in document_objects[i : i + self.index_buffer_size]:
@ -80,8 +78,6 @@ class FAISSDocumentStore(SQLDocumentStore):
super(FAISSDocumentStore, self).write_documents(docs_to_write_in_sql, index=index)
self.faiss_index = faiss_index
def _get_hnsw_vectors(self, embeddings: List[np.array], phi: int) -> np.array:
"""
HNSW indices in FAISS only support L2 distance. This transformation adds an additional dimension to obtain
@ -113,8 +109,10 @@ class FAISSDocumentStore(SQLDocumentStore):
:param index: Index name to update
:return: None
"""
# Some FAISS indexes(like the default HNSWx) do not support removing vectors, so a new index is created.
faiss_index = self._create_new_index(vector_size=self.vector_size)
self.faiss_index = self.faiss_index or self._create_new_index(vector_size=self.vector_size)
# To clear out the FAISS index contents and frees all memory immediately that is in use by the index
self.faiss_index.reset()
index = index or self.index
documents = self.get_all_documents(index=index)
@ -128,18 +126,17 @@ class FAISSDocumentStore(SQLDocumentStore):
vector_id_map = {}
for i in range(0, len(documents), self.index_buffer_size):
vector_id = faiss_index.ntotal
vector_id = self.faiss_index.ntotal
embeddings = [doc.embedding for doc in documents[i: i + self.index_buffer_size]]
hnsw_vectors = self._get_hnsw_vectors(embeddings=embeddings, phi=phi)
hnsw_vectors = hnsw_vectors.astype(np.float32)
faiss_index.add(hnsw_vectors)
self.faiss_index.add(hnsw_vectors)
for doc in documents[i: i + self.index_buffer_size]:
vector_id_map[doc.id] = vector_id
vector_id += 1
self.update_vector_ids(vector_id_map, index=index)
self.faiss_index = faiss_index
def query_by_embedding(
self, query_emb: np.array, filters: Optional[dict] = None, top_k: int = 10, index: Optional[str] = None

View File

@ -10,19 +10,59 @@ DOCUMENTS = [
{"name": "name_1", "text": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
{"name": "name_2", "text": "text_2", "embedding": np.random.rand(768).astype(np.float32)},
{"name": "name_3", "text": "text_3", "embedding": np.random.rand(768).astype(np.float64)},
{"name": "name_4", "text": "text_4", "embedding": np.random.rand(768).astype(np.float32)},
{"name": "name_5", "text": "text_5", "embedding": np.random.rand(768).astype(np.float32)},
{"name": "name_6", "text": "text_6", "embedding": np.random.rand(768).astype(np.float64)},
]
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
@pytest.mark.parametrize("index_buffer_size", [10_000, 2])
def test_faiss_write_docs(document_store, index_buffer_size):
document_store.index_buffer_size = index_buffer_size
document_store.write_documents(DOCUMENTS)
documents_indexed = document_store.get_all_documents()
def check_data_correctness(documents_indexed, documents_inserted):
# test if correct vector_ids are assigned
for i, doc in enumerate(documents_indexed):
assert doc.meta["vector_id"] == str(i)
# test if number of documents is correct
assert len(documents_indexed) == len(documents_inserted)
# test if two docs have same vector_is assigned
vector_ids = set()
for i, doc in enumerate(documents_indexed):
vector_ids.add(doc.meta["vector_id"])
assert len(vector_ids) == len(documents_inserted)
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
def test_faiss_index_save_and_load(document_store):
document_store.write_documents(DOCUMENTS)
# test saving the index
document_store.save("haystack_test_faiss")
# clear existing faiss_index
document_store.faiss_index.reset()
# test faiss index is cleared
assert document_store.faiss_index.ntotal == 0
# test loading the index
new_document_store = document_store.load(sql_url="sqlite:///haystack_test.db", faiss_file_path="haystack_test_faiss")
# check faiss index is restored
assert new_document_store.faiss_index.ntotal == len(DOCUMENTS)
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
@pytest.mark.parametrize("index_buffer_size", [10_000, 2])
@pytest.mark.parametrize("batch_size", [2])
def test_faiss_write_docs(document_store, index_buffer_size, batch_size):
document_store.index_buffer_size = index_buffer_size
# Write in small batches
for i in range(0, len(DOCUMENTS), batch_size):
document_store.write_documents(DOCUMENTS[i: i + batch_size])
documents_indexed = document_store.get_all_documents()
# test if correct vectors are associated with docs
for i, doc in enumerate(documents_indexed):
# we currently don't get the embeddings back when we call document_store.get_all_documents()
@ -31,20 +71,13 @@ def test_faiss_write_docs(document_store, index_buffer_size):
# compare original input vec with stored one (ignore extra dim added by hnsw)
assert np.allclose(original_doc["embedding"], stored_emb[:-1], rtol=0.01)
# test insertion of documents in an existing index fails
with pytest.raises(Exception):
document_store.write_documents(DOCUMENTS)
# test document correctness
check_data_correctness(documents_indexed, DOCUMENTS)
# test saving the index
document_store.save("haystack_test_faiss")
# test loading the index
document_store.load(sql_url="sqlite:///haystack_test.db", faiss_file_path="haystack_test_faiss")
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
@pytest.mark.parametrize("index_buffer_size", [10_000, 2])
def test_faiss_update_docs(document_store, index_buffer_size):
# adjust buffer size
document_store.index_buffer_size = index_buffer_size
@ -61,15 +94,6 @@ def test_faiss_update_docs(document_store, index_buffer_size):
document_store.update_embeddings(retriever=retriever)
documents_indexed = document_store.get_all_documents()
# test if number of documents is correct
assert len(documents_indexed) == len(DOCUMENTS)
# test if two docs have same vector_is assigned
vector_ids = set()
for i, doc in enumerate(documents_indexed):
vector_ids.add(doc.meta["vector_id"])
assert len(vector_ids) == len(DOCUMENTS)
# test if correct vectors are associated with docs
for i, doc in enumerate(documents_indexed):
original_doc = [d for d in DOCUMENTS if d["text"] == doc.text][0]
@ -78,25 +102,29 @@ def test_faiss_update_docs(document_store, index_buffer_size):
# compare original input vec with stored one (ignore extra dim added by hnsw)
assert np.allclose(updated_embedding, stored_emb[:-1], rtol=0.01)
# test document correctness
check_data_correctness(documents_indexed, DOCUMENTS)
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
def test_faiss_retrieving(document_store):
document_store.write_documents(DOCUMENTS)
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert", use_gpu=False)
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert",
use_gpu=False)
result = retriever.retrieve(query="How to test this?")
assert len(result) == 3
assert len(result) == len(DOCUMENTS)
assert type(result[0]) == Document
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
def test_faiss_finding(document_store):
document_store.write_documents(DOCUMENTS)
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert", use_gpu=False)
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert",
use_gpu=False)
finder = Finder(reader=None, retriever=retriever)
prediction = finder.get_answers_via_similar_questions(question="How to test this?", top_k_retriever=1)
assert len(prediction.get('answers', [])) == 1