mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-30 17:29:29 +00:00
fix: when using IVF* indexing, ensure the index is trained frist (#4311)
* add protection, in case we use IVF* indexing, we need to train the index first Signed-off-by: Liu,Kaixuan <kaixuan.liu@intel.com> * fix formatting issue Signed-off-by: Liu,Kaixuan <kaixuan.liu@intel.com> * just raising error, instead of silently training the index * fixed mypy issue * fixed error msg --------- Signed-off-by: Liu,Kaixuan <kaixuan.liu@intel.com> Co-authored-by: Mayank Jobanputra <mayankjobanputra@gmail.com>
This commit is contained in:
parent
677fc8badf
commit
edf39edda0
@ -11,10 +11,9 @@ import numpy as np
|
|||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# These deps are optional, but get installed with the `faiss` extra
|
||||||
import faiss
|
import faiss
|
||||||
from haystack.document_stores.sql import (
|
from haystack.document_stores.sql import SQLDocumentStore # type: ignore
|
||||||
SQLDocumentStore,
|
|
||||||
) # its deps are optional, but get installed with the `faiss` extra
|
|
||||||
except (ImportError, ModuleNotFoundError) as ie:
|
except (ImportError, ModuleNotFoundError) as ie:
|
||||||
from haystack.utils.import_utils import _optional_component_not_installed
|
from haystack.utils.import_utils import _optional_component_not_installed
|
||||||
|
|
||||||
@ -276,6 +275,14 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
) as progress_bar:
|
) as progress_bar:
|
||||||
for i in range(0, len(document_objects), batch_size):
|
for i in range(0, len(document_objects), batch_size):
|
||||||
if add_vectors:
|
if add_vectors:
|
||||||
|
if not self.faiss_indexes[index].is_trained:
|
||||||
|
raise ValueError(
|
||||||
|
"FAISS index of type {} must be trained before adding vectors. Call `train_index()` "
|
||||||
|
"method before adding the vectors. For details, refer to the documentation: "
|
||||||
|
"[FAISSDocumentStore API](https://docs.haystack.deepset.ai/reference/document-store-api#faissdocumentstoretrain_index)."
|
||||||
|
"".format(self.faiss_index_factory_str)
|
||||||
|
)
|
||||||
|
|
||||||
embeddings = [doc.embedding for doc in document_objects[i : i + batch_size]]
|
embeddings = [doc.embedding for doc in document_objects[i : i + batch_size]]
|
||||||
embeddings_to_index = np.array(embeddings, dtype="float32")
|
embeddings_to_index = np.array(embeddings, dtype="float32")
|
||||||
|
|
||||||
@ -339,6 +346,14 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
if not self.faiss_indexes.get(index):
|
if not self.faiss_indexes.get(index):
|
||||||
raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...")
|
raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...")
|
||||||
|
|
||||||
|
if not self.faiss_indexes[index].is_trained:
|
||||||
|
raise ValueError(
|
||||||
|
"FAISS index of type {} must be trained before adding vectors. Call `train_index()` "
|
||||||
|
"method before adding the vectors. For details, refer to the documentation: "
|
||||||
|
"[FAISSDocumentStore API](https://docs.haystack.deepset.ai/reference/document-store-api#faissdocumentstoretrain_index)."
|
||||||
|
"".format(self.faiss_index_factory_str)
|
||||||
|
)
|
||||||
|
|
||||||
document_count = self.get_document_count(index=index)
|
document_count = self.get_document_count(index=index)
|
||||||
if document_count == 0:
|
if document_count == 0:
|
||||||
logger.warning("Calling DocumentStore.update_embeddings() on an empty index")
|
logger.warning("Calling DocumentStore.update_embeddings() on an empty index")
|
||||||
|
|||||||
@ -129,6 +129,17 @@ class TestFAISSDocumentStore(DocumentStoreBaseTestAbstract):
|
|||||||
# Check that get_embedding_count works as expected
|
# Check that get_embedding_count works as expected
|
||||||
assert document_store.get_embedding_count() == len(documents_with_embeddings)
|
assert document_store.get_embedding_count() == len(documents_with_embeddings)
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_write_docs_no_training(self, documents_with_embeddings, tmp_path, caplog):
|
||||||
|
document_store = FAISSDocumentStore(
|
||||||
|
sql_url=f"sqlite:///{tmp_path}/test_write_docs_no_training.db",
|
||||||
|
faiss_index_factory_str="IVF1,Flat",
|
||||||
|
isolation_level="AUTOCOMMIT",
|
||||||
|
return_embedding=True,
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="must be trained before adding vectors"):
|
||||||
|
document_store.write_documents(documents_with_embeddings)
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
def test_train_index_from_docs(self, documents_with_embeddings, tmp_path):
|
def test_train_index_from_docs(self, documents_with_embeddings, tmp_path):
|
||||||
document_store = FAISSDocumentStore(
|
document_store = FAISSDocumentStore(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user