Fix bug when loading FAISS from supplied config file path (#1506)

* Fix the bug found in issue 135

* Add a test for the custom path
This commit is contained in:
Sara Zan 2021-09-27 11:25:05 +02:00 committed by GitHub
parent 183fd5ae5a
commit 1cd17022af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 3 deletions

View File

@ -498,14 +498,14 @@ class FAISSDocumentStore(SQLDocumentStore):
"""
if not config_path:
index_path = Path(index_path)
faiss_init_params_path = index_path.with_suffix(".json")
config_path = index_path.with_suffix(".json")
init_params: dict = {}
try:
with open(faiss_init_params_path, 'r') as ipp:
with open(config_path, 'r') as ipp:
init_params = json.load(ipp)
except OSError as e:
raise ValueError(f"Can't open FAISS configuration file `{faiss_init_params_path}`. "
raise ValueError(f"Can't open FAISS configuration file `{config_path}`. "
"Make sure the file exists and the you have the correct permissions "
"to access it.") from e

View File

@ -46,6 +46,34 @@ def test_faiss_index_save_and_load(tmp_path):
assert not new_document_store.progress_bar
def test_faiss_index_save_and_load_custom_path(tmp_path):
document_store = FAISSDocumentStore(
sql_url=f"sqlite:////{tmp_path/'haystack_test.db'}",
index="haystack_test",
progress_bar=False # Just to check if the init parameters are kept
)
document_store.write_documents(DOCUMENTS)
# test saving the index
document_store.save(index_path=tmp_path / "haystack_test_faiss", config_path=tmp_path / "custom_path.json")
# clear existing faiss_index
document_store.faiss_indexes[document_store.index].reset()
# test faiss index is cleared
assert document_store.faiss_indexes[document_store.index].ntotal == 0
# test loading the index
new_document_store = FAISSDocumentStore.load(index_path=tmp_path / "haystack_test_faiss", config_path=tmp_path / "custom_path.json")
# check faiss index is restored
assert new_document_store.faiss_indexes[document_store.index].ntotal == len(DOCUMENTS)
# check if documents are restored
assert len(new_document_store.get_all_documents()) == len(DOCUMENTS)
# Check if the init parameters are kept
assert not new_document_store.progress_bar
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
@pytest.mark.parametrize("index_buffer_size", [10_000, 2])
@pytest.mark.parametrize("batch_size", [2])