diff --git a/docs/_src/api/api/pipelines.md b/docs/_src/api/api/pipelines.md
index c0200de07..4d898f0dd 100644
--- a/docs/_src/api/api/pipelines.md
+++ b/docs/_src/api/api/pipelines.md
@@ -162,7 +162,7 @@ Runs the pipeline, one node at a time.
#### eval
```python
- | eval(labels: List[MultiLabel], params: Optional[dict] = None, sas_model_name_or_path: str = None) -> EvaluationResult
+ | eval(labels: List[MultiLabel], params: Optional[dict] = None, sas_model_name_or_path: str = None, add_isolated_node_eval: bool = False) -> EvaluationResult
```
Evaluates the pipeline by running the pipeline once per query in debug mode
@@ -186,6 +186,14 @@ and putting together all data that is needed for evaluation, e.g. calculating me
- Good default for multiple languages: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
- Large, powerful, but slow model for English only: "cross-encoder/stsb-roberta-large"
- Large model for German only: "deepset/gbert-large-sts"
+- `add_isolated_node_eval`: If set to True, in addition to the integrated evaluation of the pipeline, each node is evaluated in isolated evaluation mode.
+ This mode helps to understand the bottlenecks of a pipeline in terms of output quality of each individual node.
+ If a node performs much better in the isolated evaluation than in the integrated evaluation, the previous node needs to be optimized to improve the pipeline's performance.
+ If a node's performance is similar in both modes, this node itself needs to be optimized to improve the pipeline's performance.
+ The isolated evaluation calculates the upper bound of each node's evaluation metrics under the assumption that it received perfect inputs from the previous node.
+ To this end, labels are used as input to the node instead of the output of the previous node in the pipeline.
+ The generated dataframes in the EvaluationResult then contain additional rows, which can be distinguished from the integrated evaluation results based on the
+ values "integrated" or "isolated" in the column "eval_mode" and the evaluation report then additionally lists the upper bound of each node's evaluation metrics.
#### get\_nodes\_by\_class
@@ -627,7 +635,7 @@ Instance of DocumentStore or None
#### eval
```python
- | eval(labels: List[MultiLabel], params: Optional[dict], sas_model_name_or_path: str = None) -> EvaluationResult
+ | eval(labels: List[MultiLabel], params: Optional[dict] = None, sas_model_name_or_path: Optional[str] = None, add_isolated_node_eval: bool = False) -> EvaluationResult
```
Evaluates the pipeline by running the pipeline once per query in debug mode
@@ -640,6 +648,7 @@ and putting together all data that is needed for evaluation, e.g. calculating me
params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}}
- `sas_model_name_or_path`: SentenceTransformers semantic textual similarity model to be used for sas value calculation,
should be path or string pointing to downloadable models.
+- `add_isolated_node_eval`: Whether to additionally evaluate the reader based on labels as input instead of output of previous node in pipeline
## ExtractiveQAPipeline
diff --git a/docs/_src/api/api/primitives.md b/docs/_src/api/api/primitives.md
index 9e2a3e3ed..c1e95c829 100644
--- a/docs/_src/api/api/primitives.md
+++ b/docs/_src/api/api/primitives.md
@@ -294,7 +294,7 @@ The DataFrames have the following schema:
#### calculate\_metrics
```python
- | calculate_metrics(simulated_top_k_reader: int = -1, simulated_top_k_retriever: int = -1, doc_relevance_col: str = "gold_id_match", node_input: str = "prediction") -> Dict[str, Dict[str, float]]
+ | calculate_metrics(simulated_top_k_reader: int = -1, simulated_top_k_retriever: int = -1, doc_relevance_col: str = "gold_id_match", eval_mode: str = "integrated") -> Dict[str, Dict[str, float]]
```
Calculates proper metrics for each node.
@@ -324,19 +324,19 @@ as there are situations the result can heavily differ from an actual eval run wi
remarks: there might be a discrepancy between simulated reader metrics and an actual pipeline run with retriever top_k
- `doc_relevance_col`: column in the underlying eval table that contains the relevance criteria for documents.
values can be: 'gold_id_match', 'answer_match', 'gold_id_or_answer_match'
-- `node_input`: the input on which the node was evaluated on.
- Usually nodes get evaluated on the prediction provided by its predecessor nodes in the pipeline (value='prediction').
+- `eval_mode`: the input on which the node was evaluated on.
+ Usually nodes get evaluated on the prediction provided by its predecessor nodes in the pipeline (value='integrated').
However, as the quality of the node itself can heavily depend on the node's input and thus the predecessor's quality,
you might want to simulate a perfect predecessor in order to get an independent upper bound of the quality of your node.
- For example when evaluating the reader use value='label' to simulate a perfect retriever in an ExtractiveQAPipeline.
- Values can be 'prediction', 'label'.
- Default value is 'prediction'.
+ For example when evaluating the reader use value='isolated' to simulate a perfect retriever in an ExtractiveQAPipeline.
+ Values can be 'integrated', 'isolated'.
+ Default value is 'integrated'.
#### wrong\_examples
```python
- | wrong_examples(node: str, n: int = 3, simulated_top_k_reader: int = -1, simulated_top_k_retriever: int = -1, doc_relevance_col: str = "gold_id_match", document_metric: str = "recall_single_hit", answer_metric: str = "f1", node_input: str = "prediction") -> List[Dict]
+ | wrong_examples(node: str, n: int = 3, simulated_top_k_reader: int = -1, simulated_top_k_retriever: int = -1, doc_relevance_col: str = "gold_id_match", document_metric: str = "recall_single_hit", answer_metric: str = "f1", eval_mode: str = "integrated") -> List[Dict]
```
Returns the worst performing queries.
@@ -357,13 +357,13 @@ See calculate_metrics() for more information.
values can be: 'recall_single_hit', 'recall_multi_hit', 'mrr', 'map', 'precision'
- `document_metric`: the answer metric worst queries are calculated with.
values can be: 'f1', 'exact_match' and 'sas' if the evaluation was made using a SAS model.
-- `node_input`: the input on which the node was evaluated on.
- Usually nodes get evaluated on the prediction provided by its predecessor nodes in the pipeline (value='prediction').
+- `eval_mode`: the input on which the node was evaluated on.
+ Usually nodes get evaluated on the prediction provided by its predecessor nodes in the pipeline (value='integrated').
However, as the quality of the node itself can heavily depend on the node's input and thus the predecessor's quality,
you might want to simulate a perfect predecessor in order to get an independent upper bound of the quality of your node.
- For example when evaluating the reader use value='label' to simulate a perfect retriever in an ExtractiveQAPipeline.
- Values can be 'prediction', 'label'.
- Default value is 'prediction'.
+ For example when evaluating the reader use value='isolated' to simulate a perfect retriever in an ExtractiveQAPipeline.
+ Values can be 'integrated', 'isolated'.
+ Default value is 'integrated'.
#### save
diff --git a/haystack/nodes/reader/base.py b/haystack/nodes/reader/base.py
index e154fc0c0..11b453e86 100644
--- a/haystack/nodes/reader/base.py
+++ b/haystack/nodes/reader/base.py
@@ -7,7 +7,7 @@ from copy import deepcopy
from functools import wraps
from time import perf_counter
-from haystack.schema import Document, Answer, Span
+from haystack.schema import Document, Answer, Span, MultiLabel
from haystack.nodes.base import BaseComponent
@@ -55,7 +55,22 @@ class BaseReader(BaseComponent):
return no_ans_prediction, max_no_ans_gap
- def run(self, query: str, documents: List[Document], top_k: Optional[int] = None): # type: ignore
+ @staticmethod
+ def add_doc_meta_data_to_answer(documents: List[Document], answer):
+ # Add corresponding document_name and more meta data, if the answer contains the document_id
+ if answer.meta is None:
+ answer.meta = {}
+ # get meta from doc
+ meta_from_doc = {}
+ for doc in documents:
+ if doc.id == answer.document_id:
+ meta_from_doc = deepcopy(doc.meta)
+ break
+ # append to "own" meta
+ answer.meta.update(meta_from_doc)
+ return answer
+
+ def run(self, query: str, documents: List[Document], top_k: Optional[int] = None, labels: Optional[MultiLabel] = None, add_isolated_node_eval: bool = False): # type: ignore
self.query_count += 1
if documents:
predict = self.timing(self.predict, "query_time")
@@ -64,17 +79,15 @@ class BaseReader(BaseComponent):
results = {"answers": []}
# Add corresponding document_name and more meta data, if an answer contains the document_id
- for ans in results["answers"]:
- if ans.meta is None:
- ans.meta = {}
- # get meta from doc
- meta_from_doc = {}
- for doc in documents:
- if doc.id == ans.document_id:
- meta_from_doc = deepcopy(doc.meta)
- break
- # append to "own" meta
- ans.meta.update(meta_from_doc)
+ results["answers"] = [BaseReader.add_doc_meta_data_to_answer(documents=documents, answer=answer) for answer in results["answers"]]
+
+ # run evaluation with labels as node inputs
+ if add_isolated_node_eval and labels is not None:
+ relevant_documents = [label.document for label in labels.labels]
+ results_label_input = predict(query=query, documents=relevant_documents, top_k=top_k)
+
+ # Add corresponding document_name and more meta data, if an answer contains the document_id
+ results["answers_isolated"] = [BaseReader.add_doc_meta_data_to_answer(documents=documents, answer=answer) for answer in results_label_input["answers"]]
return results, "output_1"
diff --git a/haystack/pipelines/base.py b/haystack/pipelines/base.py
index fdfe07149..72e5e6398 100644
--- a/haystack/pipelines/base.py
+++ b/haystack/pipelines/base.py
@@ -368,7 +368,8 @@ class Pipeline(BasePipeline):
self,
labels: List[MultiLabel],
params: Optional[dict] = None,
- sas_model_name_or_path: str = None
+ sas_model_name_or_path: str = None,
+ add_isolated_node_eval: bool = False
) -> EvaluationResult:
"""
Evaluates the pipeline by running the pipeline once per query in debug mode
@@ -390,8 +391,20 @@ class Pipeline(BasePipeline):
- Good default for multiple languages: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
- Large, powerful, but slow model for English only: "cross-encoder/stsb-roberta-large"
- Large model for German only: "deepset/gbert-large-sts"
+ :param add_isolated_node_eval: If set to True, in addition to the integrated evaluation of the pipeline, each node is evaluated in isolated evaluation mode.
+ This mode helps to understand the bottlenecks of a pipeline in terms of output quality of each individual node.
+ If a node performs much better in the isolated evaluation than in the integrated evaluation, the previous node needs to be optimized to improve the pipeline's performance.
+ If a node's performance is similar in both modes, this node itself needs to be optimized to improve the pipeline's performance.
+ The isolated evaluation calculates the upper bound of each node's evaluation metrics under the assumption that it received perfect inputs from the previous node.
+ To this end, labels are used as input to the node instead of the output of the previous node in the pipeline.
+ The generated dataframes in the EvaluationResult then contain additional rows, which can be distinguished from the integrated evaluation results based on the
+ values "integrated" or "isolated" in the column "eval_mode" and the evaluation report then additionally lists the upper bound of each node's evaluation metrics.
"""
eval_result = EvaluationResult()
+ if add_isolated_node_eval:
+ if params is None:
+ params = {}
+ params["add_isolated_node_eval"] = True
queries = [label.query for label in labels]
for query, label in zip(queries, labels):
predictions = self.run(query=query, labels=label, params=params, debug=True)
@@ -420,7 +433,7 @@ class Pipeline(BasePipeline):
"gold_document_contents", "content", "gold_id_match", "answer_match", "gold_id_or_answer_match", # doc-specific
"rank", "document_id", "gold_document_ids", # generic
"offsets_in_document", "gold_offsets_in_documents", # answer-specific
- "type", "node", "node_input"] # generic
+ "type", "node", "eval_mode"] # generic
eval_result.node_results[key] = self._reorder_columns(df, desired_col_order)
return eval_result
@@ -447,11 +460,10 @@ class Pipeline(BasePipeline):
Additional answer or document specific evaluation infos like gold labels
and metrics depicting whether the row matches the gold labels are included, too.
"""
- df: DataFrame = pd.DataFrame()
if query_labels is None or query_labels.labels is None:
logger.warning(f"There is no label for query '{query}'. Query will be omitted.")
- return df
+ return pd.DataFrame()
# remarks for no_answers:
# Single 'no_answer'-labels are not contained in MultiLabel aggregates.
@@ -467,27 +479,37 @@ class Pipeline(BasePipeline):
# - the position or offsets within the document the answer was found
# - the surrounding context of the answer within the document
# - the gold answers
- # - the positon or offsets of the gold answer within the document
+ # - the position or offsets of the gold answer within the document
# - the gold document ids containing the answer
# - the exact_match metric depicting if the answer exactly matches the gold label
# - the f1 metric depicting how well the answer overlaps with the gold label on token basis
- # - the sas metric depciting how well the answer matches the gold label on a semantic basis.
+ # - the sas metric depicting how well the answer matches the gold label on a semantic basis.
# this will be calculated on all queries in eval() for performance reasons if a sas model has been provided
- answers = node_output.get("answers", None)
- if answers is not None:
- answer_cols_to_keep = ["answer", "document_id", "offsets_in_document", "context"]
- df_answers = pd.DataFrame(answers, columns=answer_cols_to_keep)
- if len(df_answers) > 0:
- df_answers["type"] = "answer"
- df_answers["gold_answers"] = [gold_answers] * len(df_answers)
- df_answers["gold_offsets_in_documents"] = [gold_offsets_in_documents] * len(df_answers)
- df_answers["gold_document_ids"] = [gold_document_ids] * len(df_answers)
- df_answers["exact_match"] = df_answers.apply(
- lambda row: calculate_em_str_multi(gold_answers, row["answer"]), axis=1)
- df_answers["f1"] = df_answers.apply(
- lambda row: calculate_f1_str_multi(gold_answers, row["answer"]), axis=1)
- df_answers["rank"] = np.arange(1, len(df_answers)+1)
- df = pd.concat([df, df_answers])
+
+ partial_dfs = []
+ for field_name in ["answers", "answers_isolated"]:
+ df = pd.DataFrame()
+ answers = node_output.get(field_name, None)
+ if answers is not None:
+ answer_cols_to_keep = ["answer", "document_id", "offsets_in_document", "context"]
+ df_answers = pd.DataFrame(answers, columns=answer_cols_to_keep)
+ if len(df_answers) > 0:
+ df_answers["type"] = "answer"
+ df_answers["gold_answers"] = [gold_answers] * len(df_answers)
+ df_answers["gold_offsets_in_documents"] = [gold_offsets_in_documents] * len(df_answers)
+ df_answers["gold_document_ids"] = [gold_document_ids] * len(df_answers)
+ df_answers["exact_match"] = df_answers.apply(
+ lambda row: calculate_em_str_multi(gold_answers, row["answer"]), axis=1)
+ df_answers["f1"] = df_answers.apply(
+ lambda row: calculate_f1_str_multi(gold_answers, row["answer"]), axis=1)
+ df_answers["rank"] = np.arange(1, len(df_answers)+1)
+ df = pd.concat([df, df_answers])
+
+ # add general info
+ df["node"] = node_name
+ df["query"] = query
+ df["eval_mode"] = "isolated" if "isolated" in field_name else "integrated"
+ partial_dfs.append(df)
# if node returned documents, include document specific info:
# - the document_id
@@ -497,34 +519,37 @@ class Pipeline(BasePipeline):
# - the gold_id_match metric depicting whether one of the gold document ids matches the document
# - the answer_match metric depicting whether the document contains the answer
# - the gold_id_or_answer_match metric depicting whether one of the former two conditions are met
- documents = node_output.get("documents", None)
- if documents is not None:
- document_cols_to_keep = ["content", "id"]
- df_docs = pd.DataFrame(documents, columns=document_cols_to_keep)
- if len(df_docs) > 0:
- df_docs = df_docs.rename(columns={"id": "document_id"})
- df_docs["type"] = "document"
- df_docs["gold_document_ids"] = [gold_document_ids] * len(df_docs)
- df_docs["gold_document_contents"] = [gold_document_contents] * len(df_docs)
- df_docs["gold_id_match"] = df_docs.apply(
- lambda row: 1.0 if row["document_id"] in gold_document_ids else 0.0, axis=1)
- df_docs["answer_match"] = df_docs.apply(
- lambda row:
- 1.0 if not query_labels.no_answer
- and any(gold_answer in row["content"] for gold_answer in gold_answers)
- else 0.0,
- axis=1)
- df_docs["gold_id_or_answer_match"] = df_docs.apply(
- lambda row: max(row["gold_id_match"], row["answer_match"]), axis=1)
- df_docs["rank"] = np.arange(1, len(df_docs)+1)
- df = pd.concat([df, df_docs])
+ for field_name in ["documents", "documents_isolated"]:
+ df = pd.DataFrame()
+ documents = node_output.get(field_name, None)
+ if documents is not None:
+ document_cols_to_keep = ["content", "id"]
+ df_docs = pd.DataFrame(documents, columns=document_cols_to_keep)
+ if len(df_docs) > 0:
+ df_docs = df_docs.rename(columns={"id": "document_id"})
+ df_docs["type"] = "document"
+ df_docs["gold_document_ids"] = [gold_document_ids] * len(df_docs)
+ df_docs["gold_document_contents"] = [gold_document_contents] * len(df_docs)
+ df_docs["gold_id_match"] = df_docs.apply(
+ lambda row: 1.0 if row["document_id"] in gold_document_ids else 0.0, axis=1)
+ df_docs["answer_match"] = df_docs.apply(
+ lambda row:
+ 1.0 if not query_labels.no_answer
+ and any(gold_answer in row["content"] for gold_answer in gold_answers)
+ else 0.0,
+ axis=1)
+ df_docs["gold_id_or_answer_match"] = df_docs.apply(
+ lambda row: max(row["gold_id_match"], row["answer_match"]), axis=1)
+ df_docs["rank"] = np.arange(1, len(df_docs)+1)
+ df = pd.concat([df, df_docs])
- # add general info
- df["node"] = node_name
- df["query"] = query
- df["node_input"] = "prediction"
+ # add general info
+ df["node"] = node_name
+ df["query"] = query
+ df["eval_mode"] = "isolated" if "isolated" in field_name else "integrated"
+ partial_dfs.append(df)
- return df
+ return pd.concat(partial_dfs, ignore_index=True)
def get_next_nodes(self, node_id: str, stream_id: str):
current_node_edges = self.graph.edges(node_id, data=True)
@@ -755,21 +780,23 @@ class Pipeline(BasePipeline):
def _format_wrong_samples_report(self, eval_result: EvaluationResult, n_wrong_examples: int = 3):
examples = {
- node: eval_result.wrong_examples(node, doc_relevance_col="gold_id_or_answer_match", n=n_wrong_examples)
+ node: eval_result.wrong_examples(node, doc_relevance_col="gold_id_or_answer_match", n=n_wrong_examples)
for node in eval_result.node_results.keys()
}
examples_formatted = {
- node: "\n".join([self._format_wrong_sample(example) for example in examples])
+ node: "\n".join([self._format_wrong_sample(example) for example in examples])
for node, examples in examples.items()
}
return "\n".join([self._format_wrong_samples_node(node, examples) for node, examples in examples_formatted.items()])
- def _format_pipeline_node(self, node: str, metrics: dict, metrics_top_1):
- metrics = metrics.get(node, {})
- metrics_top_1 = {f"{metric}_top_1": value for metric, value in metrics_top_1.get(node, {}).items()}
- node_metrics = {**metrics, **metrics_top_1}
- node_metrics_formatted = "\n".join(sorted([f" | {metric}: {value:5.3}" for metric, value in node_metrics.items()]))
+ def _format_pipeline_node(self, node: str, calculated_metrics: dict):
+ node_metrics: dict = {}
+ for metric_mode in calculated_metrics:
+ for metric, value in calculated_metrics[metric_mode].get(node, {}).items():
+ node_metrics[f"{metric}{metric_mode}"] = value
+
+ node_metrics_formatted = "\n".join(sorted([f" | {metric}: {value:5.3}" for metric, value in node_metrics.items()]))
node_metrics_formatted = f"{node_metrics_formatted}\n" if len(node_metrics_formatted) > 0 else ""
s = (
f" {node}\n"
@@ -779,8 +806,8 @@ class Pipeline(BasePipeline):
)
return s
- def _format_pipeline_overview(self, metrics: dict, metrics_top_1: dict):
- pipeline_overview = "\n".join([self._format_pipeline_node(node, metrics, metrics_top_1) for node in self.graph.nodes])
+ def _format_pipeline_overview(self, calculated_metrics: dict):
+ pipeline_overview = "\n".join([self._format_pipeline_node(node, calculated_metrics) for node in self.graph.nodes])
s = (
f"================== Evaluation Report ==================\n"
f"=======================================================\n"
@@ -807,17 +834,18 @@ class Pipeline(BasePipeline):
if any(degree > 1 for node, degree in self.graph.out_degree):
logger.warning("Pipelines with junctions are currently not supported.")
return
-
- metrics_top_n = eval_result.calculate_metrics(doc_relevance_col="gold_id_or_answer_match")
- metrics_top_1 = eval_result.calculate_metrics(doc_relevance_col="gold_id_or_answer_match", simulated_top_k_reader=1)
+
+ calculated_metrics = {"": eval_result.calculate_metrics(doc_relevance_col="gold_id_or_answer_match"),
+ "_top_1": eval_result.calculate_metrics(doc_relevance_col="gold_id_or_answer_match", simulated_top_k_reader=1),
+ " upper bound": eval_result.calculate_metrics(doc_relevance_col="gold_id_or_answer_match", eval_mode="isolated")}
+
if metrics_filter is not None:
- metrics_top_n = {node: metrics if node not in metrics_filter
- else {metric: value for metric, value in metrics.items() if metric in metrics_filter[node]}
- for node, metrics in metrics_top_n.items()}
- metrics_top_1 = {node: metrics if node not in metrics_filter
- else {metric: value for metric, value in metrics.items() if metric in metrics_filter[node]}
- for node, metrics in metrics_top_1.items()}
- pipeline_overview = self._format_pipeline_overview(metrics_top_n, metrics_top_1)
+ for metric_mode in calculated_metrics:
+ calculated_metrics[metric_mode] = {node: metrics if node not in metrics_filter
+ else {metric: value for metric, value in metrics.items() if metric in metrics_filter[node]}
+ for node, metrics in calculated_metrics[metric_mode].items()}
+
+ pipeline_overview = self._format_pipeline_overview(calculated_metrics)
wrong_samples_report = self._format_wrong_samples_report(eval_result=eval_result, n_wrong_examples=n_wrong_examples)
print(
diff --git a/haystack/pipelines/standard_pipelines.py b/haystack/pipelines/standard_pipelines.py
index 179568cb2..bac88bdcb 100644
--- a/haystack/pipelines/standard_pipelines.py
+++ b/haystack/pipelines/standard_pipelines.py
@@ -151,10 +151,11 @@ class BaseStandardPipeline(ABC):
return self.pipeline.get_document_store()
def eval(self,
- labels: List[MultiLabel],
- params: Optional[dict],
- sas_model_name_or_path: str = None) -> EvaluationResult:
-
+ labels: List[MultiLabel],
+ params: Optional[dict] = None,
+ sas_model_name_or_path: Optional[str] = None,
+ add_isolated_node_eval: bool = False) -> EvaluationResult:
+
"""
Evaluates the pipeline by running the pipeline once per query in debug mode
and putting together all data that is needed for evaluation, e.g. calculating metrics.
@@ -164,9 +165,10 @@ class BaseStandardPipeline(ABC):
params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}}
:param sas_model_name_or_path: SentenceTransformers semantic textual similarity model to be used for sas value calculation,
should be path or string pointing to downloadable models.
+ :param add_isolated_node_eval: Whether to additionally evaluate the reader based on labels as input instead of output of previous node in pipeline
"""
- output = self.pipeline.eval(labels=labels, params=params,
- sas_model_name_or_path=sas_model_name_or_path)
+ output = self.pipeline.eval(labels=labels, params=params,
+ sas_model_name_or_path=sas_model_name_or_path, add_isolated_node_eval=add_isolated_node_eval)
return output
def print_eval_report(
diff --git a/haystack/schema.py b/haystack/schema.py
index 1eed1b82b..a1eb66a33 100644
--- a/haystack/schema.py
+++ b/haystack/schema.py
@@ -670,7 +670,7 @@ class EvaluationResult:
simulated_top_k_reader: int = -1,
simulated_top_k_retriever: int = -1,
doc_relevance_col: str = "gold_id_match",
- node_input: str = "prediction"
+ eval_mode: str = "integrated"
) -> Dict[str, Dict[str, float]]:
"""
Calculates proper metrics for each node.
@@ -698,19 +698,19 @@ class EvaluationResult:
remarks: there might be a discrepancy between simulated reader metrics and an actual pipeline run with retriever top_k
:param doc_relevance_col: column in the underlying eval table that contains the relevance criteria for documents.
values can be: 'gold_id_match', 'answer_match', 'gold_id_or_answer_match'
- :param node_input: the input on which the node was evaluated on.
- Usually nodes get evaluated on the prediction provided by its predecessor nodes in the pipeline (value='prediction').
+ :param eval_mode: the input on which the node was evaluated on.
+ Usually nodes get evaluated on the prediction provided by its predecessor nodes in the pipeline (value='integrated').
However, as the quality of the node itself can heavily depend on the node's input and thus the predecessor's quality,
you might want to simulate a perfect predecessor in order to get an independent upper bound of the quality of your node.
- For example when evaluating the reader use value='label' to simulate a perfect retriever in an ExtractiveQAPipeline.
- Values can be 'prediction', 'label'.
- Default value is 'prediction'.
+ For example when evaluating the reader use value='isolated' to simulate a perfect retriever in an ExtractiveQAPipeline.
+ Values can be 'integrated', 'isolated'.
+ Default value is 'integrated'.
"""
- return {node: self._calculate_node_metrics(df,
- simulated_top_k_reader=simulated_top_k_reader,
- simulated_top_k_retriever=simulated_top_k_retriever,
- doc_relevance_col=doc_relevance_col,
- node_input=node_input)
+ return {node: self._calculate_node_metrics(df,
+ simulated_top_k_reader=simulated_top_k_reader,
+ simulated_top_k_retriever=simulated_top_k_retriever,
+ doc_relevance_col=doc_relevance_col,
+ eval_mode=eval_mode)
for node, df in self.node_results.items()}
def wrong_examples(
@@ -722,7 +722,7 @@ class EvaluationResult:
doc_relevance_col: str = "gold_id_match",
document_metric: str = "recall_single_hit",
answer_metric: str = "f1",
- node_input: str = "prediction"
+ eval_mode: str = "integrated"
) -> List[Dict]:
"""
Returns the worst performing queries.
@@ -741,16 +741,16 @@ class EvaluationResult:
values can be: 'recall_single_hit', 'recall_multi_hit', 'mrr', 'map', 'precision'
:param document_metric: the answer metric worst queries are calculated with.
values can be: 'f1', 'exact_match' and 'sas' if the evaluation was made using a SAS model.
- :param node_input: the input on which the node was evaluated on.
- Usually nodes get evaluated on the prediction provided by its predecessor nodes in the pipeline (value='prediction').
+ :param eval_mode: the input on which the node was evaluated on.
+ Usually nodes get evaluated on the prediction provided by its predecessor nodes in the pipeline (value='integrated').
However, as the quality of the node itself can heavily depend on the node's input and thus the predecessor's quality,
you might want to simulate a perfect predecessor in order to get an independent upper bound of the quality of your node.
- For example when evaluating the reader use value='label' to simulate a perfect retriever in an ExtractiveQAPipeline.
- Values can be 'prediction', 'label'.
- Default value is 'prediction'.
+ For example when evaluating the reader use value='isolated' to simulate a perfect retriever in an ExtractiveQAPipeline.
+ Values can be 'integrated', 'isolated'.
+ Default value is 'integrated'.
"""
node_df = self.node_results[node]
- node_df = self._filter_node_input(node_df, node_input)
+ node_df = self._filter_eval_mode(node_df, eval_mode)
answers = node_df[node_df["type"] == "answer"]
if len(answers) > 0:
@@ -802,25 +802,25 @@ class EvaluationResult:
simulated_top_k_reader: int = -1,
simulated_top_k_retriever: int = -1,
doc_relevance_col: str = "gold_id_match",
- node_input: str = "prediction"
+ eval_mode: str = "integrated"
) -> Dict[str, float]:
- df = self._filter_node_input(df, node_input)
+ df = self._filter_eval_mode(df, eval_mode)
answer_metrics = self._calculate_answer_metrics(df,
simulated_top_k_reader=simulated_top_k_reader,
simulated_top_k_retriever=simulated_top_k_retriever)
-
+
document_metrics = self._calculate_document_metrics(df,
simulated_top_k_retriever=simulated_top_k_retriever,
doc_relevance_col=doc_relevance_col)
return {**answer_metrics, **document_metrics}
- def _filter_node_input(self, df: pd.DataFrame, node_input: str) -> pd.DataFrame:
- if "node_input" in df.columns:
- df = df[df["node_input"] == node_input]
+ def _filter_eval_mode(self, df: pd.DataFrame, eval_mode: str) -> pd.DataFrame:
+ if "eval_mode" in df.columns:
+ df = df[df["eval_mode"] == eval_mode]
else:
- logger.warning("eval dataframe has no node_input column. node_input param will be ignored.")
+ logger.warning("eval dataframe has no eval_mode column. eval_mode param will be ignored.")
return df
def _calculate_answer_metrics(
@@ -939,7 +939,7 @@ class EvaluationResult:
num_relevants = len(set(gold_ids + relevance_criteria_ids))
num_retrieved_relevants = query_df[doc_relevance_col].values.sum()
rank_retrieved_relevants = query_df[query_df[doc_relevance_col] == 1]["rank"].values
- avp_retrieved_relevants = [query_df[doc_relevance_col].values[:rank].sum() / rank
+ avp_retrieved_relevants = [query_df[doc_relevance_col].values[:int(rank)].sum() / rank
for rank in rank_retrieved_relevants]
avg_precision = np.sum(avp_retrieved_relevants) / num_relevants if num_relevants > 0 else 0.0
diff --git a/test/test_eval.py b/test/test_eval.py
index 364834a2d..3b36f980c 100644
--- a/test/test_eval.py
+++ b/test/test_eval.py
@@ -358,7 +358,7 @@ def test_extractive_qa_eval_sas(reader, retriever_with_docs):
eval_result: EvaluationResult = pipeline.eval(
labels=EVAL_LABELS,
params={"Retriever": {"top_k": 5}},
- sas_model_name_or_path="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
+ sas_model_name_or_path="sentence-transformers/paraphrase-MiniLM-L3-v2"
)
metrics = eval_result.calculate_metrics()
@@ -399,14 +399,14 @@ def test_extractive_qa_eval_simulated_top_k_reader(reader, retriever_with_docs):
eval_result: EvaluationResult = pipeline.eval(
labels=EVAL_LABELS,
params={"Retriever": {"top_k": 5}},
- sas_model_name_or_path="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
+ sas_model_name_or_path="sentence-transformers/paraphrase-MiniLM-L3-v2"
)
metrics_top_1 = eval_result.calculate_metrics(simulated_top_k_reader=1)
assert metrics_top_1["Reader"]["exact_match"] == 0.5
assert metrics_top_1["Reader"]["f1"] == 0.5
- assert metrics_top_1["Reader"]["sas"] == pytest.approx(0.6208, abs=1e-4)
+ assert metrics_top_1["Reader"]["sas"] == pytest.approx(0.5833, abs=1e-4)
assert metrics_top_1["Retriever"]["mrr"] == 0.5
assert metrics_top_1["Retriever"]["map"] == 0.5
assert metrics_top_1["Retriever"]["recall_multi_hit"] == 0.5
@@ -417,7 +417,7 @@ def test_extractive_qa_eval_simulated_top_k_reader(reader, retriever_with_docs):
assert metrics_top_2["Reader"]["exact_match"] == 0.5
assert metrics_top_2["Reader"]["f1"] == 0.5
- assert metrics_top_2["Reader"]["sas"] == pytest.approx(0.7192, abs=1e-4)
+ assert metrics_top_2["Reader"]["sas"] == pytest.approx(0.5833, abs=1e-4)
assert metrics_top_2["Retriever"]["mrr"] == 0.5
assert metrics_top_2["Retriever"]["map"] == 0.5
assert metrics_top_2["Retriever"]["recall_multi_hit"] == 0.5
@@ -534,7 +534,35 @@ def test_extractive_qa_eval_simulated_top_k_reader_and_retriever(reader, retriev
assert metrics_top_3["Retriever"]["recall_multi_hit"] == 0.5
assert metrics_top_3["Retriever"]["recall_single_hit"] == 0.5
assert metrics_top_3["Retriever"]["precision"] == 1.0/6
-
+
+
+@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
+@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
+def test_extractive_qa_eval_isolated(reader, retriever_with_docs):
+ pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
+ eval_result: EvaluationResult = pipeline.eval(
+ labels=EVAL_LABELS,
+ sas_model_name_or_path="sentence-transformers/paraphrase-MiniLM-L3-v2",
+ add_isolated_node_eval=True
+ )
+
+ metrics_top_1 = eval_result.calculate_metrics(simulated_top_k_reader=1)
+
+ assert metrics_top_1["Reader"]["exact_match"] == 0.5
+ assert metrics_top_1["Reader"]["f1"] == 0.5
+ assert metrics_top_1["Reader"]["sas"] == pytest.approx(0.5833, abs=1e-4)
+ assert metrics_top_1["Retriever"]["mrr"] == 0.5
+ assert metrics_top_1["Retriever"]["map"] == 0.5
+ assert metrics_top_1["Retriever"]["recall_multi_hit"] == 0.5
+ assert metrics_top_1["Retriever"]["recall_single_hit"] == 0.5
+ assert metrics_top_1["Retriever"]["precision"] == 1.0 / 6
+
+ metrics_top_1 = eval_result.calculate_metrics(simulated_top_k_reader=1, eval_mode="isolated")
+
+ assert metrics_top_1["Reader"]["exact_match"] == 1.0
+ assert metrics_top_1["Reader"]["f1"] == 1.0
+ assert metrics_top_1["Reader"]["sas"] == pytest.approx(1.0, abs=1e-4)
+
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
@@ -578,9 +606,16 @@ def test_extractive_qa_print_eval_report(reader, retriever_with_docs):
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
eval_result: EvaluationResult = pipeline.eval(
labels=labels,
- params={"Retriever": {"top_k": 5}},
+ params={"Retriever": {"top_k": 5}}
)
+ pipeline.print_eval_report(eval_result)
+ # in addition with labels as input to reader node rather than output of retriever node
+ eval_result: EvaluationResult = pipeline.eval(
+ labels=labels,
+ params={"Retriever": {"top_k": 5}},
+ add_isolated_node_eval=True
+ )
pipeline.print_eval_report(eval_result)