Fix validation for dynamic outgoing edges (#2850)

* fix validation for dynamic outgoing edges

* Update Documentation & Code Style

* use class outgoing_edges as fallback if no instance is provided

* implement classmethod approach

* readd comment

* fix mypy

* fix tests

* set outgoing_edges for all components

* set outgoing_edges for mocks too

* set document store outgoing_edges to 1

* set last missing outgoing_edges

* enforce BaseComponent subclasses to define outgoing_edges

* override _calculate_outgoing_edges for FileTypeClassifier

* remove superfluous test

* set rest_api's custom component's outgoing_edges

* Update docstring

Co-authored-by: Sara Zan <sara.zanzottera@deepset.ai>

* remove unnecessary else

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Sara Zan <sara.zanzottera@deepset.ai>
This commit is contained in:
tstadel 2022-08-04 10:27:50 +02:00 committed by GitHub
parent 40d07c2038
commit b042dd9c82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 119 additions and 69 deletions

View File

@ -58,6 +58,8 @@ class BaseDocumentStore(BaseComponent):
Base class for implementing Document Stores.
"""
outgoing_edges: int = 1
index: Optional[str]
label_index: Optional[str]
similarity: Optional[str]

View File

@ -27,9 +27,6 @@ def exportable_to_yaml(init_func):
@wraps(init_func)
def wrapper_exportable_to_yaml(self, *args, **kwargs):
# Call the actuall __init__ function with all the arguments
init_func(self, *args, **kwargs)
# Create the configuration dictionary if it doesn't exist yet
if not self._component_config:
self._component_config = {"params": {}, "type": type(self).__name__}
@ -47,6 +44,9 @@ def exportable_to_yaml(init_func):
for k, v in params.items():
self._component_config["params"][k] = v
# Call the actuall __init__ function with all the arguments
init_func(self, *args, **kwargs)
return wrapper_exportable_to_yaml
@ -61,7 +61,9 @@ class BaseComponent(ABC):
def __init__(self):
# a small subset of the component's parameters is sent in an event after applying filters defined in haystack.telemetry.NonPrivateParameters
send_custom_event(event=f"{type(self).__name__} initialized", payload=self._component_config.get("params", {}))
component_params = self._component_config.get("params", {})
send_custom_event(event=f"{type(self).__name__} initialized", payload=component_params)
self.outgoing_edges = self._calculate_outgoing_edges(component_params=component_params)
# __init_subclass__ is invoked when a subclass of BaseComponent is _imported_
# (not instantiated). It works approximately as a metaclass.
@ -69,6 +71,15 @@ class BaseComponent(ABC):
super().__init_subclass__(**kwargs)
# Each component must specify the number of outgoing edges (= different outputs).
# During pipeline validation this number is compared to the requested number of output edges.
if not hasattr(cls, "outgoing_edges"):
raise ValueError(
"BaseComponent subclasses must define the outgoing_edges class attribute. "
"If this number depends on the component's parameters, make sure to override the _calculate_outgoing_edges() method. "
"See https://haystack.deepset.ai/pipeline_nodes/custom-nodes for more information."
)
# Automatically registers all the init parameters in
# an instance attribute called `_component_config`,
# used to save this component to YAML. See exportable_to_yaml()
@ -116,13 +127,31 @@ class BaseComponent(ABC):
subclass = cls._subclasses[component_type]
return subclass
@classmethod
def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int:
"""
Returns the number of outgoing edges for an instance of the component class given its component params.
In some cases (e.g. RouteDocuments) the number of outgoing edges is not static but rather depends on its component params.
Setting the number of outgoing edges inside the constructor would not be sufficient, since it is already required for validating the pipeline when there is no instance yet.
Hence, this method is responsible for calculating the number of outgoing edges
- during pipeline validation
- to set the effective instance value of `outgoing_edges`.
Override this method if the number of outgoing edges depends on the component params.
If not overridden, returns the number of outgoing edges as defined in the component class.
:param component_params: parameters to pass to the __init__() of the component.
"""
return cls.outgoing_edges
@classmethod
def _create_instance(cls, component_type: str, component_params: Dict[str, Any], name: Optional[str] = None):
"""
Returns an instance of the given subclass of BaseComponent.
:param component_type: name of the component class to load.
:param component_params: parameters to pass to the __init__() for the component.
:param component_params: parameters to pass to the __init__() of the component.
:param name: name of the component instance
"""
subclass = cls.get_subclass(component_type)

View File

@ -1,5 +1,5 @@
import mimetypes
from typing import List, Union
from typing import Any, Dict, List, Union
import logging
from pathlib import Path
@ -27,7 +27,7 @@ class FileTypeClassifier(BaseComponent):
Route files in an Indexing Pipeline to corresponding file converters.
"""
outgoing_edges = 10
outgoing_edges = len(DEFAULT_TYPES)
def __init__(self, supported_types: List[str] = DEFAULT_TYPES):
"""
@ -40,8 +40,6 @@ class FileTypeClassifier(BaseComponent):
elements will not be allowed. Lists with duplicate elements will
also be rejected.
"""
if len(supported_types) > 10:
raise ValueError("supported_types can't have more than 10 values.")
if len(set(supported_types)) != len(supported_types):
duplicates = supported_types
for item in set(supported_types):
@ -52,6 +50,11 @@ class FileTypeClassifier(BaseComponent):
self.supported_types = supported_types
@classmethod
def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int:
supported_types = component_params.get("supported_types", DEFAULT_TYPES)
return len(supported_types)
def _estimate_extension(self, file_path: Path) -> str:
"""
Return the extension found based on the contents of the given file

View File

@ -50,6 +50,8 @@ class PseudoLabelGenerator(BaseComponent):
"""
outgoing_edges: int = 1
def __init__(
self,
question_producer: Union[QuestionGenerator, List[Dict[str, str]]],

View File

@ -7,6 +7,9 @@ from haystack.nodes.base import BaseComponent
class JoinNode(BaseComponent):
outgoing_edges: int = 1
def run( # type: ignore
self,
inputs: Optional[List[dict]] = None,

View File

@ -1,4 +1,4 @@
from typing import List, Tuple, Dict, Optional, Union
from typing import Any, List, Tuple, Dict, Optional, Union
from collections import defaultdict
from haystack.nodes.base import BaseComponent
@ -28,19 +28,25 @@ class RouteDocuments(BaseComponent):
value of the provided list will be routed to `"output_2"`, etc.
"""
assert split_by == "content_type" or metadata_values is not None, (
"If split_by is set to the name of a metadata field, you must provide metadata_values "
"to group the documents to."
)
if split_by != "content_type" and metadata_values is None:
raise ValueError(
"If split_by is set to the name of a metadata field, you must provide metadata_values "
"to group the documents to."
)
super().__init__()
self.split_by = split_by
self.metadata_values = metadata_values
@classmethod
def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int:
split_by = component_params.get("split_by", "content_type")
metadata_values = component_params.get("metadata_values", None)
# If we split list of Documents by a metadata field, number of outgoing edges might change
if split_by != "content_type" and metadata_values is not None:
self.outgoing_edges = len(metadata_values)
return len(metadata_values)
return 2
def run(self, documents: List[Document]) -> Tuple[Dict, str]: # type: ignore
if self.split_by == "content_type":

View File

@ -56,7 +56,7 @@ def get_pipeline_definition(pipeline_config: Dict[str, Any], pipeline_name: Opti
def get_component_definitions(
pipeline_config: Dict[str, Any], overwrite_with_env_variables: bool = True
) -> Dict[str, Any]:
) -> Dict[str, Dict[str, Any]]:
"""
Returns the definitions of all components from a given pipeline config.
@ -393,7 +393,7 @@ def _init_pipeline_graph(root_node_name: Optional[str]) -> nx.DiGraph:
def _add_node_to_pipeline_graph(
graph: nx.DiGraph, components: Dict[str, Dict[str, str]], node: Dict[str, Any], instance: BaseComponent = None
graph: nx.DiGraph, components: Dict[str, Dict[str, Any]], node: Dict[str, Any], instance: BaseComponent = None
) -> nx.DiGraph:
"""
Adds a single node to the provided graph, performing all necessary validation steps.
@ -449,64 +449,69 @@ def _add_node_to_pipeline_graph(
graph.add_node(node["name"], component=instance, inputs=node["inputs"])
for input_node in node["inputs"]:
try:
for input_node in node["inputs"]:
# Separate node and edge name, if specified
input_node_name, input_edge_name = input_node, None
if "." in input_node:
input_node_name, input_edge_name = input_node.split(".")
# Separate node and edge name, if specified
input_node_name, input_edge_name = input_node, None
if "." in input_node:
input_node_name, input_edge_name = input_node.split(".")
root_node_name = list(graph.nodes)[0]
if input_node == root_node_name:
input_edge_name = "output_1"
elif input_node in VALID_ROOT_NODES:
raise PipelineConfigError(
f"This pipeline seems to contain two root nodes. "
f"You can only use one root node (nodes named {' or '.join(VALID_ROOT_NODES)} per pipeline."
)
else:
# Validate node definition and edge name
input_node_type = _get_defined_node_class(node_name=input_node_name, components=components)
input_node_edges_count = input_node_type.outgoing_edges
if not input_edge_name:
if input_node_edges_count != 1: # Edge was not specified, but input node has many outputs
raise PipelineConfigError(
f"Can't connect {input_node_name} to {node['name']}: "
f"{input_node_name} has {input_node_edges_count} outgoing edges. "
"Please specify the output edge explicitly (like 'filetype_classifier.output_2')."
)
root_node_name = list(graph.nodes)[0]
if input_node == root_node_name:
input_edge_name = "output_1"
if not input_edge_name.startswith("output_"):
elif input_node in VALID_ROOT_NODES:
raise PipelineConfigError(
f"'{input_edge_name}' is not a valid edge name. "
"It must start with 'output_' and must contain no dots."
f"This pipeline seems to contain two root nodes. "
f"You can only use one root node (nodes named {' or '.join(VALID_ROOT_NODES)} per pipeline."
)
requested_edge_name = input_edge_name.split("_")[1]
else:
# Validate node definition and edge name
input_node_type = _get_defined_node_class(node_name=input_node_name, components=components)
component_params: Dict[str, Any] = components[input_node_name].get("params", {})
input_node_edges_count = input_node_type._calculate_outgoing_edges(component_params=component_params)
try:
requested_edge = int(requested_edge_name)
except ValueError:
raise PipelineConfigError(
f"You must specified a numbered edge, like filetype_classifier.output_2, not {input_node}"
)
if not input_edge_name:
if input_node_edges_count != 1: # Edge was not specified, but input node has many outputs
raise PipelineConfigError(
f"Can't connect {input_node_name} to {node['name']}: "
f"{input_node_name} has {input_node_edges_count} outgoing edges. "
"Please specify the output edge explicitly (like 'filetype_classifier.output_2')."
)
input_edge_name = "output_1"
if not requested_edge <= input_node_edges_count:
raise PipelineConfigError(
f"Cannot connect '{node['name']}' to '{input_node}', as {input_node_name} has only "
f"{input_node_edges_count} outgoing edge(s)."
)
if not input_edge_name.startswith("output_"):
raise PipelineConfigError(
f"'{input_edge_name}' is not a valid edge name. "
"It must start with 'output_' and must contain no dots."
)
graph.add_edge(input_node_name, node["name"], label=input_edge_name)
requested_edge_name = input_edge_name.split("_")[1]
# Check if adding this edge created a loop in the pipeline graph
if not nx.is_directed_acyclic_graph(graph):
graph.remove_node(node["name"])
raise PipelineConfigError(f"Cannot add '{node['name']}': it will create a loop in the pipeline.")
try:
requested_edge = int(requested_edge_name)
except ValueError:
raise PipelineConfigError(
f"You must specified a numbered edge, like filetype_classifier.output_2, not {input_node}"
)
if not requested_edge <= input_node_edges_count:
raise PipelineConfigError(
f"Cannot connect '{node['name']}' to '{input_node}', as {input_node_name} has only "
f"{input_node_edges_count} outgoing edge(s)."
)
graph.add_edge(input_node_name, node["name"], label=input_edge_name)
# Check if adding this edge created a loop in the pipeline graph
if not nx.is_directed_acyclic_graph(graph):
raise PipelineConfigError(f"Cannot add '{node['name']}': it will create a loop in the pipeline.")
except PipelineConfigError:
graph.remove_node(node["name"])
raise
return graph

View File

@ -12,5 +12,7 @@ from haystack.nodes.base import BaseComponent
class SampleComponent(BaseComponent):
outgoing_edges: int = 1
def run(self, **kwargs):
raise NotImplementedError

View File

@ -53,11 +53,6 @@ def test_filetype_classifier_custom_extensions(tmp_path):
assert output == {"file_paths": [test_file]}
def test_filetype_classifier_too_many_custom_extensions():
with pytest.raises(ValueError):
FileTypeClassifier(supported_types=[f"my_extension_{idx}" for idx in range(20)])
def test_filetype_classifier_duplicate_custom_extensions():
with pytest.raises(ValueError):
FileTypeClassifier(supported_types=[f"my_extension", "my_extension"])

View File

@ -29,6 +29,7 @@ def test_routedocuments_by_content_type():
def test_routedocuments_by_metafield(docs):
route_documents = RouteDocuments(split_by="meta_field", metadata_values=["test1", "test3", "test5"])
assert route_documents.outgoing_edges == 3
result, _ = route_documents.run(docs)
assert len(result["output_1"]) == 1
assert len(result["output_2"]) == 1

View File

@ -91,6 +91,8 @@ class ParentComponent2(BaseComponent):
class ChildComponent(BaseComponent):
outgoing_edges = 0
def __init__(self, some_key: str = None) -> None:
super().__init__()