mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-22 15:38:01 +00:00
Allow multiple write calls to existing FAISS index. (#422)
- Fixing issue when update_embeddings always create new FAISS index instead of clearing existing one. New index creation may not free existing used memory and cause memory leak. Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
This commit is contained in:
parent
072e32b38a
commit
465ccbc12e
@ -50,10 +50,8 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
return index
|
||||
|
||||
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None):
|
||||
if self.faiss_index is not None:
|
||||
raise Exception("Addition of more data in an existing index is not supported.")
|
||||
|
||||
faiss_index = self._create_new_index(vector_size=self.vector_size)
|
||||
self.faiss_index = self.faiss_index or self._create_new_index(vector_size=self.vector_size)
|
||||
index = index or self.index
|
||||
document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
|
||||
|
||||
@ -63,12 +61,12 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
phi = self._get_phi(document_objects)
|
||||
|
||||
for i in range(0, len(document_objects), self.index_buffer_size):
|
||||
vector_id = faiss_index.ntotal
|
||||
vector_id = self.faiss_index.ntotal
|
||||
if add_vectors:
|
||||
embeddings = [doc.embedding for doc in document_objects[i: i + self.index_buffer_size]]
|
||||
hnsw_vectors = self._get_hnsw_vectors(embeddings=embeddings, phi=phi)
|
||||
hnsw_vectors = hnsw_vectors.astype(np.float32)
|
||||
faiss_index.add(hnsw_vectors)
|
||||
self.faiss_index.add(hnsw_vectors)
|
||||
|
||||
docs_to_write_in_sql = []
|
||||
for doc in document_objects[i : i + self.index_buffer_size]:
|
||||
@ -80,8 +78,6 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
|
||||
super(FAISSDocumentStore, self).write_documents(docs_to_write_in_sql, index=index)
|
||||
|
||||
self.faiss_index = faiss_index
|
||||
|
||||
def _get_hnsw_vectors(self, embeddings: List[np.array], phi: int) -> np.array:
|
||||
"""
|
||||
HNSW indices in FAISS only support L2 distance. This transformation adds an additional dimension to obtain
|
||||
@ -113,8 +109,10 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
:param index: Index name to update
|
||||
:return: None
|
||||
"""
|
||||
# Some FAISS indexes(like the default HNSWx) do not support removing vectors, so a new index is created.
|
||||
faiss_index = self._create_new_index(vector_size=self.vector_size)
|
||||
self.faiss_index = self.faiss_index or self._create_new_index(vector_size=self.vector_size)
|
||||
# To clear out the FAISS index contents and frees all memory immediately that is in use by the index
|
||||
self.faiss_index.reset()
|
||||
|
||||
index = index or self.index
|
||||
|
||||
documents = self.get_all_documents(index=index)
|
||||
@ -128,18 +126,17 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
|
||||
vector_id_map = {}
|
||||
for i in range(0, len(documents), self.index_buffer_size):
|
||||
vector_id = faiss_index.ntotal
|
||||
vector_id = self.faiss_index.ntotal
|
||||
embeddings = [doc.embedding for doc in documents[i: i + self.index_buffer_size]]
|
||||
hnsw_vectors = self._get_hnsw_vectors(embeddings=embeddings, phi=phi)
|
||||
hnsw_vectors = hnsw_vectors.astype(np.float32)
|
||||
faiss_index.add(hnsw_vectors)
|
||||
self.faiss_index.add(hnsw_vectors)
|
||||
|
||||
for doc in documents[i: i + self.index_buffer_size]:
|
||||
vector_id_map[doc.id] = vector_id
|
||||
vector_id += 1
|
||||
|
||||
|
||||
self.update_vector_ids(vector_id_map, index=index)
|
||||
self.faiss_index = faiss_index
|
||||
|
||||
def query_by_embedding(
|
||||
self, query_emb: np.array, filters: Optional[dict] = None, top_k: int = 10, index: Optional[str] = None
|
||||
|
@ -10,19 +10,59 @@ DOCUMENTS = [
|
||||
{"name": "name_1", "text": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
|
||||
{"name": "name_2", "text": "text_2", "embedding": np.random.rand(768).astype(np.float32)},
|
||||
{"name": "name_3", "text": "text_3", "embedding": np.random.rand(768).astype(np.float64)},
|
||||
{"name": "name_4", "text": "text_4", "embedding": np.random.rand(768).astype(np.float32)},
|
||||
{"name": "name_5", "text": "text_5", "embedding": np.random.rand(768).astype(np.float32)},
|
||||
{"name": "name_6", "text": "text_6", "embedding": np.random.rand(768).astype(np.float64)},
|
||||
]
|
||||
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
||||
@pytest.mark.parametrize("index_buffer_size", [10_000, 2])
|
||||
def test_faiss_write_docs(document_store, index_buffer_size):
|
||||
document_store.index_buffer_size = index_buffer_size
|
||||
|
||||
document_store.write_documents(DOCUMENTS)
|
||||
documents_indexed = document_store.get_all_documents()
|
||||
|
||||
def check_data_correctness(documents_indexed, documents_inserted):
|
||||
# test if correct vector_ids are assigned
|
||||
for i, doc in enumerate(documents_indexed):
|
||||
assert doc.meta["vector_id"] == str(i)
|
||||
|
||||
# test if number of documents is correct
|
||||
assert len(documents_indexed) == len(documents_inserted)
|
||||
|
||||
# test if two docs have same vector_is assigned
|
||||
vector_ids = set()
|
||||
for i, doc in enumerate(documents_indexed):
|
||||
vector_ids.add(doc.meta["vector_id"])
|
||||
assert len(vector_ids) == len(documents_inserted)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
||||
def test_faiss_index_save_and_load(document_store):
|
||||
document_store.write_documents(DOCUMENTS)
|
||||
|
||||
# test saving the index
|
||||
document_store.save("haystack_test_faiss")
|
||||
|
||||
# clear existing faiss_index
|
||||
document_store.faiss_index.reset()
|
||||
|
||||
# test faiss index is cleared
|
||||
assert document_store.faiss_index.ntotal == 0
|
||||
|
||||
# test loading the index
|
||||
new_document_store = document_store.load(sql_url="sqlite:///haystack_test.db", faiss_file_path="haystack_test_faiss")
|
||||
|
||||
# check faiss index is restored
|
||||
assert new_document_store.faiss_index.ntotal == len(DOCUMENTS)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
||||
@pytest.mark.parametrize("index_buffer_size", [10_000, 2])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
def test_faiss_write_docs(document_store, index_buffer_size, batch_size):
|
||||
document_store.index_buffer_size = index_buffer_size
|
||||
|
||||
# Write in small batches
|
||||
for i in range(0, len(DOCUMENTS), batch_size):
|
||||
document_store.write_documents(DOCUMENTS[i: i + batch_size])
|
||||
|
||||
documents_indexed = document_store.get_all_documents()
|
||||
|
||||
# test if correct vectors are associated with docs
|
||||
for i, doc in enumerate(documents_indexed):
|
||||
# we currently don't get the embeddings back when we call document_store.get_all_documents()
|
||||
@ -31,20 +71,13 @@ def test_faiss_write_docs(document_store, index_buffer_size):
|
||||
# compare original input vec with stored one (ignore extra dim added by hnsw)
|
||||
assert np.allclose(original_doc["embedding"], stored_emb[:-1], rtol=0.01)
|
||||
|
||||
# test insertion of documents in an existing index fails
|
||||
with pytest.raises(Exception):
|
||||
document_store.write_documents(DOCUMENTS)
|
||||
# test document correctness
|
||||
check_data_correctness(documents_indexed, DOCUMENTS)
|
||||
|
||||
# test saving the index
|
||||
document_store.save("haystack_test_faiss")
|
||||
|
||||
# test loading the index
|
||||
document_store.load(sql_url="sqlite:///haystack_test.db", faiss_file_path="haystack_test_faiss")
|
||||
|
||||
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
||||
@pytest.mark.parametrize("index_buffer_size", [10_000, 2])
|
||||
def test_faiss_update_docs(document_store, index_buffer_size):
|
||||
|
||||
# adjust buffer size
|
||||
document_store.index_buffer_size = index_buffer_size
|
||||
|
||||
@ -61,15 +94,6 @@ def test_faiss_update_docs(document_store, index_buffer_size):
|
||||
document_store.update_embeddings(retriever=retriever)
|
||||
documents_indexed = document_store.get_all_documents()
|
||||
|
||||
# test if number of documents is correct
|
||||
assert len(documents_indexed) == len(DOCUMENTS)
|
||||
|
||||
# test if two docs have same vector_is assigned
|
||||
vector_ids = set()
|
||||
for i, doc in enumerate(documents_indexed):
|
||||
vector_ids.add(doc.meta["vector_id"])
|
||||
assert len(vector_ids) == len(DOCUMENTS)
|
||||
|
||||
# test if correct vectors are associated with docs
|
||||
for i, doc in enumerate(documents_indexed):
|
||||
original_doc = [d for d in DOCUMENTS if d["text"] == doc.text][0]
|
||||
@ -78,25 +102,29 @@ def test_faiss_update_docs(document_store, index_buffer_size):
|
||||
# compare original input vec with stored one (ignore extra dim added by hnsw)
|
||||
assert np.allclose(updated_embedding, stored_emb[:-1], rtol=0.01)
|
||||
|
||||
# test document correctness
|
||||
check_data_correctness(documents_indexed, DOCUMENTS)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
||||
def test_faiss_retrieving(document_store):
|
||||
|
||||
document_store.write_documents(DOCUMENTS)
|
||||
|
||||
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert", use_gpu=False)
|
||||
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert",
|
||||
use_gpu=False)
|
||||
result = retriever.retrieve(query="How to test this?")
|
||||
assert len(result) == 3
|
||||
assert len(result) == len(DOCUMENTS)
|
||||
assert type(result[0]) == Document
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
||||
def test_faiss_finding(document_store):
|
||||
|
||||
document_store.write_documents(DOCUMENTS)
|
||||
|
||||
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert", use_gpu=False)
|
||||
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert",
|
||||
use_gpu=False)
|
||||
finder = Finder(reader=None, retriever=retriever)
|
||||
|
||||
prediction = finder.get_answers_via_similar_questions(question="How to test this?", top_k_retriever=1)
|
||||
|
||||
assert len(prediction.get('answers', [])) == 1
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user