| 
									
										
										
										
											2022-11-25 16:25:21 +01:00
										 |  |  | import sys | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-31 15:30:14 +01:00
										 |  |  | import pytest | 
					
						
							|  |  |  | import numpy as np | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-29 08:43:27 +01:00
										 |  |  | from haystack.schema import Document, Label, Answer, Span | 
					
						
							| 
									
										
										
										
											2022-10-31 15:30:14 +01:00
										 |  |  | from haystack.errors import DuplicateDocumentError | 
					
						
							|  |  |  | from haystack.document_stores import BaseDocumentStore | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @pytest.mark.document_store | 
					
						
							|  |  |  | class DocumentStoreBaseTestAbstract: | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     This is a base class to test abstract methods from DocumentStoreBase to be inherited by any Document Store | 
					
						
							|  |  |  |     testsuite. It doesn't have the `Test` prefix in the name so that its methods won't be collected for this | 
					
						
							|  |  |  |     class but only for its subclasses. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.fixture | 
					
						
							|  |  |  |     def documents(self): | 
					
						
							|  |  |  |         documents = [] | 
					
						
							|  |  |  |         for i in range(3): | 
					
						
							|  |  |  |             documents.append( | 
					
						
							|  |  |  |                 Document( | 
					
						
							|  |  |  |                     content=f"A Foo Document {i}", | 
					
						
							|  |  |  |                     meta={"name": f"name_{i}", "year": "2020", "month": "01", "numbers": [2, 4]}, | 
					
						
							|  |  |  |                     embedding=np.random.rand(768).astype(np.float32), | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             documents.append( | 
					
						
							|  |  |  |                 Document( | 
					
						
							|  |  |  |                     content=f"A Bar Document {i}", | 
					
						
							|  |  |  |                     meta={"name": f"name_{i}", "year": "2021", "month": "02", "numbers": [-2, -4]}, | 
					
						
							|  |  |  |                     embedding=np.random.rand(768).astype(np.float32), | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             documents.append( | 
					
						
							|  |  |  |                 Document( | 
					
						
							|  |  |  |                     content=f"Document {i} without embeddings", | 
					
						
							|  |  |  |                     meta={"name": f"name_{i}", "no_embedding": True, "month": "03"}, | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return documents | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.fixture | 
					
						
							|  |  |  |     def labels(self, documents): | 
					
						
							|  |  |  |         labels = [] | 
					
						
							|  |  |  |         for i, d in enumerate(documents): | 
					
						
							|  |  |  |             labels.append( | 
					
						
							|  |  |  |                 Label( | 
					
						
							|  |  |  |                     query=f"query_{i}", | 
					
						
							|  |  |  |                     document=d, | 
					
						
							|  |  |  |                     is_correct_document=True, | 
					
						
							|  |  |  |                     is_correct_answer=False, | 
					
						
							|  |  |  |                     # create a mix set of labels | 
					
						
							|  |  |  |                     origin="user-feedback" if i % 2 else "gold-label", | 
					
						
							| 
									
										
										
										
											2023-02-08 08:37:22 +01:00
										 |  |  |                     answer=None if not i else Answer(f"the answer is {i}", document_ids=[d.id]), | 
					
						
							| 
									
										
										
										
											2022-10-31 15:30:14 +01:00
										 |  |  |                     meta={"name": f"label_{i}", "year": f"{2020 + i}"}, | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         return labels | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # | 
					
						
							|  |  |  |     # Integration tests | 
					
						
							|  |  |  |     # | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_write_documents(self, ds, documents): | 
					
						
							|  |  |  |         ds.write_documents(documents) | 
					
						
							|  |  |  |         docs = ds.get_all_documents() | 
					
						
							|  |  |  |         assert len(docs) == len(documents) | 
					
						
							| 
									
										
										
										
											2022-11-04 09:24:19 +01:00
										 |  |  |         expected_ids = set(doc.id for doc in documents) | 
					
						
							|  |  |  |         ids = set(doc.id for doc in docs) | 
					
						
							|  |  |  |         assert ids == expected_ids | 
					
						
							| 
									
										
										
										
											2022-10-31 15:30:14 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_write_labels(self, ds, labels): | 
					
						
							|  |  |  |         ds.write_labels(labels) | 
					
						
							|  |  |  |         assert ds.get_all_labels() == labels | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_write_with_duplicate_doc_ids(self, ds): | 
					
						
							|  |  |  |         duplicate_documents = [ | 
					
						
							| 
									
										
										
										
											2022-11-15 10:04:04 +01:00
										 |  |  |             Document(content="Doc1", id_hash_keys=["content"], meta={"key1": "value1"}), | 
					
						
							|  |  |  |             Document(content="Doc1", id_hash_keys=["content"], meta={"key1": "value1"}), | 
					
						
							| 
									
										
										
										
											2022-10-31 15:30:14 +01:00
										 |  |  |         ] | 
					
						
							|  |  |  |         ds.write_documents(duplicate_documents, duplicate_documents="skip") | 
					
						
							| 
									
										
										
										
											2022-11-15 22:02:53 +01:00
										 |  |  |         results = ds.get_all_documents() | 
					
						
							|  |  |  |         assert len(results) == 1 | 
					
						
							|  |  |  |         assert results[0] == duplicate_documents[0] | 
					
						
							| 
									
										
										
										
											2022-10-31 15:30:14 +01:00
										 |  |  |         with pytest.raises(Exception): | 
					
						
							|  |  |  |             ds.write_documents(duplicate_documents, duplicate_documents="fail") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-30 12:38:30 +01:00
										 |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_get_embedding_count(self, ds, documents): | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         We expect 6 docs with embeddings because only 6 documents in the documents fixture for this class contain | 
					
						
							|  |  |  |         embeddings. | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         ds.write_documents(documents) | 
					
						
							|  |  |  |         assert ds.get_embedding_count() == 6 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-31 15:30:14 +01:00
										 |  |  |     @pytest.mark.skip | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_get_all_documents_without_filters(self, ds, documents): | 
					
						
							|  |  |  |         ds.write_documents(documents) | 
					
						
							|  |  |  |         out = ds.get_all_documents() | 
					
						
							|  |  |  |         assert out == documents | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-29 08:43:27 +01:00
										 |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_get_all_documents_without_embeddings(self, ds, documents): | 
					
						
							|  |  |  |         ds.write_documents(documents) | 
					
						
							|  |  |  |         out = ds.get_all_documents(return_embedding=False) | 
					
						
							|  |  |  |         for doc in out: | 
					
						
							|  |  |  |             assert doc.embedding is None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-31 15:30:14 +01:00
										 |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_get_all_document_filter_duplicate_text_value(self, ds): | 
					
						
							|  |  |  |         documents = [ | 
					
						
							|  |  |  |             Document(content="duplicated", meta={"meta_field": "0"}, id_hash_keys=["meta"]), | 
					
						
							|  |  |  |             Document(content="duplicated", meta={"meta_field": "1", "name": "file.txt"}, id_hash_keys=["meta"]), | 
					
						
							|  |  |  |             Document(content="Doc2", meta={"name": "file_2.txt"}, id_hash_keys=["meta"]), | 
					
						
							|  |  |  |         ] | 
					
						
							|  |  |  |         ds.write_documents(documents) | 
					
						
							|  |  |  |         documents = ds.get_all_documents(filters={"meta_field": ["1"]}) | 
					
						
							|  |  |  |         assert len(documents) == 1 | 
					
						
							|  |  |  |         assert documents[0].content == "duplicated" | 
					
						
							|  |  |  |         assert documents[0].meta["name"] == "file.txt" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         documents = ds.get_all_documents(filters={"meta_field": ["0"]}) | 
					
						
							|  |  |  |         assert len(documents) == 1 | 
					
						
							|  |  |  |         assert documents[0].content == "duplicated" | 
					
						
							|  |  |  |         assert documents[0].meta.get("name") is None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         documents = ds.get_all_documents(filters={"name": ["file_2.txt"]}) | 
					
						
							|  |  |  |         assert len(documents) == 1 | 
					
						
							|  |  |  |         assert documents[0].content == "Doc2" | 
					
						
							|  |  |  |         assert documents[0].meta.get("meta_field") is None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_get_all_documents_with_correct_filters(self, ds, documents): | 
					
						
							|  |  |  |         ds.write_documents(documents) | 
					
						
							|  |  |  |         result = ds.get_all_documents(filters={"year": ["2020"]}) | 
					
						
							|  |  |  |         assert len(result) == 3 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         documents = ds.get_all_documents(filters={"year": ["2020", "2021"]}) | 
					
						
							|  |  |  |         assert len(documents) == 6 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_get_all_documents_with_incorrect_filter_name(self, ds, documents): | 
					
						
							|  |  |  |         ds.write_documents(documents) | 
					
						
							|  |  |  |         result = ds.get_all_documents(filters={"non_existing_meta_field": ["whatever"]}) | 
					
						
							|  |  |  |         assert len(result) == 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_get_all_documents_with_incorrect_filter_value(self, ds, documents): | 
					
						
							|  |  |  |         ds.write_documents(documents) | 
					
						
							|  |  |  |         result = ds.get_all_documents(filters={"year": ["nope"]}) | 
					
						
							|  |  |  |         assert len(result) == 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							| 
									
										
										
										
											2022-11-04 09:24:19 +01:00
										 |  |  |     def test_eq_filters(self, ds, documents): | 
					
						
							| 
									
										
										
										
											2022-10-31 15:30:14 +01:00
										 |  |  |         ds.write_documents(documents) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         result = ds.get_all_documents(filters={"year": {"$eq": "2020"}}) | 
					
						
							|  |  |  |         assert len(result) == 3 | 
					
						
							|  |  |  |         result = ds.get_all_documents(filters={"year": "2020"}) | 
					
						
							|  |  |  |         assert len(result) == 3 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-04 09:24:19 +01:00
										 |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_in_filters(self, ds, documents): | 
					
						
							|  |  |  |         ds.write_documents(documents) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-31 15:30:14 +01:00
										 |  |  |         result = ds.get_all_documents(filters={"year": {"$in": ["2020", "2021", "n.a."]}}) | 
					
						
							|  |  |  |         assert len(result) == 6 | 
					
						
							|  |  |  |         result = ds.get_all_documents(filters={"year": ["2020", "2021", "n.a."]}) | 
					
						
							|  |  |  |         assert len(result) == 6 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-04 09:24:19 +01:00
										 |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_ne_filters(self, ds, documents): | 
					
						
							|  |  |  |         ds.write_documents(documents) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-31 15:30:14 +01:00
										 |  |  |         result = ds.get_all_documents(filters={"year": {"$ne": "2020"}}) | 
					
						
							|  |  |  |         assert len(result) == 6 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-04 09:24:19 +01:00
										 |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_nin_filters(self, ds, documents): | 
					
						
							|  |  |  |         ds.write_documents(documents) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-31 15:30:14 +01:00
										 |  |  |         result = ds.get_all_documents(filters={"year": {"$nin": ["2020", "2021", "n.a."]}}) | 
					
						
							|  |  |  |         assert len(result) == 3 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-04 09:24:19 +01:00
										 |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_comparison_filters(self, ds, documents): | 
					
						
							|  |  |  |         ds.write_documents(documents) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-14 09:57:30 +01:00
										 |  |  |         result = ds.get_all_documents(filters={"numbers": {"$gt": 0.0}}) | 
					
						
							| 
									
										
										
										
											2022-10-31 15:30:14 +01:00
										 |  |  |         assert len(result) == 3 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-14 09:57:30 +01:00
										 |  |  |         result = ds.get_all_documents(filters={"numbers": {"$gte": -2.0}}) | 
					
						
							| 
									
										
										
										
											2022-10-31 15:30:14 +01:00
										 |  |  |         assert len(result) == 6 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-14 09:57:30 +01:00
										 |  |  |         result = ds.get_all_documents(filters={"numbers": {"$lt": 0.0}}) | 
					
						
							| 
									
										
										
										
											2022-10-31 15:30:14 +01:00
										 |  |  |         assert len(result) == 3 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         result = ds.get_all_documents(filters={"numbers": {"$lte": 2.0}}) | 
					
						
							|  |  |  |         assert len(result) == 6 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-04 09:24:19 +01:00
										 |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_compound_filters(self, ds, documents): | 
					
						
							|  |  |  |         ds.write_documents(documents) | 
					
						
							| 
									
										
										
										
											2022-10-31 15:30:14 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |         result = ds.get_all_documents(filters={"year": {"$lte": "2021", "$gte": "2020"}}) | 
					
						
							|  |  |  |         assert len(result) == 6 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-04 09:24:19 +01:00
										 |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_simplified_filters(self, ds, documents): | 
					
						
							|  |  |  |         ds.write_documents(documents) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-31 15:30:14 +01:00
										 |  |  |         filters = {"$and": {"year": {"$lte": "2021", "$gte": "2020"}, "name": {"$in": ["name_0", "name_1"]}}} | 
					
						
							|  |  |  |         result = ds.get_all_documents(filters=filters) | 
					
						
							|  |  |  |         assert len(result) == 4 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         filters_simplified = {"year": {"$lte": "2021", "$gte": "2020"}, "name": ["name_0", "name_1"]} | 
					
						
							|  |  |  |         result = ds.get_all_documents(filters=filters_simplified) | 
					
						
							|  |  |  |         assert len(result) == 4 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-04 09:24:19 +01:00
										 |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_nested_condition_filters(self, ds, documents): | 
					
						
							|  |  |  |         ds.write_documents(documents) | 
					
						
							| 
									
										
										
										
											2022-10-31 15:30:14 +01:00
										 |  |  |         filters = { | 
					
						
							|  |  |  |             "$and": { | 
					
						
							|  |  |  |                 "year": {"$lte": "2021", "$gte": "2020"}, | 
					
						
							|  |  |  |                 "$or": {"name": {"$in": ["name_0", "name_1"]}, "numbers": {"$lt": 5.0}}, | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         result = ds.get_all_documents(filters=filters) | 
					
						
							|  |  |  |         assert len(result) == 6 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         filters_simplified = { | 
					
						
							|  |  |  |             "year": {"$lte": "2021", "$gte": "2020"}, | 
					
						
							|  |  |  |             "$or": {"name": {"$in": ["name_0", "name_2"]}, "numbers": {"$lt": 5.0}}, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         result = ds.get_all_documents(filters=filters_simplified) | 
					
						
							|  |  |  |         assert len(result) == 6 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         filters = { | 
					
						
							|  |  |  |             "$and": { | 
					
						
							|  |  |  |                 "year": {"$lte": "2021", "$gte": "2020"}, | 
					
						
							|  |  |  |                 "$or": { | 
					
						
							|  |  |  |                     "name": {"$in": ["name_0", "name_1"]}, | 
					
						
							|  |  |  |                     "$and": {"numbers": {"$lt": 5.0}, "$not": {"month": {"$eq": "01"}}}, | 
					
						
							|  |  |  |                 }, | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         result = ds.get_all_documents(filters=filters) | 
					
						
							|  |  |  |         assert len(result) == 5 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         filters_simplified = { | 
					
						
							|  |  |  |             "year": {"$lte": "2021", "$gte": "2020"}, | 
					
						
							|  |  |  |             "$or": {"name": ["name_0", "name_1"], "$and": {"numbers": {"$lt": 5.0}, "$not": {"month": {"$eq": "01"}}}}, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         result = ds.get_all_documents(filters=filters_simplified) | 
					
						
							|  |  |  |         assert len(result) == 5 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-04 09:24:19 +01:00
										 |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_nested_condition_not_filters(self, ds, documents): | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Test nested logical operations within "$not", important as we apply De Morgan's laws in WeaviateDocumentstore | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         ds.write_documents(documents) | 
					
						
							| 
									
										
										
										
											2022-10-31 15:30:14 +01:00
										 |  |  |         filters = { | 
					
						
							|  |  |  |             "$not": { | 
					
						
							|  |  |  |                 "$or": { | 
					
						
							|  |  |  |                     "$and": {"numbers": {"$lt": 5.0}, "month": {"$ne": "01"}}, | 
					
						
							|  |  |  |                     "$not": {"year": {"$lte": "2021", "$gte": "2020"}}, | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         result = ds.get_all_documents(filters=filters) | 
					
						
							|  |  |  |         assert len(result) == 3 | 
					
						
							| 
									
										
										
										
											2022-11-04 09:24:19 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |         docs_meta = result[0].meta["numbers"] | 
					
						
							| 
									
										
										
										
											2022-10-31 15:30:14 +01:00
										 |  |  |         assert [2, 4] == docs_meta | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Test same logical operator twice on same level | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         filters = { | 
					
						
							|  |  |  |             "$or": [ | 
					
						
							|  |  |  |                 {"$and": {"name": {"$in": ["name_0", "name_1"]}, "year": {"$gte": "2020"}}}, | 
					
						
							|  |  |  |                 {"$and": {"name": {"$in": ["name_0", "name_1"]}, "year": {"$lt": "2021"}}}, | 
					
						
							|  |  |  |             ] | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         result = ds.get_all_documents(filters=filters) | 
					
						
							|  |  |  |         docs_meta = [doc.meta["name"] for doc in result] | 
					
						
							|  |  |  |         assert len(result) == 4 | 
					
						
							|  |  |  |         assert "name_0" in docs_meta | 
					
						
							|  |  |  |         assert "name_2" not in docs_meta | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_get_document_by_id(self, ds, documents): | 
					
						
							|  |  |  |         ds.write_documents(documents) | 
					
						
							|  |  |  |         doc = ds.get_document_by_id(documents[0].id) | 
					
						
							|  |  |  |         assert doc.id == documents[0].id | 
					
						
							|  |  |  |         assert doc.content == documents[0].content | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_get_documents_by_id(self, ds, documents): | 
					
						
							|  |  |  |         ds.write_documents(documents) | 
					
						
							|  |  |  |         ids = [doc.id for doc in documents] | 
					
						
							|  |  |  |         result = {doc.id for doc in ds.get_documents_by_id(ids, batch_size=2)} | 
					
						
							|  |  |  |         assert set(ids) == result | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_get_document_count(self, ds, documents): | 
					
						
							|  |  |  |         ds.write_documents(documents) | 
					
						
							| 
									
										
										
										
											2022-11-14 09:57:30 +01:00
										 |  |  |         assert ds.get_document_count() == len(documents) | 
					
						
							| 
									
										
										
										
											2022-10-31 15:30:14 +01:00
										 |  |  |         assert ds.get_document_count(filters={"year": ["2020"]}) == 3 | 
					
						
							|  |  |  |         assert ds.get_document_count(filters={"month": ["02"]}) == 3 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_get_all_documents_generator(self, ds, documents): | 
					
						
							|  |  |  |         ds.write_documents(documents) | 
					
						
							|  |  |  |         assert len(list(ds.get_all_documents_generator(batch_size=2))) == 9 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_duplicate_documents_skip(self, ds, documents): | 
					
						
							|  |  |  |         ds.write_documents(documents) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         updated_docs = [] | 
					
						
							|  |  |  |         for d in documents: | 
					
						
							|  |  |  |             updated_d = Document.from_dict(d.to_dict()) | 
					
						
							|  |  |  |             updated_d.meta["name"] = "Updated" | 
					
						
							|  |  |  |             updated_docs.append(updated_d) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         ds.write_documents(updated_docs, duplicate_documents="skip") | 
					
						
							| 
									
										
										
										
											2022-11-04 09:24:19 +01:00
										 |  |  |         for d in ds.get_all_documents(): | 
					
						
							| 
									
										
										
										
											2022-11-14 15:19:15 +01:00
										 |  |  |             assert d.meta.get("name") != "Updated" | 
					
						
							| 
									
										
										
										
											2022-10-31 15:30:14 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_duplicate_documents_overwrite(self, ds, documents): | 
					
						
							|  |  |  |         ds.write_documents(documents) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         updated_docs = [] | 
					
						
							|  |  |  |         for d in documents: | 
					
						
							|  |  |  |             updated_d = Document.from_dict(d.to_dict()) | 
					
						
							|  |  |  |             updated_d.meta["name"] = "Updated" | 
					
						
							|  |  |  |             updated_docs.append(updated_d) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         ds.write_documents(updated_docs, duplicate_documents="overwrite") | 
					
						
							|  |  |  |         for doc in ds.get_all_documents(): | 
					
						
							|  |  |  |             assert doc.meta["name"] == "Updated" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_duplicate_documents_fail(self, ds, documents): | 
					
						
							|  |  |  |         ds.write_documents(documents) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         updated_docs = [] | 
					
						
							|  |  |  |         for d in documents: | 
					
						
							|  |  |  |             updated_d = Document.from_dict(d.to_dict()) | 
					
						
							|  |  |  |             updated_d.meta["name"] = "Updated" | 
					
						
							|  |  |  |             updated_docs.append(updated_d) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         with pytest.raises(DuplicateDocumentError): | 
					
						
							|  |  |  |             ds.write_documents(updated_docs, duplicate_documents="fail") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_write_document_meta(self, ds): | 
					
						
							|  |  |  |         ds.write_documents( | 
					
						
							|  |  |  |             [ | 
					
						
							|  |  |  |                 {"content": "dict_without_meta", "id": "1"}, | 
					
						
							|  |  |  |                 {"content": "dict_with_meta", "meta_field": "test2", "id": "2"}, | 
					
						
							|  |  |  |                 Document(content="document_object_without_meta", id="3"), | 
					
						
							|  |  |  |                 Document(content="document_object_with_meta", meta={"meta_field": "test4"}, id="4"), | 
					
						
							|  |  |  |             ] | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         assert not ds.get_document_by_id("1").meta | 
					
						
							|  |  |  |         assert ds.get_document_by_id("2").meta["meta_field"] == "test2" | 
					
						
							|  |  |  |         assert not ds.get_document_by_id("3").meta | 
					
						
							|  |  |  |         assert ds.get_document_by_id("4").meta["meta_field"] == "test4" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_delete_documents(self, ds, documents): | 
					
						
							|  |  |  |         ds.write_documents(documents) | 
					
						
							|  |  |  |         ds.delete_documents() | 
					
						
							|  |  |  |         assert ds.get_document_count() == 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_delete_documents_with_filters(self, ds, documents): | 
					
						
							|  |  |  |         ds.write_documents(documents) | 
					
						
							|  |  |  |         ds.delete_documents(filters={"year": ["2020", "2021"]}) | 
					
						
							|  |  |  |         documents = ds.get_all_documents() | 
					
						
							|  |  |  |         assert ds.get_document_count() == 3 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_delete_documents_by_id(self, ds, documents): | 
					
						
							|  |  |  |         ds.write_documents(documents) | 
					
						
							|  |  |  |         docs_to_delete = ds.get_all_documents(filters={"year": ["2020"]}) | 
					
						
							|  |  |  |         ds.delete_documents(ids=[doc.id for doc in docs_to_delete]) | 
					
						
							|  |  |  |         assert ds.get_document_count() == 6 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-29 08:43:27 +01:00
										 |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_delete_documents_by_id_with_filters(self, ds, documents): | 
					
						
							|  |  |  |         ds.write_documents(documents) | 
					
						
							|  |  |  |         docs_to_delete = ds.get_all_documents(filters={"year": ["2020"]}) | 
					
						
							|  |  |  |         # this should delete only 1 document out of the 3 ids passed | 
					
						
							|  |  |  |         ds.delete_documents(ids=[doc.id for doc in docs_to_delete], filters={"name": ["name_0"]}) | 
					
						
							|  |  |  |         assert ds.get_document_count() == 8 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-31 15:30:14 +01:00
										 |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_write_get_all_labels(self, ds, labels): | 
					
						
							|  |  |  |         ds.write_labels(labels) | 
					
						
							|  |  |  |         ds.write_labels(labels[:3], index="custom_index") | 
					
						
							|  |  |  |         assert len(ds.get_all_labels()) == 9 | 
					
						
							|  |  |  |         assert len(ds.get_all_labels(index="custom_index")) == 3 | 
					
						
							|  |  |  |         # remove the index we created in this test | 
					
						
							|  |  |  |         ds.delete_index("custom_index") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_delete_labels(self, ds, labels): | 
					
						
							|  |  |  |         ds.write_labels(labels) | 
					
						
							|  |  |  |         ds.write_labels(labels[:3], index="custom_index") | 
					
						
							|  |  |  |         ds.delete_labels() | 
					
						
							|  |  |  |         ds.delete_labels(index="custom_index") | 
					
						
							|  |  |  |         assert len(ds.get_all_labels()) == 0 | 
					
						
							|  |  |  |         assert len(ds.get_all_labels(index="custom_index")) == 0 | 
					
						
							|  |  |  |         # remove the index we created in this test | 
					
						
							|  |  |  |         ds.delete_index("custom_index") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_write_labels_duplicate(self, ds, labels): | 
					
						
							|  |  |  |         # create a duplicate | 
					
						
							|  |  |  |         dupe = Label.from_dict(labels[0].to_dict()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         ds.write_labels(labels + [dupe]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # ensure the duplicate was discarded | 
					
						
							|  |  |  |         assert len(ds.get_all_labels()) == len(labels) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_delete_labels_by_id(self, ds, labels): | 
					
						
							|  |  |  |         ds.write_labels(labels) | 
					
						
							|  |  |  |         ds.delete_labels(ids=[labels[0].id]) | 
					
						
							|  |  |  |         assert len(ds.get_all_labels()) == len(labels) - 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_delete_labels_by_filter(self, ds, labels): | 
					
						
							|  |  |  |         ds.write_labels(labels) | 
					
						
							|  |  |  |         ds.delete_labels(filters={"query": "query_1"}) | 
					
						
							|  |  |  |         assert len(ds.get_all_labels()) == len(labels) - 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_delete_labels_by_filter_id(self, ds, labels): | 
					
						
							|  |  |  |         ds.write_labels(labels) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # ids and filters are ANDed, the following should have no effect | 
					
						
							|  |  |  |         ds.delete_labels(ids=[labels[0].id], filters={"query": "query_9"}) | 
					
						
							|  |  |  |         assert len(ds.get_all_labels()) == len(labels) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # | 
					
						
							|  |  |  |         ds.delete_labels(ids=[labels[0].id], filters={"query": "query_0"}) | 
					
						
							|  |  |  |         assert len(ds.get_all_labels()) == len(labels) - 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_get_label_count(self, ds, labels): | 
					
						
							|  |  |  |         ds.write_labels(labels) | 
					
						
							|  |  |  |         assert ds.get_label_count() == len(labels) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_delete_index(self, ds, documents): | 
					
						
							|  |  |  |         ds.write_documents(documents, index="custom_index") | 
					
						
							|  |  |  |         assert ds.get_document_count(index="custom_index") == len(documents) | 
					
						
							|  |  |  |         ds.delete_index(index="custom_index") | 
					
						
							|  |  |  |         with pytest.raises(Exception): | 
					
						
							|  |  |  |             ds.get_document_count(index="custom_index") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_update_meta(self, ds, documents): | 
					
						
							|  |  |  |         ds.write_documents(documents) | 
					
						
							|  |  |  |         doc = documents[0] | 
					
						
							|  |  |  |         ds.update_document_meta(doc.id, meta={"year": "2099", "month": "12"}) | 
					
						
							|  |  |  |         doc = ds.get_document_by_id(doc.id) | 
					
						
							|  |  |  |         assert doc.meta["year"] == "2099" | 
					
						
							|  |  |  |         assert doc.meta["month"] == "12" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-29 08:43:27 +01:00
										 |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_labels_with_long_texts(self, ds, documents): | 
					
						
							|  |  |  |         label = Label( | 
					
						
							|  |  |  |             query="question1", | 
					
						
							|  |  |  |             answer=Answer( | 
					
						
							|  |  |  |                 answer="answer", | 
					
						
							|  |  |  |                 type="extractive", | 
					
						
							|  |  |  |                 score=0.0, | 
					
						
							|  |  |  |                 context="something " * 10_000, | 
					
						
							|  |  |  |                 offsets_in_document=[Span(start=12, end=14)], | 
					
						
							|  |  |  |                 offsets_in_context=[Span(start=12, end=14)], | 
					
						
							|  |  |  |             ), | 
					
						
							|  |  |  |             is_correct_answer=True, | 
					
						
							|  |  |  |             is_correct_document=True, | 
					
						
							|  |  |  |             document=Document(content="something " * 10_000, id="123"), | 
					
						
							|  |  |  |             origin="gold-label", | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         ds.write_labels(labels=[label]) | 
					
						
							|  |  |  |         labels = ds.get_all_labels() | 
					
						
							|  |  |  |         assert len(labels) == 1 | 
					
						
							|  |  |  |         assert label == labels[0] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-25 16:25:21 +01:00
										 |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     @pytest.mark.skipif(sys.platform == "win32", reason="_get_documents_meta() fails with 'too many SQL variables'") | 
					
						
							|  |  |  |     def test_get_all_documents_large_quantities(self, ds): | 
					
						
							|  |  |  |         # Test to exclude situations like Weaviate not returning more than 100 docs by default | 
					
						
							|  |  |  |         #   https://github.com/deepset-ai/haystack/issues/1893 | 
					
						
							|  |  |  |         docs_to_write = [ | 
					
						
							|  |  |  |             {"meta": {"name": f"name_{i}"}, "content": f"text_{i}", "embedding": np.random.rand(768).astype(np.float32)} | 
					
						
							|  |  |  |             for i in range(1000) | 
					
						
							|  |  |  |         ] | 
					
						
							|  |  |  |         ds.write_documents(docs_to_write) | 
					
						
							|  |  |  |         documents = ds.get_all_documents() | 
					
						
							|  |  |  |         assert all(isinstance(d, Document) for d in documents) | 
					
						
							|  |  |  |         assert len(documents) == len(docs_to_write) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-31 15:30:14 +01:00
										 |  |  |     # | 
					
						
							|  |  |  |     # Unit tests | 
					
						
							|  |  |  |     # | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.unit | 
					
						
							|  |  |  |     def test_normalize_embeddings_diff_shapes(self): | 
					
						
							|  |  |  |         VEC_1 = np.array([0.1, 0.2, 0.3], dtype="float32") | 
					
						
							|  |  |  |         BaseDocumentStore.normalize_embedding(VEC_1) | 
					
						
							|  |  |  |         assert np.linalg.norm(VEC_1) - 1 < 0.01 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         VEC_1 = np.array([0.1, 0.2, 0.3], dtype="float32").reshape(1, -1) | 
					
						
							|  |  |  |         BaseDocumentStore.normalize_embedding(VEC_1) | 
					
						
							|  |  |  |         assert np.linalg.norm(VEC_1) - 1 < 0.01 |