mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-11-04 03:39:31 +00:00 
			
		
		
		
	feat: extract label aggregation (#3363)
* extract label aggregation * refactoring * reformat * add missing param docstrings * fix comment
This commit is contained in:
		
							parent
							
								
									3b0f00a615
								
							
						
					
					
						commit
						ba30971d8d
					
				@ -15,6 +15,7 @@ from haystack.nodes.base import BaseComponent
 | 
			
		||||
from haystack.errors import DuplicateDocumentError, DocumentStoreError
 | 
			
		||||
from haystack.nodes.preprocessor import PreProcessor
 | 
			
		||||
from haystack.document_stores.utils import eval_data_from_json, eval_data_from_jsonl, squad_json_to_jsonl
 | 
			
		||||
from haystack.utils.labels import aggregate_labels
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
@ -274,62 +275,18 @@ class BaseDocumentStore(BaseComponent):
 | 
			
		||||
                            might return multiple MultiLabel objects with the same question string.
 | 
			
		||||
        :param headers: Custom HTTP headers to pass to document store client if supported (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='} for basic authentication)
 | 
			
		||||
        :param aggregate_by_meta: The names of the Label meta fields by which to aggregate. For example: ["product_id"]
 | 
			
		||||
        TODO drop params
 | 
			
		||||
        :param drop_negative_labels: When True, labels with incorrect answers and documents are dropped.
 | 
			
		||||
        :param drop_no_answers: When True, labels with no answers are dropped.
 | 
			
		||||
        """
 | 
			
		||||
        if aggregate_by_meta:
 | 
			
		||||
            if type(aggregate_by_meta) == str:
 | 
			
		||||
                aggregate_by_meta = [aggregate_by_meta]
 | 
			
		||||
        else:
 | 
			
		||||
            aggregate_by_meta = []
 | 
			
		||||
 | 
			
		||||
        all_labels = self.get_all_labels(index=index, filters=filters, headers=headers)
 | 
			
		||||
 | 
			
		||||
        # drop no_answers in order to not create empty MultiLabels
 | 
			
		||||
        if drop_no_answers:
 | 
			
		||||
            all_labels = [label for label in all_labels if label.no_answer == False]
 | 
			
		||||
 | 
			
		||||
        grouped_labels: dict = {}
 | 
			
		||||
        for l in all_labels:
 | 
			
		||||
            # This group_keys determines the key by which we aggregate labels. Its contents depend on
 | 
			
		||||
            # whether we are in an open / closed domain setting, on filters that are specified for labels,
 | 
			
		||||
            # or if there are fields in the meta data that we should group by dynamically (set using group_by_meta).
 | 
			
		||||
            label_filter_keys = [f"{k}={''.join(v)}" for k, v in l.filters.items()] if l.filters else []
 | 
			
		||||
            group_keys: list = [l.query] + label_filter_keys
 | 
			
		||||
            # Filters indicate the scope within which a label is valid.
 | 
			
		||||
            # Depending on the aggregation we need to add filters dynamically.
 | 
			
		||||
            label_filters_to_add: dict = {}
 | 
			
		||||
 | 
			
		||||
            if not open_domain:
 | 
			
		||||
                group_keys.append(f"_id={l.document.id}")
 | 
			
		||||
                label_filters_to_add["_id"] = l.document.id
 | 
			
		||||
 | 
			
		||||
            for meta_key in aggregate_by_meta:
 | 
			
		||||
                meta = l.meta or {}
 | 
			
		||||
                curr_meta = meta.get(meta_key, None)
 | 
			
		||||
                if curr_meta:
 | 
			
		||||
                    curr_meta = curr_meta if isinstance(curr_meta, list) else [curr_meta]
 | 
			
		||||
                    meta_str = f"{meta_key}={''.join(curr_meta)}"
 | 
			
		||||
                    group_keys.append(meta_str)
 | 
			
		||||
                    label_filters_to_add[meta_key] = curr_meta
 | 
			
		||||
 | 
			
		||||
            if label_filters_to_add:
 | 
			
		||||
                if l.filters is None:
 | 
			
		||||
                    l.filters = label_filters_to_add
 | 
			
		||||
                else:
 | 
			
		||||
                    l.filters.update(label_filters_to_add)
 | 
			
		||||
 | 
			
		||||
            group_key = tuple(group_keys)
 | 
			
		||||
            if group_key in grouped_labels:
 | 
			
		||||
                grouped_labels[group_key].append(l)
 | 
			
		||||
            else:
 | 
			
		||||
                grouped_labels[group_key] = [l]
 | 
			
		||||
 | 
			
		||||
        # Package labels that we grouped together in a MultiLabel object that allows simpler access to some
 | 
			
		||||
        # aggregated attributes like `no_answer`
 | 
			
		||||
        aggregated_labels = [
 | 
			
		||||
            MultiLabel(labels=ls, drop_negative_labels=drop_negative_labels, drop_no_answers=drop_no_answers)
 | 
			
		||||
            for ls in grouped_labels.values()
 | 
			
		||||
        ]
 | 
			
		||||
        aggregated_labels = aggregate_labels(
 | 
			
		||||
            labels=all_labels,
 | 
			
		||||
            add_closed_domain_filter=not open_domain,
 | 
			
		||||
            add_meta_filters=aggregate_by_meta,
 | 
			
		||||
            drop_negative_labels=drop_negative_labels,
 | 
			
		||||
            drop_no_answers=drop_no_answers,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return aggregated_labels
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -28,3 +28,4 @@ from haystack.utils.experiment_tracking import (
 | 
			
		||||
    StdoutTrackingHead,
 | 
			
		||||
)
 | 
			
		||||
from haystack.utils.early_stopping import EarlyStopping
 | 
			
		||||
from haystack.utils.labels import aggregate_labels
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										88
									
								
								haystack/utils/labels.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								haystack/utils/labels.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,88 @@
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from typing import Dict, List, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
from haystack.schema import Label, MultiLabel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def aggregate_labels(
 | 
			
		||||
    labels: List[Label],
 | 
			
		||||
    add_closed_domain_filter: bool = False,
 | 
			
		||||
    add_meta_filters: Optional[Union[str, list]] = None,
 | 
			
		||||
    drop_negative_labels: bool = False,
 | 
			
		||||
    drop_no_answers: bool = False,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Aggregates Labels into MultiLabel objects (e.g. for evaluation with `Pipeline.eval()`).
 | 
			
		||||
 | 
			
		||||
    Labels are always aggregated by question and filters defined in the Label objects.
 | 
			
		||||
    Beyond that you have options to drop certain labels or to dynamically add filters to control the aggregation process.
 | 
			
		||||
 | 
			
		||||
    Closed domain aggregation:
 | 
			
		||||
    If the questions are being asked only on the document defined within the Label (i.e. SQuAD style), set `add_closed_domain_filter=True` to aggregate by question, filters and document.
 | 
			
		||||
    Note that Labels' filters are enriched with the document_id of the Label's document.
 | 
			
		||||
    Note that you don't need that step
 | 
			
		||||
    - if your labels already contain the document_id in their filters
 | 
			
		||||
    - if you're using `Pipeline.eval()`'s `add_isolated_node_eval` feature
 | 
			
		||||
 | 
			
		||||
    Dynamic metadata aggregation:
 | 
			
		||||
    If the questions are being asked on a subslice of your document set, that is not defined with the Label's filters but with an additional meta field,
 | 
			
		||||
    populate `add_meta_filters` with the names of Label meta fields to aggregate by question, filters and your custom meta fields.
 | 
			
		||||
    Note that Labels' filters are enriched with the specified meta fields defined in the Label.
 | 
			
		||||
    Remarks: `add_meta_filters` is only intended for dynamic metadata aggregation (e.g. separate evaluations per document type).
 | 
			
		||||
    For standard questions use-cases, where a question is always asked on multiple files individually, consider setting the Label's filters instead.
 | 
			
		||||
    For example, if you want to ask a couple of standard questions for each of your products, set filters for "product_id" to your Labels.
 | 
			
		||||
    Thus you specify that each Label is always only valid for documents with the respective product_id.
 | 
			
		||||
 | 
			
		||||
    :param labels: List of Labels to aggregate.
 | 
			
		||||
    :param add_closed_domain_filter: When True, adds a filter for the document ID specified in the label.
 | 
			
		||||
                        Thus, labels are aggregated in a closed domain fashion based on the question text, filters,
 | 
			
		||||
                        and also the id of the document that the label is tied to. See "closed domain aggregation" section for more details.
 | 
			
		||||
    :param add_meta_filters: The names of the Label meta fields by which to aggregate in addition to question and filters. For example: ["product_id"].
 | 
			
		||||
                        Note that Labels' filters are enriched with the specified meta fields defined in the Label.
 | 
			
		||||
    :param drop_negative_labels: When True, labels with incorrect answers and documents are dropped.
 | 
			
		||||
    :param drop_no_answers: When True, labels with no answers are dropped.
 | 
			
		||||
    :return: A list of MultiLabel objects.
 | 
			
		||||
    """
 | 
			
		||||
    if add_meta_filters:
 | 
			
		||||
        if type(add_meta_filters) == str:
 | 
			
		||||
            add_meta_filters = [add_meta_filters]
 | 
			
		||||
    else:
 | 
			
		||||
        add_meta_filters = []
 | 
			
		||||
 | 
			
		||||
    # drop no_answers in order to not create empty MultiLabels
 | 
			
		||||
    if drop_no_answers:
 | 
			
		||||
        labels = [label for label in labels if label.no_answer == False]
 | 
			
		||||
 | 
			
		||||
    # add filters for closed domain and dynamic metadata aggregation
 | 
			
		||||
    for l in labels:
 | 
			
		||||
        label_filters_to_add = {}
 | 
			
		||||
        if add_closed_domain_filter:
 | 
			
		||||
            label_filters_to_add["_id"] = l.document.id
 | 
			
		||||
 | 
			
		||||
        for meta_key in add_meta_filters:
 | 
			
		||||
            meta = l.meta or {}
 | 
			
		||||
            curr_meta = meta.get(meta_key, None)
 | 
			
		||||
            if curr_meta:
 | 
			
		||||
                curr_meta = curr_meta if isinstance(curr_meta, list) else [curr_meta]
 | 
			
		||||
                label_filters_to_add[meta_key] = curr_meta
 | 
			
		||||
 | 
			
		||||
        if label_filters_to_add:
 | 
			
		||||
            if l.filters is None:
 | 
			
		||||
                l.filters = label_filters_to_add
 | 
			
		||||
            else:
 | 
			
		||||
                l.filters.update(label_filters_to_add)
 | 
			
		||||
 | 
			
		||||
    # Filters define the scope a label is valid for the query, so we group the labels by query and filters.
 | 
			
		||||
    grouped_labels: Dict[Tuple, List[Label]] = defaultdict(list)
 | 
			
		||||
    for l in labels:
 | 
			
		||||
        label_filter_keys = [f"{k}={''.join(v)}" for k, v in l.filters.items()] if l.filters else []
 | 
			
		||||
        group_keys: list = [l.query] + label_filter_keys
 | 
			
		||||
        group_key = tuple(group_keys)
 | 
			
		||||
        grouped_labels[group_key].append(l)
 | 
			
		||||
 | 
			
		||||
    aggregated_labels = [
 | 
			
		||||
        MultiLabel(labels=ls, drop_negative_labels=drop_negative_labels, drop_no_answers=drop_no_answers)
 | 
			
		||||
        for ls in grouped_labels.values()
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    return aggregated_labels
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user