From e4c3817d01c7c81c0accbdb543a0ab5cc9e86ade Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 2 Dec 2022 14:48:47 +0100 Subject: [PATCH] Adjust get_type() method for pipelines (#3657) --- haystack/pipelines/base.py | 39 ++++++++++---- test/pipelines/test_pipeline.py | 96 ++++++++++++++++++++++++++++++++- 2 files changed, 123 insertions(+), 12 deletions(-) diff --git a/haystack/pipelines/base.py b/haystack/pipelines/base.py index a21237a0d..6432e418f 100644 --- a/haystack/pipelines/base.py +++ b/haystack/pipelines/base.py @@ -3,13 +3,12 @@ from __future__ import annotations import datetime +import itertools from datetime import timedelta from functools import partial from hashlib import sha1 -import itertools from typing import Dict, List, Optional, Any, Set, Tuple, Union - try: from typing import Literal except ImportError: @@ -48,6 +47,7 @@ from haystack.pipelines.utils import generate_code, print_eval_report from haystack.utils import DeepsetCloud, calculate_context_similarity from haystack.schema import Answer, EvaluationResult, MultiLabel, Document, Span from haystack.errors import HaystackError, PipelineError, PipelineConfigError +from haystack.nodes import BaseGenerator, Docs2Answers, BaseReader, BaseSummarizer, BaseTranslator, QuestionGenerator from haystack.nodes.base import BaseComponent, RootNode from haystack.nodes.retriever.base import BaseRetriever from haystack.document_stores.base import BaseDocumentStore @@ -2238,17 +2238,34 @@ class Pipeline: # values of the dict are functions evaluating whether components of this pipeline match the pipeline type # specified by dict keys pipeline_types = { - "GenerativeQAPipeline": lambda x: {"Generator", "Retriever"} <= set(x.keys()), - "FAQPipeline": lambda x: {"Docs2Answers"} <= set(x.keys()), - "ExtractiveQAPipeline": lambda x: {"Reader", "Retriever"} <= set(x.keys()), - "SearchSummarizationPipeline": lambda x: {"Retriever", "Summarizer"} <= set(x.keys()), - "TranslationWrapperPipeline": lambda x: {"InputTranslator", "OutputTranslator"} <= set(x.keys()), - "RetrieverQuestionGenerationPipeline": lambda x: {"Retriever", "QuestionGenerator"} <= set(x.keys()), - "QuestionAnswerGenerationPipeline": lambda x: {"QuestionGenerator", "Reader"} <= set(x.keys()), - "DocumentSearchPipeline": lambda x: {"Retriever"} <= set(x.keys()), - "QuestionGenerationPipeline": lambda x: {"QuestionGenerator"} <= set(x.keys()), + # QuestionGenerationPipeline has only one component, which is a QuestionGenerator + "QuestionGenerationPipeline": lambda x: all(isinstance(x, QuestionGenerator) for x in x.values()), + # GenerativeQAPipeline has at least BaseGenerator and BaseRetriever components + "GenerativeQAPipeline": lambda x: any(isinstance(x, BaseRetriever) for x in x.values()) + and any(isinstance(x, BaseGenerator) for x in x.values()), + # FAQPipeline has at least one Docs2Answers component + "FAQPipeline": lambda x: any(isinstance(x, Docs2Answers) for x in x.values()), + # ExtractiveQAPipeline has at least one BaseRetriever component and one BaseReader component + "ExtractiveQAPipeline": lambda x: any(isinstance(x, BaseRetriever) for x in x.values()) + and any(isinstance(x, BaseReader) for x in x.values()), + # ExtractiveQAPipeline has at least one BaseSummarizer component and one BaseRetriever component + "SearchSummarizationPipeline": lambda x: any(isinstance(x, BaseRetriever) for x in x.values()) + and any(isinstance(x, BaseSummarizer) for x in x.values()), + # TranslationWrapperPipeline has two or more BaseTranslator components + "TranslationWrapperPipeline": lambda x: [isinstance(x, BaseTranslator) for x in x.values()].count(True) + >= 2, + # RetrieverQuestionGenerationPipeline has at least one BaseRetriever component and one + # QuestionGenerator component + "RetrieverQuestionGenerationPipeline": lambda x: any(isinstance(x, BaseRetriever) for x in x.values()) + and any(isinstance(x, QuestionGenerator) for x in x.values()), + # QuestionAnswerGenerationPipeline has at least one BaseReader component and one QuestionGenerator component + "QuestionAnswerGenerationPipeline": lambda x: any(isinstance(x, BaseReader) for x in x.values()) + and any(isinstance(x, QuestionGenerator) for x in x.values()), + # MostSimilarDocumentsPipeline has only BaseDocumentStore component "MostSimilarDocumentsPipeline": lambda x: len(x.values()) == 1 and isinstance(list(x.values())[0], BaseDocumentStore), + # DocumentSearchPipeline has at least one BaseRetriever component + "DocumentSearchPipeline": lambda x: any(isinstance(x, BaseRetriever) for x in x.values()), } retrievers = [type(comp).__name__ for comp in self.components.values() if isinstance(comp, BaseRetriever)] doc_stores = [type(comp).__name__ for comp in self.components.values() if isinstance(comp, BaseDocumentStore)] diff --git a/test/pipelines/test_pipeline.py b/test/pipelines/test_pipeline.py index b64da8680..9b122ea62 100644 --- a/test/pipelines/test_pipeline.py +++ b/test/pipelines/test_pipeline.py @@ -36,6 +36,7 @@ from haystack.pipelines import ( DocumentSearchPipeline, QuestionGenerationPipeline, MostSimilarDocumentsPipeline, + BaseStandardPipeline, ) from haystack.pipelines.config import validate_config_strings, get_component_definitions from haystack.pipelines.utils import generate_code @@ -784,7 +785,7 @@ def test_validate_pipeline_config_recursive_config(reduce_windows_recursion_limi validate_config_strings(pipeline_config) -def test_pipeline_classify_type(): +def test_pipeline_classify_type(tmp_path): pipe = GenerativeQAPipeline(generator=MockSeq2SegGenerator(), retriever=MockRetriever()) assert pipe.get_type().startswith("GenerativeQAPipeline") @@ -818,6 +819,99 @@ def test_pipeline_classify_type(): pipe = MostSimilarDocumentsPipeline(document_store=MockDocumentStore()) assert pipe.get_type().startswith("MostSimilarDocumentsPipeline") + # previously misclassified as "UnknownPipeline" + with open(tmp_path / "tmp_config.yml", "w") as tmp_file: + tmp_file.write( + f""" + version: ignore + components: + - name: document_store + type: MockDocumentStore + - name: retriever + type: MockRetriever + - name: retriever_2 + type: MockRetriever + pipelines: + - name: my_pipeline + nodes: + - name: retriever + inputs: + - Query + - name: retriever_2 + inputs: + - Query + - name: document_store + inputs: + - retriever + + """ + ) + pipe = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml") + # two retrievers but still a DocumentSearchPipeline + assert pipe.get_type().startswith("DocumentSearchPipeline") + + # previously misclassified as "UnknownPipeline" + with open(tmp_path / "tmp_config.yml", "w") as tmp_file: + tmp_file.write( + f""" + version: ignore + components: + - name: document_store + type: MockDocumentStore + - name: retriever + type: MockRetriever + - name: retriever_2 + type: MockRetriever + - name: retriever_3 + type: MockRetriever + pipelines: + - name: my_pipeline + nodes: + - name: retriever + inputs: + - Query + - name: retriever_2 + inputs: + - Query + - name: retriever_3 + inputs: + - Query + - name: document_store + inputs: + - retriever + + """ + ) + pipe = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml") + # three retrievers but still a DocumentSearchPipeline + assert pipe.get_type().startswith("DocumentSearchPipeline") + + # previously misclassified as "UnknownPipeline" + with open(tmp_path / "tmp_config.yml", "w") as tmp_file: + tmp_file.write( + f""" + version: ignore + components: + - name: document_store + type: MockDocumentStore + - name: retriever + type: BM25Retriever + pipelines: + - name: my_pipeline + nodes: + - name: retriever + inputs: + - Query + - name: document_store + inputs: + - retriever + + """ + ) + pipe = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml") + # BM25Retriever used - still a DocumentSearchPipeline + assert pipe.get_type().startswith("DocumentSearchPipeline") + @pytest.mark.usefixtures(deepset_cloud_fixture.__name__) @responses.activate