diff --git a/docs/_src/api/api/pipelines.md b/docs/_src/api/api/pipelines.md index b8199b99a..6d9366d91 100644 --- a/docs/_src/api/api/pipelines.md +++ b/docs/_src/api/api/pipelines.md @@ -242,7 +242,7 @@ then be found in the dict returned by this method under the key "_debug" #### eval ```python -def eval(labels: List[MultiLabel], params: Optional[dict] = None, sas_model_name_or_path: str = None, add_isolated_node_eval: bool = False) -> EvaluationResult +def eval(labels: List[MultiLabel], documents: Optional[List[List[Document]]] = None, params: Optional[dict] = None, sas_model_name_or_path: str = None, add_isolated_node_eval: bool = False) -> EvaluationResult ``` Evaluates the pipeline by running the pipeline once per query in debug mode @@ -252,6 +252,7 @@ and putting together all data that is needed for evaluation, e.g. calculating me **Arguments**: - `labels`: The labels to evaluate on +- `documents`: List of List of Document that the first node in the pipeline should get as input per multilabel. Can be used to evaluate a pipeline that consists of a reader without a retriever. - `params`: Dictionary of parameters to be dispatched to the nodes. If you want to pass a param to all nodes, you can just use: {"top_k":10} If you want to pass it to targeted nodes, you can do: diff --git a/haystack/pipelines/base.py b/haystack/pipelines/base.py index 97191d57d..a02424c26 100644 --- a/haystack/pipelines/base.py +++ b/haystack/pipelines/base.py @@ -480,6 +480,7 @@ class Pipeline(BasePipeline): def eval( self, labels: List[MultiLabel], + documents: Optional[List[List[Document]]] = None, params: Optional[dict] = None, sas_model_name_or_path: str = None, add_isolated_node_eval: bool = False, @@ -489,6 +490,7 @@ class Pipeline(BasePipeline): and putting together all data that is needed for evaluation, e.g. calculating metrics. :param labels: The labels to evaluate on + :param documents: List of List of Document that the first node in the pipeline should get as input per multilabel. Can be used to evaluate a pipeline that consists of a reader without a retriever. :param params: Dictionary of parameters to be dispatched to the nodes. If you want to pass a param to all nodes, you can just use: {"top_k":10} If you want to pass it to targeted nodes, you can do: @@ -518,7 +520,9 @@ class Pipeline(BasePipeline): if params is None: params = {} params["add_isolated_node_eval"] = True - for label in labels: + + # if documents is None, set docs_per_label to None for each label + for docs_per_label, label in zip(documents or [None] * len(labels), labels): params_per_label = copy.deepcopy(params) if label.filters is not None: if params_per_label is None: @@ -526,7 +530,9 @@ class Pipeline(BasePipeline): else: # join both filters and overwrite filters in params with filters in labels params_per_label["filters"] = {**params_per_label.get("filters", {}), **label.filters} - predictions = self.run(query=label.query, labels=label, params=params_per_label, debug=True) + predictions = self.run( + query=label.query, labels=label, documents=docs_per_label, params=params_per_label, debug=True + ) for node_name in predictions["_debug"].keys(): node_output = predictions["_debug"][node_name]["output"] diff --git a/test/test_eval.py b/test/test_eval.py index 742d8912c..295fd91b8 100644 --- a/test/test_eval.py +++ b/test/test_eval.py @@ -519,6 +519,21 @@ def test_extractive_qa_eval_sas(reader, retriever_with_docs): assert metrics["Reader"]["sas"] == pytest.approx(1.0) +def test_reader_eval_in_pipeline(reader): + pipeline = Pipeline() + pipeline.add_node(component=reader, name="Reader", inputs=["Query"]) + eval_result: EvaluationResult = pipeline.eval( + labels=EVAL_LABELS, + documents=[[label.document for label in multilabel.labels] for multilabel in EVAL_LABELS], + params={}, + ) + + metrics = eval_result.calculate_metrics() + + assert metrics["Reader"]["exact_match"] == 1.0 + assert metrics["Reader"]["f1"] == 1.0 + + @pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True) @pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True) def test_extractive_qa_eval_doc_relevance_col(reader, retriever_with_docs):