fix: when using IVF* indexing, ensure the index is trained frist (#4311)

* add protection, in case we use IVF* indexing, we need to train the index first

Signed-off-by: Liu,Kaixuan <kaixuan.liu@intel.com>

* fix formatting issue

Signed-off-by: Liu,Kaixuan <kaixuan.liu@intel.com>

* just raising error, instead of silently training the index

* fixed mypy issue

* fixed error msg

---------

Signed-off-by: Liu,Kaixuan <kaixuan.liu@intel.com>
Co-authored-by: Mayank Jobanputra <mayankjobanputra@gmail.com>
This commit is contained in:
kaixuanliu 2023-03-15 15:55:37 +08:00 committed by GitHub
parent 677fc8badf
commit edf39edda0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 3 deletions

View File

@ -11,10 +11,9 @@ import numpy as np
from tqdm.auto import tqdm
try:
# These deps are optional, but get installed with the `faiss` extra
import faiss
from haystack.document_stores.sql import (
SQLDocumentStore,
) # its deps are optional, but get installed with the `faiss` extra
from haystack.document_stores.sql import SQLDocumentStore # type: ignore
except (ImportError, ModuleNotFoundError) as ie:
from haystack.utils.import_utils import _optional_component_not_installed
@ -276,6 +275,14 @@ class FAISSDocumentStore(SQLDocumentStore):
) as progress_bar:
for i in range(0, len(document_objects), batch_size):
if add_vectors:
if not self.faiss_indexes[index].is_trained:
raise ValueError(
"FAISS index of type {} must be trained before adding vectors. Call `train_index()` "
"method before adding the vectors. For details, refer to the documentation: "
"[FAISSDocumentStore API](https://docs.haystack.deepset.ai/reference/document-store-api#faissdocumentstoretrain_index)."
"".format(self.faiss_index_factory_str)
)
embeddings = [doc.embedding for doc in document_objects[i : i + batch_size]]
embeddings_to_index = np.array(embeddings, dtype="float32")
@ -339,6 +346,14 @@ class FAISSDocumentStore(SQLDocumentStore):
if not self.faiss_indexes.get(index):
raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...")
if not self.faiss_indexes[index].is_trained:
raise ValueError(
"FAISS index of type {} must be trained before adding vectors. Call `train_index()` "
"method before adding the vectors. For details, refer to the documentation: "
"[FAISSDocumentStore API](https://docs.haystack.deepset.ai/reference/document-store-api#faissdocumentstoretrain_index)."
"".format(self.faiss_index_factory_str)
)
document_count = self.get_document_count(index=index)
if document_count == 0:
logger.warning("Calling DocumentStore.update_embeddings() on an empty index")

View File

@ -129,6 +129,17 @@ class TestFAISSDocumentStore(DocumentStoreBaseTestAbstract):
# Check that get_embedding_count works as expected
assert document_store.get_embedding_count() == len(documents_with_embeddings)
@pytest.mark.integration
def test_write_docs_no_training(self, documents_with_embeddings, tmp_path, caplog):
document_store = FAISSDocumentStore(
sql_url=f"sqlite:///{tmp_path}/test_write_docs_no_training.db",
faiss_index_factory_str="IVF1,Flat",
isolation_level="AUTOCOMMIT",
return_embedding=True,
)
with pytest.raises(ValueError, match="must be trained before adding vectors"):
document_store.write_documents(documents_with_embeddings)
@pytest.mark.integration
def test_train_index_from_docs(self, documents_with_embeddings, tmp_path):
document_store = FAISSDocumentStore(