feat: extract label aggregation (#3363)

* extract label aggregation

* refactoring

* reformat

* add missing param docstrings

* fix comment
This commit is contained in:
tstadel 2022-10-13 19:09:14 +02:00 committed by GitHub
parent 3b0f00a615
commit ba30971d8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 99 additions and 53 deletions

View File

@ -15,6 +15,7 @@ from haystack.nodes.base import BaseComponent
from haystack.errors import DuplicateDocumentError, DocumentStoreError
from haystack.nodes.preprocessor import PreProcessor
from haystack.document_stores.utils import eval_data_from_json, eval_data_from_jsonl, squad_json_to_jsonl
from haystack.utils.labels import aggregate_labels
logger = logging.getLogger(__name__)
@ -274,62 +275,18 @@ class BaseDocumentStore(BaseComponent):
might return multiple MultiLabel objects with the same question string.
:param headers: Custom HTTP headers to pass to document store client if supported (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='} for basic authentication)
:param aggregate_by_meta: The names of the Label meta fields by which to aggregate. For example: ["product_id"]
TODO drop params
:param drop_negative_labels: When True, labels with incorrect answers and documents are dropped.
:param drop_no_answers: When True, labels with no answers are dropped.
"""
if aggregate_by_meta:
if type(aggregate_by_meta) == str:
aggregate_by_meta = [aggregate_by_meta]
else:
aggregate_by_meta = []
all_labels = self.get_all_labels(index=index, filters=filters, headers=headers)
# drop no_answers in order to not create empty MultiLabels
if drop_no_answers:
all_labels = [label for label in all_labels if label.no_answer == False]
grouped_labels: dict = {}
for l in all_labels:
# This group_keys determines the key by which we aggregate labels. Its contents depend on
# whether we are in an open / closed domain setting, on filters that are specified for labels,
# or if there are fields in the meta data that we should group by dynamically (set using group_by_meta).
label_filter_keys = [f"{k}={''.join(v)}" for k, v in l.filters.items()] if l.filters else []
group_keys: list = [l.query] + label_filter_keys
# Filters indicate the scope within which a label is valid.
# Depending on the aggregation we need to add filters dynamically.
label_filters_to_add: dict = {}
if not open_domain:
group_keys.append(f"_id={l.document.id}")
label_filters_to_add["_id"] = l.document.id
for meta_key in aggregate_by_meta:
meta = l.meta or {}
curr_meta = meta.get(meta_key, None)
if curr_meta:
curr_meta = curr_meta if isinstance(curr_meta, list) else [curr_meta]
meta_str = f"{meta_key}={''.join(curr_meta)}"
group_keys.append(meta_str)
label_filters_to_add[meta_key] = curr_meta
if label_filters_to_add:
if l.filters is None:
l.filters = label_filters_to_add
else:
l.filters.update(label_filters_to_add)
group_key = tuple(group_keys)
if group_key in grouped_labels:
grouped_labels[group_key].append(l)
else:
grouped_labels[group_key] = [l]
# Package labels that we grouped together in a MultiLabel object that allows simpler access to some
# aggregated attributes like `no_answer`
aggregated_labels = [
MultiLabel(labels=ls, drop_negative_labels=drop_negative_labels, drop_no_answers=drop_no_answers)
for ls in grouped_labels.values()
]
aggregated_labels = aggregate_labels(
labels=all_labels,
add_closed_domain_filter=not open_domain,
add_meta_filters=aggregate_by_meta,
drop_negative_labels=drop_negative_labels,
drop_no_answers=drop_no_answers,
)
return aggregated_labels

View File

@ -28,3 +28,4 @@ from haystack.utils.experiment_tracking import (
StdoutTrackingHead,
)
from haystack.utils.early_stopping import EarlyStopping
from haystack.utils.labels import aggregate_labels

88
haystack/utils/labels.py Normal file
View File

@ -0,0 +1,88 @@
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union
from haystack.schema import Label, MultiLabel
def aggregate_labels(
labels: List[Label],
add_closed_domain_filter: bool = False,
add_meta_filters: Optional[Union[str, list]] = None,
drop_negative_labels: bool = False,
drop_no_answers: bool = False,
):
"""
Aggregates Labels into MultiLabel objects (e.g. for evaluation with `Pipeline.eval()`).
Labels are always aggregated by question and filters defined in the Label objects.
Beyond that you have options to drop certain labels or to dynamically add filters to control the aggregation process.
Closed domain aggregation:
If the questions are being asked only on the document defined within the Label (i.e. SQuAD style), set `add_closed_domain_filter=True` to aggregate by question, filters and document.
Note that Labels' filters are enriched with the document_id of the Label's document.
Note that you don't need that step
- if your labels already contain the document_id in their filters
- if you're using `Pipeline.eval()`'s `add_isolated_node_eval` feature
Dynamic metadata aggregation:
If the questions are being asked on a subslice of your document set, that is not defined with the Label's filters but with an additional meta field,
populate `add_meta_filters` with the names of Label meta fields to aggregate by question, filters and your custom meta fields.
Note that Labels' filters are enriched with the specified meta fields defined in the Label.
Remarks: `add_meta_filters` is only intended for dynamic metadata aggregation (e.g. separate evaluations per document type).
For standard questions use-cases, where a question is always asked on multiple files individually, consider setting the Label's filters instead.
For example, if you want to ask a couple of standard questions for each of your products, set filters for "product_id" to your Labels.
Thus you specify that each Label is always only valid for documents with the respective product_id.
:param labels: List of Labels to aggregate.
:param add_closed_domain_filter: When True, adds a filter for the document ID specified in the label.
Thus, labels are aggregated in a closed domain fashion based on the question text, filters,
and also the id of the document that the label is tied to. See "closed domain aggregation" section for more details.
:param add_meta_filters: The names of the Label meta fields by which to aggregate in addition to question and filters. For example: ["product_id"].
Note that Labels' filters are enriched with the specified meta fields defined in the Label.
:param drop_negative_labels: When True, labels with incorrect answers and documents are dropped.
:param drop_no_answers: When True, labels with no answers are dropped.
:return: A list of MultiLabel objects.
"""
if add_meta_filters:
if type(add_meta_filters) == str:
add_meta_filters = [add_meta_filters]
else:
add_meta_filters = []
# drop no_answers in order to not create empty MultiLabels
if drop_no_answers:
labels = [label for label in labels if label.no_answer == False]
# add filters for closed domain and dynamic metadata aggregation
for l in labels:
label_filters_to_add = {}
if add_closed_domain_filter:
label_filters_to_add["_id"] = l.document.id
for meta_key in add_meta_filters:
meta = l.meta or {}
curr_meta = meta.get(meta_key, None)
if curr_meta:
curr_meta = curr_meta if isinstance(curr_meta, list) else [curr_meta]
label_filters_to_add[meta_key] = curr_meta
if label_filters_to_add:
if l.filters is None:
l.filters = label_filters_to_add
else:
l.filters.update(label_filters_to_add)
# Filters define the scope a label is valid for the query, so we group the labels by query and filters.
grouped_labels: Dict[Tuple, List[Label]] = defaultdict(list)
for l in labels:
label_filter_keys = [f"{k}={''.join(v)}" for k, v in l.filters.items()] if l.filters else []
group_keys: list = [l.query] + label_filter_keys
group_key = tuple(group_keys)
grouped_labels[group_key].append(l)
aggregated_labels = [
MultiLabel(labels=ls, drop_negative_labels=drop_negative_labels, drop_no_answers=drop_no_answers)
for ls in grouped_labels.values()
]
return aggregated_labels