mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-27 17:15:35 +00:00
Adjust get_type() method for pipelines (#3657)
This commit is contained in:
parent
adb580b6b7
commit
e4c3817d01
@ -3,13 +3,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
|
import itertools
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from hashlib import sha1
|
from hashlib import sha1
|
||||||
import itertools
|
|
||||||
from typing import Dict, List, Optional, Any, Set, Tuple, Union
|
from typing import Dict, List, Optional, Any, Set, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -48,6 +47,7 @@ from haystack.pipelines.utils import generate_code, print_eval_report
|
|||||||
from haystack.utils import DeepsetCloud, calculate_context_similarity
|
from haystack.utils import DeepsetCloud, calculate_context_similarity
|
||||||
from haystack.schema import Answer, EvaluationResult, MultiLabel, Document, Span
|
from haystack.schema import Answer, EvaluationResult, MultiLabel, Document, Span
|
||||||
from haystack.errors import HaystackError, PipelineError, PipelineConfigError
|
from haystack.errors import HaystackError, PipelineError, PipelineConfigError
|
||||||
|
from haystack.nodes import BaseGenerator, Docs2Answers, BaseReader, BaseSummarizer, BaseTranslator, QuestionGenerator
|
||||||
from haystack.nodes.base import BaseComponent, RootNode
|
from haystack.nodes.base import BaseComponent, RootNode
|
||||||
from haystack.nodes.retriever.base import BaseRetriever
|
from haystack.nodes.retriever.base import BaseRetriever
|
||||||
from haystack.document_stores.base import BaseDocumentStore
|
from haystack.document_stores.base import BaseDocumentStore
|
||||||
@ -2238,17 +2238,34 @@ class Pipeline:
|
|||||||
# values of the dict are functions evaluating whether components of this pipeline match the pipeline type
|
# values of the dict are functions evaluating whether components of this pipeline match the pipeline type
|
||||||
# specified by dict keys
|
# specified by dict keys
|
||||||
pipeline_types = {
|
pipeline_types = {
|
||||||
"GenerativeQAPipeline": lambda x: {"Generator", "Retriever"} <= set(x.keys()),
|
# QuestionGenerationPipeline has only one component, which is a QuestionGenerator
|
||||||
"FAQPipeline": lambda x: {"Docs2Answers"} <= set(x.keys()),
|
"QuestionGenerationPipeline": lambda x: all(isinstance(x, QuestionGenerator) for x in x.values()),
|
||||||
"ExtractiveQAPipeline": lambda x: {"Reader", "Retriever"} <= set(x.keys()),
|
# GenerativeQAPipeline has at least BaseGenerator and BaseRetriever components
|
||||||
"SearchSummarizationPipeline": lambda x: {"Retriever", "Summarizer"} <= set(x.keys()),
|
"GenerativeQAPipeline": lambda x: any(isinstance(x, BaseRetriever) for x in x.values())
|
||||||
"TranslationWrapperPipeline": lambda x: {"InputTranslator", "OutputTranslator"} <= set(x.keys()),
|
and any(isinstance(x, BaseGenerator) for x in x.values()),
|
||||||
"RetrieverQuestionGenerationPipeline": lambda x: {"Retriever", "QuestionGenerator"} <= set(x.keys()),
|
# FAQPipeline has at least one Docs2Answers component
|
||||||
"QuestionAnswerGenerationPipeline": lambda x: {"QuestionGenerator", "Reader"} <= set(x.keys()),
|
"FAQPipeline": lambda x: any(isinstance(x, Docs2Answers) for x in x.values()),
|
||||||
"DocumentSearchPipeline": lambda x: {"Retriever"} <= set(x.keys()),
|
# ExtractiveQAPipeline has at least one BaseRetriever component and one BaseReader component
|
||||||
"QuestionGenerationPipeline": lambda x: {"QuestionGenerator"} <= set(x.keys()),
|
"ExtractiveQAPipeline": lambda x: any(isinstance(x, BaseRetriever) for x in x.values())
|
||||||
|
and any(isinstance(x, BaseReader) for x in x.values()),
|
||||||
|
# ExtractiveQAPipeline has at least one BaseSummarizer component and one BaseRetriever component
|
||||||
|
"SearchSummarizationPipeline": lambda x: any(isinstance(x, BaseRetriever) for x in x.values())
|
||||||
|
and any(isinstance(x, BaseSummarizer) for x in x.values()),
|
||||||
|
# TranslationWrapperPipeline has two or more BaseTranslator components
|
||||||
|
"TranslationWrapperPipeline": lambda x: [isinstance(x, BaseTranslator) for x in x.values()].count(True)
|
||||||
|
>= 2,
|
||||||
|
# RetrieverQuestionGenerationPipeline has at least one BaseRetriever component and one
|
||||||
|
# QuestionGenerator component
|
||||||
|
"RetrieverQuestionGenerationPipeline": lambda x: any(isinstance(x, BaseRetriever) for x in x.values())
|
||||||
|
and any(isinstance(x, QuestionGenerator) for x in x.values()),
|
||||||
|
# QuestionAnswerGenerationPipeline has at least one BaseReader component and one QuestionGenerator component
|
||||||
|
"QuestionAnswerGenerationPipeline": lambda x: any(isinstance(x, BaseReader) for x in x.values())
|
||||||
|
and any(isinstance(x, QuestionGenerator) for x in x.values()),
|
||||||
|
# MostSimilarDocumentsPipeline has only BaseDocumentStore component
|
||||||
"MostSimilarDocumentsPipeline": lambda x: len(x.values()) == 1
|
"MostSimilarDocumentsPipeline": lambda x: len(x.values()) == 1
|
||||||
and isinstance(list(x.values())[0], BaseDocumentStore),
|
and isinstance(list(x.values())[0], BaseDocumentStore),
|
||||||
|
# DocumentSearchPipeline has at least one BaseRetriever component
|
||||||
|
"DocumentSearchPipeline": lambda x: any(isinstance(x, BaseRetriever) for x in x.values()),
|
||||||
}
|
}
|
||||||
retrievers = [type(comp).__name__ for comp in self.components.values() if isinstance(comp, BaseRetriever)]
|
retrievers = [type(comp).__name__ for comp in self.components.values() if isinstance(comp, BaseRetriever)]
|
||||||
doc_stores = [type(comp).__name__ for comp in self.components.values() if isinstance(comp, BaseDocumentStore)]
|
doc_stores = [type(comp).__name__ for comp in self.components.values() if isinstance(comp, BaseDocumentStore)]
|
||||||
|
@ -36,6 +36,7 @@ from haystack.pipelines import (
|
|||||||
DocumentSearchPipeline,
|
DocumentSearchPipeline,
|
||||||
QuestionGenerationPipeline,
|
QuestionGenerationPipeline,
|
||||||
MostSimilarDocumentsPipeline,
|
MostSimilarDocumentsPipeline,
|
||||||
|
BaseStandardPipeline,
|
||||||
)
|
)
|
||||||
from haystack.pipelines.config import validate_config_strings, get_component_definitions
|
from haystack.pipelines.config import validate_config_strings, get_component_definitions
|
||||||
from haystack.pipelines.utils import generate_code
|
from haystack.pipelines.utils import generate_code
|
||||||
@ -784,7 +785,7 @@ def test_validate_pipeline_config_recursive_config(reduce_windows_recursion_limi
|
|||||||
validate_config_strings(pipeline_config)
|
validate_config_strings(pipeline_config)
|
||||||
|
|
||||||
|
|
||||||
def test_pipeline_classify_type():
|
def test_pipeline_classify_type(tmp_path):
|
||||||
|
|
||||||
pipe = GenerativeQAPipeline(generator=MockSeq2SegGenerator(), retriever=MockRetriever())
|
pipe = GenerativeQAPipeline(generator=MockSeq2SegGenerator(), retriever=MockRetriever())
|
||||||
assert pipe.get_type().startswith("GenerativeQAPipeline")
|
assert pipe.get_type().startswith("GenerativeQAPipeline")
|
||||||
@ -818,6 +819,99 @@ def test_pipeline_classify_type():
|
|||||||
pipe = MostSimilarDocumentsPipeline(document_store=MockDocumentStore())
|
pipe = MostSimilarDocumentsPipeline(document_store=MockDocumentStore())
|
||||||
assert pipe.get_type().startswith("MostSimilarDocumentsPipeline")
|
assert pipe.get_type().startswith("MostSimilarDocumentsPipeline")
|
||||||
|
|
||||||
|
# previously misclassified as "UnknownPipeline"
|
||||||
|
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||||
|
tmp_file.write(
|
||||||
|
f"""
|
||||||
|
version: ignore
|
||||||
|
components:
|
||||||
|
- name: document_store
|
||||||
|
type: MockDocumentStore
|
||||||
|
- name: retriever
|
||||||
|
type: MockRetriever
|
||||||
|
- name: retriever_2
|
||||||
|
type: MockRetriever
|
||||||
|
pipelines:
|
||||||
|
- name: my_pipeline
|
||||||
|
nodes:
|
||||||
|
- name: retriever
|
||||||
|
inputs:
|
||||||
|
- Query
|
||||||
|
- name: retriever_2
|
||||||
|
inputs:
|
||||||
|
- Query
|
||||||
|
- name: document_store
|
||||||
|
inputs:
|
||||||
|
- retriever
|
||||||
|
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
pipe = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||||
|
# two retrievers but still a DocumentSearchPipeline
|
||||||
|
assert pipe.get_type().startswith("DocumentSearchPipeline")
|
||||||
|
|
||||||
|
# previously misclassified as "UnknownPipeline"
|
||||||
|
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||||
|
tmp_file.write(
|
||||||
|
f"""
|
||||||
|
version: ignore
|
||||||
|
components:
|
||||||
|
- name: document_store
|
||||||
|
type: MockDocumentStore
|
||||||
|
- name: retriever
|
||||||
|
type: MockRetriever
|
||||||
|
- name: retriever_2
|
||||||
|
type: MockRetriever
|
||||||
|
- name: retriever_3
|
||||||
|
type: MockRetriever
|
||||||
|
pipelines:
|
||||||
|
- name: my_pipeline
|
||||||
|
nodes:
|
||||||
|
- name: retriever
|
||||||
|
inputs:
|
||||||
|
- Query
|
||||||
|
- name: retriever_2
|
||||||
|
inputs:
|
||||||
|
- Query
|
||||||
|
- name: retriever_3
|
||||||
|
inputs:
|
||||||
|
- Query
|
||||||
|
- name: document_store
|
||||||
|
inputs:
|
||||||
|
- retriever
|
||||||
|
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
pipe = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||||
|
# three retrievers but still a DocumentSearchPipeline
|
||||||
|
assert pipe.get_type().startswith("DocumentSearchPipeline")
|
||||||
|
|
||||||
|
# previously misclassified as "UnknownPipeline"
|
||||||
|
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||||
|
tmp_file.write(
|
||||||
|
f"""
|
||||||
|
version: ignore
|
||||||
|
components:
|
||||||
|
- name: document_store
|
||||||
|
type: MockDocumentStore
|
||||||
|
- name: retriever
|
||||||
|
type: BM25Retriever
|
||||||
|
pipelines:
|
||||||
|
- name: my_pipeline
|
||||||
|
nodes:
|
||||||
|
- name: retriever
|
||||||
|
inputs:
|
||||||
|
- Query
|
||||||
|
- name: document_store
|
||||||
|
inputs:
|
||||||
|
- retriever
|
||||||
|
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
pipe = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||||
|
# BM25Retriever used - still a DocumentSearchPipeline
|
||||||
|
assert pipe.get_type().startswith("DocumentSearchPipeline")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures(deepset_cloud_fixture.__name__)
|
@pytest.mark.usefixtures(deepset_cloud_fixture.__name__)
|
||||||
@responses.activate
|
@responses.activate
|
||||||
|
Loading…
x
Reference in New Issue
Block a user