mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-30 09:19:39 +00:00
fix: when using IVF* indexing, ensure the index is trained frist (#4311)
* add protection, in case we use IVF* indexing, we need to train the index first Signed-off-by: Liu,Kaixuan <kaixuan.liu@intel.com> * fix formatting issue Signed-off-by: Liu,Kaixuan <kaixuan.liu@intel.com> * just raising error, instead of silently training the index * fixed mypy issue * fixed error msg --------- Signed-off-by: Liu,Kaixuan <kaixuan.liu@intel.com> Co-authored-by: Mayank Jobanputra <mayankjobanputra@gmail.com>
This commit is contained in:
parent
677fc8badf
commit
edf39edda0
@ -11,10 +11,9 @@ import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
try:
|
||||
# These deps are optional, but get installed with the `faiss` extra
|
||||
import faiss
|
||||
from haystack.document_stores.sql import (
|
||||
SQLDocumentStore,
|
||||
) # its deps are optional, but get installed with the `faiss` extra
|
||||
from haystack.document_stores.sql import SQLDocumentStore # type: ignore
|
||||
except (ImportError, ModuleNotFoundError) as ie:
|
||||
from haystack.utils.import_utils import _optional_component_not_installed
|
||||
|
||||
@ -276,6 +275,14 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
) as progress_bar:
|
||||
for i in range(0, len(document_objects), batch_size):
|
||||
if add_vectors:
|
||||
if not self.faiss_indexes[index].is_trained:
|
||||
raise ValueError(
|
||||
"FAISS index of type {} must be trained before adding vectors. Call `train_index()` "
|
||||
"method before adding the vectors. For details, refer to the documentation: "
|
||||
"[FAISSDocumentStore API](https://docs.haystack.deepset.ai/reference/document-store-api#faissdocumentstoretrain_index)."
|
||||
"".format(self.faiss_index_factory_str)
|
||||
)
|
||||
|
||||
embeddings = [doc.embedding for doc in document_objects[i : i + batch_size]]
|
||||
embeddings_to_index = np.array(embeddings, dtype="float32")
|
||||
|
||||
@ -339,6 +346,14 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
if not self.faiss_indexes.get(index):
|
||||
raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...")
|
||||
|
||||
if not self.faiss_indexes[index].is_trained:
|
||||
raise ValueError(
|
||||
"FAISS index of type {} must be trained before adding vectors. Call `train_index()` "
|
||||
"method before adding the vectors. For details, refer to the documentation: "
|
||||
"[FAISSDocumentStore API](https://docs.haystack.deepset.ai/reference/document-store-api#faissdocumentstoretrain_index)."
|
||||
"".format(self.faiss_index_factory_str)
|
||||
)
|
||||
|
||||
document_count = self.get_document_count(index=index)
|
||||
if document_count == 0:
|
||||
logger.warning("Calling DocumentStore.update_embeddings() on an empty index")
|
||||
|
||||
@ -129,6 +129,17 @@ class TestFAISSDocumentStore(DocumentStoreBaseTestAbstract):
|
||||
# Check that get_embedding_count works as expected
|
||||
assert document_store.get_embedding_count() == len(documents_with_embeddings)
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_write_docs_no_training(self, documents_with_embeddings, tmp_path, caplog):
|
||||
document_store = FAISSDocumentStore(
|
||||
sql_url=f"sqlite:///{tmp_path}/test_write_docs_no_training.db",
|
||||
faiss_index_factory_str="IVF1,Flat",
|
||||
isolation_level="AUTOCOMMIT",
|
||||
return_embedding=True,
|
||||
)
|
||||
with pytest.raises(ValueError, match="must be trained before adding vectors"):
|
||||
document_store.write_documents(documents_with_embeddings)
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_train_index_from_docs(self, documents_with_embeddings, tmp_path):
|
||||
document_store = FAISSDocumentStore(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user