Multi query eval (#1746)

* add eval() to pipeline

* Add latest docstring and tutorial changes

* support multiple queries in eval()

* Add latest docstring and tutorial changes

* keep single query test

* fix EvaluationResult node_results default

* adjust docstrings

* Add latest docstring and tutorial changes

* minor improvements from comments

* Add latest docstring and tutorial changes

* move EvaluationResult and calculate_metrics to schema

* Add latest docstring and tutorial changes

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
tstadel 2021-11-15 14:51:11 +01:00 committed by GitHub
parent cf603042b2
commit 59e04cba05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 274 additions and 4 deletions

View File

@ -158,6 +158,25 @@ Runs the pipeline, one node at a time.
they received and the output they generated. All debug information can
then be found in the dict returned by this method under the key "_debug"
<a name="base.Pipeline.eval"></a>
#### eval
```python
| eval(queries: List[str], labels: List[MultiLabel], params: Optional[dict] = None) -> EvaluationResult
```
Evaluates the pipeline by running the pipeline once per query in debug mode
and putting together all data that is needed for evaluation, e.g. calculating metrics.
**Arguments**:
- `queries`: The queries to evaluate
- `labels`: The labels to evaluate on
- `params`: Dictionary of parameters to be dispatched to the nodes.
If you want to pass a param to all nodes, you can just use: {"top_k":10}
If you want to pass it to targeted nodes, you can do:
{"Retriever": {"top_k": 10}, "Reader": {"top_k": 3, "debug": True}}
<a name="base.Pipeline.get_nodes_by_class"></a>
#### get\_nodes\_by\_class
@ -520,6 +539,23 @@ Pipeline for Extractive Question Answering.
All debug information can then be found in the dict returned
by this method under the key "_debug"
<a name="standard_pipelines.ExtractiveQAPipeline.eval"></a>
#### eval
```python
| eval(queries: List[str], labels: List[MultiLabel], params: Optional[dict]) -> EvaluationResult
```
Evaluates the pipeline by running the pipeline once per query in debug mode
and putting together all data that is needed for evaluation, e.g. calculating metrics.
**Arguments**:
- `queries`: The queries to evaluate
- `labels`: The labels to evaluate on
- `params`: Params for the `retriever` and `reader`. For instance,
params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}}
<a name="standard_pipelines.DocumentSearchPipeline"></a>
## DocumentSearchPipeline

View File

@ -230,3 +230,20 @@ underlying Labels provided a text answer and therefore demonstrates that there i
- `drop_negative_labels`: Whether to drop negative labels from that group (e.g. thumbs down feedback from UI)
- `drop_no_answers`: Whether to drop labels that specify the answer is impossible
<a name="schema.EvaluationResult"></a>
## EvaluationResult
```python
class EvaluationResult()
```
<a name="schema.EvaluationResult.calculate_metrics"></a>
#### calculate\_metrics
```python
| calculate_metrics() -> Dict[str, float]
```
First dummy implementation of metrics calcuation just to show the way it's done.
TODO: implement retriever and reader specific metrics that must not rely on node names.

View File

@ -1,12 +1,15 @@
from typing import List, Optional, Any
from typing import Dict, List, Optional, Any, Union
import copy
import inspect
import logging
import os
import traceback
import numpy as np
import pandas as pd
from pathlib import Path
import networkx as nx
from pandas.core.frame import DataFrame
import yaml
from networkx import DiGraph
from networkx.drawing.nx_agraph import to_agraph
@ -18,7 +21,7 @@ except:
ray = None # type: ignore
serve = None # type: ignore
from haystack.schema import MultiLabel, Document
from haystack.schema import EvaluationResult, MultiLabel, Document
from haystack.nodes.base import BaseComponent
from haystack.document_stores.base import BaseDocumentStore
@ -360,6 +363,68 @@ class Pipeline(BasePipeline):
i += 1 # attempt executing next node in the queue as current `node_id` has unprocessed predecessors
return node_output
def eval(
self,
queries: List[str],
labels: List[MultiLabel],
params: Optional[dict] = None
) -> EvaluationResult:
"""
Evaluates the pipeline by running the pipeline once per query in debug mode
and putting together all data that is needed for evaluation, e.g. calculating metrics.
:param queries: The queries to evaluate
:param labels: The labels to evaluate on
:param params: Dictionary of parameters to be dispatched to the nodes.
If you want to pass a param to all nodes, you can just use: {"top_k":10}
If you want to pass it to targeted nodes, you can do:
{"Retriever": {"top_k": 10}, "Reader": {"top_k": 3, "debug": True}}
"""
if len(queries) != len(labels):
raise ValueError("length of queries must match length of labels")
eval_result = EvaluationResult()
for query, label in zip(queries, labels):
predictions = self.run(query=query, labels=label, params=params, debug=True)
for node_name in predictions["_debug"].keys():
node_output = predictions["_debug"][node_name]["output"]
df = self._build_eval_dataframe(
query, label, node_name, node_output)
eval_result.append(node_name, df)
return eval_result
def _build_eval_dataframe(self, query: str, labels: MultiLabel, node_name: str, node_output: dict) -> DataFrame:
answer_cols = ["answer", "document_id", "offsets_in_document"]
document_cols = ["content", "id"]
df: DataFrame = None
answers = node_output.get("answers", None)
if answers is not None:
df = pd.DataFrame(answers, columns=answer_cols)
if labels is not None:
df["gold_answers"] = df.apply(
lambda x: [label.answer.answer for label in labels.labels if label.answer is not None], axis=1)
df["gold_offsets_in_documents"] = df.apply(
lambda x: [label.answer.offsets_in_document for label in labels.labels if label.answer is not None], axis=1)
documents = node_output.get("documents", None)
if documents is not None:
df = pd.DataFrame(documents, columns=document_cols)
if labels is not None:
df["gold_document_ids"] = df.apply(
lambda x: [label.document.id for label in labels.labels], axis=1)
df["gold_document_contents"] = df.apply(
lambda x: [label.document.content for label in labels.labels], axis=1)
if df is not None:
df["node"] = node_name
df["query"] = query
df["rank"] = np.arange(1, len(df)+1)
return df
def get_next_nodes(self, node_id: str, stream_id: str):
current_node_edges = self.graph.edges(node_id, data=True)
next_nodes = [

View File

@ -5,7 +5,7 @@ from pathlib import Path
from typing import List, Optional, Dict
from functools import wraps
from haystack.schema import Document
from haystack.schema import Document, EvaluationResult, MultiLabel
from haystack.nodes.answer_generator import BaseGenerator
from haystack.nodes.other import Docs2Answers
from haystack.nodes.reader import BaseReader
@ -101,6 +101,22 @@ class ExtractiveQAPipeline(BaseStandardPipeline):
output = self.pipeline.run(query=query, params=params, debug=debug)
return output
def eval(self,
queries: List[str],
labels: List[MultiLabel],
params: Optional[dict]) -> EvaluationResult:
"""
Evaluates the pipeline by running the pipeline once per query in debug mode
and putting together all data that is needed for evaluation, e.g. calculating metrics.
:param queries: The queries to evaluate
:param labels: The labels to evaluate on
:param params: Params for the `retriever` and `reader`. For instance,
params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}}
"""
output = self.pipeline.eval(queries=queries, labels=labels, params=params)
return output
class DocumentSearchPipeline(BaseStandardPipeline):
"""

View File

@ -15,6 +15,7 @@ else:
from pydantic.dataclasses import dataclass
from pydantic.json import pydantic_encoder
from pathlib import Path
from uuid import uuid4
import mmh3
import numpy as np
@ -531,3 +532,54 @@ class NumpyEncoder(json.JSONEncoder):
if isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
class EvaluationResult:
def __init__(self, node_results: Dict[str, pd.DataFrame] = None) -> None:
self.node_results: Dict[str, pd.DataFrame] = {} if node_results is None else node_results
def __getitem__(self, key: str):
return self.node_results.__getitem__(key)
def __delitem__(self, key: str):
self.node_results.__delitem__(key)
def __setitem__(self, key: str, value: pd.DataFrame):
self.node_results.__setitem__(key, value)
def __contains__(self, key: str):
return self.node_results.keys().__contains__(key)
def append(self, key: str, value: pd.DataFrame):
if key in self.node_results:
self.node_results[key] = pd.concat([self.node_results[key], value])
else:
self.node_results[key] = value
def calculate_metrics(self) -> Dict[str, float]:
"""
First dummy implementation of metrics calcuation just to show the way it's done.
TODO: implement retriever and reader specific metrics that must not rely on node names.
"""
reader_df = self.node_results["Reader"]
first_answers = reader_df[reader_df["rank"] == 1]
first_correct_answers = first_answers[first_answers.apply(
lambda x: x["answer"] in x["gold_answers"], axis=1)]
return {
"MatchInTop1": len(first_correct_answers) / len(first_answers) if len(first_answers) > 0 else 0.0
}
def save(self, out_dir: Union[str, Path]):
out_dir = out_dir if isinstance(out_dir, Path) else Path(out_dir)
for node_name, df in self.node_results.items():
target_path = out_dir / f"{node_name}.csv"
df.to_csv(target_path, index=False, header=True)
@classmethod
def load(cls, load_dir: Union[str, Path]):
load_dir = load_dir if isinstance(load_dir, Path) else Path(load_dir)
csv_files = [file for file in load_dir.iterdir() if file.is_file() and file.suffix == ".csv"]
node_results = {file.stem: pd.read_csv(file, header=0) for file in csv_files}
result = cls(node_results)
return result

View File

@ -4,8 +4,9 @@ from haystack.pipeline import (
TranslationWrapperPipeline,
ExtractiveQAPipeline
)
from haystack.pipelines.base import EvaluationResult
from haystack.schema import Answer
from haystack.schema import Answer, Document, Label, MultiLabel, Span
@pytest.mark.slow
@ -98,3 +99,86 @@ def test_extractive_qa_answers_with_translator(
assert (
prediction["answers"][0].context == "My name is Carla and I live in Berlin"
)
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
def test_extractive_qa_eval(reader, retriever_with_docs, tmp_path):
queries = ["Who lives in Berlin?"]
labels = [
MultiLabel(labels=[Label(query="Who lives in Berlin?", answer=Answer(answer="Carla", offsets_in_context=[Span(11, 16)]),
document=Document(id='a0747b83aea0b60c4b114b15476dd32d', content_type="text", content='My name is Carla and I live in Berlin'),
is_correct_answer=True, is_correct_document=True, origin="gold-label")])
]
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
eval_result = pipeline.eval(
queries=queries,
labels=labels,
params={"Retriever": {"top_k": 5}},
)
metrics = eval_result.calculate_metrics()
reader_result = eval_result["Reader"]
retriever_result = eval_result["Retriever"]
assert reader_result[reader_result['rank'] == 1]["answer"].iloc[0] in reader_result[reader_result['rank'] == 1]["gold_answers"].iloc[0]
assert retriever_result[retriever_result['rank'] == 1]["id"].iloc[0] in retriever_result[retriever_result['rank'] == 1]["gold_document_ids"].iloc[0]
assert metrics["MatchInTop1"] == 1.0
eval_result.save(tmp_path)
saved_eval_result = EvaluationResult.load(tmp_path)
metrics = saved_eval_result.calculate_metrics()
assert reader_result[reader_result['rank'] == 1]["answer"].iloc[0] in reader_result[reader_result['rank'] == 1]["gold_answers"].iloc[0]
assert retriever_result[retriever_result['rank'] == 1]["id"].iloc[0] in retriever_result[retriever_result['rank'] == 1]["gold_document_ids"].iloc[0]
assert metrics["MatchInTop1"] == 1.0
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
def test_extractive_qa_eval_multiple_queries(reader, retriever_with_docs, tmp_path):
queries = ["Who lives in Berlin?", "Who lives in Munich?"]
labels = [
MultiLabel(labels=[Label(query="Who lives in Berlin?", answer=Answer(answer="Carla", offsets_in_context=[Span(11, 16)]),
document=Document(id='a0747b83aea0b60c4b114b15476dd32d', content_type="text", content='My name is Carla and I live in Berlin'),
is_correct_answer=True, is_correct_document=True, origin="gold-label")]),
MultiLabel(labels=[Label(query="Who lives in Munich?", answer=Answer(answer="Carla", offsets_in_context=[Span(11, 16)]),
document=Document(id='something_else', content_type="text", content='My name is Carla and I live in Munich'),
is_correct_answer=True, is_correct_document=True, origin="gold-label")])
]
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
eval_result = pipeline.eval(
queries=queries,
labels=labels,
params={"Retriever": {"top_k": 5}},
)
metrics = eval_result.calculate_metrics()
reader_result = eval_result["Reader"]
retriever_result = eval_result["Retriever"]
reader_berlin = reader_result[reader_result['query'] == "Who lives in Berlin?"]
reader_munich = reader_result[reader_result['query'] == "Who lives in Munich?"]
retriever_berlin = retriever_result[retriever_result['query'] == "Who lives in Berlin?"]
retriever_munich = retriever_result[retriever_result['query'] == "Who lives in Munich?"]
assert reader_berlin[reader_berlin['rank'] == 1]["answer"].iloc[0] in reader_berlin[reader_berlin['rank'] == 1]["gold_answers"].iloc[0]
assert retriever_berlin[retriever_berlin['rank'] == 1]["id"].iloc[0] in retriever_berlin[retriever_berlin['rank'] == 1]["gold_document_ids"].iloc[0]
assert reader_munich[reader_munich['rank'] == 1]["answer"].iloc[0] not in reader_munich[reader_munich['rank'] == 1]["gold_answers"].iloc[0]
assert retriever_munich[retriever_munich['rank'] == 1]["id"].iloc[0] not in retriever_munich[retriever_munich['rank'] == 1]["gold_document_ids"].iloc[0]
assert metrics["MatchInTop1"] == 0.5
eval_result.save(tmp_path)
saved_eval_result = EvaluationResult.load(tmp_path)
metrics = saved_eval_result.calculate_metrics()
assert reader_berlin[reader_berlin['rank'] == 1]["answer"].iloc[0] in reader_berlin[reader_berlin['rank'] == 1]["gold_answers"].iloc[0]
assert retriever_berlin[retriever_berlin['rank'] == 1]["id"].iloc[0] in retriever_berlin[retriever_berlin['rank'] == 1]["gold_document_ids"].iloc[0]
assert reader_munich[reader_munich['rank'] == 1]["answer"].iloc[0] not in reader_munich[reader_munich['rank'] == 1]["gold_answers"].iloc[0]
assert retriever_munich[retriever_munich['rank'] == 1]["id"].iloc[0] not in retriever_munich[retriever_munich['rank'] == 1]["gold_document_ids"].iloc[0]
assert metrics["MatchInTop1"] == 0.5