mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-10-31 01:39:45 +00:00 
			
		
		
		
	fix: when using IVF* indexing, ensure the index is trained frist (#4311)
* add protection, in case we use IVF* indexing, we need to train the index first Signed-off-by: Liu,Kaixuan <kaixuan.liu@intel.com> * fix formatting issue Signed-off-by: Liu,Kaixuan <kaixuan.liu@intel.com> * just raising error, instead of silently training the index * fixed mypy issue * fixed error msg --------- Signed-off-by: Liu,Kaixuan <kaixuan.liu@intel.com> Co-authored-by: Mayank Jobanputra <mayankjobanputra@gmail.com>
This commit is contained in:
		
							parent
							
								
									677fc8badf
								
							
						
					
					
						commit
						edf39edda0
					
				| @ -11,10 +11,9 @@ import numpy as np | ||||
| from tqdm.auto import tqdm | ||||
| 
 | ||||
| try: | ||||
|     # These deps are optional, but get installed with the `faiss` extra | ||||
|     import faiss | ||||
|     from haystack.document_stores.sql import ( | ||||
|         SQLDocumentStore, | ||||
|     )  # its deps are optional, but get installed with the `faiss` extra | ||||
|     from haystack.document_stores.sql import SQLDocumentStore  # type: ignore | ||||
| except (ImportError, ModuleNotFoundError) as ie: | ||||
|     from haystack.utils.import_utils import _optional_component_not_installed | ||||
| 
 | ||||
| @ -276,6 +275,14 @@ class FAISSDocumentStore(SQLDocumentStore): | ||||
|             ) as progress_bar: | ||||
|                 for i in range(0, len(document_objects), batch_size): | ||||
|                     if add_vectors: | ||||
|                         if not self.faiss_indexes[index].is_trained: | ||||
|                             raise ValueError( | ||||
|                                 "FAISS index of type {} must be trained before adding vectors. Call `train_index()` " | ||||
|                                 "method before adding the vectors. For details, refer to the documentation: " | ||||
|                                 "[FAISSDocumentStore API](https://docs.haystack.deepset.ai/reference/document-store-api#faissdocumentstoretrain_index)." | ||||
|                                 "".format(self.faiss_index_factory_str) | ||||
|                             ) | ||||
| 
 | ||||
|                         embeddings = [doc.embedding for doc in document_objects[i : i + batch_size]] | ||||
|                         embeddings_to_index = np.array(embeddings, dtype="float32") | ||||
| 
 | ||||
| @ -339,6 +346,14 @@ class FAISSDocumentStore(SQLDocumentStore): | ||||
|         if not self.faiss_indexes.get(index): | ||||
|             raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...") | ||||
| 
 | ||||
|         if not self.faiss_indexes[index].is_trained: | ||||
|             raise ValueError( | ||||
|                 "FAISS index of type {} must be trained before adding vectors. Call `train_index()` " | ||||
|                 "method before adding the vectors. For details, refer to the documentation: " | ||||
|                 "[FAISSDocumentStore API](https://docs.haystack.deepset.ai/reference/document-store-api#faissdocumentstoretrain_index)." | ||||
|                 "".format(self.faiss_index_factory_str) | ||||
|             ) | ||||
| 
 | ||||
|         document_count = self.get_document_count(index=index) | ||||
|         if document_count == 0: | ||||
|             logger.warning("Calling DocumentStore.update_embeddings() on an empty index") | ||||
|  | ||||
| @ -129,6 +129,17 @@ class TestFAISSDocumentStore(DocumentStoreBaseTestAbstract): | ||||
|         # Check that get_embedding_count works as expected | ||||
|         assert document_store.get_embedding_count() == len(documents_with_embeddings) | ||||
| 
 | ||||
|     @pytest.mark.integration | ||||
|     def test_write_docs_no_training(self, documents_with_embeddings, tmp_path, caplog): | ||||
|         document_store = FAISSDocumentStore( | ||||
|             sql_url=f"sqlite:///{tmp_path}/test_write_docs_no_training.db", | ||||
|             faiss_index_factory_str="IVF1,Flat", | ||||
|             isolation_level="AUTOCOMMIT", | ||||
|             return_embedding=True, | ||||
|         ) | ||||
|         with pytest.raises(ValueError, match="must be trained before adding vectors"): | ||||
|             document_store.write_documents(documents_with_embeddings) | ||||
| 
 | ||||
|     @pytest.mark.integration | ||||
|     def test_train_index_from_docs(self, documents_with_embeddings, tmp_path): | ||||
|         document_store = FAISSDocumentStore( | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 kaixuanliu
						kaixuanliu