mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-20 12:28:43 +00:00
Fix validation for dynamic outgoing edges (#2850)
* fix validation for dynamic outgoing edges * Update Documentation & Code Style * use class outgoing_edges as fallback if no instance is provided * implement classmethod approach * readd comment * fix mypy * fix tests * set outgoing_edges for all components * set outgoing_edges for mocks too * set document store outgoing_edges to 1 * set last missing outgoing_edges * enforce BaseComponent subclasses to define outgoing_edges * override _calculate_outgoing_edges for FileTypeClassifier * remove superfluous test * set rest_api's custom component's outgoing_edges * Update docstring Co-authored-by: Sara Zan <sara.zanzottera@deepset.ai> * remove unnecessary else Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Sara Zan <sara.zanzottera@deepset.ai>
This commit is contained in:
parent
40d07c2038
commit
b042dd9c82
@ -58,6 +58,8 @@ class BaseDocumentStore(BaseComponent):
|
|||||||
Base class for implementing Document Stores.
|
Base class for implementing Document Stores.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
outgoing_edges: int = 1
|
||||||
|
|
||||||
index: Optional[str]
|
index: Optional[str]
|
||||||
label_index: Optional[str]
|
label_index: Optional[str]
|
||||||
similarity: Optional[str]
|
similarity: Optional[str]
|
||||||
|
@ -27,9 +27,6 @@ def exportable_to_yaml(init_func):
|
|||||||
@wraps(init_func)
|
@wraps(init_func)
|
||||||
def wrapper_exportable_to_yaml(self, *args, **kwargs):
|
def wrapper_exportable_to_yaml(self, *args, **kwargs):
|
||||||
|
|
||||||
# Call the actuall __init__ function with all the arguments
|
|
||||||
init_func(self, *args, **kwargs)
|
|
||||||
|
|
||||||
# Create the configuration dictionary if it doesn't exist yet
|
# Create the configuration dictionary if it doesn't exist yet
|
||||||
if not self._component_config:
|
if not self._component_config:
|
||||||
self._component_config = {"params": {}, "type": type(self).__name__}
|
self._component_config = {"params": {}, "type": type(self).__name__}
|
||||||
@ -47,6 +44,9 @@ def exportable_to_yaml(init_func):
|
|||||||
for k, v in params.items():
|
for k, v in params.items():
|
||||||
self._component_config["params"][k] = v
|
self._component_config["params"][k] = v
|
||||||
|
|
||||||
|
# Call the actuall __init__ function with all the arguments
|
||||||
|
init_func(self, *args, **kwargs)
|
||||||
|
|
||||||
return wrapper_exportable_to_yaml
|
return wrapper_exportable_to_yaml
|
||||||
|
|
||||||
|
|
||||||
@ -61,7 +61,9 @@ class BaseComponent(ABC):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# a small subset of the component's parameters is sent in an event after applying filters defined in haystack.telemetry.NonPrivateParameters
|
# a small subset of the component's parameters is sent in an event after applying filters defined in haystack.telemetry.NonPrivateParameters
|
||||||
send_custom_event(event=f"{type(self).__name__} initialized", payload=self._component_config.get("params", {}))
|
component_params = self._component_config.get("params", {})
|
||||||
|
send_custom_event(event=f"{type(self).__name__} initialized", payload=component_params)
|
||||||
|
self.outgoing_edges = self._calculate_outgoing_edges(component_params=component_params)
|
||||||
|
|
||||||
# __init_subclass__ is invoked when a subclass of BaseComponent is _imported_
|
# __init_subclass__ is invoked when a subclass of BaseComponent is _imported_
|
||||||
# (not instantiated). It works approximately as a metaclass.
|
# (not instantiated). It works approximately as a metaclass.
|
||||||
@ -69,6 +71,15 @@ class BaseComponent(ABC):
|
|||||||
|
|
||||||
super().__init_subclass__(**kwargs)
|
super().__init_subclass__(**kwargs)
|
||||||
|
|
||||||
|
# Each component must specify the number of outgoing edges (= different outputs).
|
||||||
|
# During pipeline validation this number is compared to the requested number of output edges.
|
||||||
|
if not hasattr(cls, "outgoing_edges"):
|
||||||
|
raise ValueError(
|
||||||
|
"BaseComponent subclasses must define the outgoing_edges class attribute. "
|
||||||
|
"If this number depends on the component's parameters, make sure to override the _calculate_outgoing_edges() method. "
|
||||||
|
"See https://haystack.deepset.ai/pipeline_nodes/custom-nodes for more information."
|
||||||
|
)
|
||||||
|
|
||||||
# Automatically registers all the init parameters in
|
# Automatically registers all the init parameters in
|
||||||
# an instance attribute called `_component_config`,
|
# an instance attribute called `_component_config`,
|
||||||
# used to save this component to YAML. See exportable_to_yaml()
|
# used to save this component to YAML. See exportable_to_yaml()
|
||||||
@ -116,13 +127,31 @@ class BaseComponent(ABC):
|
|||||||
subclass = cls._subclasses[component_type]
|
subclass = cls._subclasses[component_type]
|
||||||
return subclass
|
return subclass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int:
|
||||||
|
"""
|
||||||
|
Returns the number of outgoing edges for an instance of the component class given its component params.
|
||||||
|
|
||||||
|
In some cases (e.g. RouteDocuments) the number of outgoing edges is not static but rather depends on its component params.
|
||||||
|
Setting the number of outgoing edges inside the constructor would not be sufficient, since it is already required for validating the pipeline when there is no instance yet.
|
||||||
|
Hence, this method is responsible for calculating the number of outgoing edges
|
||||||
|
- during pipeline validation
|
||||||
|
- to set the effective instance value of `outgoing_edges`.
|
||||||
|
|
||||||
|
Override this method if the number of outgoing edges depends on the component params.
|
||||||
|
If not overridden, returns the number of outgoing edges as defined in the component class.
|
||||||
|
|
||||||
|
:param component_params: parameters to pass to the __init__() of the component.
|
||||||
|
"""
|
||||||
|
return cls.outgoing_edges
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _create_instance(cls, component_type: str, component_params: Dict[str, Any], name: Optional[str] = None):
|
def _create_instance(cls, component_type: str, component_params: Dict[str, Any], name: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
Returns an instance of the given subclass of BaseComponent.
|
Returns an instance of the given subclass of BaseComponent.
|
||||||
|
|
||||||
:param component_type: name of the component class to load.
|
:param component_type: name of the component class to load.
|
||||||
:param component_params: parameters to pass to the __init__() for the component.
|
:param component_params: parameters to pass to the __init__() of the component.
|
||||||
:param name: name of the component instance
|
:param name: name of the component instance
|
||||||
"""
|
"""
|
||||||
subclass = cls.get_subclass(component_type)
|
subclass = cls.get_subclass(component_type)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import mimetypes
|
import mimetypes
|
||||||
from typing import List, Union
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -27,7 +27,7 @@ class FileTypeClassifier(BaseComponent):
|
|||||||
Route files in an Indexing Pipeline to corresponding file converters.
|
Route files in an Indexing Pipeline to corresponding file converters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
outgoing_edges = 10
|
outgoing_edges = len(DEFAULT_TYPES)
|
||||||
|
|
||||||
def __init__(self, supported_types: List[str] = DEFAULT_TYPES):
|
def __init__(self, supported_types: List[str] = DEFAULT_TYPES):
|
||||||
"""
|
"""
|
||||||
@ -40,8 +40,6 @@ class FileTypeClassifier(BaseComponent):
|
|||||||
elements will not be allowed. Lists with duplicate elements will
|
elements will not be allowed. Lists with duplicate elements will
|
||||||
also be rejected.
|
also be rejected.
|
||||||
"""
|
"""
|
||||||
if len(supported_types) > 10:
|
|
||||||
raise ValueError("supported_types can't have more than 10 values.")
|
|
||||||
if len(set(supported_types)) != len(supported_types):
|
if len(set(supported_types)) != len(supported_types):
|
||||||
duplicates = supported_types
|
duplicates = supported_types
|
||||||
for item in set(supported_types):
|
for item in set(supported_types):
|
||||||
@ -52,6 +50,11 @@ class FileTypeClassifier(BaseComponent):
|
|||||||
|
|
||||||
self.supported_types = supported_types
|
self.supported_types = supported_types
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int:
|
||||||
|
supported_types = component_params.get("supported_types", DEFAULT_TYPES)
|
||||||
|
return len(supported_types)
|
||||||
|
|
||||||
def _estimate_extension(self, file_path: Path) -> str:
|
def _estimate_extension(self, file_path: Path) -> str:
|
||||||
"""
|
"""
|
||||||
Return the extension found based on the contents of the given file
|
Return the extension found based on the contents of the given file
|
||||||
|
@ -50,6 +50,8 @@ class PseudoLabelGenerator(BaseComponent):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
outgoing_edges: int = 1
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
question_producer: Union[QuestionGenerator, List[Dict[str, str]]],
|
question_producer: Union[QuestionGenerator, List[Dict[str, str]]],
|
||||||
|
@ -7,6 +7,9 @@ from haystack.nodes.base import BaseComponent
|
|||||||
|
|
||||||
|
|
||||||
class JoinNode(BaseComponent):
|
class JoinNode(BaseComponent):
|
||||||
|
|
||||||
|
outgoing_edges: int = 1
|
||||||
|
|
||||||
def run( # type: ignore
|
def run( # type: ignore
|
||||||
self,
|
self,
|
||||||
inputs: Optional[List[dict]] = None,
|
inputs: Optional[List[dict]] = None,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Tuple, Dict, Optional, Union
|
from typing import Any, List, Tuple, Dict, Optional, Union
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
from haystack.nodes.base import BaseComponent
|
from haystack.nodes.base import BaseComponent
|
||||||
@ -28,7 +28,8 @@ class RouteDocuments(BaseComponent):
|
|||||||
value of the provided list will be routed to `"output_2"`, etc.
|
value of the provided list will be routed to `"output_2"`, etc.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert split_by == "content_type" or metadata_values is not None, (
|
if split_by != "content_type" and metadata_values is None:
|
||||||
|
raise ValueError(
|
||||||
"If split_by is set to the name of a metadata field, you must provide metadata_values "
|
"If split_by is set to the name of a metadata field, you must provide metadata_values "
|
||||||
"to group the documents to."
|
"to group the documents to."
|
||||||
)
|
)
|
||||||
@ -38,9 +39,14 @@ class RouteDocuments(BaseComponent):
|
|||||||
self.split_by = split_by
|
self.split_by = split_by
|
||||||
self.metadata_values = metadata_values
|
self.metadata_values = metadata_values
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int:
|
||||||
|
split_by = component_params.get("split_by", "content_type")
|
||||||
|
metadata_values = component_params.get("metadata_values", None)
|
||||||
# If we split list of Documents by a metadata field, number of outgoing edges might change
|
# If we split list of Documents by a metadata field, number of outgoing edges might change
|
||||||
if split_by != "content_type" and metadata_values is not None:
|
if split_by != "content_type" and metadata_values is not None:
|
||||||
self.outgoing_edges = len(metadata_values)
|
return len(metadata_values)
|
||||||
|
return 2
|
||||||
|
|
||||||
def run(self, documents: List[Document]) -> Tuple[Dict, str]: # type: ignore
|
def run(self, documents: List[Document]) -> Tuple[Dict, str]: # type: ignore
|
||||||
if self.split_by == "content_type":
|
if self.split_by == "content_type":
|
||||||
|
@ -56,7 +56,7 @@ def get_pipeline_definition(pipeline_config: Dict[str, Any], pipeline_name: Opti
|
|||||||
|
|
||||||
def get_component_definitions(
|
def get_component_definitions(
|
||||||
pipeline_config: Dict[str, Any], overwrite_with_env_variables: bool = True
|
pipeline_config: Dict[str, Any], overwrite_with_env_variables: bool = True
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Returns the definitions of all components from a given pipeline config.
|
Returns the definitions of all components from a given pipeline config.
|
||||||
|
|
||||||
@ -393,7 +393,7 @@ def _init_pipeline_graph(root_node_name: Optional[str]) -> nx.DiGraph:
|
|||||||
|
|
||||||
|
|
||||||
def _add_node_to_pipeline_graph(
|
def _add_node_to_pipeline_graph(
|
||||||
graph: nx.DiGraph, components: Dict[str, Dict[str, str]], node: Dict[str, Any], instance: BaseComponent = None
|
graph: nx.DiGraph, components: Dict[str, Dict[str, Any]], node: Dict[str, Any], instance: BaseComponent = None
|
||||||
) -> nx.DiGraph:
|
) -> nx.DiGraph:
|
||||||
"""
|
"""
|
||||||
Adds a single node to the provided graph, performing all necessary validation steps.
|
Adds a single node to the provided graph, performing all necessary validation steps.
|
||||||
@ -449,6 +449,7 @@ def _add_node_to_pipeline_graph(
|
|||||||
|
|
||||||
graph.add_node(node["name"], component=instance, inputs=node["inputs"])
|
graph.add_node(node["name"], component=instance, inputs=node["inputs"])
|
||||||
|
|
||||||
|
try:
|
||||||
for input_node in node["inputs"]:
|
for input_node in node["inputs"]:
|
||||||
|
|
||||||
# Separate node and edge name, if specified
|
# Separate node and edge name, if specified
|
||||||
@ -469,7 +470,8 @@ def _add_node_to_pipeline_graph(
|
|||||||
else:
|
else:
|
||||||
# Validate node definition and edge name
|
# Validate node definition and edge name
|
||||||
input_node_type = _get_defined_node_class(node_name=input_node_name, components=components)
|
input_node_type = _get_defined_node_class(node_name=input_node_name, components=components)
|
||||||
input_node_edges_count = input_node_type.outgoing_edges
|
component_params: Dict[str, Any] = components[input_node_name].get("params", {})
|
||||||
|
input_node_edges_count = input_node_type._calculate_outgoing_edges(component_params=component_params)
|
||||||
|
|
||||||
if not input_edge_name:
|
if not input_edge_name:
|
||||||
if input_node_edges_count != 1: # Edge was not specified, but input node has many outputs
|
if input_node_edges_count != 1: # Edge was not specified, but input node has many outputs
|
||||||
@ -505,9 +507,12 @@ def _add_node_to_pipeline_graph(
|
|||||||
|
|
||||||
# Check if adding this edge created a loop in the pipeline graph
|
# Check if adding this edge created a loop in the pipeline graph
|
||||||
if not nx.is_directed_acyclic_graph(graph):
|
if not nx.is_directed_acyclic_graph(graph):
|
||||||
graph.remove_node(node["name"])
|
|
||||||
raise PipelineConfigError(f"Cannot add '{node['name']}': it will create a loop in the pipeline.")
|
raise PipelineConfigError(f"Cannot add '{node['name']}': it will create a loop in the pipeline.")
|
||||||
|
|
||||||
|
except PipelineConfigError:
|
||||||
|
graph.remove_node(node["name"])
|
||||||
|
raise
|
||||||
|
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,5 +12,7 @@ from haystack.nodes.base import BaseComponent
|
|||||||
|
|
||||||
|
|
||||||
class SampleComponent(BaseComponent):
|
class SampleComponent(BaseComponent):
|
||||||
|
outgoing_edges: int = 1
|
||||||
|
|
||||||
def run(self, **kwargs):
|
def run(self, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -53,11 +53,6 @@ def test_filetype_classifier_custom_extensions(tmp_path):
|
|||||||
assert output == {"file_paths": [test_file]}
|
assert output == {"file_paths": [test_file]}
|
||||||
|
|
||||||
|
|
||||||
def test_filetype_classifier_too_many_custom_extensions():
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
FileTypeClassifier(supported_types=[f"my_extension_{idx}" for idx in range(20)])
|
|
||||||
|
|
||||||
|
|
||||||
def test_filetype_classifier_duplicate_custom_extensions():
|
def test_filetype_classifier_duplicate_custom_extensions():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
FileTypeClassifier(supported_types=[f"my_extension", "my_extension"])
|
FileTypeClassifier(supported_types=[f"my_extension", "my_extension"])
|
||||||
|
@ -29,6 +29,7 @@ def test_routedocuments_by_content_type():
|
|||||||
|
|
||||||
def test_routedocuments_by_metafield(docs):
|
def test_routedocuments_by_metafield(docs):
|
||||||
route_documents = RouteDocuments(split_by="meta_field", metadata_values=["test1", "test3", "test5"])
|
route_documents = RouteDocuments(split_by="meta_field", metadata_values=["test1", "test3", "test5"])
|
||||||
|
assert route_documents.outgoing_edges == 3
|
||||||
result, _ = route_documents.run(docs)
|
result, _ = route_documents.run(docs)
|
||||||
assert len(result["output_1"]) == 1
|
assert len(result["output_1"]) == 1
|
||||||
assert len(result["output_2"]) == 1
|
assert len(result["output_2"]) == 1
|
||||||
|
@ -91,6 +91,8 @@ class ParentComponent2(BaseComponent):
|
|||||||
|
|
||||||
|
|
||||||
class ChildComponent(BaseComponent):
|
class ChildComponent(BaseComponent):
|
||||||
|
outgoing_edges = 0
|
||||||
|
|
||||||
def __init__(self, some_key: str = None) -> None:
|
def __init__(self, some_key: str = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user