feat: Add Type Validation parameter for Pipeline Connections (#8875)

* Starting to refactor type util tests to be more systematic

* refactoring

* Expand tests

* Update to type utils

* Add missing subclass check

* Expand and refactor tests, introduce type_validation Literal

* More test refactoring

* Test refactoring, adding type validation variable to pipeline base

* Update relaxed version of type checking to pass all newly added tests

* trim whitespace

* Add tests

* cleanup

* Updates docstrings

* Add reno

* docs

* Fix mypy and add docstrings

* Changes based on advice from Tobi

* Remove unused imports

* Doc strings

* Add connection type validation to to_dict and from_dict

* Update tests

* Fix test

* Also save connection_type_validation at global pipeline level

* Fix tests

* Remove connection type validation from the connect level, only keep at pipeline level

* Formatting

* Fix tests

* formatting
This commit is contained in:
Sebastian Husch Lee 2025-03-03 16:00:22 +01:00 committed by GitHub
parent 00fe4d157d
commit 296e31c182
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 97 additions and 49 deletions

View File

@ -68,7 +68,7 @@ Some examples of what you can do with Haystack:
> [!TIP] > [!TIP]
> >
> Would you like to deploy and serve Haystack pipelines as REST APIs yourself? [Hayhooks](https://github.com/deepset-ai/hayhooks) provides a simple way to wrap your pipelines with custom logic and expose them via HTTP endpoints, including OpenAI-compatible chat completion endpoints and compatibility with fully-featured chat interfaces like [open-webui](https://openwebui.com/). > Would you like to deploy and serve Haystack pipelines as REST APIs yourself? [Hayhooks](https://github.com/deepset-ai/hayhooks) provides a simple way to wrap your pipelines with custom logic and expose them via HTTP endpoints, including OpenAI-compatible chat completion endpoints and compatibility with fully-featured chat interfaces like [open-webui](https://openwebui.com/).
## 🆕 deepset Studio: Your Development Environment for Haystack ## 🆕 deepset Studio: Your Development Environment for Haystack

View File

@ -68,7 +68,12 @@ class PipelineBase:
Builds a graph of components and orchestrates their execution according to the execution graph. Builds a graph of components and orchestrates their execution according to the execution graph.
""" """
def __init__(self, metadata: Optional[Dict[str, Any]] = None, max_runs_per_component: int = 100): def __init__(
self,
metadata: Optional[Dict[str, Any]] = None,
max_runs_per_component: int = 100,
connection_type_validation: bool = True,
):
""" """
Creates the Pipeline. Creates the Pipeline.
@ -79,12 +84,15 @@ class PipelineBase:
How many times the `Pipeline` can run the same Component. How many times the `Pipeline` can run the same Component.
If this limit is reached a `PipelineMaxComponentRuns` exception is raised. If this limit is reached a `PipelineMaxComponentRuns` exception is raised.
If not set defaults to 100 runs per Component. If not set defaults to 100 runs per Component.
:param connection_type_validation: Whether the pipeline will validate the types of the connections.
Defaults to True.
""" """
self._telemetry_runs = 0 self._telemetry_runs = 0
self._last_telemetry_sent: Optional[datetime] = None self._last_telemetry_sent: Optional[datetime] = None
self.metadata = metadata or {} self.metadata = metadata or {}
self.graph = networkx.MultiDiGraph() self.graph = networkx.MultiDiGraph()
self._max_runs_per_component = max_runs_per_component self._max_runs_per_component = max_runs_per_component
self._connection_type_validation = connection_type_validation
def __eq__(self, other) -> bool: def __eq__(self, other) -> bool:
""" """
@ -142,6 +150,7 @@ class PipelineBase:
"max_runs_per_component": self._max_runs_per_component, "max_runs_per_component": self._max_runs_per_component,
"components": components, "components": components,
"connections": connections, "connections": connections,
"connection_type_validation": self._connection_type_validation,
} }
@classmethod @classmethod
@ -164,7 +173,12 @@ class PipelineBase:
data_copy = deepcopy(data) # to prevent modification of original data data_copy = deepcopy(data) # to prevent modification of original data
metadata = data_copy.get("metadata", {}) metadata = data_copy.get("metadata", {})
max_runs_per_component = data_copy.get("max_runs_per_component", 100) max_runs_per_component = data_copy.get("max_runs_per_component", 100)
pipe = cls(metadata=metadata, max_runs_per_component=max_runs_per_component) connection_type_validation = data_copy.get("connection_type_validation", True)
pipe = cls(
metadata=metadata,
max_runs_per_component=max_runs_per_component,
connection_type_validation=connection_type_validation,
)
components_to_reuse = kwargs.get("components", {}) components_to_reuse = kwargs.get("components", {})
for name, component_data in data_copy.get("components", {}).items(): for name, component_data in data_copy.get("components", {}).items():
if name in components_to_reuse: if name in components_to_reuse:
@ -402,6 +416,8 @@ class PipelineBase:
:param receiver: :param receiver:
The component that receives the value. This can be either just a component name or can be The component that receives the value. This can be either just a component name or can be
in the format `component_name.connection_name` if the component has multiple inputs. in the format `component_name.connection_name` if the component has multiple inputs.
:param connection_type_validation: Whether the pipeline will validate the types of the connections.
Defaults to the value set in the pipeline.
:returns: :returns:
The Pipeline instance. The Pipeline instance.
@ -418,48 +434,51 @@ class PipelineBase:
# Get the nodes data. # Get the nodes data.
try: try:
from_sockets = self.graph.nodes[sender_component_name]["output_sockets"] sender_sockets = self.graph.nodes[sender_component_name]["output_sockets"]
except KeyError as exc: except KeyError as exc:
raise ValueError(f"Component named {sender_component_name} not found in the pipeline.") from exc raise ValueError(f"Component named {sender_component_name} not found in the pipeline.") from exc
try: try:
to_sockets = self.graph.nodes[receiver_component_name]["input_sockets"] receiver_sockets = self.graph.nodes[receiver_component_name]["input_sockets"]
except KeyError as exc: except KeyError as exc:
raise ValueError(f"Component named {receiver_component_name} not found in the pipeline.") from exc raise ValueError(f"Component named {receiver_component_name} not found in the pipeline.") from exc
# If the name of either socket is given, get the socket # If the name of either socket is given, get the socket
sender_socket: Optional[OutputSocket] = None sender_socket: Optional[OutputSocket] = None
if sender_socket_name: if sender_socket_name:
sender_socket = from_sockets.get(sender_socket_name) sender_socket = sender_sockets.get(sender_socket_name)
if not sender_socket: if not sender_socket:
raise PipelineConnectError( raise PipelineConnectError(
f"'{sender} does not exist. " f"'{sender} does not exist. "
f"Output connections of {sender_component_name} are: " f"Output connections of {sender_component_name} are: "
+ ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in from_sockets.items()]) + ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in sender_sockets.items()])
) )
receiver_socket: Optional[InputSocket] = None receiver_socket: Optional[InputSocket] = None
if receiver_socket_name: if receiver_socket_name:
receiver_socket = to_sockets.get(receiver_socket_name) receiver_socket = receiver_sockets.get(receiver_socket_name)
if not receiver_socket: if not receiver_socket:
raise PipelineConnectError( raise PipelineConnectError(
f"'{receiver} does not exist. " f"'{receiver} does not exist. "
f"Input connections of {receiver_component_name} are: " f"Input connections of {receiver_component_name} are: "
+ ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in to_sockets.items()]) + ", ".join(
[f"{name} (type {_type_name(socket.type)})" for name, socket in receiver_sockets.items()]
)
) )
# Look for a matching connection among the possible ones. # Look for a matching connection among the possible ones.
# Note that if there is more than one possible connection but two sockets match by name, they're paired. # Note that if there is more than one possible connection but two sockets match by name, they're paired.
sender_socket_candidates: List[OutputSocket] = [sender_socket] if sender_socket else list(from_sockets.values()) sender_socket_candidates: List[OutputSocket] = (
[sender_socket] if sender_socket else list(sender_sockets.values())
)
receiver_socket_candidates: List[InputSocket] = ( receiver_socket_candidates: List[InputSocket] = (
[receiver_socket] if receiver_socket else list(to_sockets.values()) [receiver_socket] if receiver_socket else list(receiver_sockets.values())
) )
# Find all possible connections between these two components # Find all possible connections between these two components
possible_connections = [ possible_connections = []
(sender_sock, receiver_sock) for sender_sock, receiver_sock in itertools.product(sender_socket_candidates, receiver_socket_candidates):
for sender_sock, receiver_sock in itertools.product(sender_socket_candidates, receiver_socket_candidates) if _types_are_compatible(sender_sock.type, receiver_sock.type, self._connection_type_validation):
if _types_are_compatible(sender_sock.type, receiver_sock.type) possible_connections.append((sender_sock, receiver_sock))
]
# We need this status for error messages, since we might need it in multiple places we calculate it here # We need this status for error messages, since we might need it in multiple places we calculate it here
status = _connections_status( status = _connections_status(
@ -860,7 +879,7 @@ class PipelineBase:
def _find_receivers_from(self, component_name: str) -> List[Tuple[str, OutputSocket, InputSocket]]: def _find_receivers_from(self, component_name: str) -> List[Tuple[str, OutputSocket, InputSocket]]:
""" """
Utility function to find all Components that receive input form `component_name`. Utility function to find all Components that receive input from `component_name`.
:param component_name: :param component_name:
Name of the sender Component Name of the sender Component
@ -1179,7 +1198,7 @@ class PipelineBase:
def _connections_status( def _connections_status(
sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket] sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]
): ) -> str:
""" """
Lists the status of the sockets, for error messages. Lists the status of the sockets, for error messages.
""" """

View File

@ -2,29 +2,42 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Union, get_args, get_origin from typing import Any, TypeVar, Union, get_args, get_origin
from haystack import logging from haystack import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
T = TypeVar("T")
def _is_optional(type_: type) -> bool:
"""
Utility method that returns whether a type is Optional.
"""
return get_origin(type_) is Union and type(None) in get_args(type_)
def _types_are_compatible(sender, receiver): # pylint: disable=too-many-return-statements def _types_are_compatible(sender, receiver, type_validation: bool = True) -> bool:
""" """
Checks whether the source type is equal or a subtype of the destination type. Used to validate pipeline connections. Determines if two types are compatible based on the specified validation mode.
:param sender: The sender type.
:param receiver: The receiver type.
:param type_validation: Whether to perform strict type validation.
:return: True if the types are compatible, False otherwise.
"""
if type_validation:
return _strict_types_are_compatible(sender, receiver)
else:
return True
def _strict_types_are_compatible(sender, receiver): # pylint: disable=too-many-return-statements
"""
Checks whether the sender type is equal to or a subtype of the receiver type under strict validation.
Note: this method has no pretense to perform proper type matching. It especially does not deal with aliasing of Note: this method has no pretense to perform proper type matching. It especially does not deal with aliasing of
typing classes such as `List` or `Dict` to their runtime counterparts `list` and `dict`. It also does not deal well typing classes such as `List` or `Dict` to their runtime counterparts `list` and `dict`. It also does not deal well
with "bare" types, so `List` is treated differently from `List[Any]`, even though they should be the same. with "bare" types, so `List` is treated differently from `List[Any]`, even though they should be the same.
Consider simplifying the typing of your components if you observe unexpected errors during component connection. Consider simplifying the typing of your components if you observe unexpected errors during component connection.
:param sender: The sender type.
:param receiver: The receiver type.
:return: True if the sender type is strictly compatible with the receiver type, False otherwise.
""" """
if sender == receiver or receiver is Any: if sender == receiver or receiver is Any:
return True return True
@ -42,17 +55,19 @@ def _types_are_compatible(sender, receiver): # pylint: disable=too-many-return-
receiver_origin = get_origin(receiver) receiver_origin = get_origin(receiver)
if sender_origin is not Union and receiver_origin is Union: if sender_origin is not Union and receiver_origin is Union:
return any(_types_are_compatible(sender, union_arg) for union_arg in get_args(receiver)) return any(_strict_types_are_compatible(sender, union_arg) for union_arg in get_args(receiver))
if not sender_origin or not receiver_origin or sender_origin != receiver_origin: # Both must have origins and they must be equal
if not (sender_origin and receiver_origin and sender_origin == receiver_origin):
return False return False
# Compare generic type arguments
sender_args = get_args(sender) sender_args = get_args(sender)
receiver_args = get_args(receiver) receiver_args = get_args(receiver)
if len(sender_args) > len(receiver_args): if len(sender_args) > len(receiver_args):
return False return False
return all(_types_are_compatible(*args) for args in zip(sender_args, receiver_args)) return all(_strict_types_are_compatible(*args) for args in zip(sender_args, receiver_args))
def _type_name(type_): def _type_name(type_):

View File

@ -0,0 +1,5 @@
---
features:
- |
We've introduced a new type_validation parameter to control type compatibility checks in pipeline connections.
It can be set to True (default) or False which means no type checks will be done and everything is allowed.

View File

@ -151,6 +151,7 @@ class TestOpenAPIConnector:
assert pipeline_dict == { assert pipeline_dict == {
"metadata": {}, "metadata": {},
"max_runs_per_component": 100, "max_runs_per_component": 100,
"connection_type_validation": True,
"components": { "components": {
"api": { "api": {
"type": "haystack.components.connectors.openapi.OpenAPIConnector", "type": "haystack.components.connectors.openapi.OpenAPIConnector",

View File

@ -218,6 +218,7 @@ class TestOpenAPIServiceConnector:
assert pipeline_dict == { assert pipeline_dict == {
"metadata": {}, "metadata": {},
"max_runs_per_component": 100, "max_runs_per_component": 100,
"connection_type_validation": True,
"components": { "components": {
"connector": { "connector": {
"type": "haystack.components.connectors.openapi_service.OpenAPIServiceConnector", "type": "haystack.components.connectors.openapi_service.OpenAPIServiceConnector",

View File

@ -171,6 +171,7 @@ class TestAzureOpenAIChatGenerator:
assert p.to_dict() == { assert p.to_dict() == {
"metadata": {}, "metadata": {},
"max_runs_per_component": 100, "max_runs_per_component": 100,
"connection_type_validation": True,
"components": { "components": {
"generator": { "generator": {
"type": "haystack.components.generators.chat.azure.AzureOpenAIChatGenerator", "type": "haystack.components.generators.chat.azure.AzureOpenAIChatGenerator",

View File

@ -270,6 +270,7 @@ class TestHuggingFaceAPIChatGenerator:
assert pipeline_dict == { assert pipeline_dict == {
"metadata": {}, "metadata": {},
"max_runs_per_component": 100, "max_runs_per_component": 100,
"connection_type_validation": True,
"components": { "components": {
"generator": { "generator": {
"type": "haystack.components.generators.chat.hugging_face_api.HuggingFaceAPIChatGenerator", "type": "haystack.components.generators.chat.hugging_face_api.HuggingFaceAPIChatGenerator",

View File

@ -354,6 +354,7 @@ class TestFileTypeRouter:
assert pipeline_dict == { assert pipeline_dict == {
"metadata": {}, "metadata": {},
"max_runs_per_component": 100, "max_runs_per_component": 100,
"connection_type_validation": True,
"components": { "components": {
"file_type_router": { "file_type_router": {
"type": "haystack.components.routers.file_type_router.FileTypeRouter", "type": "haystack.components.routers.file_type_router.FileTypeRouter",

View File

@ -232,6 +232,7 @@ class TestToolInvoker:
assert pipeline_dict == { assert pipeline_dict == {
"metadata": {}, "metadata": {},
"max_runs_per_component": 100, "max_runs_per_component": 100,
"connection_type_validation": True,
"components": { "components": {
"invoker": { "invoker": {
"type": "haystack.components.tools.tool_invoker.ToolInvoker", "type": "haystack.components.tools.tool_invoker.ToolInvoker",

View File

@ -308,6 +308,7 @@ class TestPipelineBase:
expected = { expected = {
"metadata": {"test": "test"}, "metadata": {"test": "test"},
"max_runs_per_component": 42, "max_runs_per_component": 42,
"connection_type_validation": True,
"components": { "components": {
"add_two": { "add_two": {
"type": "haystack.testing.sample_components.add_value.AddFixedValue", "type": "haystack.testing.sample_components.add_value.AddFixedValue",

View File

@ -76,7 +76,7 @@ def generate_symmetric_cases():
] ]
def generate_asymmetric_cases(): def generate_strict_asymmetric_cases():
"""Generate asymmetric test cases with different sender and receiver types.""" """Generate asymmetric test cases with different sender and receiver types."""
cases = [] cases = []
@ -194,6 +194,11 @@ def generate_asymmetric_cases():
(Dict[(str, Mapping[(Any, Dict[(Any, Any)])])]), (Dict[(str, Mapping[(Any, Dict[(Any, Any)])])]),
id="nested-mapping-of-classes-to-nested-mapping-of-any-keys-and-values", id="nested-mapping-of-classes-to-nested-mapping-of-any-keys-and-values",
), ),
pytest.param(
(Tuple[Literal["a", "b", "c"], Union[(Path, Dict[(int, Class1)])]]),
(Tuple[Optional[Literal["a", "b", "c"]], Union[(Path, Dict[(int, Class1)])]]),
id="deeply-nested-complex-type",
),
] ]
) )
@ -202,7 +207,7 @@ def generate_asymmetric_cases():
# Precompute test cases for reuse # Precompute test cases for reuse
symmetric_cases = generate_symmetric_cases() symmetric_cases = generate_symmetric_cases()
asymmetric_cases = generate_asymmetric_cases() asymmetric_cases = generate_strict_asymmetric_cases()
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -261,25 +266,28 @@ def test_type_name(type_, repr_):
@pytest.mark.parametrize("sender_type, receiver_type", symmetric_cases) @pytest.mark.parametrize("sender_type, receiver_type", symmetric_cases)
def test_same_types_are_compatible(sender_type, receiver_type): def test_same_types_are_compatible_strict(sender_type, receiver_type):
assert _types_are_compatible(sender_type, receiver_type) assert _types_are_compatible(sender_type, receiver_type, "strict")
@pytest.mark.parametrize("sender_type, receiver_type", asymmetric_cases) @pytest.mark.parametrize("sender_type, receiver_type", asymmetric_cases)
def test_asymmetric_types_are_compatible(sender_type, receiver_type): def test_asymmetric_types_are_compatible_strict(sender_type, receiver_type):
assert _types_are_compatible(sender_type, receiver_type) assert _types_are_compatible(sender_type, receiver_type, "strict")
@pytest.mark.parametrize("sender_type, receiver_type", asymmetric_cases) @pytest.mark.parametrize("sender_type, receiver_type", asymmetric_cases)
def test_asymmetric_types_are_not_compatible(sender_type, receiver_type): def test_asymmetric_types_are_not_compatible_strict(sender_type, receiver_type):
assert not _types_are_compatible(receiver_type, sender_type) assert not _types_are_compatible(receiver_type, sender_type, "strict")
incompatible_type_cases = [ incompatible_type_cases = [
pytest.param(Tuple[int, str], Tuple[Any], id="tuple-of-primitive-to-tuple-of-any-different-lengths"),
pytest.param(int, str, id="different-primitives"), pytest.param(int, str, id="different-primitives"),
pytest.param(Class1, Class2, id="different-classes"), pytest.param(Class1, Class2, id="different-classes"),
pytest.param((List[int]), (List[str]), id="different-lists-of-primitives"), pytest.param((List[int]), (List[str]), id="different-lists-of-primitives"),
pytest.param((List[Class1]), (List[Class2]), id="different-lists-of-classes"), pytest.param((List[Class1]), (List[Class2]), id="different-lists-of-classes"),
pytest.param((Literal["a", "b", "c"]), (Literal["x", "y"]), id="different-literal-of-same-primitive"),
pytest.param((Literal[Enum1.TEST1]), (Literal[Enum1.TEST2]), id="different-literal-of-same-enum"),
pytest.param( pytest.param(
(List[Set[Sequence[str]]]), (List[Set[Sequence[bool]]]), id="nested-sequences-of-different-primitives" (List[Set[Sequence[str]]]), (List[Set[Sequence[bool]]]), id="nested-sequences-of-different-primitives"
), ),
@ -327,19 +335,11 @@ incompatible_type_cases = [
(Dict[(str, Mapping[(str, Dict[(str, Class2)])])]), (Dict[(str, Mapping[(str, Dict[(str, Class2)])])]),
id="same-nested-mappings-of-class-to-subclass-values", id="same-nested-mappings-of-class-to-subclass-values",
), ),
pytest.param((Literal["a", "b", "c"]), (Literal["x", "y"]), id="different-literal-of-same-primitive"),
pytest.param((Literal[Enum1.TEST1]), (Literal[Enum1.TEST2]), id="different-literal-of-same-enum"),
] ]
@pytest.mark.parametrize("sender_type, receiver_type", incompatible_type_cases) @pytest.mark.parametrize("sender_type, receiver_type", incompatible_type_cases)
def test_types_are_always_not_compatible(sender_type, receiver_type): def test_types_are_always_not_compatible_strict(sender_type, receiver_type):
assert not _types_are_compatible(sender_type, receiver_type)
def test_deeply_nested_type_is_compatible_but_cannot_be_checked():
sender_type = Tuple[Optional[Literal["a", "b", "c"]], Union[(Path, Dict[(int, Class1)])]]
receiver_type = Tuple[Literal["a", "b", "c"], Union[(Path, Dict[(int, Class1)])]]
assert not _types_are_compatible(sender_type, receiver_type) assert not _types_are_compatible(sender_type, receiver_type)
@ -350,7 +350,7 @@ def test_deeply_nested_type_is_compatible_but_cannot_be_checked():
pytest.param((Union[(int, Class1)]), (Union[(int, Class2)]), id="partially-overlapping-unions-with-classes"), pytest.param((Union[(int, Class1)]), (Union[(int, Class2)]), id="partially-overlapping-unions-with-classes"),
], ],
) )
def test_partially_overlapping_unions(sender_type, receiver_type): def test_partially_overlapping_unions_are_not_compatible_strict(sender_type, receiver_type):
assert not _types_are_compatible(sender_type, receiver_type) assert not _types_are_compatible(sender_type, receiver_type)
@ -362,4 +362,5 @@ def test_partially_overlapping_unions(sender_type, receiver_type):
], ],
) )
def test_list_of_primitive_to_list(sender_type, receiver_type): def test_list_of_primitive_to_list(sender_type, receiver_type):
"""This currently doesn't work because we don't handle bare types without arguments."""
assert not _types_are_compatible(sender_type, receiver_type) assert not _types_are_compatible(sender_type, receiver_type)

View File

@ -7,6 +7,7 @@ components:
init_parameters: init_parameters:
an_init_param: null an_init_param: null
type: test.core.pipeline.test_pipeline_base.FakeComponent type: test.core.pipeline.test_pipeline_base.FakeComponent
connection_type_validation: true
connections: connections:
- receiver: Comp2.input_ - receiver: Comp2.input_
sender: Comp1.value sender: Comp1.value