mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-15 19:23:25 +00:00
feat: extract label aggregation (#3363)
* extract label aggregation * refactoring * reformat * add missing param docstrings * fix comment
This commit is contained in:
parent
3b0f00a615
commit
ba30971d8d
@ -15,6 +15,7 @@ from haystack.nodes.base import BaseComponent
|
||||
from haystack.errors import DuplicateDocumentError, DocumentStoreError
|
||||
from haystack.nodes.preprocessor import PreProcessor
|
||||
from haystack.document_stores.utils import eval_data_from_json, eval_data_from_jsonl, squad_json_to_jsonl
|
||||
from haystack.utils.labels import aggregate_labels
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -274,62 +275,18 @@ class BaseDocumentStore(BaseComponent):
|
||||
might return multiple MultiLabel objects with the same question string.
|
||||
:param headers: Custom HTTP headers to pass to document store client if supported (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='} for basic authentication)
|
||||
:param aggregate_by_meta: The names of the Label meta fields by which to aggregate. For example: ["product_id"]
|
||||
TODO drop params
|
||||
:param drop_negative_labels: When True, labels with incorrect answers and documents are dropped.
|
||||
:param drop_no_answers: When True, labels with no answers are dropped.
|
||||
"""
|
||||
if aggregate_by_meta:
|
||||
if type(aggregate_by_meta) == str:
|
||||
aggregate_by_meta = [aggregate_by_meta]
|
||||
else:
|
||||
aggregate_by_meta = []
|
||||
|
||||
all_labels = self.get_all_labels(index=index, filters=filters, headers=headers)
|
||||
|
||||
# drop no_answers in order to not create empty MultiLabels
|
||||
if drop_no_answers:
|
||||
all_labels = [label for label in all_labels if label.no_answer == False]
|
||||
|
||||
grouped_labels: dict = {}
|
||||
for l in all_labels:
|
||||
# This group_keys determines the key by which we aggregate labels. Its contents depend on
|
||||
# whether we are in an open / closed domain setting, on filters that are specified for labels,
|
||||
# or if there are fields in the meta data that we should group by dynamically (set using group_by_meta).
|
||||
label_filter_keys = [f"{k}={''.join(v)}" for k, v in l.filters.items()] if l.filters else []
|
||||
group_keys: list = [l.query] + label_filter_keys
|
||||
# Filters indicate the scope within which a label is valid.
|
||||
# Depending on the aggregation we need to add filters dynamically.
|
||||
label_filters_to_add: dict = {}
|
||||
|
||||
if not open_domain:
|
||||
group_keys.append(f"_id={l.document.id}")
|
||||
label_filters_to_add["_id"] = l.document.id
|
||||
|
||||
for meta_key in aggregate_by_meta:
|
||||
meta = l.meta or {}
|
||||
curr_meta = meta.get(meta_key, None)
|
||||
if curr_meta:
|
||||
curr_meta = curr_meta if isinstance(curr_meta, list) else [curr_meta]
|
||||
meta_str = f"{meta_key}={''.join(curr_meta)}"
|
||||
group_keys.append(meta_str)
|
||||
label_filters_to_add[meta_key] = curr_meta
|
||||
|
||||
if label_filters_to_add:
|
||||
if l.filters is None:
|
||||
l.filters = label_filters_to_add
|
||||
else:
|
||||
l.filters.update(label_filters_to_add)
|
||||
|
||||
group_key = tuple(group_keys)
|
||||
if group_key in grouped_labels:
|
||||
grouped_labels[group_key].append(l)
|
||||
else:
|
||||
grouped_labels[group_key] = [l]
|
||||
|
||||
# Package labels that we grouped together in a MultiLabel object that allows simpler access to some
|
||||
# aggregated attributes like `no_answer`
|
||||
aggregated_labels = [
|
||||
MultiLabel(labels=ls, drop_negative_labels=drop_negative_labels, drop_no_answers=drop_no_answers)
|
||||
for ls in grouped_labels.values()
|
||||
]
|
||||
aggregated_labels = aggregate_labels(
|
||||
labels=all_labels,
|
||||
add_closed_domain_filter=not open_domain,
|
||||
add_meta_filters=aggregate_by_meta,
|
||||
drop_negative_labels=drop_negative_labels,
|
||||
drop_no_answers=drop_no_answers,
|
||||
)
|
||||
|
||||
return aggregated_labels
|
||||
|
||||
|
@ -28,3 +28,4 @@ from haystack.utils.experiment_tracking import (
|
||||
StdoutTrackingHead,
|
||||
)
|
||||
from haystack.utils.early_stopping import EarlyStopping
|
||||
from haystack.utils.labels import aggregate_labels
|
||||
|
88
haystack/utils/labels.py
Normal file
88
haystack/utils/labels.py
Normal file
@ -0,0 +1,88 @@
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from haystack.schema import Label, MultiLabel
|
||||
|
||||
|
||||
def aggregate_labels(
|
||||
labels: List[Label],
|
||||
add_closed_domain_filter: bool = False,
|
||||
add_meta_filters: Optional[Union[str, list]] = None,
|
||||
drop_negative_labels: bool = False,
|
||||
drop_no_answers: bool = False,
|
||||
):
|
||||
"""
|
||||
Aggregates Labels into MultiLabel objects (e.g. for evaluation with `Pipeline.eval()`).
|
||||
|
||||
Labels are always aggregated by question and filters defined in the Label objects.
|
||||
Beyond that you have options to drop certain labels or to dynamically add filters to control the aggregation process.
|
||||
|
||||
Closed domain aggregation:
|
||||
If the questions are being asked only on the document defined within the Label (i.e. SQuAD style), set `add_closed_domain_filter=True` to aggregate by question, filters and document.
|
||||
Note that Labels' filters are enriched with the document_id of the Label's document.
|
||||
Note that you don't need that step
|
||||
- if your labels already contain the document_id in their filters
|
||||
- if you're using `Pipeline.eval()`'s `add_isolated_node_eval` feature
|
||||
|
||||
Dynamic metadata aggregation:
|
||||
If the questions are being asked on a subslice of your document set, that is not defined with the Label's filters but with an additional meta field,
|
||||
populate `add_meta_filters` with the names of Label meta fields to aggregate by question, filters and your custom meta fields.
|
||||
Note that Labels' filters are enriched with the specified meta fields defined in the Label.
|
||||
Remarks: `add_meta_filters` is only intended for dynamic metadata aggregation (e.g. separate evaluations per document type).
|
||||
For standard questions use-cases, where a question is always asked on multiple files individually, consider setting the Label's filters instead.
|
||||
For example, if you want to ask a couple of standard questions for each of your products, set filters for "product_id" to your Labels.
|
||||
Thus you specify that each Label is always only valid for documents with the respective product_id.
|
||||
|
||||
:param labels: List of Labels to aggregate.
|
||||
:param add_closed_domain_filter: When True, adds a filter for the document ID specified in the label.
|
||||
Thus, labels are aggregated in a closed domain fashion based on the question text, filters,
|
||||
and also the id of the document that the label is tied to. See "closed domain aggregation" section for more details.
|
||||
:param add_meta_filters: The names of the Label meta fields by which to aggregate in addition to question and filters. For example: ["product_id"].
|
||||
Note that Labels' filters are enriched with the specified meta fields defined in the Label.
|
||||
:param drop_negative_labels: When True, labels with incorrect answers and documents are dropped.
|
||||
:param drop_no_answers: When True, labels with no answers are dropped.
|
||||
:return: A list of MultiLabel objects.
|
||||
"""
|
||||
if add_meta_filters:
|
||||
if type(add_meta_filters) == str:
|
||||
add_meta_filters = [add_meta_filters]
|
||||
else:
|
||||
add_meta_filters = []
|
||||
|
||||
# drop no_answers in order to not create empty MultiLabels
|
||||
if drop_no_answers:
|
||||
labels = [label for label in labels if label.no_answer == False]
|
||||
|
||||
# add filters for closed domain and dynamic metadata aggregation
|
||||
for l in labels:
|
||||
label_filters_to_add = {}
|
||||
if add_closed_domain_filter:
|
||||
label_filters_to_add["_id"] = l.document.id
|
||||
|
||||
for meta_key in add_meta_filters:
|
||||
meta = l.meta or {}
|
||||
curr_meta = meta.get(meta_key, None)
|
||||
if curr_meta:
|
||||
curr_meta = curr_meta if isinstance(curr_meta, list) else [curr_meta]
|
||||
label_filters_to_add[meta_key] = curr_meta
|
||||
|
||||
if label_filters_to_add:
|
||||
if l.filters is None:
|
||||
l.filters = label_filters_to_add
|
||||
else:
|
||||
l.filters.update(label_filters_to_add)
|
||||
|
||||
# Filters define the scope a label is valid for the query, so we group the labels by query and filters.
|
||||
grouped_labels: Dict[Tuple, List[Label]] = defaultdict(list)
|
||||
for l in labels:
|
||||
label_filter_keys = [f"{k}={''.join(v)}" for k, v in l.filters.items()] if l.filters else []
|
||||
group_keys: list = [l.query] + label_filter_keys
|
||||
group_key = tuple(group_keys)
|
||||
grouped_labels[group_key].append(l)
|
||||
|
||||
aggregated_labels = [
|
||||
MultiLabel(labels=ls, drop_negative_labels=drop_negative_labels, drop_no_answers=drop_no_answers)
|
||||
for ls in grouped_labels.values()
|
||||
]
|
||||
|
||||
return aggregated_labels
|
Loading…
x
Reference in New Issue
Block a user