From ba30971d8d77827da9d2c81d82f7d02bf1917d8c Mon Sep 17 00:00:00 2001 From: tstadel <60758086+tstadel@users.noreply.github.com> Date: Thu, 13 Oct 2022 19:09:14 +0200 Subject: [PATCH] feat: extract label aggregation (#3363) * extract label aggregation * refactoring * reformat * add missing param docstrings * fix comment --- haystack/document_stores/base.py | 63 ++++------------------- haystack/utils/__init__.py | 1 + haystack/utils/labels.py | 88 ++++++++++++++++++++++++++++++++ 3 files changed, 99 insertions(+), 53 deletions(-) create mode 100644 haystack/utils/labels.py diff --git a/haystack/document_stores/base.py b/haystack/document_stores/base.py index e60b283af..042420107 100644 --- a/haystack/document_stores/base.py +++ b/haystack/document_stores/base.py @@ -15,6 +15,7 @@ from haystack.nodes.base import BaseComponent from haystack.errors import DuplicateDocumentError, DocumentStoreError from haystack.nodes.preprocessor import PreProcessor from haystack.document_stores.utils import eval_data_from_json, eval_data_from_jsonl, squad_json_to_jsonl +from haystack.utils.labels import aggregate_labels logger = logging.getLogger(__name__) @@ -274,62 +275,18 @@ class BaseDocumentStore(BaseComponent): might return multiple MultiLabel objects with the same question string. :param headers: Custom HTTP headers to pass to document store client if supported (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='} for basic authentication) :param aggregate_by_meta: The names of the Label meta fields by which to aggregate. For example: ["product_id"] - TODO drop params + :param drop_negative_labels: When True, labels with incorrect answers and documents are dropped. + :param drop_no_answers: When True, labels with no answers are dropped. """ - if aggregate_by_meta: - if type(aggregate_by_meta) == str: - aggregate_by_meta = [aggregate_by_meta] - else: - aggregate_by_meta = [] - all_labels = self.get_all_labels(index=index, filters=filters, headers=headers) - # drop no_answers in order to not create empty MultiLabels - if drop_no_answers: - all_labels = [label for label in all_labels if label.no_answer == False] - - grouped_labels: dict = {} - for l in all_labels: - # This group_keys determines the key by which we aggregate labels. Its contents depend on - # whether we are in an open / closed domain setting, on filters that are specified for labels, - # or if there are fields in the meta data that we should group by dynamically (set using group_by_meta). - label_filter_keys = [f"{k}={''.join(v)}" for k, v in l.filters.items()] if l.filters else [] - group_keys: list = [l.query] + label_filter_keys - # Filters indicate the scope within which a label is valid. - # Depending on the aggregation we need to add filters dynamically. - label_filters_to_add: dict = {} - - if not open_domain: - group_keys.append(f"_id={l.document.id}") - label_filters_to_add["_id"] = l.document.id - - for meta_key in aggregate_by_meta: - meta = l.meta or {} - curr_meta = meta.get(meta_key, None) - if curr_meta: - curr_meta = curr_meta if isinstance(curr_meta, list) else [curr_meta] - meta_str = f"{meta_key}={''.join(curr_meta)}" - group_keys.append(meta_str) - label_filters_to_add[meta_key] = curr_meta - - if label_filters_to_add: - if l.filters is None: - l.filters = label_filters_to_add - else: - l.filters.update(label_filters_to_add) - - group_key = tuple(group_keys) - if group_key in grouped_labels: - grouped_labels[group_key].append(l) - else: - grouped_labels[group_key] = [l] - - # Package labels that we grouped together in a MultiLabel object that allows simpler access to some - # aggregated attributes like `no_answer` - aggregated_labels = [ - MultiLabel(labels=ls, drop_negative_labels=drop_negative_labels, drop_no_answers=drop_no_answers) - for ls in grouped_labels.values() - ] + aggregated_labels = aggregate_labels( + labels=all_labels, + add_closed_domain_filter=not open_domain, + add_meta_filters=aggregate_by_meta, + drop_negative_labels=drop_negative_labels, + drop_no_answers=drop_no_answers, + ) return aggregated_labels diff --git a/haystack/utils/__init__.py b/haystack/utils/__init__.py index 4dc26b12b..392218a80 100644 --- a/haystack/utils/__init__.py +++ b/haystack/utils/__init__.py @@ -28,3 +28,4 @@ from haystack.utils.experiment_tracking import ( StdoutTrackingHead, ) from haystack.utils.early_stopping import EarlyStopping +from haystack.utils.labels import aggregate_labels diff --git a/haystack/utils/labels.py b/haystack/utils/labels.py new file mode 100644 index 000000000..9f8474c7a --- /dev/null +++ b/haystack/utils/labels.py @@ -0,0 +1,88 @@ +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union + +from haystack.schema import Label, MultiLabel + + +def aggregate_labels( + labels: List[Label], + add_closed_domain_filter: bool = False, + add_meta_filters: Optional[Union[str, list]] = None, + drop_negative_labels: bool = False, + drop_no_answers: bool = False, +): + """ + Aggregates Labels into MultiLabel objects (e.g. for evaluation with `Pipeline.eval()`). + + Labels are always aggregated by question and filters defined in the Label objects. + Beyond that you have options to drop certain labels or to dynamically add filters to control the aggregation process. + + Closed domain aggregation: + If the questions are being asked only on the document defined within the Label (i.e. SQuAD style), set `add_closed_domain_filter=True` to aggregate by question, filters and document. + Note that Labels' filters are enriched with the document_id of the Label's document. + Note that you don't need that step + - if your labels already contain the document_id in their filters + - if you're using `Pipeline.eval()`'s `add_isolated_node_eval` feature + + Dynamic metadata aggregation: + If the questions are being asked on a subslice of your document set, that is not defined with the Label's filters but with an additional meta field, + populate `add_meta_filters` with the names of Label meta fields to aggregate by question, filters and your custom meta fields. + Note that Labels' filters are enriched with the specified meta fields defined in the Label. + Remarks: `add_meta_filters` is only intended for dynamic metadata aggregation (e.g. separate evaluations per document type). + For standard questions use-cases, where a question is always asked on multiple files individually, consider setting the Label's filters instead. + For example, if you want to ask a couple of standard questions for each of your products, set filters for "product_id" to your Labels. + Thus you specify that each Label is always only valid for documents with the respective product_id. + + :param labels: List of Labels to aggregate. + :param add_closed_domain_filter: When True, adds a filter for the document ID specified in the label. + Thus, labels are aggregated in a closed domain fashion based on the question text, filters, + and also the id of the document that the label is tied to. See "closed domain aggregation" section for more details. + :param add_meta_filters: The names of the Label meta fields by which to aggregate in addition to question and filters. For example: ["product_id"]. + Note that Labels' filters are enriched with the specified meta fields defined in the Label. + :param drop_negative_labels: When True, labels with incorrect answers and documents are dropped. + :param drop_no_answers: When True, labels with no answers are dropped. + :return: A list of MultiLabel objects. + """ + if add_meta_filters: + if type(add_meta_filters) == str: + add_meta_filters = [add_meta_filters] + else: + add_meta_filters = [] + + # drop no_answers in order to not create empty MultiLabels + if drop_no_answers: + labels = [label for label in labels if label.no_answer == False] + + # add filters for closed domain and dynamic metadata aggregation + for l in labels: + label_filters_to_add = {} + if add_closed_domain_filter: + label_filters_to_add["_id"] = l.document.id + + for meta_key in add_meta_filters: + meta = l.meta or {} + curr_meta = meta.get(meta_key, None) + if curr_meta: + curr_meta = curr_meta if isinstance(curr_meta, list) else [curr_meta] + label_filters_to_add[meta_key] = curr_meta + + if label_filters_to_add: + if l.filters is None: + l.filters = label_filters_to_add + else: + l.filters.update(label_filters_to_add) + + # Filters define the scope a label is valid for the query, so we group the labels by query and filters. + grouped_labels: Dict[Tuple, List[Label]] = defaultdict(list) + for l in labels: + label_filter_keys = [f"{k}={''.join(v)}" for k, v in l.filters.items()] if l.filters else [] + group_keys: list = [l.query] + label_filter_keys + group_key = tuple(group_keys) + grouped_labels[group_key].append(l) + + aggregated_labels = [ + MultiLabel(labels=ls, drop_negative_labels=drop_negative_labels, drop_no_answers=drop_no_answers) + for ls in grouped_labels.values() + ] + + return aggregated_labels