From b042dd9c82f359882ac2b003816e3b3fb0298e3c Mon Sep 17 00:00:00 2001 From: tstadel <60758086+tstadel@users.noreply.github.com> Date: Thu, 4 Aug 2022 10:27:50 +0200 Subject: [PATCH] Fix validation for dynamic outgoing edges (#2850) * fix validation for dynamic outgoing edges * Update Documentation & Code Style * use class outgoing_edges as fallback if no instance is provided * implement classmethod approach * readd comment * fix mypy * fix tests * set outgoing_edges for all components * set outgoing_edges for mocks too * set document store outgoing_edges to 1 * set last missing outgoing_edges * enforce BaseComponent subclasses to define outgoing_edges * override _calculate_outgoing_edges for FileTypeClassifier * remove superfluous test * set rest_api's custom component's outgoing_edges * Update docstring Co-authored-by: Sara Zan * remove unnecessary else Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Sara Zan --- haystack/document_stores/base.py | 2 + haystack/nodes/base.py | 39 ++++++- haystack/nodes/file_classifier/file_type.py | 11 +- .../label_generator/pseudo_label_generator.py | 2 + haystack/nodes/other/join.py | 3 + haystack/nodes/other/route_documents.py | 18 ++- haystack/pipelines/config.py | 103 +++++++++--------- rest_api/pipeline/custom_component.py | 2 + test/nodes/test_filetype_classifier.py | 5 - test/nodes/test_other.py | 1 + test/pipelines/test_pipeline.py | 2 + 11 files changed, 119 insertions(+), 69 deletions(-) diff --git a/haystack/document_stores/base.py b/haystack/document_stores/base.py index 486aa450f..c17553a27 100644 --- a/haystack/document_stores/base.py +++ b/haystack/document_stores/base.py @@ -58,6 +58,8 @@ class BaseDocumentStore(BaseComponent): Base class for implementing Document Stores. """ + outgoing_edges: int = 1 + index: Optional[str] label_index: Optional[str] similarity: Optional[str] diff --git a/haystack/nodes/base.py b/haystack/nodes/base.py index 4978e8319..e21fc8a8e 100644 --- a/haystack/nodes/base.py +++ b/haystack/nodes/base.py @@ -27,9 +27,6 @@ def exportable_to_yaml(init_func): @wraps(init_func) def wrapper_exportable_to_yaml(self, *args, **kwargs): - # Call the actuall __init__ function with all the arguments - init_func(self, *args, **kwargs) - # Create the configuration dictionary if it doesn't exist yet if not self._component_config: self._component_config = {"params": {}, "type": type(self).__name__} @@ -47,6 +44,9 @@ def exportable_to_yaml(init_func): for k, v in params.items(): self._component_config["params"][k] = v + # Call the actuall __init__ function with all the arguments + init_func(self, *args, **kwargs) + return wrapper_exportable_to_yaml @@ -61,7 +61,9 @@ class BaseComponent(ABC): def __init__(self): # a small subset of the component's parameters is sent in an event after applying filters defined in haystack.telemetry.NonPrivateParameters - send_custom_event(event=f"{type(self).__name__} initialized", payload=self._component_config.get("params", {})) + component_params = self._component_config.get("params", {}) + send_custom_event(event=f"{type(self).__name__} initialized", payload=component_params) + self.outgoing_edges = self._calculate_outgoing_edges(component_params=component_params) # __init_subclass__ is invoked when a subclass of BaseComponent is _imported_ # (not instantiated). It works approximately as a metaclass. @@ -69,6 +71,15 @@ class BaseComponent(ABC): super().__init_subclass__(**kwargs) + # Each component must specify the number of outgoing edges (= different outputs). + # During pipeline validation this number is compared to the requested number of output edges. + if not hasattr(cls, "outgoing_edges"): + raise ValueError( + "BaseComponent subclasses must define the outgoing_edges class attribute. " + "If this number depends on the component's parameters, make sure to override the _calculate_outgoing_edges() method. " + "See https://haystack.deepset.ai/pipeline_nodes/custom-nodes for more information." + ) + # Automatically registers all the init parameters in # an instance attribute called `_component_config`, # used to save this component to YAML. See exportable_to_yaml() @@ -116,13 +127,31 @@ class BaseComponent(ABC): subclass = cls._subclasses[component_type] return subclass + @classmethod + def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int: + """ + Returns the number of outgoing edges for an instance of the component class given its component params. + + In some cases (e.g. RouteDocuments) the number of outgoing edges is not static but rather depends on its component params. + Setting the number of outgoing edges inside the constructor would not be sufficient, since it is already required for validating the pipeline when there is no instance yet. + Hence, this method is responsible for calculating the number of outgoing edges + - during pipeline validation + - to set the effective instance value of `outgoing_edges`. + + Override this method if the number of outgoing edges depends on the component params. + If not overridden, returns the number of outgoing edges as defined in the component class. + + :param component_params: parameters to pass to the __init__() of the component. + """ + return cls.outgoing_edges + @classmethod def _create_instance(cls, component_type: str, component_params: Dict[str, Any], name: Optional[str] = None): """ Returns an instance of the given subclass of BaseComponent. :param component_type: name of the component class to load. - :param component_params: parameters to pass to the __init__() for the component. + :param component_params: parameters to pass to the __init__() of the component. :param name: name of the component instance """ subclass = cls.get_subclass(component_type) diff --git a/haystack/nodes/file_classifier/file_type.py b/haystack/nodes/file_classifier/file_type.py index f27dbe72c..c19114958 100644 --- a/haystack/nodes/file_classifier/file_type.py +++ b/haystack/nodes/file_classifier/file_type.py @@ -1,5 +1,5 @@ import mimetypes -from typing import List, Union +from typing import Any, Dict, List, Union import logging from pathlib import Path @@ -27,7 +27,7 @@ class FileTypeClassifier(BaseComponent): Route files in an Indexing Pipeline to corresponding file converters. """ - outgoing_edges = 10 + outgoing_edges = len(DEFAULT_TYPES) def __init__(self, supported_types: List[str] = DEFAULT_TYPES): """ @@ -40,8 +40,6 @@ class FileTypeClassifier(BaseComponent): elements will not be allowed. Lists with duplicate elements will also be rejected. """ - if len(supported_types) > 10: - raise ValueError("supported_types can't have more than 10 values.") if len(set(supported_types)) != len(supported_types): duplicates = supported_types for item in set(supported_types): @@ -52,6 +50,11 @@ class FileTypeClassifier(BaseComponent): self.supported_types = supported_types + @classmethod + def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int: + supported_types = component_params.get("supported_types", DEFAULT_TYPES) + return len(supported_types) + def _estimate_extension(self, file_path: Path) -> str: """ Return the extension found based on the contents of the given file diff --git a/haystack/nodes/label_generator/pseudo_label_generator.py b/haystack/nodes/label_generator/pseudo_label_generator.py index 2f15d7c7d..5429c0412 100644 --- a/haystack/nodes/label_generator/pseudo_label_generator.py +++ b/haystack/nodes/label_generator/pseudo_label_generator.py @@ -50,6 +50,8 @@ class PseudoLabelGenerator(BaseComponent): """ + outgoing_edges: int = 1 + def __init__( self, question_producer: Union[QuestionGenerator, List[Dict[str, str]]], diff --git a/haystack/nodes/other/join.py b/haystack/nodes/other/join.py index c722c0f3e..7caf0f8f4 100644 --- a/haystack/nodes/other/join.py +++ b/haystack/nodes/other/join.py @@ -7,6 +7,9 @@ from haystack.nodes.base import BaseComponent class JoinNode(BaseComponent): + + outgoing_edges: int = 1 + def run( # type: ignore self, inputs: Optional[List[dict]] = None, diff --git a/haystack/nodes/other/route_documents.py b/haystack/nodes/other/route_documents.py index fb4b0dc85..e230c6ac4 100644 --- a/haystack/nodes/other/route_documents.py +++ b/haystack/nodes/other/route_documents.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Dict, Optional, Union +from typing import Any, List, Tuple, Dict, Optional, Union from collections import defaultdict from haystack.nodes.base import BaseComponent @@ -28,19 +28,25 @@ class RouteDocuments(BaseComponent): value of the provided list will be routed to `"output_2"`, etc. """ - assert split_by == "content_type" or metadata_values is not None, ( - "If split_by is set to the name of a metadata field, you must provide metadata_values " - "to group the documents to." - ) + if split_by != "content_type" and metadata_values is None: + raise ValueError( + "If split_by is set to the name of a metadata field, you must provide metadata_values " + "to group the documents to." + ) super().__init__() self.split_by = split_by self.metadata_values = metadata_values + @classmethod + def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int: + split_by = component_params.get("split_by", "content_type") + metadata_values = component_params.get("metadata_values", None) # If we split list of Documents by a metadata field, number of outgoing edges might change if split_by != "content_type" and metadata_values is not None: - self.outgoing_edges = len(metadata_values) + return len(metadata_values) + return 2 def run(self, documents: List[Document]) -> Tuple[Dict, str]: # type: ignore if self.split_by == "content_type": diff --git a/haystack/pipelines/config.py b/haystack/pipelines/config.py index 591a36f88..99ea05e6e 100644 --- a/haystack/pipelines/config.py +++ b/haystack/pipelines/config.py @@ -56,7 +56,7 @@ def get_pipeline_definition(pipeline_config: Dict[str, Any], pipeline_name: Opti def get_component_definitions( pipeline_config: Dict[str, Any], overwrite_with_env_variables: bool = True -) -> Dict[str, Any]: +) -> Dict[str, Dict[str, Any]]: """ Returns the definitions of all components from a given pipeline config. @@ -393,7 +393,7 @@ def _init_pipeline_graph(root_node_name: Optional[str]) -> nx.DiGraph: def _add_node_to_pipeline_graph( - graph: nx.DiGraph, components: Dict[str, Dict[str, str]], node: Dict[str, Any], instance: BaseComponent = None + graph: nx.DiGraph, components: Dict[str, Dict[str, Any]], node: Dict[str, Any], instance: BaseComponent = None ) -> nx.DiGraph: """ Adds a single node to the provided graph, performing all necessary validation steps. @@ -449,64 +449,69 @@ def _add_node_to_pipeline_graph( graph.add_node(node["name"], component=instance, inputs=node["inputs"]) - for input_node in node["inputs"]: + try: + for input_node in node["inputs"]: - # Separate node and edge name, if specified - input_node_name, input_edge_name = input_node, None - if "." in input_node: - input_node_name, input_edge_name = input_node.split(".") + # Separate node and edge name, if specified + input_node_name, input_edge_name = input_node, None + if "." in input_node: + input_node_name, input_edge_name = input_node.split(".") - root_node_name = list(graph.nodes)[0] - if input_node == root_node_name: - input_edge_name = "output_1" - - elif input_node in VALID_ROOT_NODES: - raise PipelineConfigError( - f"This pipeline seems to contain two root nodes. " - f"You can only use one root node (nodes named {' or '.join(VALID_ROOT_NODES)} per pipeline." - ) - - else: - # Validate node definition and edge name - input_node_type = _get_defined_node_class(node_name=input_node_name, components=components) - input_node_edges_count = input_node_type.outgoing_edges - - if not input_edge_name: - if input_node_edges_count != 1: # Edge was not specified, but input node has many outputs - raise PipelineConfigError( - f"Can't connect {input_node_name} to {node['name']}: " - f"{input_node_name} has {input_node_edges_count} outgoing edges. " - "Please specify the output edge explicitly (like 'filetype_classifier.output_2')." - ) + root_node_name = list(graph.nodes)[0] + if input_node == root_node_name: input_edge_name = "output_1" - if not input_edge_name.startswith("output_"): + elif input_node in VALID_ROOT_NODES: raise PipelineConfigError( - f"'{input_edge_name}' is not a valid edge name. " - "It must start with 'output_' and must contain no dots." + f"This pipeline seems to contain two root nodes. " + f"You can only use one root node (nodes named {' or '.join(VALID_ROOT_NODES)} per pipeline." ) - requested_edge_name = input_edge_name.split("_")[1] + else: + # Validate node definition and edge name + input_node_type = _get_defined_node_class(node_name=input_node_name, components=components) + component_params: Dict[str, Any] = components[input_node_name].get("params", {}) + input_node_edges_count = input_node_type._calculate_outgoing_edges(component_params=component_params) - try: - requested_edge = int(requested_edge_name) - except ValueError: - raise PipelineConfigError( - f"You must specified a numbered edge, like filetype_classifier.output_2, not {input_node}" - ) + if not input_edge_name: + if input_node_edges_count != 1: # Edge was not specified, but input node has many outputs + raise PipelineConfigError( + f"Can't connect {input_node_name} to {node['name']}: " + f"{input_node_name} has {input_node_edges_count} outgoing edges. " + "Please specify the output edge explicitly (like 'filetype_classifier.output_2')." + ) + input_edge_name = "output_1" - if not requested_edge <= input_node_edges_count: - raise PipelineConfigError( - f"Cannot connect '{node['name']}' to '{input_node}', as {input_node_name} has only " - f"{input_node_edges_count} outgoing edge(s)." - ) + if not input_edge_name.startswith("output_"): + raise PipelineConfigError( + f"'{input_edge_name}' is not a valid edge name. " + "It must start with 'output_' and must contain no dots." + ) - graph.add_edge(input_node_name, node["name"], label=input_edge_name) + requested_edge_name = input_edge_name.split("_")[1] - # Check if adding this edge created a loop in the pipeline graph - if not nx.is_directed_acyclic_graph(graph): - graph.remove_node(node["name"]) - raise PipelineConfigError(f"Cannot add '{node['name']}': it will create a loop in the pipeline.") + try: + requested_edge = int(requested_edge_name) + except ValueError: + raise PipelineConfigError( + f"You must specified a numbered edge, like filetype_classifier.output_2, not {input_node}" + ) + + if not requested_edge <= input_node_edges_count: + raise PipelineConfigError( + f"Cannot connect '{node['name']}' to '{input_node}', as {input_node_name} has only " + f"{input_node_edges_count} outgoing edge(s)." + ) + + graph.add_edge(input_node_name, node["name"], label=input_edge_name) + + # Check if adding this edge created a loop in the pipeline graph + if not nx.is_directed_acyclic_graph(graph): + raise PipelineConfigError(f"Cannot add '{node['name']}': it will create a loop in the pipeline.") + + except PipelineConfigError: + graph.remove_node(node["name"]) + raise return graph diff --git a/rest_api/pipeline/custom_component.py b/rest_api/pipeline/custom_component.py index 479ad1bc9..9bf032b00 100644 --- a/rest_api/pipeline/custom_component.py +++ b/rest_api/pipeline/custom_component.py @@ -12,5 +12,7 @@ from haystack.nodes.base import BaseComponent class SampleComponent(BaseComponent): + outgoing_edges: int = 1 + def run(self, **kwargs): raise NotImplementedError diff --git a/test/nodes/test_filetype_classifier.py b/test/nodes/test_filetype_classifier.py index 99db2752e..b3e4f42a8 100644 --- a/test/nodes/test_filetype_classifier.py +++ b/test/nodes/test_filetype_classifier.py @@ -53,11 +53,6 @@ def test_filetype_classifier_custom_extensions(tmp_path): assert output == {"file_paths": [test_file]} -def test_filetype_classifier_too_many_custom_extensions(): - with pytest.raises(ValueError): - FileTypeClassifier(supported_types=[f"my_extension_{idx}" for idx in range(20)]) - - def test_filetype_classifier_duplicate_custom_extensions(): with pytest.raises(ValueError): FileTypeClassifier(supported_types=[f"my_extension", "my_extension"]) diff --git a/test/nodes/test_other.py b/test/nodes/test_other.py index 866962028..d5686c8db 100644 --- a/test/nodes/test_other.py +++ b/test/nodes/test_other.py @@ -29,6 +29,7 @@ def test_routedocuments_by_content_type(): def test_routedocuments_by_metafield(docs): route_documents = RouteDocuments(split_by="meta_field", metadata_values=["test1", "test3", "test5"]) + assert route_documents.outgoing_edges == 3 result, _ = route_documents.run(docs) assert len(result["output_1"]) == 1 assert len(result["output_2"]) == 1 diff --git a/test/pipelines/test_pipeline.py b/test/pipelines/test_pipeline.py index def204f26..0d3e09eab 100644 --- a/test/pipelines/test_pipeline.py +++ b/test/pipelines/test_pipeline.py @@ -91,6 +91,8 @@ class ParentComponent2(BaseComponent): class ChildComponent(BaseComponent): + outgoing_edges = 0 + def __init__(self, some_key: str = None) -> None: super().__init__()