diff --git a/haystack/document_stores/faiss.py b/haystack/document_stores/faiss.py index 213e8f7e3..bc4d4f8e2 100644 --- a/haystack/document_stores/faiss.py +++ b/haystack/document_stores/faiss.py @@ -11,10 +11,9 @@ import numpy as np from tqdm.auto import tqdm try: + # These deps are optional, but get installed with the `faiss` extra import faiss - from haystack.document_stores.sql import ( - SQLDocumentStore, - ) # its deps are optional, but get installed with the `faiss` extra + from haystack.document_stores.sql import SQLDocumentStore # type: ignore except (ImportError, ModuleNotFoundError) as ie: from haystack.utils.import_utils import _optional_component_not_installed @@ -276,6 +275,14 @@ class FAISSDocumentStore(SQLDocumentStore): ) as progress_bar: for i in range(0, len(document_objects), batch_size): if add_vectors: + if not self.faiss_indexes[index].is_trained: + raise ValueError( + "FAISS index of type {} must be trained before adding vectors. Call `train_index()` " + "method before adding the vectors. For details, refer to the documentation: " + "[FAISSDocumentStore API](https://docs.haystack.deepset.ai/reference/document-store-api#faissdocumentstoretrain_index)." + "".format(self.faiss_index_factory_str) + ) + embeddings = [doc.embedding for doc in document_objects[i : i + batch_size]] embeddings_to_index = np.array(embeddings, dtype="float32") @@ -339,6 +346,14 @@ class FAISSDocumentStore(SQLDocumentStore): if not self.faiss_indexes.get(index): raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...") + if not self.faiss_indexes[index].is_trained: + raise ValueError( + "FAISS index of type {} must be trained before adding vectors. Call `train_index()` " + "method before adding the vectors. For details, refer to the documentation: " + "[FAISSDocumentStore API](https://docs.haystack.deepset.ai/reference/document-store-api#faissdocumentstoretrain_index)." + "".format(self.faiss_index_factory_str) + ) + document_count = self.get_document_count(index=index) if document_count == 0: logger.warning("Calling DocumentStore.update_embeddings() on an empty index") diff --git a/test/document_stores/test_faiss.py b/test/document_stores/test_faiss.py index 512c1f995..a72d3abec 100644 --- a/test/document_stores/test_faiss.py +++ b/test/document_stores/test_faiss.py @@ -129,6 +129,17 @@ class TestFAISSDocumentStore(DocumentStoreBaseTestAbstract): # Check that get_embedding_count works as expected assert document_store.get_embedding_count() == len(documents_with_embeddings) + @pytest.mark.integration + def test_write_docs_no_training(self, documents_with_embeddings, tmp_path, caplog): + document_store = FAISSDocumentStore( + sql_url=f"sqlite:///{tmp_path}/test_write_docs_no_training.db", + faiss_index_factory_str="IVF1,Flat", + isolation_level="AUTOCOMMIT", + return_embedding=True, + ) + with pytest.raises(ValueError, match="must be trained before adding vectors"): + document_store.write_documents(documents_with_embeddings) + @pytest.mark.integration def test_train_index_from_docs(self, documents_with_embeddings, tmp_path): document_store = FAISSDocumentStore(