Adjust get_type() method for pipelines (#3657)

This commit is contained in:
Vladimir Blagojevic 2022-12-02 14:48:47 +01:00 committed by GitHub
parent adb580b6b7
commit e4c3817d01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 123 additions and 12 deletions

View File

@ -3,13 +3,12 @@
from __future__ import annotations
import datetime
import itertools
from datetime import timedelta
from functools import partial
from hashlib import sha1
import itertools
from typing import Dict, List, Optional, Any, Set, Tuple, Union
try:
from typing import Literal
except ImportError:
@ -48,6 +47,7 @@ from haystack.pipelines.utils import generate_code, print_eval_report
from haystack.utils import DeepsetCloud, calculate_context_similarity
from haystack.schema import Answer, EvaluationResult, MultiLabel, Document, Span
from haystack.errors import HaystackError, PipelineError, PipelineConfigError
from haystack.nodes import BaseGenerator, Docs2Answers, BaseReader, BaseSummarizer, BaseTranslator, QuestionGenerator
from haystack.nodes.base import BaseComponent, RootNode
from haystack.nodes.retriever.base import BaseRetriever
from haystack.document_stores.base import BaseDocumentStore
@ -2238,17 +2238,34 @@ class Pipeline:
# values of the dict are functions evaluating whether components of this pipeline match the pipeline type
# specified by dict keys
pipeline_types = {
"GenerativeQAPipeline": lambda x: {"Generator", "Retriever"} <= set(x.keys()),
"FAQPipeline": lambda x: {"Docs2Answers"} <= set(x.keys()),
"ExtractiveQAPipeline": lambda x: {"Reader", "Retriever"} <= set(x.keys()),
"SearchSummarizationPipeline": lambda x: {"Retriever", "Summarizer"} <= set(x.keys()),
"TranslationWrapperPipeline": lambda x: {"InputTranslator", "OutputTranslator"} <= set(x.keys()),
"RetrieverQuestionGenerationPipeline": lambda x: {"Retriever", "QuestionGenerator"} <= set(x.keys()),
"QuestionAnswerGenerationPipeline": lambda x: {"QuestionGenerator", "Reader"} <= set(x.keys()),
"DocumentSearchPipeline": lambda x: {"Retriever"} <= set(x.keys()),
"QuestionGenerationPipeline": lambda x: {"QuestionGenerator"} <= set(x.keys()),
# QuestionGenerationPipeline has only one component, which is a QuestionGenerator
"QuestionGenerationPipeline": lambda x: all(isinstance(x, QuestionGenerator) for x in x.values()),
# GenerativeQAPipeline has at least BaseGenerator and BaseRetriever components
"GenerativeQAPipeline": lambda x: any(isinstance(x, BaseRetriever) for x in x.values())
and any(isinstance(x, BaseGenerator) for x in x.values()),
# FAQPipeline has at least one Docs2Answers component
"FAQPipeline": lambda x: any(isinstance(x, Docs2Answers) for x in x.values()),
# ExtractiveQAPipeline has at least one BaseRetriever component and one BaseReader component
"ExtractiveQAPipeline": lambda x: any(isinstance(x, BaseRetriever) for x in x.values())
and any(isinstance(x, BaseReader) for x in x.values()),
# ExtractiveQAPipeline has at least one BaseSummarizer component and one BaseRetriever component
"SearchSummarizationPipeline": lambda x: any(isinstance(x, BaseRetriever) for x in x.values())
and any(isinstance(x, BaseSummarizer) for x in x.values()),
# TranslationWrapperPipeline has two or more BaseTranslator components
"TranslationWrapperPipeline": lambda x: [isinstance(x, BaseTranslator) for x in x.values()].count(True)
>= 2,
# RetrieverQuestionGenerationPipeline has at least one BaseRetriever component and one
# QuestionGenerator component
"RetrieverQuestionGenerationPipeline": lambda x: any(isinstance(x, BaseRetriever) for x in x.values())
and any(isinstance(x, QuestionGenerator) for x in x.values()),
# QuestionAnswerGenerationPipeline has at least one BaseReader component and one QuestionGenerator component
"QuestionAnswerGenerationPipeline": lambda x: any(isinstance(x, BaseReader) for x in x.values())
and any(isinstance(x, QuestionGenerator) for x in x.values()),
# MostSimilarDocumentsPipeline has only BaseDocumentStore component
"MostSimilarDocumentsPipeline": lambda x: len(x.values()) == 1
and isinstance(list(x.values())[0], BaseDocumentStore),
# DocumentSearchPipeline has at least one BaseRetriever component
"DocumentSearchPipeline": lambda x: any(isinstance(x, BaseRetriever) for x in x.values()),
}
retrievers = [type(comp).__name__ for comp in self.components.values() if isinstance(comp, BaseRetriever)]
doc_stores = [type(comp).__name__ for comp in self.components.values() if isinstance(comp, BaseDocumentStore)]

View File

@ -36,6 +36,7 @@ from haystack.pipelines import (
DocumentSearchPipeline,
QuestionGenerationPipeline,
MostSimilarDocumentsPipeline,
BaseStandardPipeline,
)
from haystack.pipelines.config import validate_config_strings, get_component_definitions
from haystack.pipelines.utils import generate_code
@ -784,7 +785,7 @@ def test_validate_pipeline_config_recursive_config(reduce_windows_recursion_limi
validate_config_strings(pipeline_config)
def test_pipeline_classify_type():
def test_pipeline_classify_type(tmp_path):
pipe = GenerativeQAPipeline(generator=MockSeq2SegGenerator(), retriever=MockRetriever())
assert pipe.get_type().startswith("GenerativeQAPipeline")
@ -818,6 +819,99 @@ def test_pipeline_classify_type():
pipe = MostSimilarDocumentsPipeline(document_store=MockDocumentStore())
assert pipe.get_type().startswith("MostSimilarDocumentsPipeline")
# previously misclassified as "UnknownPipeline"
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
f"""
version: ignore
components:
- name: document_store
type: MockDocumentStore
- name: retriever
type: MockRetriever
- name: retriever_2
type: MockRetriever
pipelines:
- name: my_pipeline
nodes:
- name: retriever
inputs:
- Query
- name: retriever_2
inputs:
- Query
- name: document_store
inputs:
- retriever
"""
)
pipe = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
# two retrievers but still a DocumentSearchPipeline
assert pipe.get_type().startswith("DocumentSearchPipeline")
# previously misclassified as "UnknownPipeline"
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
f"""
version: ignore
components:
- name: document_store
type: MockDocumentStore
- name: retriever
type: MockRetriever
- name: retriever_2
type: MockRetriever
- name: retriever_3
type: MockRetriever
pipelines:
- name: my_pipeline
nodes:
- name: retriever
inputs:
- Query
- name: retriever_2
inputs:
- Query
- name: retriever_3
inputs:
- Query
- name: document_store
inputs:
- retriever
"""
)
pipe = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
# three retrievers but still a DocumentSearchPipeline
assert pipe.get_type().startswith("DocumentSearchPipeline")
# previously misclassified as "UnknownPipeline"
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
f"""
version: ignore
components:
- name: document_store
type: MockDocumentStore
- name: retriever
type: BM25Retriever
pipelines:
- name: my_pipeline
nodes:
- name: retriever
inputs:
- Query
- name: document_store
inputs:
- retriever
"""
)
pipe = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
# BM25Retriever used - still a DocumentSearchPipeline
assert pipe.get_type().startswith("DocumentSearchPipeline")
@pytest.mark.usefixtures(deepset_cloud_fixture.__name__)
@responses.activate