diff --git a/docs/_src/api/api/pipelines.md b/docs/_src/api/api/pipelines.md
index b8f326908..caab86b71 100644
--- a/docs/_src/api/api/pipelines.md
+++ b/docs/_src/api/api/pipelines.md
@@ -158,6 +158,25 @@ Runs the pipeline, one node at a time.
they received and the output they generated. All debug information can
then be found in the dict returned by this method under the key "_debug"
+
+#### eval
+
+```python
+ | eval(queries: List[str], labels: List[MultiLabel], params: Optional[dict] = None) -> EvaluationResult
+```
+
+Evaluates the pipeline by running the pipeline once per query in debug mode
+and putting together all data that is needed for evaluation, e.g. calculating metrics.
+
+**Arguments**:
+
+- `queries`: The queries to evaluate
+- `labels`: The labels to evaluate on
+- `params`: Dictionary of parameters to be dispatched to the nodes.
+ If you want to pass a param to all nodes, you can just use: {"top_k":10}
+ If you want to pass it to targeted nodes, you can do:
+ {"Retriever": {"top_k": 10}, "Reader": {"top_k": 3, "debug": True}}
+
#### get\_nodes\_by\_class
@@ -520,6 +539,23 @@ Pipeline for Extractive Question Answering.
All debug information can then be found in the dict returned
by this method under the key "_debug"
+
+#### eval
+
+```python
+ | eval(queries: List[str], labels: List[MultiLabel], params: Optional[dict]) -> EvaluationResult
+```
+
+Evaluates the pipeline by running the pipeline once per query in debug mode
+and putting together all data that is needed for evaluation, e.g. calculating metrics.
+
+**Arguments**:
+
+- `queries`: The queries to evaluate
+- `labels`: The labels to evaluate on
+- `params`: Params for the `retriever` and `reader`. For instance,
+ params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}}
+
## DocumentSearchPipeline
diff --git a/docs/_src/api/api/primitives.md b/docs/_src/api/api/primitives.md
index 294d75afa..827b422a4 100644
--- a/docs/_src/api/api/primitives.md
+++ b/docs/_src/api/api/primitives.md
@@ -230,3 +230,20 @@ underlying Labels provided a text answer and therefore demonstrates that there i
- `drop_negative_labels`: Whether to drop negative labels from that group (e.g. thumbs down feedback from UI)
- `drop_no_answers`: Whether to drop labels that specify the answer is impossible
+
+## EvaluationResult
+
+```python
+class EvaluationResult()
+```
+
+
+#### calculate\_metrics
+
+```python
+ | calculate_metrics() -> Dict[str, float]
+```
+
+First dummy implementation of metrics calcuation just to show the way it's done.
+TODO: implement retriever and reader specific metrics that must not rely on node names.
+
diff --git a/haystack/pipelines/base.py b/haystack/pipelines/base.py
index cabe837b1..3836cc487 100644
--- a/haystack/pipelines/base.py
+++ b/haystack/pipelines/base.py
@@ -1,12 +1,15 @@
-from typing import List, Optional, Any
+from typing import Dict, List, Optional, Any, Union
import copy
import inspect
import logging
import os
import traceback
+import numpy as np
+import pandas as pd
from pathlib import Path
import networkx as nx
+from pandas.core.frame import DataFrame
import yaml
from networkx import DiGraph
from networkx.drawing.nx_agraph import to_agraph
@@ -18,7 +21,7 @@ except:
ray = None # type: ignore
serve = None # type: ignore
-from haystack.schema import MultiLabel, Document
+from haystack.schema import EvaluationResult, MultiLabel, Document
from haystack.nodes.base import BaseComponent
from haystack.document_stores.base import BaseDocumentStore
@@ -360,6 +363,68 @@ class Pipeline(BasePipeline):
i += 1 # attempt executing next node in the queue as current `node_id` has unprocessed predecessors
return node_output
+ def eval(
+ self,
+ queries: List[str],
+ labels: List[MultiLabel],
+ params: Optional[dict] = None
+ ) -> EvaluationResult:
+ """
+ Evaluates the pipeline by running the pipeline once per query in debug mode
+ and putting together all data that is needed for evaluation, e.g. calculating metrics.
+
+ :param queries: The queries to evaluate
+ :param labels: The labels to evaluate on
+ :param params: Dictionary of parameters to be dispatched to the nodes.
+ If you want to pass a param to all nodes, you can just use: {"top_k":10}
+ If you want to pass it to targeted nodes, you can do:
+ {"Retriever": {"top_k": 10}, "Reader": {"top_k": 3, "debug": True}}
+ """
+ if len(queries) != len(labels):
+ raise ValueError("length of queries must match length of labels")
+
+ eval_result = EvaluationResult()
+ for query, label in zip(queries, labels):
+ predictions = self.run(query=query, labels=label, params=params, debug=True)
+
+ for node_name in predictions["_debug"].keys():
+ node_output = predictions["_debug"][node_name]["output"]
+ df = self._build_eval_dataframe(
+ query, label, node_name, node_output)
+ eval_result.append(node_name, df)
+
+ return eval_result
+
+ def _build_eval_dataframe(self, query: str, labels: MultiLabel, node_name: str, node_output: dict) -> DataFrame:
+ answer_cols = ["answer", "document_id", "offsets_in_document"]
+ document_cols = ["content", "id"]
+
+ df: DataFrame = None
+ answers = node_output.get("answers", None)
+ if answers is not None:
+ df = pd.DataFrame(answers, columns=answer_cols)
+ if labels is not None:
+ df["gold_answers"] = df.apply(
+ lambda x: [label.answer.answer for label in labels.labels if label.answer is not None], axis=1)
+ df["gold_offsets_in_documents"] = df.apply(
+ lambda x: [label.answer.offsets_in_document for label in labels.labels if label.answer is not None], axis=1)
+
+ documents = node_output.get("documents", None)
+ if documents is not None:
+ df = pd.DataFrame(documents, columns=document_cols)
+ if labels is not None:
+ df["gold_document_ids"] = df.apply(
+ lambda x: [label.document.id for label in labels.labels], axis=1)
+ df["gold_document_contents"] = df.apply(
+ lambda x: [label.document.content for label in labels.labels], axis=1)
+
+ if df is not None:
+ df["node"] = node_name
+ df["query"] = query
+ df["rank"] = np.arange(1, len(df)+1)
+
+ return df
+
def get_next_nodes(self, node_id: str, stream_id: str):
current_node_edges = self.graph.edges(node_id, data=True)
next_nodes = [
diff --git a/haystack/pipelines/standard_pipelines.py b/haystack/pipelines/standard_pipelines.py
index e9f41c1c5..2cb8de582 100644
--- a/haystack/pipelines/standard_pipelines.py
+++ b/haystack/pipelines/standard_pipelines.py
@@ -5,7 +5,7 @@ from pathlib import Path
from typing import List, Optional, Dict
from functools import wraps
-from haystack.schema import Document
+from haystack.schema import Document, EvaluationResult, MultiLabel
from haystack.nodes.answer_generator import BaseGenerator
from haystack.nodes.other import Docs2Answers
from haystack.nodes.reader import BaseReader
@@ -101,6 +101,22 @@ class ExtractiveQAPipeline(BaseStandardPipeline):
output = self.pipeline.run(query=query, params=params, debug=debug)
return output
+ def eval(self,
+ queries: List[str],
+ labels: List[MultiLabel],
+ params: Optional[dict]) -> EvaluationResult:
+
+ """
+ Evaluates the pipeline by running the pipeline once per query in debug mode
+ and putting together all data that is needed for evaluation, e.g. calculating metrics.
+
+ :param queries: The queries to evaluate
+ :param labels: The labels to evaluate on
+ :param params: Params for the `retriever` and `reader`. For instance,
+ params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}}
+ """
+ output = self.pipeline.eval(queries=queries, labels=labels, params=params)
+ return output
class DocumentSearchPipeline(BaseStandardPipeline):
"""
diff --git a/haystack/schema.py b/haystack/schema.py
index 8467959b5..9f915ce42 100644
--- a/haystack/schema.py
+++ b/haystack/schema.py
@@ -15,6 +15,7 @@ else:
from pydantic.dataclasses import dataclass
from pydantic.json import pydantic_encoder
+from pathlib import Path
from uuid import uuid4
import mmh3
import numpy as np
@@ -531,3 +532,54 @@ class NumpyEncoder(json.JSONEncoder):
if isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
+
+
+class EvaluationResult:
+ def __init__(self, node_results: Dict[str, pd.DataFrame] = None) -> None:
+ self.node_results: Dict[str, pd.DataFrame] = {} if node_results is None else node_results
+
+ def __getitem__(self, key: str):
+ return self.node_results.__getitem__(key)
+
+ def __delitem__(self, key: str):
+ self.node_results.__delitem__(key)
+
+ def __setitem__(self, key: str, value: pd.DataFrame):
+ self.node_results.__setitem__(key, value)
+
+ def __contains__(self, key: str):
+ return self.node_results.keys().__contains__(key)
+
+ def append(self, key: str, value: pd.DataFrame):
+ if key in self.node_results:
+ self.node_results[key] = pd.concat([self.node_results[key], value])
+ else:
+ self.node_results[key] = value
+
+ def calculate_metrics(self) -> Dict[str, float]:
+ """
+ First dummy implementation of metrics calcuation just to show the way it's done.
+ TODO: implement retriever and reader specific metrics that must not rely on node names.
+ """
+ reader_df = self.node_results["Reader"]
+ first_answers = reader_df[reader_df["rank"] == 1]
+ first_correct_answers = first_answers[first_answers.apply(
+ lambda x: x["answer"] in x["gold_answers"], axis=1)]
+
+ return {
+ "MatchInTop1": len(first_correct_answers) / len(first_answers) if len(first_answers) > 0 else 0.0
+ }
+
+ def save(self, out_dir: Union[str, Path]):
+ out_dir = out_dir if isinstance(out_dir, Path) else Path(out_dir)
+ for node_name, df in self.node_results.items():
+ target_path = out_dir / f"{node_name}.csv"
+ df.to_csv(target_path, index=False, header=True)
+
+ @classmethod
+ def load(cls, load_dir: Union[str, Path]):
+ load_dir = load_dir if isinstance(load_dir, Path) else Path(load_dir)
+ csv_files = [file for file in load_dir.iterdir() if file.is_file() and file.suffix == ".csv"]
+ node_results = {file.stem: pd.read_csv(file, header=0) for file in csv_files}
+ result = cls(node_results)
+ return result
diff --git a/test/test_pipeline_extractive_qa.py b/test/test_pipeline_extractive_qa.py
index 48910f70f..4b2f783a8 100644
--- a/test/test_pipeline_extractive_qa.py
+++ b/test/test_pipeline_extractive_qa.py
@@ -4,8 +4,9 @@ from haystack.pipeline import (
TranslationWrapperPipeline,
ExtractiveQAPipeline
)
+from haystack.pipelines.base import EvaluationResult
-from haystack.schema import Answer
+from haystack.schema import Answer, Document, Label, MultiLabel, Span
@pytest.mark.slow
@@ -98,3 +99,86 @@ def test_extractive_qa_answers_with_translator(
assert (
prediction["answers"][0].context == "My name is Carla and I live in Berlin"
)
+
+
+@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
+@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
+def test_extractive_qa_eval(reader, retriever_with_docs, tmp_path):
+ queries = ["Who lives in Berlin?"]
+ labels = [
+ MultiLabel(labels=[Label(query="Who lives in Berlin?", answer=Answer(answer="Carla", offsets_in_context=[Span(11, 16)]),
+ document=Document(id='a0747b83aea0b60c4b114b15476dd32d', content_type="text", content='My name is Carla and I live in Berlin'),
+ is_correct_answer=True, is_correct_document=True, origin="gold-label")])
+ ]
+
+ pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
+ eval_result = pipeline.eval(
+ queries=queries,
+ labels=labels,
+ params={"Retriever": {"top_k": 5}},
+ )
+
+ metrics = eval_result.calculate_metrics()
+
+ reader_result = eval_result["Reader"]
+ retriever_result = eval_result["Retriever"]
+
+ assert reader_result[reader_result['rank'] == 1]["answer"].iloc[0] in reader_result[reader_result['rank'] == 1]["gold_answers"].iloc[0]
+ assert retriever_result[retriever_result['rank'] == 1]["id"].iloc[0] in retriever_result[retriever_result['rank'] == 1]["gold_document_ids"].iloc[0]
+ assert metrics["MatchInTop1"] == 1.0
+
+ eval_result.save(tmp_path)
+ saved_eval_result = EvaluationResult.load(tmp_path)
+ metrics = saved_eval_result.calculate_metrics()
+
+ assert reader_result[reader_result['rank'] == 1]["answer"].iloc[0] in reader_result[reader_result['rank'] == 1]["gold_answers"].iloc[0]
+ assert retriever_result[retriever_result['rank'] == 1]["id"].iloc[0] in retriever_result[retriever_result['rank'] == 1]["gold_document_ids"].iloc[0]
+ assert metrics["MatchInTop1"] == 1.0
+
+
+@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
+@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
+def test_extractive_qa_eval_multiple_queries(reader, retriever_with_docs, tmp_path):
+ queries = ["Who lives in Berlin?", "Who lives in Munich?"]
+ labels = [
+ MultiLabel(labels=[Label(query="Who lives in Berlin?", answer=Answer(answer="Carla", offsets_in_context=[Span(11, 16)]),
+ document=Document(id='a0747b83aea0b60c4b114b15476dd32d', content_type="text", content='My name is Carla and I live in Berlin'),
+ is_correct_answer=True, is_correct_document=True, origin="gold-label")]),
+ MultiLabel(labels=[Label(query="Who lives in Munich?", answer=Answer(answer="Carla", offsets_in_context=[Span(11, 16)]),
+ document=Document(id='something_else', content_type="text", content='My name is Carla and I live in Munich'),
+ is_correct_answer=True, is_correct_document=True, origin="gold-label")])
+ ]
+
+ pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
+ eval_result = pipeline.eval(
+ queries=queries,
+ labels=labels,
+ params={"Retriever": {"top_k": 5}},
+ )
+
+ metrics = eval_result.calculate_metrics()
+
+ reader_result = eval_result["Reader"]
+ retriever_result = eval_result["Retriever"]
+
+ reader_berlin = reader_result[reader_result['query'] == "Who lives in Berlin?"]
+ reader_munich = reader_result[reader_result['query'] == "Who lives in Munich?"]
+
+ retriever_berlin = retriever_result[retriever_result['query'] == "Who lives in Berlin?"]
+ retriever_munich = retriever_result[retriever_result['query'] == "Who lives in Munich?"]
+
+ assert reader_berlin[reader_berlin['rank'] == 1]["answer"].iloc[0] in reader_berlin[reader_berlin['rank'] == 1]["gold_answers"].iloc[0]
+ assert retriever_berlin[retriever_berlin['rank'] == 1]["id"].iloc[0] in retriever_berlin[retriever_berlin['rank'] == 1]["gold_document_ids"].iloc[0]
+ assert reader_munich[reader_munich['rank'] == 1]["answer"].iloc[0] not in reader_munich[reader_munich['rank'] == 1]["gold_answers"].iloc[0]
+ assert retriever_munich[retriever_munich['rank'] == 1]["id"].iloc[0] not in retriever_munich[retriever_munich['rank'] == 1]["gold_document_ids"].iloc[0]
+ assert metrics["MatchInTop1"] == 0.5
+
+ eval_result.save(tmp_path)
+ saved_eval_result = EvaluationResult.load(tmp_path)
+ metrics = saved_eval_result.calculate_metrics()
+
+ assert reader_berlin[reader_berlin['rank'] == 1]["answer"].iloc[0] in reader_berlin[reader_berlin['rank'] == 1]["gold_answers"].iloc[0]
+ assert retriever_berlin[retriever_berlin['rank'] == 1]["id"].iloc[0] in retriever_berlin[retriever_berlin['rank'] == 1]["gold_document_ids"].iloc[0]
+ assert reader_munich[reader_munich['rank'] == 1]["answer"].iloc[0] not in reader_munich[reader_munich['rank'] == 1]["gold_answers"].iloc[0]
+ assert retriever_munich[retriever_munich['rank'] == 1]["id"].iloc[0] not in retriever_munich[retriever_munich['rank'] == 1]["gold_document_ids"].iloc[0]
+ assert metrics["MatchInTop1"] == 0.5