| 
									
										
										
										
											2022-05-04 17:39:06 +02:00
										 |  |  | import os | 
					
						
							|  |  |  | import sys | 
					
						
							| 
									
										
										
										
											2021-09-20 07:54:26 +02:00
										 |  |  | import math | 
					
						
							| 
									
										
										
										
											2022-05-04 17:39:06 +02:00
										 |  |  | from pathlib import Path | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import yaml | 
					
						
							|  |  |  | import faiss | 
					
						
							| 
									
										
										
										
											2020-08-07 14:25:08 +02:00
										 |  |  | import pytest | 
					
						
							| 
									
										
										
										
											2022-05-04 17:39:06 +02:00
										 |  |  | import numpy as np | 
					
						
							| 
									
										
										
										
											2021-11-01 15:42:32 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-10-25 15:50:23 +02:00
										 |  |  | from haystack.schema import Document | 
					
						
							| 
									
										
										
										
											2021-11-01 15:42:32 +03:00
										 |  |  | from haystack.pipelines import DocumentSearchPipeline | 
					
						
							| 
									
										
										
										
											2022-05-04 17:39:06 +02:00
										 |  |  | from haystack.document_stores.base import BaseDocumentStore | 
					
						
							| 
									
										
										
										
											2021-10-25 15:50:23 +02:00
										 |  |  | from haystack.document_stores.faiss import FAISSDocumentStore | 
					
						
							| 
									
										
										
										
											2021-11-01 15:42:32 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | from haystack.pipelines import Pipeline | 
					
						
							|  |  |  | from haystack.nodes.retriever.dense import EmbeddingRetriever | 
					
						
							| 
									
										
										
										
											2020-09-18 17:08:13 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-05-04 17:39:06 +02:00
										 |  |  | from .conftest import document_classifier, ensure_ids_are_correct_uuids, SAMPLES_PATH, MockDenseRetriever | 
					
						
							| 
									
										
										
										
											2022-01-14 13:48:58 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-18 17:08:13 +02:00
										 |  |  | DOCUMENTS = [ | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     { | 
					
						
							|  |  |  |         "meta": {"name": "name_1", "year": "2020", "month": "01"}, | 
					
						
							|  |  |  |         "content": "text_1", | 
					
						
							|  |  |  |         "embedding": np.random.rand(768).astype(np.float32), | 
					
						
							|  |  |  |     }, | 
					
						
							|  |  |  |     { | 
					
						
							|  |  |  |         "meta": {"name": "name_2", "year": "2020", "month": "02"}, | 
					
						
							|  |  |  |         "content": "text_2", | 
					
						
							|  |  |  |         "embedding": np.random.rand(768).astype(np.float32), | 
					
						
							|  |  |  |     }, | 
					
						
							|  |  |  |     { | 
					
						
							|  |  |  |         "meta": {"name": "name_3", "year": "2020", "month": "03"}, | 
					
						
							|  |  |  |         "content": "text_3", | 
					
						
							|  |  |  |         "embedding": np.random.rand(768).astype(np.float64), | 
					
						
							|  |  |  |     }, | 
					
						
							|  |  |  |     { | 
					
						
							|  |  |  |         "meta": {"name": "name_4", "year": "2021", "month": "01"}, | 
					
						
							|  |  |  |         "content": "text_4", | 
					
						
							|  |  |  |         "embedding": np.random.rand(768).astype(np.float32), | 
					
						
							|  |  |  |     }, | 
					
						
							|  |  |  |     { | 
					
						
							|  |  |  |         "meta": {"name": "name_5", "year": "2021", "month": "02"}, | 
					
						
							|  |  |  |         "content": "text_5", | 
					
						
							|  |  |  |         "embedding": np.random.rand(768).astype(np.float32), | 
					
						
							|  |  |  |     }, | 
					
						
							|  |  |  |     { | 
					
						
							|  |  |  |         "meta": {"name": "name_6", "year": "2021", "month": "03"}, | 
					
						
							|  |  |  |         "content": "text_6", | 
					
						
							|  |  |  |         "embedding": np.random.rand(768).astype(np.float64), | 
					
						
							|  |  |  |     }, | 
					
						
							| 
									
										
										
										
											2020-09-18 17:08:13 +02:00
										 |  |  | ] | 
					
						
							| 
									
										
										
										
											2020-09-18 12:52:22 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-07 14:25:08 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | @pytest.mark.skipif(sys.platform in ["win32", "cygwin"], reason="Test with tmp_path not working on windows runner") | 
					
						
							| 
									
										
										
										
											2022-01-14 13:48:58 +01:00
										 |  |  | def test_faiss_index_save_and_load(tmp_path, sql_url): | 
					
						
							| 
									
										
										
										
											2021-04-27 09:55:31 +02:00
										 |  |  |     document_store = FAISSDocumentStore( | 
					
						
							| 
									
										
										
										
											2022-01-14 13:48:58 +01:00
										 |  |  |         sql_url=sql_url, | 
					
						
							| 
									
										
										
										
											2021-04-27 09:55:31 +02:00
										 |  |  |         index="haystack_test", | 
					
						
							| 
									
										
										
										
											2022-01-14 13:48:58 +01:00
										 |  |  |         progress_bar=False,  # Just to check if the init parameters are kept | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |         isolation_level="AUTOCOMMIT", | 
					
						
							| 
									
										
										
										
											2021-04-27 09:55:31 +02:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-01-11 13:24:38 +01:00
										 |  |  |     document_store.write_documents(DOCUMENTS) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # test saving the index | 
					
						
							| 
									
										
										
										
											2021-04-27 09:55:31 +02:00
										 |  |  |     document_store.save(tmp_path / "haystack_test_faiss") | 
					
						
							| 
									
										
										
										
											2021-01-11 13:24:38 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # clear existing faiss_index | 
					
						
							| 
									
										
										
										
											2021-02-09 21:25:01 +01:00
										 |  |  |     document_store.faiss_indexes[document_store.index].reset() | 
					
						
							| 
									
										
										
										
											2021-01-11 13:24:38 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # test faiss index is cleared | 
					
						
							| 
									
										
										
										
											2021-02-09 21:25:01 +01:00
										 |  |  |     assert document_store.faiss_indexes[document_store.index].ntotal == 0 | 
					
						
							| 
									
										
										
										
											2021-01-11 13:24:38 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # test loading the index | 
					
						
							| 
									
										
										
										
											2021-09-20 08:32:14 +02:00
										 |  |  |     new_document_store = FAISSDocumentStore.load(tmp_path / "haystack_test_faiss") | 
					
						
							| 
									
										
										
										
											2021-01-11 13:24:38 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # check faiss index is restored | 
					
						
							| 
									
										
										
										
											2021-02-09 21:25:01 +01:00
										 |  |  |     assert new_document_store.faiss_indexes[document_store.index].ntotal == len(DOCUMENTS) | 
					
						
							| 
									
										
										
										
											2021-04-27 09:55:31 +02:00
										 |  |  |     # check if documents are restored | 
					
						
							|  |  |  |     assert len(new_document_store.get_all_documents()) == len(DOCUMENTS) | 
					
						
							| 
									
										
										
										
											2021-09-20 08:32:14 +02:00
										 |  |  |     # Check if the init parameters are kept | 
					
						
							|  |  |  |     assert not new_document_store.progress_bar | 
					
						
							| 
									
										
										
										
											2020-10-05 12:01:20 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-01-04 12:22:31 +01:00
										 |  |  |     # test saving and loading the loaded faiss index | 
					
						
							|  |  |  |     new_document_store.save(tmp_path / "haystack_test_faiss") | 
					
						
							|  |  |  |     reloaded_document_store = FAISSDocumentStore.load(tmp_path / "haystack_test_faiss") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # check faiss index is restored | 
					
						
							|  |  |  |     assert reloaded_document_store.faiss_indexes[document_store.index].ntotal == len(DOCUMENTS) | 
					
						
							|  |  |  |     # check if documents are restored | 
					
						
							|  |  |  |     assert len(reloaded_document_store.get_all_documents()) == len(DOCUMENTS) | 
					
						
							|  |  |  |     # Check if the init parameters are kept | 
					
						
							|  |  |  |     assert not reloaded_document_store.progress_bar | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-11 11:02:22 +01:00
										 |  |  |     # test loading the index via init | 
					
						
							|  |  |  |     new_document_store = FAISSDocumentStore(faiss_index_path=tmp_path / "haystack_test_faiss") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # check faiss index is restored | 
					
						
							|  |  |  |     assert new_document_store.faiss_indexes[document_store.index].ntotal == len(DOCUMENTS) | 
					
						
							|  |  |  |     # check if documents are restored | 
					
						
							|  |  |  |     assert len(new_document_store.get_all_documents()) == len(DOCUMENTS) | 
					
						
							|  |  |  |     # Check if the init parameters are kept | 
					
						
							|  |  |  |     assert not new_document_store.progress_bar | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-05 12:01:20 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | @pytest.mark.skipif(sys.platform in ["win32", "cygwin"], reason="Test with tmp_path not working on windows runner") | 
					
						
							| 
									
										
										
										
											2022-01-14 13:48:58 +01:00
										 |  |  | def test_faiss_index_save_and_load_custom_path(tmp_path, sql_url): | 
					
						
							| 
									
										
										
										
											2021-09-27 11:25:05 +02:00
										 |  |  |     document_store = FAISSDocumentStore( | 
					
						
							| 
									
										
										
										
											2022-01-14 13:48:58 +01:00
										 |  |  |         sql_url=sql_url, | 
					
						
							| 
									
										
										
										
											2021-09-27 11:25:05 +02:00
										 |  |  |         index="haystack_test", | 
					
						
							| 
									
										
										
										
											2022-01-14 13:48:58 +01:00
										 |  |  |         progress_bar=False,  # Just to check if the init parameters are kept | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |         isolation_level="AUTOCOMMIT", | 
					
						
							| 
									
										
										
										
											2021-09-27 11:25:05 +02:00
										 |  |  |     ) | 
					
						
							|  |  |  |     document_store.write_documents(DOCUMENTS) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # test saving the index | 
					
						
							|  |  |  |     document_store.save(index_path=tmp_path / "haystack_test_faiss", config_path=tmp_path / "custom_path.json") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # clear existing faiss_index | 
					
						
							|  |  |  |     document_store.faiss_indexes[document_store.index].reset() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # test faiss index is cleared | 
					
						
							|  |  |  |     assert document_store.faiss_indexes[document_store.index].ntotal == 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # test loading the index | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     new_document_store = FAISSDocumentStore.load( | 
					
						
							|  |  |  |         index_path=tmp_path / "haystack_test_faiss", config_path=tmp_path / "custom_path.json" | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-09-27 11:25:05 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # check faiss index is restored | 
					
						
							|  |  |  |     assert new_document_store.faiss_indexes[document_store.index].ntotal == len(DOCUMENTS) | 
					
						
							|  |  |  |     # check if documents are restored | 
					
						
							|  |  |  |     assert len(new_document_store.get_all_documents()) == len(DOCUMENTS) | 
					
						
							|  |  |  |     # Check if the init parameters are kept | 
					
						
							|  |  |  |     assert not new_document_store.progress_bar | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-01-04 12:22:31 +01:00
										 |  |  |     # test saving and loading the loaded faiss index | 
					
						
							|  |  |  |     new_document_store.save(tmp_path / "haystack_test_faiss", config_path=tmp_path / "custom_path.json") | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     reloaded_document_store = FAISSDocumentStore.load( | 
					
						
							|  |  |  |         tmp_path / "haystack_test_faiss", config_path=tmp_path / "custom_path.json" | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2022-01-04 12:22:31 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # check faiss index is restored | 
					
						
							|  |  |  |     assert reloaded_document_store.faiss_indexes[document_store.index].ntotal == len(DOCUMENTS) | 
					
						
							|  |  |  |     # check if documents are restored | 
					
						
							|  |  |  |     assert len(reloaded_document_store.get_all_documents()) == len(DOCUMENTS) | 
					
						
							|  |  |  |     # Check if the init parameters are kept | 
					
						
							|  |  |  |     assert not reloaded_document_store.progress_bar | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-11 11:02:22 +01:00
										 |  |  |     # test loading the index via init | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     new_document_store = FAISSDocumentStore( | 
					
						
							|  |  |  |         faiss_index_path=tmp_path / "haystack_test_faiss", faiss_config_path=tmp_path / "custom_path.json" | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-11-11 11:02:22 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # check faiss index is restored | 
					
						
							|  |  |  |     assert new_document_store.faiss_indexes[document_store.index].ntotal == len(DOCUMENTS) | 
					
						
							|  |  |  |     # check if documents are restored | 
					
						
							|  |  |  |     assert len(new_document_store.get_all_documents()) == len(DOCUMENTS) | 
					
						
							|  |  |  |     # Check if the init parameters are kept | 
					
						
							|  |  |  |     assert not new_document_store.progress_bar | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | @pytest.mark.skipif(sys.platform in ["win32", "cygwin"], reason="Test with tmp_path not working on windows runner") | 
					
						
							| 
									
										
										
										
											2021-11-11 11:02:22 +01:00
										 |  |  | def test_faiss_index_mutual_exclusive_args(tmp_path): | 
					
						
							|  |  |  |     with pytest.raises(ValueError): | 
					
						
							|  |  |  |         FAISSDocumentStore( | 
					
						
							|  |  |  |             sql_url=f"sqlite:////{tmp_path/'haystack_test.db'}", | 
					
						
							| 
									
										
										
										
											2022-01-14 13:48:58 +01:00
										 |  |  |             faiss_index_path=f"{tmp_path/'haystack_test'}", | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |             isolation_level="AUTOCOMMIT", | 
					
						
							| 
									
										
										
										
											2021-11-11 11:02:22 +01:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     with pytest.raises(ValueError): | 
					
						
							|  |  |  |         FAISSDocumentStore( | 
					
						
							|  |  |  |             f"sqlite:////{tmp_path/'haystack_test.db'}", | 
					
						
							| 
									
										
										
										
											2022-01-14 13:48:58 +01:00
										 |  |  |             faiss_index_path=f"{tmp_path/'haystack_test'}", | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |             isolation_level="AUTOCOMMIT", | 
					
						
							| 
									
										
										
										
											2021-11-11 11:02:22 +01:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-27 11:25:05 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-12-14 18:15:44 +01:00
										 |  |  | @pytest.mark.parametrize("document_store", ["faiss"], indirect=True) | 
					
						
							| 
									
										
										
										
											2020-10-05 12:01:20 +02:00
										 |  |  | @pytest.mark.parametrize("index_buffer_size", [10_000, 2]) | 
					
						
							|  |  |  | @pytest.mark.parametrize("batch_size", [2]) | 
					
						
							| 
									
										
										
										
											2020-12-14 18:15:44 +01:00
										 |  |  | def test_faiss_write_docs(document_store, index_buffer_size, batch_size): | 
					
						
							|  |  |  |     document_store.index_buffer_size = index_buffer_size | 
					
						
							| 
									
										
										
										
											2020-10-05 12:01:20 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # Write in small batches | 
					
						
							|  |  |  |     for i in range(0, len(DOCUMENTS), batch_size): | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |         document_store.write_documents(DOCUMENTS[i : i + batch_size]) | 
					
						
							| 
									
										
										
										
											2020-10-05 12:01:20 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-12-14 18:15:44 +01:00
										 |  |  |     documents_indexed = document_store.get_all_documents() | 
					
						
							| 
									
										
										
										
											2021-01-21 16:00:08 +01:00
										 |  |  |     assert len(documents_indexed) == len(DOCUMENTS) | 
					
						
							| 
									
										
										
										
											2020-10-05 12:01:20 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-18 12:52:22 +02:00
										 |  |  |     # test if correct vectors are associated with docs | 
					
						
							|  |  |  |     for i, doc in enumerate(documents_indexed): | 
					
						
							|  |  |  |         # we currently don't get the embeddings back when we call document_store.get_all_documents() | 
					
						
							| 
									
										
										
										
											2021-10-13 14:23:23 +02:00
										 |  |  |         original_doc = [d for d in DOCUMENTS if d["content"] == doc.content][0] | 
					
						
							| 
									
										
										
										
											2021-02-09 21:25:01 +01:00
										 |  |  |         stored_emb = document_store.faiss_indexes[document_store.index].reconstruct(int(doc.meta["vector_id"])) | 
					
						
							| 
									
										
										
										
											2020-09-18 12:52:22 +02:00
										 |  |  |         # compare original input vec with stored one (ignore extra dim added by hnsw) | 
					
						
							| 
									
										
										
										
											2022-01-12 19:28:20 +01:00
										 |  |  |         # original input vec is normalized as faiss only stores normalized vectors | 
					
						
							|  |  |  |         assert np.allclose(original_doc["embedding"] / np.linalg.norm(original_doc["embedding"]), stored_emb, rtol=0.01) | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-18 12:52:22 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-26 19:19:10 +01:00
										 |  |  | @pytest.mark.slow | 
					
						
							| 
									
										
										
										
											2020-12-14 18:15:44 +01:00
										 |  |  | @pytest.mark.parametrize("retriever", ["dpr"], indirect=True) | 
					
						
							| 
									
										
										
										
											2022-02-24 17:43:38 +01:00
										 |  |  | @pytest.mark.parametrize("document_store", ["faiss", "milvus1", "milvus"], indirect=True) | 
					
						
							| 
									
										
										
										
											2021-01-21 16:00:08 +01:00
										 |  |  | @pytest.mark.parametrize("batch_size", [4, 6]) | 
					
						
							| 
									
										
										
										
											2021-01-29 13:29:12 +01:00
										 |  |  | def test_update_docs(document_store, retriever, batch_size): | 
					
						
							| 
									
										
										
										
											2020-09-18 12:52:22 +02:00
										 |  |  |     # initial write | 
					
						
							| 
									
										
										
										
											2020-12-14 18:15:44 +01:00
										 |  |  |     document_store.write_documents(DOCUMENTS) | 
					
						
							| 
									
										
										
										
											2020-09-18 12:52:22 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-21 16:00:08 +01:00
										 |  |  |     document_store.update_embeddings(retriever=retriever, batch_size=batch_size) | 
					
						
							| 
									
										
										
										
											2020-12-14 18:15:44 +01:00
										 |  |  |     documents_indexed = document_store.get_all_documents() | 
					
						
							| 
									
										
										
										
											2021-01-21 16:00:08 +01:00
										 |  |  |     assert len(documents_indexed) == len(DOCUMENTS) | 
					
						
							| 
									
										
										
										
											2020-09-18 12:52:22 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # test if correct vectors are associated with docs | 
					
						
							| 
									
										
										
										
											2021-01-21 16:00:08 +01:00
										 |  |  |     for doc in documents_indexed: | 
					
						
							| 
									
										
										
										
											2021-10-13 14:23:23 +02:00
										 |  |  |         original_doc = [d for d in DOCUMENTS if d["content"] == doc.content][0] | 
					
						
							| 
									
										
										
										
											2021-10-28 12:17:56 +02:00
										 |  |  |         updated_embedding = retriever.embed_documents([Document.from_dict(original_doc)]) | 
					
						
							| 
									
										
										
										
											2021-01-21 16:00:08 +01:00
										 |  |  |         stored_doc = document_store.get_all_documents(filters={"name": [doc.meta["name"]]})[0] | 
					
						
							| 
									
										
										
										
											2020-09-18 12:52:22 +02:00
										 |  |  |         # compare original input vec with stored one (ignore extra dim added by hnsw) | 
					
						
							| 
									
										
										
										
											2022-01-12 19:28:20 +01:00
										 |  |  |         # original input vec is normalized as faiss only stores normalized vectors | 
					
						
							| 
									
										
										
										
											2022-02-24 17:43:38 +01:00
										 |  |  |         a = updated_embedding / np.linalg.norm(updated_embedding) | 
					
						
							|  |  |  |         assert np.allclose(a[0], stored_doc.embedding, rtol=0.2)  # high tolerance necessary for Milvus 2 | 
					
						
							| 
									
										
										
										
											2020-10-05 12:01:20 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-29 13:29:12 +01:00
										 |  |  | @pytest.mark.slow | 
					
						
							| 
									
										
										
										
											2020-12-14 18:15:44 +01:00
										 |  |  | @pytest.mark.parametrize("retriever", ["dpr"], indirect=True) | 
					
						
							| 
									
										
										
										
											2022-02-24 17:43:38 +01:00
										 |  |  | @pytest.mark.parametrize("document_store", ["milvus1", "milvus", "faiss"], indirect=True) | 
					
						
							| 
									
										
										
										
											2021-05-25 16:30:06 +05:00
										 |  |  | def test_update_existing_docs(document_store, retriever): | 
					
						
							|  |  |  |     document_store.duplicate_documents = "overwrite" | 
					
						
							| 
									
										
										
										
											2021-10-13 14:23:23 +02:00
										 |  |  |     old_document = Document(content="text_1") | 
					
						
							| 
									
										
										
										
											2021-01-29 13:29:12 +01:00
										 |  |  |     # initial write | 
					
						
							|  |  |  |     document_store.write_documents([old_document]) | 
					
						
							|  |  |  |     document_store.update_embeddings(retriever=retriever) | 
					
						
							| 
									
										
										
										
											2022-02-24 17:43:38 +01:00
										 |  |  |     old_documents_indexed = document_store.get_all_documents(return_embedding=True) | 
					
						
							| 
									
										
										
										
											2021-01-29 13:29:12 +01:00
										 |  |  |     assert len(old_documents_indexed) == 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Update document data | 
					
						
							| 
									
										
										
										
											2021-10-13 14:23:23 +02:00
										 |  |  |     new_document = Document(content="text_2") | 
					
						
							| 
									
										
										
										
											2021-01-29 13:29:12 +01:00
										 |  |  |     new_document.id = old_document.id | 
					
						
							|  |  |  |     document_store.write_documents([new_document]) | 
					
						
							|  |  |  |     document_store.update_embeddings(retriever=retriever) | 
					
						
							| 
									
										
										
										
											2022-02-24 17:43:38 +01:00
										 |  |  |     new_documents_indexed = document_store.get_all_documents(return_embedding=True) | 
					
						
							| 
									
										
										
										
											2021-01-29 13:29:12 +01:00
										 |  |  |     assert len(new_documents_indexed) == 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     assert old_documents_indexed[0].id == new_documents_indexed[0].id | 
					
						
							| 
									
										
										
										
											2021-10-13 14:23:23 +02:00
										 |  |  |     assert old_documents_indexed[0].content == "text_1" | 
					
						
							|  |  |  |     assert new_documents_indexed[0].content == "text_2" | 
					
						
							| 
									
										
										
										
											2022-02-24 17:43:38 +01:00
										 |  |  |     print(type(old_documents_indexed[0].embedding)) | 
					
						
							|  |  |  |     print(type(new_documents_indexed[0].embedding)) | 
					
						
							| 
									
										
										
										
											2021-01-29 13:29:12 +01:00
										 |  |  |     assert not np.allclose(old_documents_indexed[0].embedding, new_documents_indexed[0].embedding, rtol=0.01) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @pytest.mark.parametrize("retriever", ["dpr"], indirect=True) | 
					
						
							| 
									
										
										
										
											2022-02-24 17:43:38 +01:00
										 |  |  | @pytest.mark.parametrize("document_store", ["faiss", "milvus1", "milvus"], indirect=True) | 
					
						
							| 
									
										
										
										
											2021-01-29 13:29:12 +01:00
										 |  |  | def test_update_with_empty_store(document_store, retriever): | 
					
						
							| 
									
										
										
										
											2020-10-14 16:15:04 +02:00
										 |  |  |     # Call update with empty doc store | 
					
						
							| 
									
										
										
										
											2020-12-14 18:15:44 +01:00
										 |  |  |     document_store.update_embeddings(retriever=retriever) | 
					
						
							| 
									
										
										
										
											2020-10-14 16:15:04 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # initial write | 
					
						
							| 
									
										
										
										
											2020-12-14 18:15:44 +01:00
										 |  |  |     document_store.write_documents(DOCUMENTS) | 
					
						
							| 
									
										
										
										
											2020-09-18 17:08:13 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-12-14 18:15:44 +01:00
										 |  |  |     documents_indexed = document_store.get_all_documents() | 
					
						
							| 
									
										
										
										
											2020-10-14 16:15:04 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-21 16:00:08 +01:00
										 |  |  |     assert len(documents_indexed) == len(DOCUMENTS) | 
					
						
							| 
									
										
										
										
											2020-10-14 16:15:04 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | @pytest.mark.skipif(sys.platform in ["win32", "cygwin"], reason="Test with tmp_path not working on windows runner") | 
					
						
							| 
									
										
										
										
											2020-11-26 10:32:30 +01:00
										 |  |  | @pytest.mark.parametrize("index_factory", ["Flat", "HNSW", "IVF1,Flat"]) | 
					
						
							| 
									
										
										
										
											2021-04-27 09:55:31 +02:00
										 |  |  | def test_faiss_retrieving(index_factory, tmp_path): | 
					
						
							| 
									
										
										
										
											2021-01-06 15:56:19 +01:00
										 |  |  |     document_store = FAISSDocumentStore( | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |         sql_url=f"sqlite:////{tmp_path/'test_faiss_retrieving.db'}", | 
					
						
							| 
									
										
										
										
											2022-01-14 13:48:58 +01:00
										 |  |  |         faiss_index_factory_str=index_factory, | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |         isolation_level="AUTOCOMMIT", | 
					
						
							| 
									
										
										
										
											2021-01-06 15:56:19 +01:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-26 10:32:30 +01:00
										 |  |  |     document_store.delete_all_documents(index="document") | 
					
						
							|  |  |  |     if "ivf" in index_factory.lower(): | 
					
						
							|  |  |  |         document_store.train_index(DOCUMENTS) | 
					
						
							|  |  |  |     document_store.write_documents(DOCUMENTS) | 
					
						
							| 
									
										
										
										
											2021-01-06 15:56:19 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     retriever = EmbeddingRetriever( | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |         document_store=document_store, embedding_model="deepset/sentence_bert", use_gpu=False | 
					
						
							| 
									
										
										
										
											2021-01-06 15:56:19 +01:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-11-26 10:32:30 +01:00
										 |  |  |     result = retriever.retrieve(query="How to test this?") | 
					
						
							| 
									
										
										
										
											2021-01-06 15:56:19 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-05 12:01:20 +02:00
										 |  |  |     assert len(result) == len(DOCUMENTS) | 
					
						
							| 
									
										
										
										
											2020-09-18 17:08:13 +02:00
										 |  |  |     assert type(result[0]) == Document | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-06 15:56:19 +01:00
										 |  |  |     # Cleanup | 
					
						
							| 
									
										
										
										
											2021-02-09 21:25:01 +01:00
										 |  |  |     document_store.faiss_indexes[document_store.index].reset() | 
					
						
							| 
									
										
										
										
											2021-01-06 15:56:19 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-05 12:01:20 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-12-14 18:15:44 +01:00
										 |  |  | @pytest.mark.parametrize("retriever", ["embedding"], indirect=True) | 
					
						
							| 
									
										
										
										
											2022-02-24 17:43:38 +01:00
										 |  |  | @pytest.mark.parametrize("document_store", ["faiss", "milvus1", "milvus"], indirect=True) | 
					
						
							| 
									
										
										
										
											2021-01-29 13:29:12 +01:00
										 |  |  | def test_finding(document_store, retriever): | 
					
						
							| 
									
										
										
										
											2020-12-14 18:15:44 +01:00
										 |  |  |     document_store.write_documents(DOCUMENTS) | 
					
						
							| 
									
										
										
										
											2021-08-09 13:41:40 +02:00
										 |  |  |     pipe = DocumentSearchPipeline(retriever=retriever) | 
					
						
							| 
									
										
										
										
											2020-09-18 17:08:13 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-10-19 15:22:44 +02:00
										 |  |  |     prediction = pipe.run(query="How to test this?", params={"Retriever": {"top_k": 1}}) | 
					
						
							| 
									
										
										
										
											2020-09-18 12:52:22 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     assert len(prediction.get("documents", [])) == 1 | 
					
						
							| 
									
										
										
										
											2020-10-06 16:09:56 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-14 16:15:04 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-29 09:27:06 +02:00
										 |  |  | @pytest.mark.slow | 
					
						
							|  |  |  | @pytest.mark.parametrize("retriever", ["dpr"], indirect=True) | 
					
						
							| 
									
										
										
										
											2022-02-24 17:43:38 +01:00
										 |  |  | @pytest.mark.parametrize("document_store", ["faiss", "milvus1", "milvus"], indirect=True) | 
					
						
							| 
									
										
										
										
											2021-09-29 09:27:06 +02:00
										 |  |  | def test_delete_docs_with_filters(document_store, retriever): | 
					
						
							|  |  |  |     document_store.write_documents(DOCUMENTS) | 
					
						
							|  |  |  |     document_store.update_embeddings(retriever=retriever, batch_size=4) | 
					
						
							|  |  |  |     assert document_store.get_embedding_count() == 6 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     document_store.delete_documents(filters={"name": ["name_1", "name_2", "name_3", "name_4"]}) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     documents = document_store.get_all_documents() | 
					
						
							|  |  |  |     assert len(documents) == 2 | 
					
						
							|  |  |  |     assert document_store.get_embedding_count() == 2 | 
					
						
							|  |  |  |     assert {doc.meta["name"] for doc in documents} == {"name_5", "name_6"} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-12-01 16:16:17 +01:00
										 |  |  | @pytest.mark.slow | 
					
						
							|  |  |  | @pytest.mark.parametrize("retriever", ["dpr"], indirect=True) | 
					
						
							| 
									
										
										
										
											2022-02-24 17:43:38 +01:00
										 |  |  | @pytest.mark.parametrize("document_store", ["faiss", "milvus1", "milvus"], indirect=True) | 
					
						
							| 
									
										
										
										
											2021-12-01 16:16:17 +01:00
										 |  |  | def test_delete_docs_with_filters(document_store, retriever): | 
					
						
							|  |  |  |     document_store.write_documents(DOCUMENTS) | 
					
						
							|  |  |  |     document_store.update_embeddings(retriever=retriever, batch_size=4) | 
					
						
							|  |  |  |     assert document_store.get_embedding_count() == 6 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     document_store.delete_documents(filters={"year": ["2020"]}) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     documents = document_store.get_all_documents() | 
					
						
							|  |  |  |     assert len(documents) == 3 | 
					
						
							|  |  |  |     assert document_store.get_embedding_count() == 3 | 
					
						
							|  |  |  |     assert all("2021" == doc.meta["year"] for doc in documents) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @pytest.mark.slow | 
					
						
							|  |  |  | @pytest.mark.parametrize("retriever", ["dpr"], indirect=True) | 
					
						
							| 
									
										
										
										
											2022-02-24 17:43:38 +01:00
										 |  |  | @pytest.mark.parametrize("document_store", ["faiss", "milvus1", "milvus"], indirect=True) | 
					
						
							| 
									
										
										
										
											2021-12-01 16:16:17 +01:00
										 |  |  | def test_delete_docs_with_many_filters(document_store, retriever): | 
					
						
							|  |  |  |     document_store.write_documents(DOCUMENTS) | 
					
						
							|  |  |  |     document_store.update_embeddings(retriever=retriever, batch_size=4) | 
					
						
							|  |  |  |     assert document_store.get_embedding_count() == 6 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     document_store.delete_documents(filters={"month": ["01"], "year": ["2020"]}) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     documents = document_store.get_all_documents() | 
					
						
							|  |  |  |     assert len(documents) == 5 | 
					
						
							|  |  |  |     assert document_store.get_embedding_count() == 5 | 
					
						
							|  |  |  |     assert "name_1" not in {doc.meta["name"] for doc in documents} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-10-19 12:30:15 +02:00
										 |  |  | @pytest.mark.slow | 
					
						
							|  |  |  | @pytest.mark.parametrize("retriever", ["dpr"], indirect=True) | 
					
						
							| 
									
										
										
										
											2022-02-24 17:43:38 +01:00
										 |  |  | @pytest.mark.parametrize("document_store", ["faiss", "milvus1", "milvus"], indirect=True) | 
					
						
							| 
									
										
										
										
											2021-10-19 12:30:15 +02:00
										 |  |  | def test_delete_docs_by_id(document_store, retriever): | 
					
						
							|  |  |  |     document_store.write_documents(DOCUMENTS) | 
					
						
							|  |  |  |     document_store.update_embeddings(retriever=retriever, batch_size=4) | 
					
						
							|  |  |  |     assert document_store.get_embedding_count() == 6 | 
					
						
							|  |  |  |     doc_ids = [doc.id for doc in document_store.get_all_documents()] | 
					
						
							|  |  |  |     ids_to_delete = doc_ids[0:3] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     document_store.delete_documents(ids=ids_to_delete) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     documents = document_store.get_all_documents() | 
					
						
							|  |  |  |     assert len(documents) == len(doc_ids) - len(ids_to_delete) | 
					
						
							|  |  |  |     assert document_store.get_embedding_count() == len(doc_ids) - len(ids_to_delete) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     remaining_ids = [doc.id for doc in documents] | 
					
						
							|  |  |  |     assert all(doc_id not in remaining_ids for doc_id in ids_to_delete) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @pytest.mark.slow | 
					
						
							|  |  |  | @pytest.mark.parametrize("retriever", ["dpr"], indirect=True) | 
					
						
							| 
									
										
										
										
											2022-02-24 17:43:38 +01:00
										 |  |  | @pytest.mark.parametrize("document_store", ["faiss", "milvus1", "milvus"], indirect=True) | 
					
						
							| 
									
										
										
										
											2021-10-19 12:30:15 +02:00
										 |  |  | def test_delete_docs_by_id_with_filters(document_store, retriever): | 
					
						
							|  |  |  |     document_store.write_documents(DOCUMENTS) | 
					
						
							|  |  |  |     document_store.update_embeddings(retriever=retriever, batch_size=4) | 
					
						
							|  |  |  |     assert document_store.get_embedding_count() == 6 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ids_to_delete = [doc.id for doc in document_store.get_all_documents(filters={"name": ["name_1", "name_2"]})] | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     ids_not_to_delete = [ | 
					
						
							|  |  |  |         doc.id for doc in document_store.get_all_documents(filters={"name": ["name_3", "name_4", "name_5", "name_6"]}) | 
					
						
							|  |  |  |     ] | 
					
						
							| 
									
										
										
										
											2021-10-19 12:30:15 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     document_store.delete_documents(ids=ids_to_delete, filters={"name": ["name_1", "name_2", "name_3", "name_4"]}) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     documents = document_store.get_all_documents() | 
					
						
							|  |  |  |     assert len(documents) == len(DOCUMENTS) - len(ids_to_delete) | 
					
						
							|  |  |  |     assert document_store.get_embedding_count() == len(DOCUMENTS) - len(ids_to_delete) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     assert all(doc.meta["name"] != "name_1" for doc in documents) | 
					
						
							|  |  |  |     assert all(doc.meta["name"] != "name_2" for doc in documents) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     all_ids_left = [doc.id for doc in documents] | 
					
						
							|  |  |  |     assert all(doc_id in all_ids_left for doc_id in ids_not_to_delete) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-10-29 13:52:28 +05:30
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-12-01 16:16:17 +01:00
										 |  |  | @pytest.mark.slow | 
					
						
							|  |  |  | @pytest.mark.parametrize("retriever", ["dpr"], indirect=True) | 
					
						
							| 
									
										
										
										
											2022-02-24 17:43:38 +01:00
										 |  |  | @pytest.mark.parametrize("document_store", ["faiss", "milvus1", "milvus"], indirect=True) | 
					
						
							| 
									
										
										
										
											2021-12-01 16:16:17 +01:00
										 |  |  | def test_get_docs_with_filters_one_value(document_store, retriever): | 
					
						
							|  |  |  |     document_store.write_documents(DOCUMENTS) | 
					
						
							|  |  |  |     document_store.update_embeddings(retriever=retriever, batch_size=4) | 
					
						
							|  |  |  |     assert document_store.get_embedding_count() == 6 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     documents = document_store.get_all_documents(filters={"year": ["2020"]}) | 
					
						
							| 
									
										
										
										
											2021-12-01 16:16:17 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     assert len(documents) == 3 | 
					
						
							|  |  |  |     assert all("2020" == doc.meta["year"] for doc in documents) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @pytest.mark.slow | 
					
						
							|  |  |  | @pytest.mark.parametrize("retriever", ["dpr"], indirect=True) | 
					
						
							| 
									
										
										
										
											2022-02-24 17:43:38 +01:00
										 |  |  | @pytest.mark.parametrize("document_store", ["faiss", "milvus1", "milvus"], indirect=True) | 
					
						
							| 
									
										
										
										
											2021-12-01 16:16:17 +01:00
										 |  |  | def test_get_docs_with_filters_many_values(document_store, retriever): | 
					
						
							|  |  |  |     document_store.write_documents(DOCUMENTS) | 
					
						
							|  |  |  |     document_store.update_embeddings(retriever=retriever, batch_size=4) | 
					
						
							|  |  |  |     assert document_store.get_embedding_count() == 6 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     documents = document_store.get_all_documents(filters={"name": ["name_5", "name_6"]}) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     assert len(documents) == 2 | 
					
						
							|  |  |  |     assert {doc.meta["name"] for doc in documents} == {"name_5", "name_6"} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @pytest.mark.slow | 
					
						
							|  |  |  | @pytest.mark.parametrize("retriever", ["dpr"], indirect=True) | 
					
						
							| 
									
										
										
										
											2022-02-24 17:43:38 +01:00
										 |  |  | @pytest.mark.parametrize("document_store", ["faiss", "milvus1", "milvus"], indirect=True) | 
					
						
							| 
									
										
										
										
											2021-12-01 16:16:17 +01:00
										 |  |  | def test_get_docs_with_many_filters(document_store, retriever): | 
					
						
							|  |  |  |     document_store.write_documents(DOCUMENTS) | 
					
						
							|  |  |  |     document_store.update_embeddings(retriever=retriever, batch_size=4) | 
					
						
							|  |  |  |     assert document_store.get_embedding_count() == 6 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     documents = document_store.get_all_documents(filters={"month": ["01"], "year": ["2020"]}) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     assert len(documents) == 1 | 
					
						
							|  |  |  |     assert "name_1" == documents[0].meta["name"] | 
					
						
							|  |  |  |     assert "01" == documents[0].meta["month"] | 
					
						
							|  |  |  |     assert "2020" == documents[0].meta["year"] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-12-14 18:15:44 +01:00
										 |  |  | @pytest.mark.parametrize("retriever", ["embedding"], indirect=True) | 
					
						
							| 
									
										
										
										
											2022-02-24 17:43:38 +01:00
										 |  |  | @pytest.mark.parametrize("document_store", ["faiss", "milvus1", "milvus"], indirect=True) | 
					
						
							| 
									
										
										
										
											2021-01-29 13:29:12 +01:00
										 |  |  | def test_pipeline(document_store, retriever): | 
					
						
							| 
									
										
										
										
											2020-12-03 10:27:06 +01:00
										 |  |  |     documents = [ | 
					
						
							| 
									
										
										
										
											2021-10-13 14:23:23 +02:00
										 |  |  |         {"name": "name_1", "content": "text_1", "embedding": np.random.rand(768).astype(np.float32)}, | 
					
						
							|  |  |  |         {"name": "name_2", "content": "text_2", "embedding": np.random.rand(768).astype(np.float32)}, | 
					
						
							|  |  |  |         {"name": "name_3", "content": "text_3", "embedding": np.random.rand(768).astype(np.float64)}, | 
					
						
							|  |  |  |         {"name": "name_4", "content": "text_4", "embedding": np.random.rand(768).astype(np.float32)}, | 
					
						
							| 
									
										
										
										
											2020-12-03 10:27:06 +01:00
										 |  |  |     ] | 
					
						
							| 
									
										
										
										
											2020-12-14 18:15:44 +01:00
										 |  |  |     document_store.write_documents(documents) | 
					
						
							| 
									
										
										
										
											2020-12-03 10:27:06 +01:00
										 |  |  |     pipeline = Pipeline() | 
					
						
							| 
									
										
										
										
											2020-12-14 18:15:44 +01:00
										 |  |  |     pipeline.add_node(component=retriever, name="FAISS", inputs=["Query"]) | 
					
						
							| 
									
										
										
										
											2021-10-19 15:22:44 +02:00
										 |  |  |     output = pipeline.run(query="How to test this?", params={"FAISS": {"top_k": 3}}) | 
					
						
							| 
									
										
										
										
											2020-12-03 10:27:06 +01:00
										 |  |  |     assert len(output["documents"]) == 3 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | @pytest.mark.skipif(sys.platform in ["win32", "cygwin"], reason="Test with tmp_path not working on windows runner") | 
					
						
							| 
									
										
										
										
											2021-04-27 09:55:31 +02:00
										 |  |  | def test_faiss_passing_index_from_outside(tmp_path): | 
					
						
							| 
									
										
										
										
											2020-10-06 16:09:56 +02:00
										 |  |  |     d = 768 | 
					
						
							|  |  |  |     nlist = 2 | 
					
						
							|  |  |  |     quantizer = faiss.IndexFlatIP(d) | 
					
						
							| 
									
										
										
										
											2021-02-09 21:25:01 +01:00
										 |  |  |     index = "haystack_test_1" | 
					
						
							| 
									
										
										
										
											2020-10-06 16:09:56 +02:00
										 |  |  |     faiss_index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_INNER_PRODUCT) | 
					
						
							| 
									
										
										
										
											2020-11-26 10:32:30 +01:00
										 |  |  |     faiss_index.set_direct_map_type(faiss.DirectMap.Hashtable) | 
					
						
							| 
									
										
										
										
											2020-10-06 16:09:56 +02:00
										 |  |  |     faiss_index.nprobe = 2 | 
					
						
							| 
									
										
										
										
											2021-04-27 09:55:31 +02:00
										 |  |  |     document_store = FAISSDocumentStore( | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |         sql_url=f"sqlite:////{tmp_path/'haystack_test_faiss.db'}", | 
					
						
							|  |  |  |         faiss_index=faiss_index, | 
					
						
							|  |  |  |         index=index, | 
					
						
							|  |  |  |         isolation_level="AUTOCOMMIT", | 
					
						
							| 
									
										
										
										
											2021-04-27 09:55:31 +02:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-10-06 16:09:56 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-30 18:48:28 +05:30
										 |  |  |     document_store.delete_documents() | 
					
						
							| 
									
										
										
										
											2020-10-06 16:09:56 +02:00
										 |  |  |     # as it is a IVF index we need to train it before adding docs | 
					
						
							|  |  |  |     document_store.train_index(DOCUMENTS) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-02-09 21:25:01 +01:00
										 |  |  |     document_store.write_documents(documents=DOCUMENTS) | 
					
						
							|  |  |  |     documents_indexed = document_store.get_all_documents() | 
					
						
							| 
									
										
										
										
											2020-10-06 16:09:56 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-21 16:00:08 +01:00
										 |  |  |     # test if vectors ids are associated with docs | 
					
						
							|  |  |  |     for doc in documents_indexed: | 
					
						
							|  |  |  |         assert 0 <= int(doc.meta["vector_id"]) <= 7 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-24 17:43:38 +01:00
										 |  |  | @pytest.mark.parametrize("document_store", ["faiss", "milvus1", "milvus", "weaviate"], indirect=True) | 
					
						
							| 
									
										
										
										
											2022-01-12 19:28:20 +01:00
										 |  |  | def test_cosine_similarity(document_store): | 
					
						
							| 
									
										
										
										
											2021-09-20 07:54:26 +02:00
										 |  |  |     # below we will write documents to the store and then query it to see if vectors were normalized | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     ensure_ids_are_correct_uuids(docs=DOCUMENTS, document_store=document_store) | 
					
						
							| 
									
										
										
										
											2022-01-12 19:28:20 +01:00
										 |  |  |     document_store.write_documents(documents=DOCUMENTS) | 
					
						
							| 
									
										
										
										
											2021-09-20 07:54:26 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # note that the same query will be used later when querying after updating the embeddings | 
					
						
							|  |  |  |     query = np.random.rand(768).astype(np.float32) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-01-12 19:28:20 +01:00
										 |  |  |     query_results = document_store.query_by_embedding(query_emb=query, top_k=len(DOCUMENTS), return_embedding=True) | 
					
						
							| 
									
										
										
										
											2021-09-20 07:54:26 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # check if search with cosine similarity returns the correct number of results | 
					
						
							|  |  |  |     assert len(query_results) == len(DOCUMENTS) | 
					
						
							|  |  |  |     indexed_docs = {} | 
					
						
							|  |  |  |     for doc in DOCUMENTS: | 
					
						
							| 
									
										
										
										
											2021-10-13 14:23:23 +02:00
										 |  |  |         indexed_docs[doc["content"]] = doc["embedding"] | 
					
						
							| 
									
										
										
										
											2022-02-24 17:43:38 +01:00
										 |  |  |         indexed_docs[doc["content"]] /= np.linalg.norm(doc["embedding"]) | 
					
						
							| 
									
										
										
										
											2021-09-20 07:54:26 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     for doc in query_results: | 
					
						
							|  |  |  |         result_emb = doc.embedding | 
					
						
							| 
									
										
										
										
											2022-02-24 17:43:38 +01:00
										 |  |  |         original_emb = indexed_docs[doc.content].astype("float32") | 
					
						
							| 
									
										
										
										
											2021-09-20 07:54:26 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # check if the stored embedding was normalized | 
					
						
							| 
									
										
										
										
											2022-03-31 11:41:13 +02:00
										 |  |  |         np.testing.assert_allclose( | 
					
						
							|  |  |  |             original_emb, result_emb, rtol=0.2, atol=5e-07 | 
					
						
							|  |  |  |         )  # high tolerance necessary for Milvus 2 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-20 07:54:26 +02:00
										 |  |  |         # check if the score is plausible for cosine similarity | 
					
						
							|  |  |  |         assert 0 <= doc.score <= 1.0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # now check if vectors are normalized when updating embeddings | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     class MockRetriever: | 
					
						
							| 
									
										
										
										
											2021-10-28 12:17:56 +02:00
										 |  |  |         def embed_documents(self, docs): | 
					
						
							| 
									
										
										
										
											2021-09-20 07:54:26 +02:00
										 |  |  |             return [np.random.rand(768).astype(np.float32) for doc in docs] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     retriever = MockRetriever() | 
					
						
							| 
									
										
										
										
											2022-01-12 19:28:20 +01:00
										 |  |  |     document_store.update_embeddings(retriever=retriever) | 
					
						
							|  |  |  |     query_results = document_store.query_by_embedding(query_emb=query, top_k=len(DOCUMENTS), return_embedding=True) | 
					
						
							| 
									
										
										
										
											2021-09-20 07:54:26 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     for doc in query_results: | 
					
						
							| 
									
										
										
										
											2021-10-13 14:23:23 +02:00
										 |  |  |         original_emb = np.array([indexed_docs[doc.content]], dtype="float32") | 
					
						
							| 
									
										
										
										
											2022-01-12 19:28:20 +01:00
										 |  |  |         document_store.normalize_embedding(original_emb[0]) | 
					
						
							| 
									
										
										
										
											2021-09-20 07:54:26 +02:00
										 |  |  |         # check if the original embedding has changed after updating the embeddings | 
					
						
							|  |  |  |         assert not np.allclose(original_emb[0], doc.embedding, rtol=0.01) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-24 17:43:38 +01:00
										 |  |  | @pytest.mark.parametrize("document_store_dot_product_small", ["faiss", "milvus1", "milvus"], indirect=True) | 
					
						
							| 
									
										
										
										
											2022-01-12 19:28:20 +01:00
										 |  |  | def test_normalize_embeddings_diff_shapes(document_store_dot_product_small): | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     VEC_1 = np.array([0.1, 0.2, 0.3], dtype="float32") | 
					
						
							| 
									
										
										
										
											2022-01-12 19:28:20 +01:00
										 |  |  |     document_store_dot_product_small.normalize_embedding(VEC_1) | 
					
						
							| 
									
										
										
										
											2021-11-01 15:42:32 +03:00
										 |  |  |     assert np.linalg.norm(VEC_1) - 1 < 0.01 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     VEC_1 = np.array([0.1, 0.2, 0.3], dtype="float32").reshape(1, -1) | 
					
						
							| 
									
										
										
										
											2022-01-12 19:28:20 +01:00
										 |  |  |     document_store_dot_product_small.normalize_embedding(VEC_1) | 
					
						
							| 
									
										
										
										
											2021-11-01 15:42:32 +03:00
										 |  |  |     assert np.linalg.norm(VEC_1) - 1 < 0.01 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-20 07:54:26 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-24 17:43:38 +01:00
										 |  |  | @pytest.mark.parametrize("document_store_small", ["faiss", "milvus1", "milvus", "weaviate"], indirect=True) | 
					
						
							| 
									
										
										
										
											2022-01-12 19:28:20 +01:00
										 |  |  | def test_cosine_sanity_check(document_store_small): | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     VEC_1 = np.array([0.1, 0.2, 0.3], dtype="float32") | 
					
						
							|  |  |  |     VEC_2 = np.array([0.4, 0.5, 0.6], dtype="float32") | 
					
						
							| 
									
										
										
										
											2021-09-20 07:54:26 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # This is the cosine similarity of VEC_1 and VEC_2 calculated using sklearn.metrics.pairwise.cosine_similarity | 
					
						
							|  |  |  |     # The score is normalized to yield a value between 0 and 1. | 
					
						
							| 
									
										
										
										
											2022-05-02 13:35:07 +02:00
										 |  |  |     KNOWN_COSINE = 0.9746317 | 
					
						
							|  |  |  |     KNOWN_SCALED_COSINE = (KNOWN_COSINE + 1) / 2 | 
					
						
							| 
									
										
										
										
											2021-09-20 07:54:26 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-01 15:42:32 +03:00
										 |  |  |     docs = [{"name": "vec_1", "text": "vec_1", "content": "vec_1", "embedding": VEC_1}] | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     ensure_ids_are_correct_uuids(docs=docs, document_store=document_store_small) | 
					
						
							| 
									
										
										
										
											2022-01-12 19:28:20 +01:00
										 |  |  |     document_store_small.write_documents(documents=docs) | 
					
						
							| 
									
										
										
										
											2021-09-20 07:54:26 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-05-02 13:35:07 +02:00
										 |  |  |     query_results = document_store_small.query_by_embedding( | 
					
						
							|  |  |  |         query_emb=VEC_2, top_k=1, return_embedding=True, scale_score=True | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # check if faiss returns the same cosine similarity. Manual testing with faiss yielded 0.9746318 | 
					
						
							|  |  |  |     assert math.isclose(query_results[0].score, KNOWN_SCALED_COSINE, abs_tol=0.00002) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     query_results = document_store_small.query_by_embedding( | 
					
						
							|  |  |  |         query_emb=VEC_2, top_k=1, return_embedding=True, scale_score=False | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-09-20 07:54:26 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # check if faiss returns the same cosine similarity. Manual testing with faiss yielded 0.9746318 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     assert math.isclose(query_results[0].score, KNOWN_COSINE, abs_tol=0.00002) | 
					
						
							| 
									
										
										
										
											2022-05-04 17:39:06 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @pytest.mark.integration | 
					
						
							|  |  |  | def test_pipeline_with_existing_faiss_docstore(tmp_path): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     document_store: FAISSDocumentStore = FAISSDocumentStore( | 
					
						
							|  |  |  |         sql_url=f'sqlite:///{(tmp_path / "faiss_document_store.db").absolute()}' | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     retriever = MockDenseRetriever(document_store=document_store) | 
					
						
							|  |  |  |     document_store.write_documents(DOCUMENTS) | 
					
						
							|  |  |  |     document_store.update_embeddings(retriever=retriever, update_existing_embeddings=True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     document_store.save(tmp_path / "existing_faiss_document_store") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     query_config = f"""
 | 
					
						
							|  |  |  | version: ignore | 
					
						
							|  |  |  | components: | 
					
						
							|  |  |  |   - name: DPRRetriever | 
					
						
							|  |  |  |     type: MockDenseRetriever | 
					
						
							|  |  |  |     params: | 
					
						
							|  |  |  |       document_store: ExistingFAISSDocumentStore | 
					
						
							|  |  |  |   - name: ExistingFAISSDocumentStore | 
					
						
							|  |  |  |     type: FAISSDocumentStore | 
					
						
							|  |  |  |     params: | 
					
						
							|  |  |  |       faiss_index_path: '{tmp_path / "existing_faiss_document_store"}' | 
					
						
							|  |  |  | pipelines: | 
					
						
							|  |  |  |   - name: query_pipeline | 
					
						
							|  |  |  |     nodes: | 
					
						
							|  |  |  |       - name: DPRRetriever | 
					
						
							|  |  |  |         inputs: [Query] | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     pipeline = Pipeline.load_from_config(yaml.safe_load(query_config)) | 
					
						
							|  |  |  |     existing_document_store = pipeline.get_document_store() | 
					
						
							|  |  |  |     faiss_index = existing_document_store.faiss_indexes["document"] | 
					
						
							|  |  |  |     assert faiss_index.ntotal == len(DOCUMENTS) |