mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-10-31 01:39:45 +00:00 
			
		
		
		
	Add tests on MultiLabel's meta and filter aggregation (#2169)
This commit is contained in:
		
							parent
							
								
									fdc36292f1
								
							
						
					
					
						commit
						db4d6f43ba
					
				| @ -968,6 +968,194 @@ def test_multilabel_no_answer(document_store): | |||||||
|     assert len(multi_labels[0].answers) == 1 |     assert len(multi_labels[0].answers) == 1 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | # exclude weaviate because it does not support storing labels | ||||||
|  | # exclude faiss and milvus as label metadata is not implemented | ||||||
|  | @pytest.mark.parametrize("document_store", ["elasticsearch", "memory"], indirect=True) | ||||||
|  | def test_multilabel_filter_aggregations(document_store): | ||||||
|  |     labels = [ | ||||||
|  |         Label( | ||||||
|  |             id="standard", | ||||||
|  |             query="question", | ||||||
|  |             answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]), | ||||||
|  |             document=Document(content="some", id="123"), | ||||||
|  |             is_correct_answer=True, | ||||||
|  |             is_correct_document=True, | ||||||
|  |             no_answer=False, | ||||||
|  |             origin="gold-label", | ||||||
|  |             filters={"name": ["123"]}, | ||||||
|  |         ), | ||||||
|  |         # different answer in same doc | ||||||
|  |         Label( | ||||||
|  |             id="diff-answer-same-doc", | ||||||
|  |             query="question", | ||||||
|  |             answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]), | ||||||
|  |             document=Document(content="some", id="123"), | ||||||
|  |             is_correct_answer=True, | ||||||
|  |             is_correct_document=True, | ||||||
|  |             no_answer=False, | ||||||
|  |             origin="gold-label", | ||||||
|  |             filters={"name": ["123"]}, | ||||||
|  |         ), | ||||||
|  |         # answer in different doc | ||||||
|  |         Label( | ||||||
|  |             id="diff-answer-diff-doc", | ||||||
|  |             query="question", | ||||||
|  |             answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]), | ||||||
|  |             document=Document(content="some other", id="333"), | ||||||
|  |             is_correct_answer=True, | ||||||
|  |             is_correct_document=True, | ||||||
|  |             no_answer=False, | ||||||
|  |             origin="gold-label", | ||||||
|  |             filters={"name": ["333"]}, | ||||||
|  |         ), | ||||||
|  |         # 'no answer', should be excluded from MultiLabel | ||||||
|  |         Label( | ||||||
|  |             id="4-no-answer", | ||||||
|  |             query="question", | ||||||
|  |             answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]), | ||||||
|  |             document=Document(content="some", id="777"), | ||||||
|  |             is_correct_answer=True, | ||||||
|  |             is_correct_document=True, | ||||||
|  |             no_answer=True, | ||||||
|  |             origin="gold-label", | ||||||
|  |             filters={"name": ["777"]}, | ||||||
|  |         ), | ||||||
|  |         # is_correct_answer=False, should be excluded from MultiLabel if "drop_negatives = True" | ||||||
|  |         Label( | ||||||
|  |             id="5-negative", | ||||||
|  |             query="question", | ||||||
|  |             answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]), | ||||||
|  |             document=Document(content="some", id="123"), | ||||||
|  |             is_correct_answer=False, | ||||||
|  |             is_correct_document=True, | ||||||
|  |             no_answer=False, | ||||||
|  |             origin="gold-label", | ||||||
|  |             filters={"name": ["123"]}, | ||||||
|  |         ), | ||||||
|  |     ] | ||||||
|  |     document_store.write_labels(labels, index="haystack_test_multilabel") | ||||||
|  |     # regular labels - not aggregated | ||||||
|  |     list_labels = document_store.get_all_labels(index="haystack_test_multilabel") | ||||||
|  |     assert list_labels == labels | ||||||
|  |     assert len(list_labels) == 5 | ||||||
|  | 
 | ||||||
|  |     # Multi labels (open domain) | ||||||
|  |     multi_labels_open = document_store.get_all_labels_aggregated( | ||||||
|  |         index="haystack_test_multilabel", open_domain=True, drop_negative_labels=True | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     # for open-domain we group all together as long as they have the same question and filters | ||||||
|  |     assert len(multi_labels_open) == 3 | ||||||
|  |     label_counts = set([len(ml.labels) for ml in multi_labels_open]) | ||||||
|  |     assert label_counts == set([2, 1, 1]) | ||||||
|  |     # all labels are in there except the negative one and the no_answer | ||||||
|  |     assert "5-negative" not in [l.id for multi_label in multi_labels_open for l in multi_label.labels] | ||||||
|  | 
 | ||||||
|  |     assert len(multi_labels_open[0].answers) == len(multi_labels_open[0].document_ids) | ||||||
|  | 
 | ||||||
|  |     # for closed domain we group by document so we expect the same as with filters | ||||||
|  |     multi_labels = document_store.get_all_labels_aggregated( | ||||||
|  |         index="haystack_test_multilabel", open_domain=False, drop_negative_labels=True | ||||||
|  |     ) | ||||||
|  |     assert len(multi_labels) == 3 | ||||||
|  |     label_counts = set([len(ml.labels) for ml in multi_labels]) | ||||||
|  |     assert label_counts == set([2, 1, 1]) | ||||||
|  | 
 | ||||||
|  |     assert len(multi_labels[0].answers) == len(multi_labels[0].document_ids) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # exclude weaviate because it does not support storing labels | ||||||
|  | # exclude faiss and milvus as label metadata is not implemented | ||||||
|  | @pytest.mark.parametrize("document_store", ["elasticsearch", "memory"], indirect=True) | ||||||
|  | def test_multilabel_meta_aggregations(document_store): | ||||||
|  |     labels = [ | ||||||
|  |         Label( | ||||||
|  |             id="standard", | ||||||
|  |             query="question", | ||||||
|  |             answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]), | ||||||
|  |             document=Document(content="some", id="123"), | ||||||
|  |             is_correct_answer=True, | ||||||
|  |             is_correct_document=True, | ||||||
|  |             no_answer=False, | ||||||
|  |             origin="gold-label", | ||||||
|  |             meta={"file_id": ["123"]}, | ||||||
|  |         ), | ||||||
|  |         # different answer in same doc | ||||||
|  |         Label( | ||||||
|  |             id="diff-answer-same-doc", | ||||||
|  |             query="question", | ||||||
|  |             answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]), | ||||||
|  |             document=Document(content="some", id="123"), | ||||||
|  |             is_correct_answer=True, | ||||||
|  |             is_correct_document=True, | ||||||
|  |             no_answer=False, | ||||||
|  |             origin="gold-label", | ||||||
|  |             meta={"file_id": ["123"]}, | ||||||
|  |         ), | ||||||
|  |         # answer in different doc | ||||||
|  |         Label( | ||||||
|  |             id="diff-answer-diff-doc", | ||||||
|  |             query="question", | ||||||
|  |             answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]), | ||||||
|  |             document=Document(content="some other", id="333"), | ||||||
|  |             is_correct_answer=True, | ||||||
|  |             is_correct_document=True, | ||||||
|  |             no_answer=False, | ||||||
|  |             origin="gold-label", | ||||||
|  |             meta={"file_id": ["333"]}, | ||||||
|  |         ), | ||||||
|  |         # 'no answer', should be excluded from MultiLabel | ||||||
|  |         Label( | ||||||
|  |             id="4-no-answer", | ||||||
|  |             query="question", | ||||||
|  |             answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]), | ||||||
|  |             document=Document(content="some", id="777"), | ||||||
|  |             is_correct_answer=True, | ||||||
|  |             is_correct_document=True, | ||||||
|  |             no_answer=True, | ||||||
|  |             origin="gold-label", | ||||||
|  |             meta={"file_id": ["777"]}, | ||||||
|  |         ), | ||||||
|  |         # is_correct_answer=False, should be excluded from MultiLabel if "drop_negatives = True" | ||||||
|  |         Label( | ||||||
|  |             id="5-888", | ||||||
|  |             query="question", | ||||||
|  |             answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]), | ||||||
|  |             document=Document(content="some", id="123"), | ||||||
|  |             is_correct_answer=True, | ||||||
|  |             is_correct_document=True, | ||||||
|  |             no_answer=False, | ||||||
|  |             origin="gold-label", | ||||||
|  |             meta={"file_id": ["888"]}, | ||||||
|  |         ), | ||||||
|  |     ] | ||||||
|  |     document_store.write_labels(labels, index="haystack_test_multilabel") | ||||||
|  |     # regular labels - not aggregated | ||||||
|  |     list_labels = document_store.get_all_labels(index="haystack_test_multilabel") | ||||||
|  |     assert list_labels == labels | ||||||
|  |     assert len(list_labels) == 5 | ||||||
|  | 
 | ||||||
|  |     # Multi labels (open domain) | ||||||
|  |     multi_labels_open = document_store.get_all_labels_aggregated( | ||||||
|  |         index="haystack_test_multilabel", open_domain=True, drop_negative_labels=True | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     # for open-domain we group all together as long as they have the same question and filters | ||||||
|  |     assert len(multi_labels_open) == 1 | ||||||
|  |     assert len(multi_labels_open[0].labels) == 5 | ||||||
|  | 
 | ||||||
|  |     multi_labels = document_store.get_all_labels_aggregated( | ||||||
|  |         index="haystack_test_multilabel", open_domain=True, drop_negative_labels=True, aggregate_by_meta="file_id" | ||||||
|  |     ) | ||||||
|  |     assert len(multi_labels) == 4 | ||||||
|  |     label_counts = set([len(ml.labels) for ml in multi_labels]) | ||||||
|  |     assert label_counts == set([2, 1, 1, 1]) | ||||||
|  |     for multi_label in multi_labels: | ||||||
|  |         for l in multi_label.labels: | ||||||
|  |             assert l.filters == l.meta | ||||||
|  |             assert multi_label.filters == l.filters | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| @pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "milvus", "weaviate"], indirect=True) | @pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "milvus", "weaviate"], indirect=True) | ||||||
| # Currently update_document_meta() is not implemented for Memory doc store | # Currently update_document_meta() is not implemented for Memory doc store | ||||||
| def test_update_meta(document_store): | def test_update_meta(document_store): | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 tstadel
						tstadel