mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-16 19:53:23 +00:00
feat: extract label aggregation (#3363)
* extract label aggregation * refactoring * reformat * add missing param docstrings * fix comment
This commit is contained in:
parent
3b0f00a615
commit
ba30971d8d
@ -15,6 +15,7 @@ from haystack.nodes.base import BaseComponent
|
|||||||
from haystack.errors import DuplicateDocumentError, DocumentStoreError
|
from haystack.errors import DuplicateDocumentError, DocumentStoreError
|
||||||
from haystack.nodes.preprocessor import PreProcessor
|
from haystack.nodes.preprocessor import PreProcessor
|
||||||
from haystack.document_stores.utils import eval_data_from_json, eval_data_from_jsonl, squad_json_to_jsonl
|
from haystack.document_stores.utils import eval_data_from_json, eval_data_from_jsonl, squad_json_to_jsonl
|
||||||
|
from haystack.utils.labels import aggregate_labels
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -274,62 +275,18 @@ class BaseDocumentStore(BaseComponent):
|
|||||||
might return multiple MultiLabel objects with the same question string.
|
might return multiple MultiLabel objects with the same question string.
|
||||||
:param headers: Custom HTTP headers to pass to document store client if supported (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='} for basic authentication)
|
:param headers: Custom HTTP headers to pass to document store client if supported (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='} for basic authentication)
|
||||||
:param aggregate_by_meta: The names of the Label meta fields by which to aggregate. For example: ["product_id"]
|
:param aggregate_by_meta: The names of the Label meta fields by which to aggregate. For example: ["product_id"]
|
||||||
TODO drop params
|
:param drop_negative_labels: When True, labels with incorrect answers and documents are dropped.
|
||||||
|
:param drop_no_answers: When True, labels with no answers are dropped.
|
||||||
"""
|
"""
|
||||||
if aggregate_by_meta:
|
|
||||||
if type(aggregate_by_meta) == str:
|
|
||||||
aggregate_by_meta = [aggregate_by_meta]
|
|
||||||
else:
|
|
||||||
aggregate_by_meta = []
|
|
||||||
|
|
||||||
all_labels = self.get_all_labels(index=index, filters=filters, headers=headers)
|
all_labels = self.get_all_labels(index=index, filters=filters, headers=headers)
|
||||||
|
|
||||||
# drop no_answers in order to not create empty MultiLabels
|
aggregated_labels = aggregate_labels(
|
||||||
if drop_no_answers:
|
labels=all_labels,
|
||||||
all_labels = [label for label in all_labels if label.no_answer == False]
|
add_closed_domain_filter=not open_domain,
|
||||||
|
add_meta_filters=aggregate_by_meta,
|
||||||
grouped_labels: dict = {}
|
drop_negative_labels=drop_negative_labels,
|
||||||
for l in all_labels:
|
drop_no_answers=drop_no_answers,
|
||||||
# This group_keys determines the key by which we aggregate labels. Its contents depend on
|
)
|
||||||
# whether we are in an open / closed domain setting, on filters that are specified for labels,
|
|
||||||
# or if there are fields in the meta data that we should group by dynamically (set using group_by_meta).
|
|
||||||
label_filter_keys = [f"{k}={''.join(v)}" for k, v in l.filters.items()] if l.filters else []
|
|
||||||
group_keys: list = [l.query] + label_filter_keys
|
|
||||||
# Filters indicate the scope within which a label is valid.
|
|
||||||
# Depending on the aggregation we need to add filters dynamically.
|
|
||||||
label_filters_to_add: dict = {}
|
|
||||||
|
|
||||||
if not open_domain:
|
|
||||||
group_keys.append(f"_id={l.document.id}")
|
|
||||||
label_filters_to_add["_id"] = l.document.id
|
|
||||||
|
|
||||||
for meta_key in aggregate_by_meta:
|
|
||||||
meta = l.meta or {}
|
|
||||||
curr_meta = meta.get(meta_key, None)
|
|
||||||
if curr_meta:
|
|
||||||
curr_meta = curr_meta if isinstance(curr_meta, list) else [curr_meta]
|
|
||||||
meta_str = f"{meta_key}={''.join(curr_meta)}"
|
|
||||||
group_keys.append(meta_str)
|
|
||||||
label_filters_to_add[meta_key] = curr_meta
|
|
||||||
|
|
||||||
if label_filters_to_add:
|
|
||||||
if l.filters is None:
|
|
||||||
l.filters = label_filters_to_add
|
|
||||||
else:
|
|
||||||
l.filters.update(label_filters_to_add)
|
|
||||||
|
|
||||||
group_key = tuple(group_keys)
|
|
||||||
if group_key in grouped_labels:
|
|
||||||
grouped_labels[group_key].append(l)
|
|
||||||
else:
|
|
||||||
grouped_labels[group_key] = [l]
|
|
||||||
|
|
||||||
# Package labels that we grouped together in a MultiLabel object that allows simpler access to some
|
|
||||||
# aggregated attributes like `no_answer`
|
|
||||||
aggregated_labels = [
|
|
||||||
MultiLabel(labels=ls, drop_negative_labels=drop_negative_labels, drop_no_answers=drop_no_answers)
|
|
||||||
for ls in grouped_labels.values()
|
|
||||||
]
|
|
||||||
|
|
||||||
return aggregated_labels
|
return aggregated_labels
|
||||||
|
|
||||||
|
@ -28,3 +28,4 @@ from haystack.utils.experiment_tracking import (
|
|||||||
StdoutTrackingHead,
|
StdoutTrackingHead,
|
||||||
)
|
)
|
||||||
from haystack.utils.early_stopping import EarlyStopping
|
from haystack.utils.early_stopping import EarlyStopping
|
||||||
|
from haystack.utils.labels import aggregate_labels
|
||||||
|
88
haystack/utils/labels.py
Normal file
88
haystack/utils/labels.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from haystack.schema import Label, MultiLabel
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_labels(
|
||||||
|
labels: List[Label],
|
||||||
|
add_closed_domain_filter: bool = False,
|
||||||
|
add_meta_filters: Optional[Union[str, list]] = None,
|
||||||
|
drop_negative_labels: bool = False,
|
||||||
|
drop_no_answers: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Aggregates Labels into MultiLabel objects (e.g. for evaluation with `Pipeline.eval()`).
|
||||||
|
|
||||||
|
Labels are always aggregated by question and filters defined in the Label objects.
|
||||||
|
Beyond that you have options to drop certain labels or to dynamically add filters to control the aggregation process.
|
||||||
|
|
||||||
|
Closed domain aggregation:
|
||||||
|
If the questions are being asked only on the document defined within the Label (i.e. SQuAD style), set `add_closed_domain_filter=True` to aggregate by question, filters and document.
|
||||||
|
Note that Labels' filters are enriched with the document_id of the Label's document.
|
||||||
|
Note that you don't need that step
|
||||||
|
- if your labels already contain the document_id in their filters
|
||||||
|
- if you're using `Pipeline.eval()`'s `add_isolated_node_eval` feature
|
||||||
|
|
||||||
|
Dynamic metadata aggregation:
|
||||||
|
If the questions are being asked on a subslice of your document set, that is not defined with the Label's filters but with an additional meta field,
|
||||||
|
populate `add_meta_filters` with the names of Label meta fields to aggregate by question, filters and your custom meta fields.
|
||||||
|
Note that Labels' filters are enriched with the specified meta fields defined in the Label.
|
||||||
|
Remarks: `add_meta_filters` is only intended for dynamic metadata aggregation (e.g. separate evaluations per document type).
|
||||||
|
For standard questions use-cases, where a question is always asked on multiple files individually, consider setting the Label's filters instead.
|
||||||
|
For example, if you want to ask a couple of standard questions for each of your products, set filters for "product_id" to your Labels.
|
||||||
|
Thus you specify that each Label is always only valid for documents with the respective product_id.
|
||||||
|
|
||||||
|
:param labels: List of Labels to aggregate.
|
||||||
|
:param add_closed_domain_filter: When True, adds a filter for the document ID specified in the label.
|
||||||
|
Thus, labels are aggregated in a closed domain fashion based on the question text, filters,
|
||||||
|
and also the id of the document that the label is tied to. See "closed domain aggregation" section for more details.
|
||||||
|
:param add_meta_filters: The names of the Label meta fields by which to aggregate in addition to question and filters. For example: ["product_id"].
|
||||||
|
Note that Labels' filters are enriched with the specified meta fields defined in the Label.
|
||||||
|
:param drop_negative_labels: When True, labels with incorrect answers and documents are dropped.
|
||||||
|
:param drop_no_answers: When True, labels with no answers are dropped.
|
||||||
|
:return: A list of MultiLabel objects.
|
||||||
|
"""
|
||||||
|
if add_meta_filters:
|
||||||
|
if type(add_meta_filters) == str:
|
||||||
|
add_meta_filters = [add_meta_filters]
|
||||||
|
else:
|
||||||
|
add_meta_filters = []
|
||||||
|
|
||||||
|
# drop no_answers in order to not create empty MultiLabels
|
||||||
|
if drop_no_answers:
|
||||||
|
labels = [label for label in labels if label.no_answer == False]
|
||||||
|
|
||||||
|
# add filters for closed domain and dynamic metadata aggregation
|
||||||
|
for l in labels:
|
||||||
|
label_filters_to_add = {}
|
||||||
|
if add_closed_domain_filter:
|
||||||
|
label_filters_to_add["_id"] = l.document.id
|
||||||
|
|
||||||
|
for meta_key in add_meta_filters:
|
||||||
|
meta = l.meta or {}
|
||||||
|
curr_meta = meta.get(meta_key, None)
|
||||||
|
if curr_meta:
|
||||||
|
curr_meta = curr_meta if isinstance(curr_meta, list) else [curr_meta]
|
||||||
|
label_filters_to_add[meta_key] = curr_meta
|
||||||
|
|
||||||
|
if label_filters_to_add:
|
||||||
|
if l.filters is None:
|
||||||
|
l.filters = label_filters_to_add
|
||||||
|
else:
|
||||||
|
l.filters.update(label_filters_to_add)
|
||||||
|
|
||||||
|
# Filters define the scope a label is valid for the query, so we group the labels by query and filters.
|
||||||
|
grouped_labels: Dict[Tuple, List[Label]] = defaultdict(list)
|
||||||
|
for l in labels:
|
||||||
|
label_filter_keys = [f"{k}={''.join(v)}" for k, v in l.filters.items()] if l.filters else []
|
||||||
|
group_keys: list = [l.query] + label_filter_keys
|
||||||
|
group_key = tuple(group_keys)
|
||||||
|
grouped_labels[group_key].append(l)
|
||||||
|
|
||||||
|
aggregated_labels = [
|
||||||
|
MultiLabel(labels=ls, drop_negative_labels=drop_negative_labels, drop_no_answers=drop_no_answers)
|
||||||
|
for ls in grouped_labels.values()
|
||||||
|
]
|
||||||
|
|
||||||
|
return aggregated_labels
|
Loading…
x
Reference in New Issue
Block a user