mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-05 19:47:45 +00:00
Multi query eval (#1746)
* add eval() to pipeline * Add latest docstring and tutorial changes * support multiple queries in eval() * Add latest docstring and tutorial changes * keep single query test * fix EvaluationResult node_results default * adjust docstrings * Add latest docstring and tutorial changes * minor improvements from comments * Add latest docstring and tutorial changes * move EvaluationResult and calculate_metrics to schema * Add latest docstring and tutorial changes Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
cf603042b2
commit
59e04cba05
@ -158,6 +158,25 @@ Runs the pipeline, one node at a time.
|
||||
they received and the output they generated. All debug information can
|
||||
then be found in the dict returned by this method under the key "_debug"
|
||||
|
||||
<a name="base.Pipeline.eval"></a>
|
||||
#### eval
|
||||
|
||||
```python
|
||||
| eval(queries: List[str], labels: List[MultiLabel], params: Optional[dict] = None) -> EvaluationResult
|
||||
```
|
||||
|
||||
Evaluates the pipeline by running the pipeline once per query in debug mode
|
||||
and putting together all data that is needed for evaluation, e.g. calculating metrics.
|
||||
|
||||
**Arguments**:
|
||||
|
||||
- `queries`: The queries to evaluate
|
||||
- `labels`: The labels to evaluate on
|
||||
- `params`: Dictionary of parameters to be dispatched to the nodes.
|
||||
If you want to pass a param to all nodes, you can just use: {"top_k":10}
|
||||
If you want to pass it to targeted nodes, you can do:
|
||||
{"Retriever": {"top_k": 10}, "Reader": {"top_k": 3, "debug": True}}
|
||||
|
||||
<a name="base.Pipeline.get_nodes_by_class"></a>
|
||||
#### get\_nodes\_by\_class
|
||||
|
||||
@ -520,6 +539,23 @@ Pipeline for Extractive Question Answering.
|
||||
All debug information can then be found in the dict returned
|
||||
by this method under the key "_debug"
|
||||
|
||||
<a name="standard_pipelines.ExtractiveQAPipeline.eval"></a>
|
||||
#### eval
|
||||
|
||||
```python
|
||||
| eval(queries: List[str], labels: List[MultiLabel], params: Optional[dict]) -> EvaluationResult
|
||||
```
|
||||
|
||||
Evaluates the pipeline by running the pipeline once per query in debug mode
|
||||
and putting together all data that is needed for evaluation, e.g. calculating metrics.
|
||||
|
||||
**Arguments**:
|
||||
|
||||
- `queries`: The queries to evaluate
|
||||
- `labels`: The labels to evaluate on
|
||||
- `params`: Params for the `retriever` and `reader`. For instance,
|
||||
params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}}
|
||||
|
||||
<a name="standard_pipelines.DocumentSearchPipeline"></a>
|
||||
## DocumentSearchPipeline
|
||||
|
||||
|
||||
@ -230,3 +230,20 @@ underlying Labels provided a text answer and therefore demonstrates that there i
|
||||
- `drop_negative_labels`: Whether to drop negative labels from that group (e.g. thumbs down feedback from UI)
|
||||
- `drop_no_answers`: Whether to drop labels that specify the answer is impossible
|
||||
|
||||
<a name="schema.EvaluationResult"></a>
|
||||
## EvaluationResult
|
||||
|
||||
```python
|
||||
class EvaluationResult()
|
||||
```
|
||||
|
||||
<a name="schema.EvaluationResult.calculate_metrics"></a>
|
||||
#### calculate\_metrics
|
||||
|
||||
```python
|
||||
| calculate_metrics() -> Dict[str, float]
|
||||
```
|
||||
|
||||
First dummy implementation of metrics calcuation just to show the way it's done.
|
||||
TODO: implement retriever and reader specific metrics that must not rely on node names.
|
||||
|
||||
|
||||
@ -1,12 +1,15 @@
|
||||
from typing import List, Optional, Any
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
import networkx as nx
|
||||
from pandas.core.frame import DataFrame
|
||||
import yaml
|
||||
from networkx import DiGraph
|
||||
from networkx.drawing.nx_agraph import to_agraph
|
||||
@ -18,7 +21,7 @@ except:
|
||||
ray = None # type: ignore
|
||||
serve = None # type: ignore
|
||||
|
||||
from haystack.schema import MultiLabel, Document
|
||||
from haystack.schema import EvaluationResult, MultiLabel, Document
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.document_stores.base import BaseDocumentStore
|
||||
|
||||
@ -360,6 +363,68 @@ class Pipeline(BasePipeline):
|
||||
i += 1 # attempt executing next node in the queue as current `node_id` has unprocessed predecessors
|
||||
return node_output
|
||||
|
||||
def eval(
|
||||
self,
|
||||
queries: List[str],
|
||||
labels: List[MultiLabel],
|
||||
params: Optional[dict] = None
|
||||
) -> EvaluationResult:
|
||||
"""
|
||||
Evaluates the pipeline by running the pipeline once per query in debug mode
|
||||
and putting together all data that is needed for evaluation, e.g. calculating metrics.
|
||||
|
||||
:param queries: The queries to evaluate
|
||||
:param labels: The labels to evaluate on
|
||||
:param params: Dictionary of parameters to be dispatched to the nodes.
|
||||
If you want to pass a param to all nodes, you can just use: {"top_k":10}
|
||||
If you want to pass it to targeted nodes, you can do:
|
||||
{"Retriever": {"top_k": 10}, "Reader": {"top_k": 3, "debug": True}}
|
||||
"""
|
||||
if len(queries) != len(labels):
|
||||
raise ValueError("length of queries must match length of labels")
|
||||
|
||||
eval_result = EvaluationResult()
|
||||
for query, label in zip(queries, labels):
|
||||
predictions = self.run(query=query, labels=label, params=params, debug=True)
|
||||
|
||||
for node_name in predictions["_debug"].keys():
|
||||
node_output = predictions["_debug"][node_name]["output"]
|
||||
df = self._build_eval_dataframe(
|
||||
query, label, node_name, node_output)
|
||||
eval_result.append(node_name, df)
|
||||
|
||||
return eval_result
|
||||
|
||||
def _build_eval_dataframe(self, query: str, labels: MultiLabel, node_name: str, node_output: dict) -> DataFrame:
|
||||
answer_cols = ["answer", "document_id", "offsets_in_document"]
|
||||
document_cols = ["content", "id"]
|
||||
|
||||
df: DataFrame = None
|
||||
answers = node_output.get("answers", None)
|
||||
if answers is not None:
|
||||
df = pd.DataFrame(answers, columns=answer_cols)
|
||||
if labels is not None:
|
||||
df["gold_answers"] = df.apply(
|
||||
lambda x: [label.answer.answer for label in labels.labels if label.answer is not None], axis=1)
|
||||
df["gold_offsets_in_documents"] = df.apply(
|
||||
lambda x: [label.answer.offsets_in_document for label in labels.labels if label.answer is not None], axis=1)
|
||||
|
||||
documents = node_output.get("documents", None)
|
||||
if documents is not None:
|
||||
df = pd.DataFrame(documents, columns=document_cols)
|
||||
if labels is not None:
|
||||
df["gold_document_ids"] = df.apply(
|
||||
lambda x: [label.document.id for label in labels.labels], axis=1)
|
||||
df["gold_document_contents"] = df.apply(
|
||||
lambda x: [label.document.content for label in labels.labels], axis=1)
|
||||
|
||||
if df is not None:
|
||||
df["node"] = node_name
|
||||
df["query"] = query
|
||||
df["rank"] = np.arange(1, len(df)+1)
|
||||
|
||||
return df
|
||||
|
||||
def get_next_nodes(self, node_id: str, stream_id: str):
|
||||
current_node_edges = self.graph.edges(node_id, data=True)
|
||||
next_nodes = [
|
||||
|
||||
@ -5,7 +5,7 @@ from pathlib import Path
|
||||
from typing import List, Optional, Dict
|
||||
from functools import wraps
|
||||
|
||||
from haystack.schema import Document
|
||||
from haystack.schema import Document, EvaluationResult, MultiLabel
|
||||
from haystack.nodes.answer_generator import BaseGenerator
|
||||
from haystack.nodes.other import Docs2Answers
|
||||
from haystack.nodes.reader import BaseReader
|
||||
@ -101,6 +101,22 @@ class ExtractiveQAPipeline(BaseStandardPipeline):
|
||||
output = self.pipeline.run(query=query, params=params, debug=debug)
|
||||
return output
|
||||
|
||||
def eval(self,
|
||||
queries: List[str],
|
||||
labels: List[MultiLabel],
|
||||
params: Optional[dict]) -> EvaluationResult:
|
||||
|
||||
"""
|
||||
Evaluates the pipeline by running the pipeline once per query in debug mode
|
||||
and putting together all data that is needed for evaluation, e.g. calculating metrics.
|
||||
|
||||
:param queries: The queries to evaluate
|
||||
:param labels: The labels to evaluate on
|
||||
:param params: Params for the `retriever` and `reader`. For instance,
|
||||
params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}}
|
||||
"""
|
||||
output = self.pipeline.eval(queries=queries, labels=labels, params=params)
|
||||
return output
|
||||
|
||||
class DocumentSearchPipeline(BaseStandardPipeline):
|
||||
"""
|
||||
|
||||
@ -15,6 +15,7 @@ else:
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from pydantic.json import pydantic_encoder
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
import mmh3
|
||||
import numpy as np
|
||||
@ -531,3 +532,54 @@ class NumpyEncoder(json.JSONEncoder):
|
||||
if isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
|
||||
|
||||
class EvaluationResult:
|
||||
def __init__(self, node_results: Dict[str, pd.DataFrame] = None) -> None:
|
||||
self.node_results: Dict[str, pd.DataFrame] = {} if node_results is None else node_results
|
||||
|
||||
def __getitem__(self, key: str):
|
||||
return self.node_results.__getitem__(key)
|
||||
|
||||
def __delitem__(self, key: str):
|
||||
self.node_results.__delitem__(key)
|
||||
|
||||
def __setitem__(self, key: str, value: pd.DataFrame):
|
||||
self.node_results.__setitem__(key, value)
|
||||
|
||||
def __contains__(self, key: str):
|
||||
return self.node_results.keys().__contains__(key)
|
||||
|
||||
def append(self, key: str, value: pd.DataFrame):
|
||||
if key in self.node_results:
|
||||
self.node_results[key] = pd.concat([self.node_results[key], value])
|
||||
else:
|
||||
self.node_results[key] = value
|
||||
|
||||
def calculate_metrics(self) -> Dict[str, float]:
|
||||
"""
|
||||
First dummy implementation of metrics calcuation just to show the way it's done.
|
||||
TODO: implement retriever and reader specific metrics that must not rely on node names.
|
||||
"""
|
||||
reader_df = self.node_results["Reader"]
|
||||
first_answers = reader_df[reader_df["rank"] == 1]
|
||||
first_correct_answers = first_answers[first_answers.apply(
|
||||
lambda x: x["answer"] in x["gold_answers"], axis=1)]
|
||||
|
||||
return {
|
||||
"MatchInTop1": len(first_correct_answers) / len(first_answers) if len(first_answers) > 0 else 0.0
|
||||
}
|
||||
|
||||
def save(self, out_dir: Union[str, Path]):
|
||||
out_dir = out_dir if isinstance(out_dir, Path) else Path(out_dir)
|
||||
for node_name, df in self.node_results.items():
|
||||
target_path = out_dir / f"{node_name}.csv"
|
||||
df.to_csv(target_path, index=False, header=True)
|
||||
|
||||
@classmethod
|
||||
def load(cls, load_dir: Union[str, Path]):
|
||||
load_dir = load_dir if isinstance(load_dir, Path) else Path(load_dir)
|
||||
csv_files = [file for file in load_dir.iterdir() if file.is_file() and file.suffix == ".csv"]
|
||||
node_results = {file.stem: pd.read_csv(file, header=0) for file in csv_files}
|
||||
result = cls(node_results)
|
||||
return result
|
||||
|
||||
@ -4,8 +4,9 @@ from haystack.pipeline import (
|
||||
TranslationWrapperPipeline,
|
||||
ExtractiveQAPipeline
|
||||
)
|
||||
from haystack.pipelines.base import EvaluationResult
|
||||
|
||||
from haystack.schema import Answer
|
||||
from haystack.schema import Answer, Document, Label, MultiLabel, Span
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@ -98,3 +99,86 @@ def test_extractive_qa_answers_with_translator(
|
||||
assert (
|
||||
prediction["answers"][0].context == "My name is Carla and I live in Berlin"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
def test_extractive_qa_eval(reader, retriever_with_docs, tmp_path):
|
||||
queries = ["Who lives in Berlin?"]
|
||||
labels = [
|
||||
MultiLabel(labels=[Label(query="Who lives in Berlin?", answer=Answer(answer="Carla", offsets_in_context=[Span(11, 16)]),
|
||||
document=Document(id='a0747b83aea0b60c4b114b15476dd32d', content_type="text", content='My name is Carla and I live in Berlin'),
|
||||
is_correct_answer=True, is_correct_document=True, origin="gold-label")])
|
||||
]
|
||||
|
||||
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||
eval_result = pipeline.eval(
|
||||
queries=queries,
|
||||
labels=labels,
|
||||
params={"Retriever": {"top_k": 5}},
|
||||
)
|
||||
|
||||
metrics = eval_result.calculate_metrics()
|
||||
|
||||
reader_result = eval_result["Reader"]
|
||||
retriever_result = eval_result["Retriever"]
|
||||
|
||||
assert reader_result[reader_result['rank'] == 1]["answer"].iloc[0] in reader_result[reader_result['rank'] == 1]["gold_answers"].iloc[0]
|
||||
assert retriever_result[retriever_result['rank'] == 1]["id"].iloc[0] in retriever_result[retriever_result['rank'] == 1]["gold_document_ids"].iloc[0]
|
||||
assert metrics["MatchInTop1"] == 1.0
|
||||
|
||||
eval_result.save(tmp_path)
|
||||
saved_eval_result = EvaluationResult.load(tmp_path)
|
||||
metrics = saved_eval_result.calculate_metrics()
|
||||
|
||||
assert reader_result[reader_result['rank'] == 1]["answer"].iloc[0] in reader_result[reader_result['rank'] == 1]["gold_answers"].iloc[0]
|
||||
assert retriever_result[retriever_result['rank'] == 1]["id"].iloc[0] in retriever_result[retriever_result['rank'] == 1]["gold_document_ids"].iloc[0]
|
||||
assert metrics["MatchInTop1"] == 1.0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
def test_extractive_qa_eval_multiple_queries(reader, retriever_with_docs, tmp_path):
|
||||
queries = ["Who lives in Berlin?", "Who lives in Munich?"]
|
||||
labels = [
|
||||
MultiLabel(labels=[Label(query="Who lives in Berlin?", answer=Answer(answer="Carla", offsets_in_context=[Span(11, 16)]),
|
||||
document=Document(id='a0747b83aea0b60c4b114b15476dd32d', content_type="text", content='My name is Carla and I live in Berlin'),
|
||||
is_correct_answer=True, is_correct_document=True, origin="gold-label")]),
|
||||
MultiLabel(labels=[Label(query="Who lives in Munich?", answer=Answer(answer="Carla", offsets_in_context=[Span(11, 16)]),
|
||||
document=Document(id='something_else', content_type="text", content='My name is Carla and I live in Munich'),
|
||||
is_correct_answer=True, is_correct_document=True, origin="gold-label")])
|
||||
]
|
||||
|
||||
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||
eval_result = pipeline.eval(
|
||||
queries=queries,
|
||||
labels=labels,
|
||||
params={"Retriever": {"top_k": 5}},
|
||||
)
|
||||
|
||||
metrics = eval_result.calculate_metrics()
|
||||
|
||||
reader_result = eval_result["Reader"]
|
||||
retriever_result = eval_result["Retriever"]
|
||||
|
||||
reader_berlin = reader_result[reader_result['query'] == "Who lives in Berlin?"]
|
||||
reader_munich = reader_result[reader_result['query'] == "Who lives in Munich?"]
|
||||
|
||||
retriever_berlin = retriever_result[retriever_result['query'] == "Who lives in Berlin?"]
|
||||
retriever_munich = retriever_result[retriever_result['query'] == "Who lives in Munich?"]
|
||||
|
||||
assert reader_berlin[reader_berlin['rank'] == 1]["answer"].iloc[0] in reader_berlin[reader_berlin['rank'] == 1]["gold_answers"].iloc[0]
|
||||
assert retriever_berlin[retriever_berlin['rank'] == 1]["id"].iloc[0] in retriever_berlin[retriever_berlin['rank'] == 1]["gold_document_ids"].iloc[0]
|
||||
assert reader_munich[reader_munich['rank'] == 1]["answer"].iloc[0] not in reader_munich[reader_munich['rank'] == 1]["gold_answers"].iloc[0]
|
||||
assert retriever_munich[retriever_munich['rank'] == 1]["id"].iloc[0] not in retriever_munich[retriever_munich['rank'] == 1]["gold_document_ids"].iloc[0]
|
||||
assert metrics["MatchInTop1"] == 0.5
|
||||
|
||||
eval_result.save(tmp_path)
|
||||
saved_eval_result = EvaluationResult.load(tmp_path)
|
||||
metrics = saved_eval_result.calculate_metrics()
|
||||
|
||||
assert reader_berlin[reader_berlin['rank'] == 1]["answer"].iloc[0] in reader_berlin[reader_berlin['rank'] == 1]["gold_answers"].iloc[0]
|
||||
assert retriever_berlin[retriever_berlin['rank'] == 1]["id"].iloc[0] in retriever_berlin[retriever_berlin['rank'] == 1]["gold_document_ids"].iloc[0]
|
||||
assert reader_munich[reader_munich['rank'] == 1]["answer"].iloc[0] not in reader_munich[reader_munich['rank'] == 1]["gold_answers"].iloc[0]
|
||||
assert retriever_munich[retriever_munich['rank'] == 1]["id"].iloc[0] not in retriever_munich[retriever_munich['rank'] == 1]["gold_document_ids"].iloc[0]
|
||||
assert metrics["MatchInTop1"] == 0.5
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user