Fix loading a saved FAISSDocumentStore (#1937)

* Remove faiss_index param from config

* Add Tests

* Add assertions to tests
This commit is contained in:
bogdankostic 2022-01-04 12:22:31 +01:00 committed by GitHub
parent 3e0ef1cc8a
commit 381fc302cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 1 deletions

View File

@ -106,7 +106,6 @@ class FAISSDocumentStore(SQLDocumentStore):
sql_url=sql_url,
vector_dim=vector_dim,
faiss_index_factory_str=faiss_index_factory_str,
faiss_index=faiss_index,
return_embedding=return_embedding,
duplicate_documents=duplicate_documents,
index=index,

View File

@ -51,6 +51,17 @@ def test_faiss_index_save_and_load(tmp_path):
# Check if the init parameters are kept
assert not new_document_store.progress_bar
# test saving and loading the loaded faiss index
new_document_store.save(tmp_path / "haystack_test_faiss")
reloaded_document_store = FAISSDocumentStore.load(tmp_path / "haystack_test_faiss")
# check faiss index is restored
assert reloaded_document_store.faiss_indexes[document_store.index].ntotal == len(DOCUMENTS)
# check if documents are restored
assert len(reloaded_document_store.get_all_documents()) == len(DOCUMENTS)
# Check if the init parameters are kept
assert not reloaded_document_store.progress_bar
# test loading the index via init
new_document_store = FAISSDocumentStore(faiss_index_path=tmp_path / "haystack_test_faiss")
@ -90,6 +101,17 @@ def test_faiss_index_save_and_load_custom_path(tmp_path):
# Check if the init parameters are kept
assert not new_document_store.progress_bar
# test saving and loading the loaded faiss index
new_document_store.save(tmp_path / "haystack_test_faiss", config_path=tmp_path / "custom_path.json")
reloaded_document_store = FAISSDocumentStore.load(tmp_path / "haystack_test_faiss", config_path=tmp_path / "custom_path.json")
# check faiss index is restored
assert reloaded_document_store.faiss_indexes[document_store.index].ntotal == len(DOCUMENTS)
# check if documents are restored
assert len(reloaded_document_store.get_all_documents()) == len(DOCUMENTS)
# Check if the init parameters are kept
assert not reloaded_document_store.progress_bar
# test loading the index via init
new_document_store = FAISSDocumentStore(faiss_index_path=tmp_path / "haystack_test_faiss", faiss_config_path=tmp_path / "custom_path.json")