mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-15 18:08:40 +00:00
Fix validation for dynamic outgoing edges (#2850)
* fix validation for dynamic outgoing edges * Update Documentation & Code Style * use class outgoing_edges as fallback if no instance is provided * implement classmethod approach * readd comment * fix mypy * fix tests * set outgoing_edges for all components * set outgoing_edges for mocks too * set document store outgoing_edges to 1 * set last missing outgoing_edges * enforce BaseComponent subclasses to define outgoing_edges * override _calculate_outgoing_edges for FileTypeClassifier * remove superfluous test * set rest_api's custom component's outgoing_edges * Update docstring Co-authored-by: Sara Zan <sara.zanzottera@deepset.ai> * remove unnecessary else Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Sara Zan <sara.zanzottera@deepset.ai>
This commit is contained in:
parent
40d07c2038
commit
b042dd9c82
@ -58,6 +58,8 @@ class BaseDocumentStore(BaseComponent):
|
||||
Base class for implementing Document Stores.
|
||||
"""
|
||||
|
||||
outgoing_edges: int = 1
|
||||
|
||||
index: Optional[str]
|
||||
label_index: Optional[str]
|
||||
similarity: Optional[str]
|
||||
|
@ -27,9 +27,6 @@ def exportable_to_yaml(init_func):
|
||||
@wraps(init_func)
|
||||
def wrapper_exportable_to_yaml(self, *args, **kwargs):
|
||||
|
||||
# Call the actuall __init__ function with all the arguments
|
||||
init_func(self, *args, **kwargs)
|
||||
|
||||
# Create the configuration dictionary if it doesn't exist yet
|
||||
if not self._component_config:
|
||||
self._component_config = {"params": {}, "type": type(self).__name__}
|
||||
@ -47,6 +44,9 @@ def exportable_to_yaml(init_func):
|
||||
for k, v in params.items():
|
||||
self._component_config["params"][k] = v
|
||||
|
||||
# Call the actuall __init__ function with all the arguments
|
||||
init_func(self, *args, **kwargs)
|
||||
|
||||
return wrapper_exportable_to_yaml
|
||||
|
||||
|
||||
@ -61,7 +61,9 @@ class BaseComponent(ABC):
|
||||
|
||||
def __init__(self):
|
||||
# a small subset of the component's parameters is sent in an event after applying filters defined in haystack.telemetry.NonPrivateParameters
|
||||
send_custom_event(event=f"{type(self).__name__} initialized", payload=self._component_config.get("params", {}))
|
||||
component_params = self._component_config.get("params", {})
|
||||
send_custom_event(event=f"{type(self).__name__} initialized", payload=component_params)
|
||||
self.outgoing_edges = self._calculate_outgoing_edges(component_params=component_params)
|
||||
|
||||
# __init_subclass__ is invoked when a subclass of BaseComponent is _imported_
|
||||
# (not instantiated). It works approximately as a metaclass.
|
||||
@ -69,6 +71,15 @@ class BaseComponent(ABC):
|
||||
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# Each component must specify the number of outgoing edges (= different outputs).
|
||||
# During pipeline validation this number is compared to the requested number of output edges.
|
||||
if not hasattr(cls, "outgoing_edges"):
|
||||
raise ValueError(
|
||||
"BaseComponent subclasses must define the outgoing_edges class attribute. "
|
||||
"If this number depends on the component's parameters, make sure to override the _calculate_outgoing_edges() method. "
|
||||
"See https://haystack.deepset.ai/pipeline_nodes/custom-nodes for more information."
|
||||
)
|
||||
|
||||
# Automatically registers all the init parameters in
|
||||
# an instance attribute called `_component_config`,
|
||||
# used to save this component to YAML. See exportable_to_yaml()
|
||||
@ -116,13 +127,31 @@ class BaseComponent(ABC):
|
||||
subclass = cls._subclasses[component_type]
|
||||
return subclass
|
||||
|
||||
@classmethod
|
||||
def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int:
|
||||
"""
|
||||
Returns the number of outgoing edges for an instance of the component class given its component params.
|
||||
|
||||
In some cases (e.g. RouteDocuments) the number of outgoing edges is not static but rather depends on its component params.
|
||||
Setting the number of outgoing edges inside the constructor would not be sufficient, since it is already required for validating the pipeline when there is no instance yet.
|
||||
Hence, this method is responsible for calculating the number of outgoing edges
|
||||
- during pipeline validation
|
||||
- to set the effective instance value of `outgoing_edges`.
|
||||
|
||||
Override this method if the number of outgoing edges depends on the component params.
|
||||
If not overridden, returns the number of outgoing edges as defined in the component class.
|
||||
|
||||
:param component_params: parameters to pass to the __init__() of the component.
|
||||
"""
|
||||
return cls.outgoing_edges
|
||||
|
||||
@classmethod
|
||||
def _create_instance(cls, component_type: str, component_params: Dict[str, Any], name: Optional[str] = None):
|
||||
"""
|
||||
Returns an instance of the given subclass of BaseComponent.
|
||||
|
||||
:param component_type: name of the component class to load.
|
||||
:param component_params: parameters to pass to the __init__() for the component.
|
||||
:param component_params: parameters to pass to the __init__() of the component.
|
||||
:param name: name of the component instance
|
||||
"""
|
||||
subclass = cls.get_subclass(component_type)
|
||||
|
@ -1,5 +1,5 @@
|
||||
import mimetypes
|
||||
from typing import List, Union
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
@ -27,7 +27,7 @@ class FileTypeClassifier(BaseComponent):
|
||||
Route files in an Indexing Pipeline to corresponding file converters.
|
||||
"""
|
||||
|
||||
outgoing_edges = 10
|
||||
outgoing_edges = len(DEFAULT_TYPES)
|
||||
|
||||
def __init__(self, supported_types: List[str] = DEFAULT_TYPES):
|
||||
"""
|
||||
@ -40,8 +40,6 @@ class FileTypeClassifier(BaseComponent):
|
||||
elements will not be allowed. Lists with duplicate elements will
|
||||
also be rejected.
|
||||
"""
|
||||
if len(supported_types) > 10:
|
||||
raise ValueError("supported_types can't have more than 10 values.")
|
||||
if len(set(supported_types)) != len(supported_types):
|
||||
duplicates = supported_types
|
||||
for item in set(supported_types):
|
||||
@ -52,6 +50,11 @@ class FileTypeClassifier(BaseComponent):
|
||||
|
||||
self.supported_types = supported_types
|
||||
|
||||
@classmethod
|
||||
def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int:
|
||||
supported_types = component_params.get("supported_types", DEFAULT_TYPES)
|
||||
return len(supported_types)
|
||||
|
||||
def _estimate_extension(self, file_path: Path) -> str:
|
||||
"""
|
||||
Return the extension found based on the contents of the given file
|
||||
|
@ -50,6 +50,8 @@ class PseudoLabelGenerator(BaseComponent):
|
||||
|
||||
"""
|
||||
|
||||
outgoing_edges: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
question_producer: Union[QuestionGenerator, List[Dict[str, str]]],
|
||||
|
@ -7,6 +7,9 @@ from haystack.nodes.base import BaseComponent
|
||||
|
||||
|
||||
class JoinNode(BaseComponent):
|
||||
|
||||
outgoing_edges: int = 1
|
||||
|
||||
def run( # type: ignore
|
||||
self,
|
||||
inputs: Optional[List[dict]] = None,
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import List, Tuple, Dict, Optional, Union
|
||||
from typing import Any, List, Tuple, Dict, Optional, Union
|
||||
from collections import defaultdict
|
||||
|
||||
from haystack.nodes.base import BaseComponent
|
||||
@ -28,19 +28,25 @@ class RouteDocuments(BaseComponent):
|
||||
value of the provided list will be routed to `"output_2"`, etc.
|
||||
"""
|
||||
|
||||
assert split_by == "content_type" or metadata_values is not None, (
|
||||
"If split_by is set to the name of a metadata field, you must provide metadata_values "
|
||||
"to group the documents to."
|
||||
)
|
||||
if split_by != "content_type" and metadata_values is None:
|
||||
raise ValueError(
|
||||
"If split_by is set to the name of a metadata field, you must provide metadata_values "
|
||||
"to group the documents to."
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.split_by = split_by
|
||||
self.metadata_values = metadata_values
|
||||
|
||||
@classmethod
|
||||
def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int:
|
||||
split_by = component_params.get("split_by", "content_type")
|
||||
metadata_values = component_params.get("metadata_values", None)
|
||||
# If we split list of Documents by a metadata field, number of outgoing edges might change
|
||||
if split_by != "content_type" and metadata_values is not None:
|
||||
self.outgoing_edges = len(metadata_values)
|
||||
return len(metadata_values)
|
||||
return 2
|
||||
|
||||
def run(self, documents: List[Document]) -> Tuple[Dict, str]: # type: ignore
|
||||
if self.split_by == "content_type":
|
||||
|
@ -56,7 +56,7 @@ def get_pipeline_definition(pipeline_config: Dict[str, Any], pipeline_name: Opti
|
||||
|
||||
def get_component_definitions(
|
||||
pipeline_config: Dict[str, Any], overwrite_with_env_variables: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Returns the definitions of all components from a given pipeline config.
|
||||
|
||||
@ -393,7 +393,7 @@ def _init_pipeline_graph(root_node_name: Optional[str]) -> nx.DiGraph:
|
||||
|
||||
|
||||
def _add_node_to_pipeline_graph(
|
||||
graph: nx.DiGraph, components: Dict[str, Dict[str, str]], node: Dict[str, Any], instance: BaseComponent = None
|
||||
graph: nx.DiGraph, components: Dict[str, Dict[str, Any]], node: Dict[str, Any], instance: BaseComponent = None
|
||||
) -> nx.DiGraph:
|
||||
"""
|
||||
Adds a single node to the provided graph, performing all necessary validation steps.
|
||||
@ -449,64 +449,69 @@ def _add_node_to_pipeline_graph(
|
||||
|
||||
graph.add_node(node["name"], component=instance, inputs=node["inputs"])
|
||||
|
||||
for input_node in node["inputs"]:
|
||||
try:
|
||||
for input_node in node["inputs"]:
|
||||
|
||||
# Separate node and edge name, if specified
|
||||
input_node_name, input_edge_name = input_node, None
|
||||
if "." in input_node:
|
||||
input_node_name, input_edge_name = input_node.split(".")
|
||||
# Separate node and edge name, if specified
|
||||
input_node_name, input_edge_name = input_node, None
|
||||
if "." in input_node:
|
||||
input_node_name, input_edge_name = input_node.split(".")
|
||||
|
||||
root_node_name = list(graph.nodes)[0]
|
||||
if input_node == root_node_name:
|
||||
input_edge_name = "output_1"
|
||||
|
||||
elif input_node in VALID_ROOT_NODES:
|
||||
raise PipelineConfigError(
|
||||
f"This pipeline seems to contain two root nodes. "
|
||||
f"You can only use one root node (nodes named {' or '.join(VALID_ROOT_NODES)} per pipeline."
|
||||
)
|
||||
|
||||
else:
|
||||
# Validate node definition and edge name
|
||||
input_node_type = _get_defined_node_class(node_name=input_node_name, components=components)
|
||||
input_node_edges_count = input_node_type.outgoing_edges
|
||||
|
||||
if not input_edge_name:
|
||||
if input_node_edges_count != 1: # Edge was not specified, but input node has many outputs
|
||||
raise PipelineConfigError(
|
||||
f"Can't connect {input_node_name} to {node['name']}: "
|
||||
f"{input_node_name} has {input_node_edges_count} outgoing edges. "
|
||||
"Please specify the output edge explicitly (like 'filetype_classifier.output_2')."
|
||||
)
|
||||
root_node_name = list(graph.nodes)[0]
|
||||
if input_node == root_node_name:
|
||||
input_edge_name = "output_1"
|
||||
|
||||
if not input_edge_name.startswith("output_"):
|
||||
elif input_node in VALID_ROOT_NODES:
|
||||
raise PipelineConfigError(
|
||||
f"'{input_edge_name}' is not a valid edge name. "
|
||||
"It must start with 'output_' and must contain no dots."
|
||||
f"This pipeline seems to contain two root nodes. "
|
||||
f"You can only use one root node (nodes named {' or '.join(VALID_ROOT_NODES)} per pipeline."
|
||||
)
|
||||
|
||||
requested_edge_name = input_edge_name.split("_")[1]
|
||||
else:
|
||||
# Validate node definition and edge name
|
||||
input_node_type = _get_defined_node_class(node_name=input_node_name, components=components)
|
||||
component_params: Dict[str, Any] = components[input_node_name].get("params", {})
|
||||
input_node_edges_count = input_node_type._calculate_outgoing_edges(component_params=component_params)
|
||||
|
||||
try:
|
||||
requested_edge = int(requested_edge_name)
|
||||
except ValueError:
|
||||
raise PipelineConfigError(
|
||||
f"You must specified a numbered edge, like filetype_classifier.output_2, not {input_node}"
|
||||
)
|
||||
if not input_edge_name:
|
||||
if input_node_edges_count != 1: # Edge was not specified, but input node has many outputs
|
||||
raise PipelineConfigError(
|
||||
f"Can't connect {input_node_name} to {node['name']}: "
|
||||
f"{input_node_name} has {input_node_edges_count} outgoing edges. "
|
||||
"Please specify the output edge explicitly (like 'filetype_classifier.output_2')."
|
||||
)
|
||||
input_edge_name = "output_1"
|
||||
|
||||
if not requested_edge <= input_node_edges_count:
|
||||
raise PipelineConfigError(
|
||||
f"Cannot connect '{node['name']}' to '{input_node}', as {input_node_name} has only "
|
||||
f"{input_node_edges_count} outgoing edge(s)."
|
||||
)
|
||||
if not input_edge_name.startswith("output_"):
|
||||
raise PipelineConfigError(
|
||||
f"'{input_edge_name}' is not a valid edge name. "
|
||||
"It must start with 'output_' and must contain no dots."
|
||||
)
|
||||
|
||||
graph.add_edge(input_node_name, node["name"], label=input_edge_name)
|
||||
requested_edge_name = input_edge_name.split("_")[1]
|
||||
|
||||
# Check if adding this edge created a loop in the pipeline graph
|
||||
if not nx.is_directed_acyclic_graph(graph):
|
||||
graph.remove_node(node["name"])
|
||||
raise PipelineConfigError(f"Cannot add '{node['name']}': it will create a loop in the pipeline.")
|
||||
try:
|
||||
requested_edge = int(requested_edge_name)
|
||||
except ValueError:
|
||||
raise PipelineConfigError(
|
||||
f"You must specified a numbered edge, like filetype_classifier.output_2, not {input_node}"
|
||||
)
|
||||
|
||||
if not requested_edge <= input_node_edges_count:
|
||||
raise PipelineConfigError(
|
||||
f"Cannot connect '{node['name']}' to '{input_node}', as {input_node_name} has only "
|
||||
f"{input_node_edges_count} outgoing edge(s)."
|
||||
)
|
||||
|
||||
graph.add_edge(input_node_name, node["name"], label=input_edge_name)
|
||||
|
||||
# Check if adding this edge created a loop in the pipeline graph
|
||||
if not nx.is_directed_acyclic_graph(graph):
|
||||
raise PipelineConfigError(f"Cannot add '{node['name']}': it will create a loop in the pipeline.")
|
||||
|
||||
except PipelineConfigError:
|
||||
graph.remove_node(node["name"])
|
||||
raise
|
||||
|
||||
return graph
|
||||
|
||||
|
@ -12,5 +12,7 @@ from haystack.nodes.base import BaseComponent
|
||||
|
||||
|
||||
class SampleComponent(BaseComponent):
|
||||
outgoing_edges: int = 1
|
||||
|
||||
def run(self, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
@ -53,11 +53,6 @@ def test_filetype_classifier_custom_extensions(tmp_path):
|
||||
assert output == {"file_paths": [test_file]}
|
||||
|
||||
|
||||
def test_filetype_classifier_too_many_custom_extensions():
|
||||
with pytest.raises(ValueError):
|
||||
FileTypeClassifier(supported_types=[f"my_extension_{idx}" for idx in range(20)])
|
||||
|
||||
|
||||
def test_filetype_classifier_duplicate_custom_extensions():
|
||||
with pytest.raises(ValueError):
|
||||
FileTypeClassifier(supported_types=[f"my_extension", "my_extension"])
|
||||
|
@ -29,6 +29,7 @@ def test_routedocuments_by_content_type():
|
||||
|
||||
def test_routedocuments_by_metafield(docs):
|
||||
route_documents = RouteDocuments(split_by="meta_field", metadata_values=["test1", "test3", "test5"])
|
||||
assert route_documents.outgoing_edges == 3
|
||||
result, _ = route_documents.run(docs)
|
||||
assert len(result["output_1"]) == 1
|
||||
assert len(result["output_2"]) == 1
|
||||
|
@ -91,6 +91,8 @@ class ParentComponent2(BaseComponent):
|
||||
|
||||
|
||||
class ChildComponent(BaseComponent):
|
||||
outgoing_edges = 0
|
||||
|
||||
def __init__(self, some_key: str = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user