mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-26 00:24:14 +00:00
Adjust get_type() method for pipelines (#3657)
This commit is contained in:
parent
adb580b6b7
commit
e4c3817d01
@ -3,13 +3,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import itertools
|
||||
from datetime import timedelta
|
||||
from functools import partial
|
||||
from hashlib import sha1
|
||||
import itertools
|
||||
from typing import Dict, List, Optional, Any, Set, Tuple, Union
|
||||
|
||||
|
||||
try:
|
||||
from typing import Literal
|
||||
except ImportError:
|
||||
@ -48,6 +47,7 @@ from haystack.pipelines.utils import generate_code, print_eval_report
|
||||
from haystack.utils import DeepsetCloud, calculate_context_similarity
|
||||
from haystack.schema import Answer, EvaluationResult, MultiLabel, Document, Span
|
||||
from haystack.errors import HaystackError, PipelineError, PipelineConfigError
|
||||
from haystack.nodes import BaseGenerator, Docs2Answers, BaseReader, BaseSummarizer, BaseTranslator, QuestionGenerator
|
||||
from haystack.nodes.base import BaseComponent, RootNode
|
||||
from haystack.nodes.retriever.base import BaseRetriever
|
||||
from haystack.document_stores.base import BaseDocumentStore
|
||||
@ -2238,17 +2238,34 @@ class Pipeline:
|
||||
# values of the dict are functions evaluating whether components of this pipeline match the pipeline type
|
||||
# specified by dict keys
|
||||
pipeline_types = {
|
||||
"GenerativeQAPipeline": lambda x: {"Generator", "Retriever"} <= set(x.keys()),
|
||||
"FAQPipeline": lambda x: {"Docs2Answers"} <= set(x.keys()),
|
||||
"ExtractiveQAPipeline": lambda x: {"Reader", "Retriever"} <= set(x.keys()),
|
||||
"SearchSummarizationPipeline": lambda x: {"Retriever", "Summarizer"} <= set(x.keys()),
|
||||
"TranslationWrapperPipeline": lambda x: {"InputTranslator", "OutputTranslator"} <= set(x.keys()),
|
||||
"RetrieverQuestionGenerationPipeline": lambda x: {"Retriever", "QuestionGenerator"} <= set(x.keys()),
|
||||
"QuestionAnswerGenerationPipeline": lambda x: {"QuestionGenerator", "Reader"} <= set(x.keys()),
|
||||
"DocumentSearchPipeline": lambda x: {"Retriever"} <= set(x.keys()),
|
||||
"QuestionGenerationPipeline": lambda x: {"QuestionGenerator"} <= set(x.keys()),
|
||||
# QuestionGenerationPipeline has only one component, which is a QuestionGenerator
|
||||
"QuestionGenerationPipeline": lambda x: all(isinstance(x, QuestionGenerator) for x in x.values()),
|
||||
# GenerativeQAPipeline has at least BaseGenerator and BaseRetriever components
|
||||
"GenerativeQAPipeline": lambda x: any(isinstance(x, BaseRetriever) for x in x.values())
|
||||
and any(isinstance(x, BaseGenerator) for x in x.values()),
|
||||
# FAQPipeline has at least one Docs2Answers component
|
||||
"FAQPipeline": lambda x: any(isinstance(x, Docs2Answers) for x in x.values()),
|
||||
# ExtractiveQAPipeline has at least one BaseRetriever component and one BaseReader component
|
||||
"ExtractiveQAPipeline": lambda x: any(isinstance(x, BaseRetriever) for x in x.values())
|
||||
and any(isinstance(x, BaseReader) for x in x.values()),
|
||||
# ExtractiveQAPipeline has at least one BaseSummarizer component and one BaseRetriever component
|
||||
"SearchSummarizationPipeline": lambda x: any(isinstance(x, BaseRetriever) for x in x.values())
|
||||
and any(isinstance(x, BaseSummarizer) for x in x.values()),
|
||||
# TranslationWrapperPipeline has two or more BaseTranslator components
|
||||
"TranslationWrapperPipeline": lambda x: [isinstance(x, BaseTranslator) for x in x.values()].count(True)
|
||||
>= 2,
|
||||
# RetrieverQuestionGenerationPipeline has at least one BaseRetriever component and one
|
||||
# QuestionGenerator component
|
||||
"RetrieverQuestionGenerationPipeline": lambda x: any(isinstance(x, BaseRetriever) for x in x.values())
|
||||
and any(isinstance(x, QuestionGenerator) for x in x.values()),
|
||||
# QuestionAnswerGenerationPipeline has at least one BaseReader component and one QuestionGenerator component
|
||||
"QuestionAnswerGenerationPipeline": lambda x: any(isinstance(x, BaseReader) for x in x.values())
|
||||
and any(isinstance(x, QuestionGenerator) for x in x.values()),
|
||||
# MostSimilarDocumentsPipeline has only BaseDocumentStore component
|
||||
"MostSimilarDocumentsPipeline": lambda x: len(x.values()) == 1
|
||||
and isinstance(list(x.values())[0], BaseDocumentStore),
|
||||
# DocumentSearchPipeline has at least one BaseRetriever component
|
||||
"DocumentSearchPipeline": lambda x: any(isinstance(x, BaseRetriever) for x in x.values()),
|
||||
}
|
||||
retrievers = [type(comp).__name__ for comp in self.components.values() if isinstance(comp, BaseRetriever)]
|
||||
doc_stores = [type(comp).__name__ for comp in self.components.values() if isinstance(comp, BaseDocumentStore)]
|
||||
|
@ -36,6 +36,7 @@ from haystack.pipelines import (
|
||||
DocumentSearchPipeline,
|
||||
QuestionGenerationPipeline,
|
||||
MostSimilarDocumentsPipeline,
|
||||
BaseStandardPipeline,
|
||||
)
|
||||
from haystack.pipelines.config import validate_config_strings, get_component_definitions
|
||||
from haystack.pipelines.utils import generate_code
|
||||
@ -784,7 +785,7 @@ def test_validate_pipeline_config_recursive_config(reduce_windows_recursion_limi
|
||||
validate_config_strings(pipeline_config)
|
||||
|
||||
|
||||
def test_pipeline_classify_type():
|
||||
def test_pipeline_classify_type(tmp_path):
|
||||
|
||||
pipe = GenerativeQAPipeline(generator=MockSeq2SegGenerator(), retriever=MockRetriever())
|
||||
assert pipe.get_type().startswith("GenerativeQAPipeline")
|
||||
@ -818,6 +819,99 @@ def test_pipeline_classify_type():
|
||||
pipe = MostSimilarDocumentsPipeline(document_store=MockDocumentStore())
|
||||
assert pipe.get_type().startswith("MostSimilarDocumentsPipeline")
|
||||
|
||||
# previously misclassified as "UnknownPipeline"
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: ignore
|
||||
components:
|
||||
- name: document_store
|
||||
type: MockDocumentStore
|
||||
- name: retriever
|
||||
type: MockRetriever
|
||||
- name: retriever_2
|
||||
type: MockRetriever
|
||||
pipelines:
|
||||
- name: my_pipeline
|
||||
nodes:
|
||||
- name: retriever
|
||||
inputs:
|
||||
- Query
|
||||
- name: retriever_2
|
||||
inputs:
|
||||
- Query
|
||||
- name: document_store
|
||||
inputs:
|
||||
- retriever
|
||||
|
||||
"""
|
||||
)
|
||||
pipe = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
# two retrievers but still a DocumentSearchPipeline
|
||||
assert pipe.get_type().startswith("DocumentSearchPipeline")
|
||||
|
||||
# previously misclassified as "UnknownPipeline"
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: ignore
|
||||
components:
|
||||
- name: document_store
|
||||
type: MockDocumentStore
|
||||
- name: retriever
|
||||
type: MockRetriever
|
||||
- name: retriever_2
|
||||
type: MockRetriever
|
||||
- name: retriever_3
|
||||
type: MockRetriever
|
||||
pipelines:
|
||||
- name: my_pipeline
|
||||
nodes:
|
||||
- name: retriever
|
||||
inputs:
|
||||
- Query
|
||||
- name: retriever_2
|
||||
inputs:
|
||||
- Query
|
||||
- name: retriever_3
|
||||
inputs:
|
||||
- Query
|
||||
- name: document_store
|
||||
inputs:
|
||||
- retriever
|
||||
|
||||
"""
|
||||
)
|
||||
pipe = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
# three retrievers but still a DocumentSearchPipeline
|
||||
assert pipe.get_type().startswith("DocumentSearchPipeline")
|
||||
|
||||
# previously misclassified as "UnknownPipeline"
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: ignore
|
||||
components:
|
||||
- name: document_store
|
||||
type: MockDocumentStore
|
||||
- name: retriever
|
||||
type: BM25Retriever
|
||||
pipelines:
|
||||
- name: my_pipeline
|
||||
nodes:
|
||||
- name: retriever
|
||||
inputs:
|
||||
- Query
|
||||
- name: document_store
|
||||
inputs:
|
||||
- retriever
|
||||
|
||||
"""
|
||||
)
|
||||
pipe = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
# BM25Retriever used - still a DocumentSearchPipeline
|
||||
assert pipe.get_type().startswith("DocumentSearchPipeline")
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(deepset_cloud_fixture.__name__)
|
||||
@responses.activate
|
||||
|
Loading…
x
Reference in New Issue
Block a user