mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-22 15:38:01 +00:00
Add isolated node eval mode in pipeline eval (#1962)
* run predictions on ground-truth docs in reader * build dataframe for closed/open domain eval * fix looping through multilabel * fix looping through multilabel's list of labels * simplify collecting relevant docs * switch closed-domain eval off by default * Add latest docstring and tutorial changes * handle edge case params not given * renaming & generate pipeline eval report * add test case for closed-domain eval metrics * Add latest docstring and tutorial changes * test report of closed-domain eval * report closed-domain metrics only for answer metrics not doc metrics * refactoring * fix mypy & remove comment * add second for-loop & use answer as method input * renaming & add separate loop building docs eval df * Add latest docstring and tutorial changes * source /home/tstad/miniconda3/bin/activatechange column order for evaluatation dataframe (#1957) conda activate haystack-dev2 * change column order for evaluatation dataframe * added missing eval column node_input * generic order for both document and answer returning nodes; ensure no columns get lost Co-authored-by: tstadel <60758086+tstadel@users.noreply.github.com> * fix column reordering after renaming of node_input * simplify tests & add docu * Add latest docstring and tutorial changes Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: ju-gu <87523290+ju-gu@users.noreply.github.com> Co-authored-by: tstadel <60758086+tstadel@users.noreply.github.com> Co-authored-by: Thomas Stadelmann <thomas.stadelmann@deepset.ai>
This commit is contained in:
parent
e28bf618d7
commit
a3147cae47
@ -162,7 +162,7 @@ Runs the pipeline, one node at a time.
|
|||||||
#### eval
|
#### eval
|
||||||
|
|
||||||
```python
|
```python
|
||||||
| eval(labels: List[MultiLabel], params: Optional[dict] = None, sas_model_name_or_path: str = None) -> EvaluationResult
|
| eval(labels: List[MultiLabel], params: Optional[dict] = None, sas_model_name_or_path: str = None, add_isolated_node_eval: bool = False) -> EvaluationResult
|
||||||
```
|
```
|
||||||
|
|
||||||
Evaluates the pipeline by running the pipeline once per query in debug mode
|
Evaluates the pipeline by running the pipeline once per query in debug mode
|
||||||
@ -186,6 +186,14 @@ and putting together all data that is needed for evaluation, e.g. calculating me
|
|||||||
- Good default for multiple languages: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
- Good default for multiple languages: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||||
- Large, powerful, but slow model for English only: "cross-encoder/stsb-roberta-large"
|
- Large, powerful, but slow model for English only: "cross-encoder/stsb-roberta-large"
|
||||||
- Large model for German only: "deepset/gbert-large-sts"
|
- Large model for German only: "deepset/gbert-large-sts"
|
||||||
|
- `add_isolated_node_eval`: If set to True, in addition to the integrated evaluation of the pipeline, each node is evaluated in isolated evaluation mode.
|
||||||
|
This mode helps to understand the bottlenecks of a pipeline in terms of output quality of each individual node.
|
||||||
|
If a node performs much better in the isolated evaluation than in the integrated evaluation, the previous node needs to be optimized to improve the pipeline's performance.
|
||||||
|
If a node's performance is similar in both modes, this node itself needs to be optimized to improve the pipeline's performance.
|
||||||
|
The isolated evaluation calculates the upper bound of each node's evaluation metrics under the assumption that it received perfect inputs from the previous node.
|
||||||
|
To this end, labels are used as input to the node instead of the output of the previous node in the pipeline.
|
||||||
|
The generated dataframes in the EvaluationResult then contain additional rows, which can be distinguished from the integrated evaluation results based on the
|
||||||
|
values "integrated" or "isolated" in the column "eval_mode" and the evaluation report then additionally lists the upper bound of each node's evaluation metrics.
|
||||||
|
|
||||||
<a name="base.Pipeline.get_nodes_by_class"></a>
|
<a name="base.Pipeline.get_nodes_by_class"></a>
|
||||||
#### get\_nodes\_by\_class
|
#### get\_nodes\_by\_class
|
||||||
@ -627,7 +635,7 @@ Instance of DocumentStore or None
|
|||||||
#### eval
|
#### eval
|
||||||
|
|
||||||
```python
|
```python
|
||||||
| eval(labels: List[MultiLabel], params: Optional[dict], sas_model_name_or_path: str = None) -> EvaluationResult
|
| eval(labels: List[MultiLabel], params: Optional[dict] = None, sas_model_name_or_path: Optional[str] = None, add_isolated_node_eval: bool = False) -> EvaluationResult
|
||||||
```
|
```
|
||||||
|
|
||||||
Evaluates the pipeline by running the pipeline once per query in debug mode
|
Evaluates the pipeline by running the pipeline once per query in debug mode
|
||||||
@ -640,6 +648,7 @@ and putting together all data that is needed for evaluation, e.g. calculating me
|
|||||||
params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}}
|
params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}}
|
||||||
- `sas_model_name_or_path`: SentenceTransformers semantic textual similarity model to be used for sas value calculation,
|
- `sas_model_name_or_path`: SentenceTransformers semantic textual similarity model to be used for sas value calculation,
|
||||||
should be path or string pointing to downloadable models.
|
should be path or string pointing to downloadable models.
|
||||||
|
- `add_isolated_node_eval`: Whether to additionally evaluate the reader based on labels as input instead of output of previous node in pipeline
|
||||||
|
|
||||||
<a name="standard_pipelines.ExtractiveQAPipeline"></a>
|
<a name="standard_pipelines.ExtractiveQAPipeline"></a>
|
||||||
## ExtractiveQAPipeline
|
## ExtractiveQAPipeline
|
||||||
|
@ -294,7 +294,7 @@ The DataFrames have the following schema:
|
|||||||
#### calculate\_metrics
|
#### calculate\_metrics
|
||||||
|
|
||||||
```python
|
```python
|
||||||
| calculate_metrics(simulated_top_k_reader: int = -1, simulated_top_k_retriever: int = -1, doc_relevance_col: str = "gold_id_match", node_input: str = "prediction") -> Dict[str, Dict[str, float]]
|
| calculate_metrics(simulated_top_k_reader: int = -1, simulated_top_k_retriever: int = -1, doc_relevance_col: str = "gold_id_match", eval_mode: str = "integrated") -> Dict[str, Dict[str, float]]
|
||||||
```
|
```
|
||||||
|
|
||||||
Calculates proper metrics for each node.
|
Calculates proper metrics for each node.
|
||||||
@ -324,19 +324,19 @@ as there are situations the result can heavily differ from an actual eval run wi
|
|||||||
remarks: there might be a discrepancy between simulated reader metrics and an actual pipeline run with retriever top_k
|
remarks: there might be a discrepancy between simulated reader metrics and an actual pipeline run with retriever top_k
|
||||||
- `doc_relevance_col`: column in the underlying eval table that contains the relevance criteria for documents.
|
- `doc_relevance_col`: column in the underlying eval table that contains the relevance criteria for documents.
|
||||||
values can be: 'gold_id_match', 'answer_match', 'gold_id_or_answer_match'
|
values can be: 'gold_id_match', 'answer_match', 'gold_id_or_answer_match'
|
||||||
- `node_input`: the input on which the node was evaluated on.
|
- `eval_mode`: the input on which the node was evaluated on.
|
||||||
Usually nodes get evaluated on the prediction provided by its predecessor nodes in the pipeline (value='prediction').
|
Usually nodes get evaluated on the prediction provided by its predecessor nodes in the pipeline (value='integrated').
|
||||||
However, as the quality of the node itself can heavily depend on the node's input and thus the predecessor's quality,
|
However, as the quality of the node itself can heavily depend on the node's input and thus the predecessor's quality,
|
||||||
you might want to simulate a perfect predecessor in order to get an independent upper bound of the quality of your node.
|
you might want to simulate a perfect predecessor in order to get an independent upper bound of the quality of your node.
|
||||||
For example when evaluating the reader use value='label' to simulate a perfect retriever in an ExtractiveQAPipeline.
|
For example when evaluating the reader use value='isolated' to simulate a perfect retriever in an ExtractiveQAPipeline.
|
||||||
Values can be 'prediction', 'label'.
|
Values can be 'integrated', 'isolated'.
|
||||||
Default value is 'prediction'.
|
Default value is 'integrated'.
|
||||||
|
|
||||||
<a name="schema.EvaluationResult.wrong_examples"></a>
|
<a name="schema.EvaluationResult.wrong_examples"></a>
|
||||||
#### wrong\_examples
|
#### wrong\_examples
|
||||||
|
|
||||||
```python
|
```python
|
||||||
| wrong_examples(node: str, n: int = 3, simulated_top_k_reader: int = -1, simulated_top_k_retriever: int = -1, doc_relevance_col: str = "gold_id_match", document_metric: str = "recall_single_hit", answer_metric: str = "f1", node_input: str = "prediction") -> List[Dict]
|
| wrong_examples(node: str, n: int = 3, simulated_top_k_reader: int = -1, simulated_top_k_retriever: int = -1, doc_relevance_col: str = "gold_id_match", document_metric: str = "recall_single_hit", answer_metric: str = "f1", eval_mode: str = "integrated") -> List[Dict]
|
||||||
```
|
```
|
||||||
|
|
||||||
Returns the worst performing queries.
|
Returns the worst performing queries.
|
||||||
@ -357,13 +357,13 @@ See calculate_metrics() for more information.
|
|||||||
values can be: 'recall_single_hit', 'recall_multi_hit', 'mrr', 'map', 'precision'
|
values can be: 'recall_single_hit', 'recall_multi_hit', 'mrr', 'map', 'precision'
|
||||||
- `document_metric`: the answer metric worst queries are calculated with.
|
- `document_metric`: the answer metric worst queries are calculated with.
|
||||||
values can be: 'f1', 'exact_match' and 'sas' if the evaluation was made using a SAS model.
|
values can be: 'f1', 'exact_match' and 'sas' if the evaluation was made using a SAS model.
|
||||||
- `node_input`: the input on which the node was evaluated on.
|
- `eval_mode`: the input on which the node was evaluated on.
|
||||||
Usually nodes get evaluated on the prediction provided by its predecessor nodes in the pipeline (value='prediction').
|
Usually nodes get evaluated on the prediction provided by its predecessor nodes in the pipeline (value='integrated').
|
||||||
However, as the quality of the node itself can heavily depend on the node's input and thus the predecessor's quality,
|
However, as the quality of the node itself can heavily depend on the node's input and thus the predecessor's quality,
|
||||||
you might want to simulate a perfect predecessor in order to get an independent upper bound of the quality of your node.
|
you might want to simulate a perfect predecessor in order to get an independent upper bound of the quality of your node.
|
||||||
For example when evaluating the reader use value='label' to simulate a perfect retriever in an ExtractiveQAPipeline.
|
For example when evaluating the reader use value='isolated' to simulate a perfect retriever in an ExtractiveQAPipeline.
|
||||||
Values can be 'prediction', 'label'.
|
Values can be 'integrated', 'isolated'.
|
||||||
Default value is 'prediction'.
|
Default value is 'integrated'.
|
||||||
|
|
||||||
<a name="schema.EvaluationResult.save"></a>
|
<a name="schema.EvaluationResult.save"></a>
|
||||||
#### save
|
#### save
|
||||||
|
@ -7,7 +7,7 @@ from copy import deepcopy
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
from time import perf_counter
|
from time import perf_counter
|
||||||
|
|
||||||
from haystack.schema import Document, Answer, Span
|
from haystack.schema import Document, Answer, Span, MultiLabel
|
||||||
from haystack.nodes.base import BaseComponent
|
from haystack.nodes.base import BaseComponent
|
||||||
|
|
||||||
|
|
||||||
@ -55,7 +55,22 @@ class BaseReader(BaseComponent):
|
|||||||
|
|
||||||
return no_ans_prediction, max_no_ans_gap
|
return no_ans_prediction, max_no_ans_gap
|
||||||
|
|
||||||
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None): # type: ignore
|
@staticmethod
|
||||||
|
def add_doc_meta_data_to_answer(documents: List[Document], answer):
|
||||||
|
# Add corresponding document_name and more meta data, if the answer contains the document_id
|
||||||
|
if answer.meta is None:
|
||||||
|
answer.meta = {}
|
||||||
|
# get meta from doc
|
||||||
|
meta_from_doc = {}
|
||||||
|
for doc in documents:
|
||||||
|
if doc.id == answer.document_id:
|
||||||
|
meta_from_doc = deepcopy(doc.meta)
|
||||||
|
break
|
||||||
|
# append to "own" meta
|
||||||
|
answer.meta.update(meta_from_doc)
|
||||||
|
return answer
|
||||||
|
|
||||||
|
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None, labels: Optional[MultiLabel] = None, add_isolated_node_eval: bool = False): # type: ignore
|
||||||
self.query_count += 1
|
self.query_count += 1
|
||||||
if documents:
|
if documents:
|
||||||
predict = self.timing(self.predict, "query_time")
|
predict = self.timing(self.predict, "query_time")
|
||||||
@ -64,17 +79,15 @@ class BaseReader(BaseComponent):
|
|||||||
results = {"answers": []}
|
results = {"answers": []}
|
||||||
|
|
||||||
# Add corresponding document_name and more meta data, if an answer contains the document_id
|
# Add corresponding document_name and more meta data, if an answer contains the document_id
|
||||||
for ans in results["answers"]:
|
results["answers"] = [BaseReader.add_doc_meta_data_to_answer(documents=documents, answer=answer) for answer in results["answers"]]
|
||||||
if ans.meta is None:
|
|
||||||
ans.meta = {}
|
# run evaluation with labels as node inputs
|
||||||
# get meta from doc
|
if add_isolated_node_eval and labels is not None:
|
||||||
meta_from_doc = {}
|
relevant_documents = [label.document for label in labels.labels]
|
||||||
for doc in documents:
|
results_label_input = predict(query=query, documents=relevant_documents, top_k=top_k)
|
||||||
if doc.id == ans.document_id:
|
|
||||||
meta_from_doc = deepcopy(doc.meta)
|
# Add corresponding document_name and more meta data, if an answer contains the document_id
|
||||||
break
|
results["answers_isolated"] = [BaseReader.add_doc_meta_data_to_answer(documents=documents, answer=answer) for answer in results_label_input["answers"]]
|
||||||
# append to "own" meta
|
|
||||||
ans.meta.update(meta_from_doc)
|
|
||||||
|
|
||||||
return results, "output_1"
|
return results, "output_1"
|
||||||
|
|
||||||
|
@ -368,7 +368,8 @@ class Pipeline(BasePipeline):
|
|||||||
self,
|
self,
|
||||||
labels: List[MultiLabel],
|
labels: List[MultiLabel],
|
||||||
params: Optional[dict] = None,
|
params: Optional[dict] = None,
|
||||||
sas_model_name_or_path: str = None
|
sas_model_name_or_path: str = None,
|
||||||
|
add_isolated_node_eval: bool = False
|
||||||
) -> EvaluationResult:
|
) -> EvaluationResult:
|
||||||
"""
|
"""
|
||||||
Evaluates the pipeline by running the pipeline once per query in debug mode
|
Evaluates the pipeline by running the pipeline once per query in debug mode
|
||||||
@ -390,8 +391,20 @@ class Pipeline(BasePipeline):
|
|||||||
- Good default for multiple languages: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
- Good default for multiple languages: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||||
- Large, powerful, but slow model for English only: "cross-encoder/stsb-roberta-large"
|
- Large, powerful, but slow model for English only: "cross-encoder/stsb-roberta-large"
|
||||||
- Large model for German only: "deepset/gbert-large-sts"
|
- Large model for German only: "deepset/gbert-large-sts"
|
||||||
|
:param add_isolated_node_eval: If set to True, in addition to the integrated evaluation of the pipeline, each node is evaluated in isolated evaluation mode.
|
||||||
|
This mode helps to understand the bottlenecks of a pipeline in terms of output quality of each individual node.
|
||||||
|
If a node performs much better in the isolated evaluation than in the integrated evaluation, the previous node needs to be optimized to improve the pipeline's performance.
|
||||||
|
If a node's performance is similar in both modes, this node itself needs to be optimized to improve the pipeline's performance.
|
||||||
|
The isolated evaluation calculates the upper bound of each node's evaluation metrics under the assumption that it received perfect inputs from the previous node.
|
||||||
|
To this end, labels are used as input to the node instead of the output of the previous node in the pipeline.
|
||||||
|
The generated dataframes in the EvaluationResult then contain additional rows, which can be distinguished from the integrated evaluation results based on the
|
||||||
|
values "integrated" or "isolated" in the column "eval_mode" and the evaluation report then additionally lists the upper bound of each node's evaluation metrics.
|
||||||
"""
|
"""
|
||||||
eval_result = EvaluationResult()
|
eval_result = EvaluationResult()
|
||||||
|
if add_isolated_node_eval:
|
||||||
|
if params is None:
|
||||||
|
params = {}
|
||||||
|
params["add_isolated_node_eval"] = True
|
||||||
queries = [label.query for label in labels]
|
queries = [label.query for label in labels]
|
||||||
for query, label in zip(queries, labels):
|
for query, label in zip(queries, labels):
|
||||||
predictions = self.run(query=query, labels=label, params=params, debug=True)
|
predictions = self.run(query=query, labels=label, params=params, debug=True)
|
||||||
@ -420,7 +433,7 @@ class Pipeline(BasePipeline):
|
|||||||
"gold_document_contents", "content", "gold_id_match", "answer_match", "gold_id_or_answer_match", # doc-specific
|
"gold_document_contents", "content", "gold_id_match", "answer_match", "gold_id_or_answer_match", # doc-specific
|
||||||
"rank", "document_id", "gold_document_ids", # generic
|
"rank", "document_id", "gold_document_ids", # generic
|
||||||
"offsets_in_document", "gold_offsets_in_documents", # answer-specific
|
"offsets_in_document", "gold_offsets_in_documents", # answer-specific
|
||||||
"type", "node", "node_input"] # generic
|
"type", "node", "eval_mode"] # generic
|
||||||
eval_result.node_results[key] = self._reorder_columns(df, desired_col_order)
|
eval_result.node_results[key] = self._reorder_columns(df, desired_col_order)
|
||||||
|
|
||||||
return eval_result
|
return eval_result
|
||||||
@ -447,11 +460,10 @@ class Pipeline(BasePipeline):
|
|||||||
Additional answer or document specific evaluation infos like gold labels
|
Additional answer or document specific evaluation infos like gold labels
|
||||||
and metrics depicting whether the row matches the gold labels are included, too.
|
and metrics depicting whether the row matches the gold labels are included, too.
|
||||||
"""
|
"""
|
||||||
df: DataFrame = pd.DataFrame()
|
|
||||||
|
|
||||||
if query_labels is None or query_labels.labels is None:
|
if query_labels is None or query_labels.labels is None:
|
||||||
logger.warning(f"There is no label for query '{query}'. Query will be omitted.")
|
logger.warning(f"There is no label for query '{query}'. Query will be omitted.")
|
||||||
return df
|
return pd.DataFrame()
|
||||||
|
|
||||||
# remarks for no_answers:
|
# remarks for no_answers:
|
||||||
# Single 'no_answer'-labels are not contained in MultiLabel aggregates.
|
# Single 'no_answer'-labels are not contained in MultiLabel aggregates.
|
||||||
@ -467,27 +479,37 @@ class Pipeline(BasePipeline):
|
|||||||
# - the position or offsets within the document the answer was found
|
# - the position or offsets within the document the answer was found
|
||||||
# - the surrounding context of the answer within the document
|
# - the surrounding context of the answer within the document
|
||||||
# - the gold answers
|
# - the gold answers
|
||||||
# - the positon or offsets of the gold answer within the document
|
# - the position or offsets of the gold answer within the document
|
||||||
# - the gold document ids containing the answer
|
# - the gold document ids containing the answer
|
||||||
# - the exact_match metric depicting if the answer exactly matches the gold label
|
# - the exact_match metric depicting if the answer exactly matches the gold label
|
||||||
# - the f1 metric depicting how well the answer overlaps with the gold label on token basis
|
# - the f1 metric depicting how well the answer overlaps with the gold label on token basis
|
||||||
# - the sas metric depciting how well the answer matches the gold label on a semantic basis.
|
# - the sas metric depicting how well the answer matches the gold label on a semantic basis.
|
||||||
# this will be calculated on all queries in eval() for performance reasons if a sas model has been provided
|
# this will be calculated on all queries in eval() for performance reasons if a sas model has been provided
|
||||||
answers = node_output.get("answers", None)
|
|
||||||
if answers is not None:
|
partial_dfs = []
|
||||||
answer_cols_to_keep = ["answer", "document_id", "offsets_in_document", "context"]
|
for field_name in ["answers", "answers_isolated"]:
|
||||||
df_answers = pd.DataFrame(answers, columns=answer_cols_to_keep)
|
df = pd.DataFrame()
|
||||||
if len(df_answers) > 0:
|
answers = node_output.get(field_name, None)
|
||||||
df_answers["type"] = "answer"
|
if answers is not None:
|
||||||
df_answers["gold_answers"] = [gold_answers] * len(df_answers)
|
answer_cols_to_keep = ["answer", "document_id", "offsets_in_document", "context"]
|
||||||
df_answers["gold_offsets_in_documents"] = [gold_offsets_in_documents] * len(df_answers)
|
df_answers = pd.DataFrame(answers, columns=answer_cols_to_keep)
|
||||||
df_answers["gold_document_ids"] = [gold_document_ids] * len(df_answers)
|
if len(df_answers) > 0:
|
||||||
df_answers["exact_match"] = df_answers.apply(
|
df_answers["type"] = "answer"
|
||||||
lambda row: calculate_em_str_multi(gold_answers, row["answer"]), axis=1)
|
df_answers["gold_answers"] = [gold_answers] * len(df_answers)
|
||||||
df_answers["f1"] = df_answers.apply(
|
df_answers["gold_offsets_in_documents"] = [gold_offsets_in_documents] * len(df_answers)
|
||||||
lambda row: calculate_f1_str_multi(gold_answers, row["answer"]), axis=1)
|
df_answers["gold_document_ids"] = [gold_document_ids] * len(df_answers)
|
||||||
df_answers["rank"] = np.arange(1, len(df_answers)+1)
|
df_answers["exact_match"] = df_answers.apply(
|
||||||
df = pd.concat([df, df_answers])
|
lambda row: calculate_em_str_multi(gold_answers, row["answer"]), axis=1)
|
||||||
|
df_answers["f1"] = df_answers.apply(
|
||||||
|
lambda row: calculate_f1_str_multi(gold_answers, row["answer"]), axis=1)
|
||||||
|
df_answers["rank"] = np.arange(1, len(df_answers)+1)
|
||||||
|
df = pd.concat([df, df_answers])
|
||||||
|
|
||||||
|
# add general info
|
||||||
|
df["node"] = node_name
|
||||||
|
df["query"] = query
|
||||||
|
df["eval_mode"] = "isolated" if "isolated" in field_name else "integrated"
|
||||||
|
partial_dfs.append(df)
|
||||||
|
|
||||||
# if node returned documents, include document specific info:
|
# if node returned documents, include document specific info:
|
||||||
# - the document_id
|
# - the document_id
|
||||||
@ -497,34 +519,37 @@ class Pipeline(BasePipeline):
|
|||||||
# - the gold_id_match metric depicting whether one of the gold document ids matches the document
|
# - the gold_id_match metric depicting whether one of the gold document ids matches the document
|
||||||
# - the answer_match metric depicting whether the document contains the answer
|
# - the answer_match metric depicting whether the document contains the answer
|
||||||
# - the gold_id_or_answer_match metric depicting whether one of the former two conditions are met
|
# - the gold_id_or_answer_match metric depicting whether one of the former two conditions are met
|
||||||
documents = node_output.get("documents", None)
|
for field_name in ["documents", "documents_isolated"]:
|
||||||
if documents is not None:
|
df = pd.DataFrame()
|
||||||
document_cols_to_keep = ["content", "id"]
|
documents = node_output.get(field_name, None)
|
||||||
df_docs = pd.DataFrame(documents, columns=document_cols_to_keep)
|
if documents is not None:
|
||||||
if len(df_docs) > 0:
|
document_cols_to_keep = ["content", "id"]
|
||||||
df_docs = df_docs.rename(columns={"id": "document_id"})
|
df_docs = pd.DataFrame(documents, columns=document_cols_to_keep)
|
||||||
df_docs["type"] = "document"
|
if len(df_docs) > 0:
|
||||||
df_docs["gold_document_ids"] = [gold_document_ids] * len(df_docs)
|
df_docs = df_docs.rename(columns={"id": "document_id"})
|
||||||
df_docs["gold_document_contents"] = [gold_document_contents] * len(df_docs)
|
df_docs["type"] = "document"
|
||||||
df_docs["gold_id_match"] = df_docs.apply(
|
df_docs["gold_document_ids"] = [gold_document_ids] * len(df_docs)
|
||||||
lambda row: 1.0 if row["document_id"] in gold_document_ids else 0.0, axis=1)
|
df_docs["gold_document_contents"] = [gold_document_contents] * len(df_docs)
|
||||||
df_docs["answer_match"] = df_docs.apply(
|
df_docs["gold_id_match"] = df_docs.apply(
|
||||||
lambda row:
|
lambda row: 1.0 if row["document_id"] in gold_document_ids else 0.0, axis=1)
|
||||||
1.0 if not query_labels.no_answer
|
df_docs["answer_match"] = df_docs.apply(
|
||||||
and any(gold_answer in row["content"] for gold_answer in gold_answers)
|
lambda row:
|
||||||
else 0.0,
|
1.0 if not query_labels.no_answer
|
||||||
axis=1)
|
and any(gold_answer in row["content"] for gold_answer in gold_answers)
|
||||||
df_docs["gold_id_or_answer_match"] = df_docs.apply(
|
else 0.0,
|
||||||
lambda row: max(row["gold_id_match"], row["answer_match"]), axis=1)
|
axis=1)
|
||||||
df_docs["rank"] = np.arange(1, len(df_docs)+1)
|
df_docs["gold_id_or_answer_match"] = df_docs.apply(
|
||||||
df = pd.concat([df, df_docs])
|
lambda row: max(row["gold_id_match"], row["answer_match"]), axis=1)
|
||||||
|
df_docs["rank"] = np.arange(1, len(df_docs)+1)
|
||||||
|
df = pd.concat([df, df_docs])
|
||||||
|
|
||||||
# add general info
|
# add general info
|
||||||
df["node"] = node_name
|
df["node"] = node_name
|
||||||
df["query"] = query
|
df["query"] = query
|
||||||
df["node_input"] = "prediction"
|
df["eval_mode"] = "isolated" if "isolated" in field_name else "integrated"
|
||||||
|
partial_dfs.append(df)
|
||||||
|
|
||||||
return df
|
return pd.concat(partial_dfs, ignore_index=True)
|
||||||
|
|
||||||
def get_next_nodes(self, node_id: str, stream_id: str):
|
def get_next_nodes(self, node_id: str, stream_id: str):
|
||||||
current_node_edges = self.graph.edges(node_id, data=True)
|
current_node_edges = self.graph.edges(node_id, data=True)
|
||||||
@ -755,21 +780,23 @@ class Pipeline(BasePipeline):
|
|||||||
|
|
||||||
def _format_wrong_samples_report(self, eval_result: EvaluationResult, n_wrong_examples: int = 3):
|
def _format_wrong_samples_report(self, eval_result: EvaluationResult, n_wrong_examples: int = 3):
|
||||||
examples = {
|
examples = {
|
||||||
node: eval_result.wrong_examples(node, doc_relevance_col="gold_id_or_answer_match", n=n_wrong_examples)
|
node: eval_result.wrong_examples(node, doc_relevance_col="gold_id_or_answer_match", n=n_wrong_examples)
|
||||||
for node in eval_result.node_results.keys()
|
for node in eval_result.node_results.keys()
|
||||||
}
|
}
|
||||||
examples_formatted = {
|
examples_formatted = {
|
||||||
node: "\n".join([self._format_wrong_sample(example) for example in examples])
|
node: "\n".join([self._format_wrong_sample(example) for example in examples])
|
||||||
for node, examples in examples.items()
|
for node, examples in examples.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
return "\n".join([self._format_wrong_samples_node(node, examples) for node, examples in examples_formatted.items()])
|
return "\n".join([self._format_wrong_samples_node(node, examples) for node, examples in examples_formatted.items()])
|
||||||
|
|
||||||
def _format_pipeline_node(self, node: str, metrics: dict, metrics_top_1):
|
def _format_pipeline_node(self, node: str, calculated_metrics: dict):
|
||||||
metrics = metrics.get(node, {})
|
node_metrics: dict = {}
|
||||||
metrics_top_1 = {f"{metric}_top_1": value for metric, value in metrics_top_1.get(node, {}).items()}
|
for metric_mode in calculated_metrics:
|
||||||
node_metrics = {**metrics, **metrics_top_1}
|
for metric, value in calculated_metrics[metric_mode].get(node, {}).items():
|
||||||
node_metrics_formatted = "\n".join(sorted([f" | {metric}: {value:5.3}" for metric, value in node_metrics.items()]))
|
node_metrics[f"{metric}{metric_mode}"] = value
|
||||||
|
|
||||||
|
node_metrics_formatted = "\n".join(sorted([f" | {metric}: {value:5.3}" for metric, value in node_metrics.items()]))
|
||||||
node_metrics_formatted = f"{node_metrics_formatted}\n" if len(node_metrics_formatted) > 0 else ""
|
node_metrics_formatted = f"{node_metrics_formatted}\n" if len(node_metrics_formatted) > 0 else ""
|
||||||
s = (
|
s = (
|
||||||
f" {node}\n"
|
f" {node}\n"
|
||||||
@ -779,8 +806,8 @@ class Pipeline(BasePipeline):
|
|||||||
)
|
)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def _format_pipeline_overview(self, metrics: dict, metrics_top_1: dict):
|
def _format_pipeline_overview(self, calculated_metrics: dict):
|
||||||
pipeline_overview = "\n".join([self._format_pipeline_node(node, metrics, metrics_top_1) for node in self.graph.nodes])
|
pipeline_overview = "\n".join([self._format_pipeline_node(node, calculated_metrics) for node in self.graph.nodes])
|
||||||
s = (
|
s = (
|
||||||
f"================== Evaluation Report ==================\n"
|
f"================== Evaluation Report ==================\n"
|
||||||
f"=======================================================\n"
|
f"=======================================================\n"
|
||||||
@ -807,17 +834,18 @@ class Pipeline(BasePipeline):
|
|||||||
if any(degree > 1 for node, degree in self.graph.out_degree):
|
if any(degree > 1 for node, degree in self.graph.out_degree):
|
||||||
logger.warning("Pipelines with junctions are currently not supported.")
|
logger.warning("Pipelines with junctions are currently not supported.")
|
||||||
return
|
return
|
||||||
|
|
||||||
metrics_top_n = eval_result.calculate_metrics(doc_relevance_col="gold_id_or_answer_match")
|
calculated_metrics = {"": eval_result.calculate_metrics(doc_relevance_col="gold_id_or_answer_match"),
|
||||||
metrics_top_1 = eval_result.calculate_metrics(doc_relevance_col="gold_id_or_answer_match", simulated_top_k_reader=1)
|
"_top_1": eval_result.calculate_metrics(doc_relevance_col="gold_id_or_answer_match", simulated_top_k_reader=1),
|
||||||
|
" upper bound": eval_result.calculate_metrics(doc_relevance_col="gold_id_or_answer_match", eval_mode="isolated")}
|
||||||
|
|
||||||
if metrics_filter is not None:
|
if metrics_filter is not None:
|
||||||
metrics_top_n = {node: metrics if node not in metrics_filter
|
for metric_mode in calculated_metrics:
|
||||||
else {metric: value for metric, value in metrics.items() if metric in metrics_filter[node]}
|
calculated_metrics[metric_mode] = {node: metrics if node not in metrics_filter
|
||||||
for node, metrics in metrics_top_n.items()}
|
else {metric: value for metric, value in metrics.items() if metric in metrics_filter[node]}
|
||||||
metrics_top_1 = {node: metrics if node not in metrics_filter
|
for node, metrics in calculated_metrics[metric_mode].items()}
|
||||||
else {metric: value for metric, value in metrics.items() if metric in metrics_filter[node]}
|
|
||||||
for node, metrics in metrics_top_1.items()}
|
pipeline_overview = self._format_pipeline_overview(calculated_metrics)
|
||||||
pipeline_overview = self._format_pipeline_overview(metrics_top_n, metrics_top_1)
|
|
||||||
wrong_samples_report = self._format_wrong_samples_report(eval_result=eval_result, n_wrong_examples=n_wrong_examples)
|
wrong_samples_report = self._format_wrong_samples_report(eval_result=eval_result, n_wrong_examples=n_wrong_examples)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
|
@ -151,10 +151,11 @@ class BaseStandardPipeline(ABC):
|
|||||||
return self.pipeline.get_document_store()
|
return self.pipeline.get_document_store()
|
||||||
|
|
||||||
def eval(self,
|
def eval(self,
|
||||||
labels: List[MultiLabel],
|
labels: List[MultiLabel],
|
||||||
params: Optional[dict],
|
params: Optional[dict] = None,
|
||||||
sas_model_name_or_path: str = None) -> EvaluationResult:
|
sas_model_name_or_path: Optional[str] = None,
|
||||||
|
add_isolated_node_eval: bool = False) -> EvaluationResult:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Evaluates the pipeline by running the pipeline once per query in debug mode
|
Evaluates the pipeline by running the pipeline once per query in debug mode
|
||||||
and putting together all data that is needed for evaluation, e.g. calculating metrics.
|
and putting together all data that is needed for evaluation, e.g. calculating metrics.
|
||||||
@ -164,9 +165,10 @@ class BaseStandardPipeline(ABC):
|
|||||||
params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}}
|
params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}}
|
||||||
:param sas_model_name_or_path: SentenceTransformers semantic textual similarity model to be used for sas value calculation,
|
:param sas_model_name_or_path: SentenceTransformers semantic textual similarity model to be used for sas value calculation,
|
||||||
should be path or string pointing to downloadable models.
|
should be path or string pointing to downloadable models.
|
||||||
|
:param add_isolated_node_eval: Whether to additionally evaluate the reader based on labels as input instead of output of previous node in pipeline
|
||||||
"""
|
"""
|
||||||
output = self.pipeline.eval(labels=labels, params=params,
|
output = self.pipeline.eval(labels=labels, params=params,
|
||||||
sas_model_name_or_path=sas_model_name_or_path)
|
sas_model_name_or_path=sas_model_name_or_path, add_isolated_node_eval=add_isolated_node_eval)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def print_eval_report(
|
def print_eval_report(
|
||||||
|
@ -670,7 +670,7 @@ class EvaluationResult:
|
|||||||
simulated_top_k_reader: int = -1,
|
simulated_top_k_reader: int = -1,
|
||||||
simulated_top_k_retriever: int = -1,
|
simulated_top_k_retriever: int = -1,
|
||||||
doc_relevance_col: str = "gold_id_match",
|
doc_relevance_col: str = "gold_id_match",
|
||||||
node_input: str = "prediction"
|
eval_mode: str = "integrated"
|
||||||
) -> Dict[str, Dict[str, float]]:
|
) -> Dict[str, Dict[str, float]]:
|
||||||
"""
|
"""
|
||||||
Calculates proper metrics for each node.
|
Calculates proper metrics for each node.
|
||||||
@ -698,19 +698,19 @@ class EvaluationResult:
|
|||||||
remarks: there might be a discrepancy between simulated reader metrics and an actual pipeline run with retriever top_k
|
remarks: there might be a discrepancy between simulated reader metrics and an actual pipeline run with retriever top_k
|
||||||
:param doc_relevance_col: column in the underlying eval table that contains the relevance criteria for documents.
|
:param doc_relevance_col: column in the underlying eval table that contains the relevance criteria for documents.
|
||||||
values can be: 'gold_id_match', 'answer_match', 'gold_id_or_answer_match'
|
values can be: 'gold_id_match', 'answer_match', 'gold_id_or_answer_match'
|
||||||
:param node_input: the input on which the node was evaluated on.
|
:param eval_mode: the input on which the node was evaluated on.
|
||||||
Usually nodes get evaluated on the prediction provided by its predecessor nodes in the pipeline (value='prediction').
|
Usually nodes get evaluated on the prediction provided by its predecessor nodes in the pipeline (value='integrated').
|
||||||
However, as the quality of the node itself can heavily depend on the node's input and thus the predecessor's quality,
|
However, as the quality of the node itself can heavily depend on the node's input and thus the predecessor's quality,
|
||||||
you might want to simulate a perfect predecessor in order to get an independent upper bound of the quality of your node.
|
you might want to simulate a perfect predecessor in order to get an independent upper bound of the quality of your node.
|
||||||
For example when evaluating the reader use value='label' to simulate a perfect retriever in an ExtractiveQAPipeline.
|
For example when evaluating the reader use value='isolated' to simulate a perfect retriever in an ExtractiveQAPipeline.
|
||||||
Values can be 'prediction', 'label'.
|
Values can be 'integrated', 'isolated'.
|
||||||
Default value is 'prediction'.
|
Default value is 'integrated'.
|
||||||
"""
|
"""
|
||||||
return {node: self._calculate_node_metrics(df,
|
return {node: self._calculate_node_metrics(df,
|
||||||
simulated_top_k_reader=simulated_top_k_reader,
|
simulated_top_k_reader=simulated_top_k_reader,
|
||||||
simulated_top_k_retriever=simulated_top_k_retriever,
|
simulated_top_k_retriever=simulated_top_k_retriever,
|
||||||
doc_relevance_col=doc_relevance_col,
|
doc_relevance_col=doc_relevance_col,
|
||||||
node_input=node_input)
|
eval_mode=eval_mode)
|
||||||
for node, df in self.node_results.items()}
|
for node, df in self.node_results.items()}
|
||||||
|
|
||||||
def wrong_examples(
|
def wrong_examples(
|
||||||
@ -722,7 +722,7 @@ class EvaluationResult:
|
|||||||
doc_relevance_col: str = "gold_id_match",
|
doc_relevance_col: str = "gold_id_match",
|
||||||
document_metric: str = "recall_single_hit",
|
document_metric: str = "recall_single_hit",
|
||||||
answer_metric: str = "f1",
|
answer_metric: str = "f1",
|
||||||
node_input: str = "prediction"
|
eval_mode: str = "integrated"
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
Returns the worst performing queries.
|
Returns the worst performing queries.
|
||||||
@ -741,16 +741,16 @@ class EvaluationResult:
|
|||||||
values can be: 'recall_single_hit', 'recall_multi_hit', 'mrr', 'map', 'precision'
|
values can be: 'recall_single_hit', 'recall_multi_hit', 'mrr', 'map', 'precision'
|
||||||
:param document_metric: the answer metric worst queries are calculated with.
|
:param document_metric: the answer metric worst queries are calculated with.
|
||||||
values can be: 'f1', 'exact_match' and 'sas' if the evaluation was made using a SAS model.
|
values can be: 'f1', 'exact_match' and 'sas' if the evaluation was made using a SAS model.
|
||||||
:param node_input: the input on which the node was evaluated on.
|
:param eval_mode: the input on which the node was evaluated on.
|
||||||
Usually nodes get evaluated on the prediction provided by its predecessor nodes in the pipeline (value='prediction').
|
Usually nodes get evaluated on the prediction provided by its predecessor nodes in the pipeline (value='integrated').
|
||||||
However, as the quality of the node itself can heavily depend on the node's input and thus the predecessor's quality,
|
However, as the quality of the node itself can heavily depend on the node's input and thus the predecessor's quality,
|
||||||
you might want to simulate a perfect predecessor in order to get an independent upper bound of the quality of your node.
|
you might want to simulate a perfect predecessor in order to get an independent upper bound of the quality of your node.
|
||||||
For example when evaluating the reader use value='label' to simulate a perfect retriever in an ExtractiveQAPipeline.
|
For example when evaluating the reader use value='isolated' to simulate a perfect retriever in an ExtractiveQAPipeline.
|
||||||
Values can be 'prediction', 'label'.
|
Values can be 'integrated', 'isolated'.
|
||||||
Default value is 'prediction'.
|
Default value is 'integrated'.
|
||||||
"""
|
"""
|
||||||
node_df = self.node_results[node]
|
node_df = self.node_results[node]
|
||||||
node_df = self._filter_node_input(node_df, node_input)
|
node_df = self._filter_eval_mode(node_df, eval_mode)
|
||||||
|
|
||||||
answers = node_df[node_df["type"] == "answer"]
|
answers = node_df[node_df["type"] == "answer"]
|
||||||
if len(answers) > 0:
|
if len(answers) > 0:
|
||||||
@ -802,25 +802,25 @@ class EvaluationResult:
|
|||||||
simulated_top_k_reader: int = -1,
|
simulated_top_k_reader: int = -1,
|
||||||
simulated_top_k_retriever: int = -1,
|
simulated_top_k_retriever: int = -1,
|
||||||
doc_relevance_col: str = "gold_id_match",
|
doc_relevance_col: str = "gold_id_match",
|
||||||
node_input: str = "prediction"
|
eval_mode: str = "integrated"
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, float]:
|
||||||
df = self._filter_node_input(df, node_input)
|
df = self._filter_eval_mode(df, eval_mode)
|
||||||
|
|
||||||
answer_metrics = self._calculate_answer_metrics(df,
|
answer_metrics = self._calculate_answer_metrics(df,
|
||||||
simulated_top_k_reader=simulated_top_k_reader,
|
simulated_top_k_reader=simulated_top_k_reader,
|
||||||
simulated_top_k_retriever=simulated_top_k_retriever)
|
simulated_top_k_retriever=simulated_top_k_retriever)
|
||||||
|
|
||||||
document_metrics = self._calculate_document_metrics(df,
|
document_metrics = self._calculate_document_metrics(df,
|
||||||
simulated_top_k_retriever=simulated_top_k_retriever,
|
simulated_top_k_retriever=simulated_top_k_retriever,
|
||||||
doc_relevance_col=doc_relevance_col)
|
doc_relevance_col=doc_relevance_col)
|
||||||
|
|
||||||
return {**answer_metrics, **document_metrics}
|
return {**answer_metrics, **document_metrics}
|
||||||
|
|
||||||
def _filter_node_input(self, df: pd.DataFrame, node_input: str) -> pd.DataFrame:
|
def _filter_eval_mode(self, df: pd.DataFrame, eval_mode: str) -> pd.DataFrame:
|
||||||
if "node_input" in df.columns:
|
if "eval_mode" in df.columns:
|
||||||
df = df[df["node_input"] == node_input]
|
df = df[df["eval_mode"] == eval_mode]
|
||||||
else:
|
else:
|
||||||
logger.warning("eval dataframe has no node_input column. node_input param will be ignored.")
|
logger.warning("eval dataframe has no eval_mode column. eval_mode param will be ignored.")
|
||||||
return df
|
return df
|
||||||
|
|
||||||
def _calculate_answer_metrics(
|
def _calculate_answer_metrics(
|
||||||
@ -939,7 +939,7 @@ class EvaluationResult:
|
|||||||
num_relevants = len(set(gold_ids + relevance_criteria_ids))
|
num_relevants = len(set(gold_ids + relevance_criteria_ids))
|
||||||
num_retrieved_relevants = query_df[doc_relevance_col].values.sum()
|
num_retrieved_relevants = query_df[doc_relevance_col].values.sum()
|
||||||
rank_retrieved_relevants = query_df[query_df[doc_relevance_col] == 1]["rank"].values
|
rank_retrieved_relevants = query_df[query_df[doc_relevance_col] == 1]["rank"].values
|
||||||
avp_retrieved_relevants = [query_df[doc_relevance_col].values[:rank].sum() / rank
|
avp_retrieved_relevants = [query_df[doc_relevance_col].values[:int(rank)].sum() / rank
|
||||||
for rank in rank_retrieved_relevants]
|
for rank in rank_retrieved_relevants]
|
||||||
|
|
||||||
avg_precision = np.sum(avp_retrieved_relevants) / num_relevants if num_relevants > 0 else 0.0
|
avg_precision = np.sum(avp_retrieved_relevants) / num_relevants if num_relevants > 0 else 0.0
|
||||||
|
@ -358,7 +358,7 @@ def test_extractive_qa_eval_sas(reader, retriever_with_docs):
|
|||||||
eval_result: EvaluationResult = pipeline.eval(
|
eval_result: EvaluationResult = pipeline.eval(
|
||||||
labels=EVAL_LABELS,
|
labels=EVAL_LABELS,
|
||||||
params={"Retriever": {"top_k": 5}},
|
params={"Retriever": {"top_k": 5}},
|
||||||
sas_model_name_or_path="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
sas_model_name_or_path="sentence-transformers/paraphrase-MiniLM-L3-v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
metrics = eval_result.calculate_metrics()
|
metrics = eval_result.calculate_metrics()
|
||||||
@ -399,14 +399,14 @@ def test_extractive_qa_eval_simulated_top_k_reader(reader, retriever_with_docs):
|
|||||||
eval_result: EvaluationResult = pipeline.eval(
|
eval_result: EvaluationResult = pipeline.eval(
|
||||||
labels=EVAL_LABELS,
|
labels=EVAL_LABELS,
|
||||||
params={"Retriever": {"top_k": 5}},
|
params={"Retriever": {"top_k": 5}},
|
||||||
sas_model_name_or_path="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
sas_model_name_or_path="sentence-transformers/paraphrase-MiniLM-L3-v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
metrics_top_1 = eval_result.calculate_metrics(simulated_top_k_reader=1)
|
metrics_top_1 = eval_result.calculate_metrics(simulated_top_k_reader=1)
|
||||||
|
|
||||||
assert metrics_top_1["Reader"]["exact_match"] == 0.5
|
assert metrics_top_1["Reader"]["exact_match"] == 0.5
|
||||||
assert metrics_top_1["Reader"]["f1"] == 0.5
|
assert metrics_top_1["Reader"]["f1"] == 0.5
|
||||||
assert metrics_top_1["Reader"]["sas"] == pytest.approx(0.6208, abs=1e-4)
|
assert metrics_top_1["Reader"]["sas"] == pytest.approx(0.5833, abs=1e-4)
|
||||||
assert metrics_top_1["Retriever"]["mrr"] == 0.5
|
assert metrics_top_1["Retriever"]["mrr"] == 0.5
|
||||||
assert metrics_top_1["Retriever"]["map"] == 0.5
|
assert metrics_top_1["Retriever"]["map"] == 0.5
|
||||||
assert metrics_top_1["Retriever"]["recall_multi_hit"] == 0.5
|
assert metrics_top_1["Retriever"]["recall_multi_hit"] == 0.5
|
||||||
@ -417,7 +417,7 @@ def test_extractive_qa_eval_simulated_top_k_reader(reader, retriever_with_docs):
|
|||||||
|
|
||||||
assert metrics_top_2["Reader"]["exact_match"] == 0.5
|
assert metrics_top_2["Reader"]["exact_match"] == 0.5
|
||||||
assert metrics_top_2["Reader"]["f1"] == 0.5
|
assert metrics_top_2["Reader"]["f1"] == 0.5
|
||||||
assert metrics_top_2["Reader"]["sas"] == pytest.approx(0.7192, abs=1e-4)
|
assert metrics_top_2["Reader"]["sas"] == pytest.approx(0.5833, abs=1e-4)
|
||||||
assert metrics_top_2["Retriever"]["mrr"] == 0.5
|
assert metrics_top_2["Retriever"]["mrr"] == 0.5
|
||||||
assert metrics_top_2["Retriever"]["map"] == 0.5
|
assert metrics_top_2["Retriever"]["map"] == 0.5
|
||||||
assert metrics_top_2["Retriever"]["recall_multi_hit"] == 0.5
|
assert metrics_top_2["Retriever"]["recall_multi_hit"] == 0.5
|
||||||
@ -534,7 +534,35 @@ def test_extractive_qa_eval_simulated_top_k_reader_and_retriever(reader, retriev
|
|||||||
assert metrics_top_3["Retriever"]["recall_multi_hit"] == 0.5
|
assert metrics_top_3["Retriever"]["recall_multi_hit"] == 0.5
|
||||||
assert metrics_top_3["Retriever"]["recall_single_hit"] == 0.5
|
assert metrics_top_3["Retriever"]["recall_single_hit"] == 0.5
|
||||||
assert metrics_top_3["Retriever"]["precision"] == 1.0/6
|
assert metrics_top_3["Retriever"]["precision"] == 1.0/6
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||||
|
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||||
|
def test_extractive_qa_eval_isolated(reader, retriever_with_docs):
|
||||||
|
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||||
|
eval_result: EvaluationResult = pipeline.eval(
|
||||||
|
labels=EVAL_LABELS,
|
||||||
|
sas_model_name_or_path="sentence-transformers/paraphrase-MiniLM-L3-v2",
|
||||||
|
add_isolated_node_eval=True
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics_top_1 = eval_result.calculate_metrics(simulated_top_k_reader=1)
|
||||||
|
|
||||||
|
assert metrics_top_1["Reader"]["exact_match"] == 0.5
|
||||||
|
assert metrics_top_1["Reader"]["f1"] == 0.5
|
||||||
|
assert metrics_top_1["Reader"]["sas"] == pytest.approx(0.5833, abs=1e-4)
|
||||||
|
assert metrics_top_1["Retriever"]["mrr"] == 0.5
|
||||||
|
assert metrics_top_1["Retriever"]["map"] == 0.5
|
||||||
|
assert metrics_top_1["Retriever"]["recall_multi_hit"] == 0.5
|
||||||
|
assert metrics_top_1["Retriever"]["recall_single_hit"] == 0.5
|
||||||
|
assert metrics_top_1["Retriever"]["precision"] == 1.0 / 6
|
||||||
|
|
||||||
|
metrics_top_1 = eval_result.calculate_metrics(simulated_top_k_reader=1, eval_mode="isolated")
|
||||||
|
|
||||||
|
assert metrics_top_1["Reader"]["exact_match"] == 1.0
|
||||||
|
assert metrics_top_1["Reader"]["f1"] == 1.0
|
||||||
|
assert metrics_top_1["Reader"]["sas"] == pytest.approx(1.0, abs=1e-4)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||||
@ -578,9 +606,16 @@ def test_extractive_qa_print_eval_report(reader, retriever_with_docs):
|
|||||||
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||||
eval_result: EvaluationResult = pipeline.eval(
|
eval_result: EvaluationResult = pipeline.eval(
|
||||||
labels=labels,
|
labels=labels,
|
||||||
params={"Retriever": {"top_k": 5}},
|
params={"Retriever": {"top_k": 5}}
|
||||||
)
|
)
|
||||||
|
pipeline.print_eval_report(eval_result)
|
||||||
|
|
||||||
|
# in addition with labels as input to reader node rather than output of retriever node
|
||||||
|
eval_result: EvaluationResult = pipeline.eval(
|
||||||
|
labels=labels,
|
||||||
|
params={"Retriever": {"top_k": 5}},
|
||||||
|
add_isolated_node_eval=True
|
||||||
|
)
|
||||||
pipeline.print_eval_report(eval_result)
|
pipeline.print_eval_report(eval_result)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user