From 296e31c182ac12a22e4ea243e89d2af7873de06c Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Mon, 3 Mar 2025 16:00:22 +0100 Subject: [PATCH] feat: Add Type Validation parameter for Pipeline Connections (#8875) * Starting to refactor type util tests to be more systematic * refactoring * Expand tests * Update to type utils * Add missing subclass check * Expand and refactor tests, introduce type_validation Literal * More test refactoring * Test refactoring, adding type validation variable to pipeline base * Update relaxed version of type checking to pass all newly added tests * trim whitespace * Add tests * cleanup * Updates docstrings * Add reno * docs * Fix mypy and add docstrings * Changes based on advice from Tobi * Remove unused imports * Doc strings * Add connection type validation to to_dict and from_dict * Update tests * Fix test * Also save connection_type_validation at global pipeline level * Fix tests * Remove connection type validation from the connect level, only keep at pipeline level * Formatting * Fix tests * formatting --- README.md | 2 +- haystack/core/pipeline/base.py | 53 +++++++++++++------ haystack/core/type_utils.py | 41 +++++++++----- ...onnection-validation-6ca8b2d9741c225b.yaml | 5 ++ .../connectors/test_openapi_connector.py | 1 + .../connectors/test_openapi_service.py | 1 + test/components/generators/chat/test_azure.py | 1 + .../generators/chat/test_hugging_face_api.py | 1 + test/components/routers/test_file_router.py | 1 + test/components/tools/test_tool_invoker.py | 1 + test/core/pipeline/test_pipeline_base.py | 1 + test/core/pipeline/test_type_utils.py | 37 ++++++------- test/test_files/yaml/test_pipeline.yaml | 1 + 13 files changed, 97 insertions(+), 49 deletions(-) create mode 100644 releasenotes/notes/add-relaxed-and-disabled-pipeline-connection-validation-6ca8b2d9741c225b.yaml diff --git a/README.md b/README.md index f1e9c9a03..51952278f 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ Some examples of what you can do with Haystack: > [!TIP] > -> Would you like to deploy and serve Haystack pipelines as REST APIs yourself? [Hayhooks](https://github.com/deepset-ai/hayhooks) provides a simple way to wrap your pipelines with custom logic and expose them via HTTP endpoints, including OpenAI-compatible chat completion endpoints and compatibility with fully-featured chat interfaces like [open-webui](https://openwebui.com/). +> Would you like to deploy and serve Haystack pipelines as REST APIs yourself? [Hayhooks](https://github.com/deepset-ai/hayhooks) provides a simple way to wrap your pipelines with custom logic and expose them via HTTP endpoints, including OpenAI-compatible chat completion endpoints and compatibility with fully-featured chat interfaces like [open-webui](https://openwebui.com/). ## 🆕 deepset Studio: Your Development Environment for Haystack diff --git a/haystack/core/pipeline/base.py b/haystack/core/pipeline/base.py index 2ab1d172e..8068d4e61 100644 --- a/haystack/core/pipeline/base.py +++ b/haystack/core/pipeline/base.py @@ -68,7 +68,12 @@ class PipelineBase: Builds a graph of components and orchestrates their execution according to the execution graph. """ - def __init__(self, metadata: Optional[Dict[str, Any]] = None, max_runs_per_component: int = 100): + def __init__( + self, + metadata: Optional[Dict[str, Any]] = None, + max_runs_per_component: int = 100, + connection_type_validation: bool = True, + ): """ Creates the Pipeline. @@ -79,12 +84,15 @@ class PipelineBase: How many times the `Pipeline` can run the same Component. If this limit is reached a `PipelineMaxComponentRuns` exception is raised. If not set defaults to 100 runs per Component. + :param connection_type_validation: Whether the pipeline will validate the types of the connections. + Defaults to True. """ self._telemetry_runs = 0 self._last_telemetry_sent: Optional[datetime] = None self.metadata = metadata or {} self.graph = networkx.MultiDiGraph() self._max_runs_per_component = max_runs_per_component + self._connection_type_validation = connection_type_validation def __eq__(self, other) -> bool: """ @@ -142,6 +150,7 @@ class PipelineBase: "max_runs_per_component": self._max_runs_per_component, "components": components, "connections": connections, + "connection_type_validation": self._connection_type_validation, } @classmethod @@ -164,7 +173,12 @@ class PipelineBase: data_copy = deepcopy(data) # to prevent modification of original data metadata = data_copy.get("metadata", {}) max_runs_per_component = data_copy.get("max_runs_per_component", 100) - pipe = cls(metadata=metadata, max_runs_per_component=max_runs_per_component) + connection_type_validation = data_copy.get("connection_type_validation", True) + pipe = cls( + metadata=metadata, + max_runs_per_component=max_runs_per_component, + connection_type_validation=connection_type_validation, + ) components_to_reuse = kwargs.get("components", {}) for name, component_data in data_copy.get("components", {}).items(): if name in components_to_reuse: @@ -402,6 +416,8 @@ class PipelineBase: :param receiver: The component that receives the value. This can be either just a component name or can be in the format `component_name.connection_name` if the component has multiple inputs. + :param connection_type_validation: Whether the pipeline will validate the types of the connections. + Defaults to the value set in the pipeline. :returns: The Pipeline instance. @@ -418,48 +434,51 @@ class PipelineBase: # Get the nodes data. try: - from_sockets = self.graph.nodes[sender_component_name]["output_sockets"] + sender_sockets = self.graph.nodes[sender_component_name]["output_sockets"] except KeyError as exc: raise ValueError(f"Component named {sender_component_name} not found in the pipeline.") from exc try: - to_sockets = self.graph.nodes[receiver_component_name]["input_sockets"] + receiver_sockets = self.graph.nodes[receiver_component_name]["input_sockets"] except KeyError as exc: raise ValueError(f"Component named {receiver_component_name} not found in the pipeline.") from exc # If the name of either socket is given, get the socket sender_socket: Optional[OutputSocket] = None if sender_socket_name: - sender_socket = from_sockets.get(sender_socket_name) + sender_socket = sender_sockets.get(sender_socket_name) if not sender_socket: raise PipelineConnectError( f"'{sender} does not exist. " f"Output connections of {sender_component_name} are: " - + ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in from_sockets.items()]) + + ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in sender_sockets.items()]) ) receiver_socket: Optional[InputSocket] = None if receiver_socket_name: - receiver_socket = to_sockets.get(receiver_socket_name) + receiver_socket = receiver_sockets.get(receiver_socket_name) if not receiver_socket: raise PipelineConnectError( f"'{receiver} does not exist. " f"Input connections of {receiver_component_name} are: " - + ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in to_sockets.items()]) + + ", ".join( + [f"{name} (type {_type_name(socket.type)})" for name, socket in receiver_sockets.items()] + ) ) # Look for a matching connection among the possible ones. # Note that if there is more than one possible connection but two sockets match by name, they're paired. - sender_socket_candidates: List[OutputSocket] = [sender_socket] if sender_socket else list(from_sockets.values()) + sender_socket_candidates: List[OutputSocket] = ( + [sender_socket] if sender_socket else list(sender_sockets.values()) + ) receiver_socket_candidates: List[InputSocket] = ( - [receiver_socket] if receiver_socket else list(to_sockets.values()) + [receiver_socket] if receiver_socket else list(receiver_sockets.values()) ) # Find all possible connections between these two components - possible_connections = [ - (sender_sock, receiver_sock) - for sender_sock, receiver_sock in itertools.product(sender_socket_candidates, receiver_socket_candidates) - if _types_are_compatible(sender_sock.type, receiver_sock.type) - ] + possible_connections = [] + for sender_sock, receiver_sock in itertools.product(sender_socket_candidates, receiver_socket_candidates): + if _types_are_compatible(sender_sock.type, receiver_sock.type, self._connection_type_validation): + possible_connections.append((sender_sock, receiver_sock)) # We need this status for error messages, since we might need it in multiple places we calculate it here status = _connections_status( @@ -860,7 +879,7 @@ class PipelineBase: def _find_receivers_from(self, component_name: str) -> List[Tuple[str, OutputSocket, InputSocket]]: """ - Utility function to find all Components that receive input form `component_name`. + Utility function to find all Components that receive input from `component_name`. :param component_name: Name of the sender Component @@ -1179,7 +1198,7 @@ class PipelineBase: def _connections_status( sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket] -): +) -> str: """ Lists the status of the sockets, for error messages. """ diff --git a/haystack/core/type_utils.py b/haystack/core/type_utils.py index 06651cd56..4b84c74dc 100644 --- a/haystack/core/type_utils.py +++ b/haystack/core/type_utils.py @@ -2,29 +2,42 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Union, get_args, get_origin +from typing import Any, TypeVar, Union, get_args, get_origin from haystack import logging logger = logging.getLogger(__name__) - -def _is_optional(type_: type) -> bool: - """ - Utility method that returns whether a type is Optional. - """ - return get_origin(type_) is Union and type(None) in get_args(type_) +T = TypeVar("T") -def _types_are_compatible(sender, receiver): # pylint: disable=too-many-return-statements +def _types_are_compatible(sender, receiver, type_validation: bool = True) -> bool: """ - Checks whether the source type is equal or a subtype of the destination type. Used to validate pipeline connections. + Determines if two types are compatible based on the specified validation mode. + + :param sender: The sender type. + :param receiver: The receiver type. + :param type_validation: Whether to perform strict type validation. + :return: True if the types are compatible, False otherwise. + """ + if type_validation: + return _strict_types_are_compatible(sender, receiver) + else: + return True + + +def _strict_types_are_compatible(sender, receiver): # pylint: disable=too-many-return-statements + """ + Checks whether the sender type is equal to or a subtype of the receiver type under strict validation. Note: this method has no pretense to perform proper type matching. It especially does not deal with aliasing of typing classes such as `List` or `Dict` to their runtime counterparts `list` and `dict`. It also does not deal well with "bare" types, so `List` is treated differently from `List[Any]`, even though they should be the same. - Consider simplifying the typing of your components if you observe unexpected errors during component connection. + + :param sender: The sender type. + :param receiver: The receiver type. + :return: True if the sender type is strictly compatible with the receiver type, False otherwise. """ if sender == receiver or receiver is Any: return True @@ -42,17 +55,19 @@ def _types_are_compatible(sender, receiver): # pylint: disable=too-many-return- receiver_origin = get_origin(receiver) if sender_origin is not Union and receiver_origin is Union: - return any(_types_are_compatible(sender, union_arg) for union_arg in get_args(receiver)) + return any(_strict_types_are_compatible(sender, union_arg) for union_arg in get_args(receiver)) - if not sender_origin or not receiver_origin or sender_origin != receiver_origin: + # Both must have origins and they must be equal + if not (sender_origin and receiver_origin and sender_origin == receiver_origin): return False + # Compare generic type arguments sender_args = get_args(sender) receiver_args = get_args(receiver) if len(sender_args) > len(receiver_args): return False - return all(_types_are_compatible(*args) for args in zip(sender_args, receiver_args)) + return all(_strict_types_are_compatible(*args) for args in zip(sender_args, receiver_args)) def _type_name(type_): diff --git a/releasenotes/notes/add-relaxed-and-disabled-pipeline-connection-validation-6ca8b2d9741c225b.yaml b/releasenotes/notes/add-relaxed-and-disabled-pipeline-connection-validation-6ca8b2d9741c225b.yaml new file mode 100644 index 000000000..6cc819ce1 --- /dev/null +++ b/releasenotes/notes/add-relaxed-and-disabled-pipeline-connection-validation-6ca8b2d9741c225b.yaml @@ -0,0 +1,5 @@ +--- +features: + - | + We've introduced a new type_validation parameter to control type compatibility checks in pipeline connections. + It can be set to True (default) or False which means no type checks will be done and everything is allowed. diff --git a/test/components/connectors/test_openapi_connector.py b/test/components/connectors/test_openapi_connector.py index 48cdfade6..76cfb3960 100644 --- a/test/components/connectors/test_openapi_connector.py +++ b/test/components/connectors/test_openapi_connector.py @@ -151,6 +151,7 @@ class TestOpenAPIConnector: assert pipeline_dict == { "metadata": {}, "max_runs_per_component": 100, + "connection_type_validation": True, "components": { "api": { "type": "haystack.components.connectors.openapi.OpenAPIConnector", diff --git a/test/components/connectors/test_openapi_service.py b/test/components/connectors/test_openapi_service.py index 8e4b8ab86..a0fe4f9e6 100644 --- a/test/components/connectors/test_openapi_service.py +++ b/test/components/connectors/test_openapi_service.py @@ -218,6 +218,7 @@ class TestOpenAPIServiceConnector: assert pipeline_dict == { "metadata": {}, "max_runs_per_component": 100, + "connection_type_validation": True, "components": { "connector": { "type": "haystack.components.connectors.openapi_service.OpenAPIServiceConnector", diff --git a/test/components/generators/chat/test_azure.py b/test/components/generators/chat/test_azure.py index 4f73febde..83ff4bf77 100644 --- a/test/components/generators/chat/test_azure.py +++ b/test/components/generators/chat/test_azure.py @@ -171,6 +171,7 @@ class TestAzureOpenAIChatGenerator: assert p.to_dict() == { "metadata": {}, "max_runs_per_component": 100, + "connection_type_validation": True, "components": { "generator": { "type": "haystack.components.generators.chat.azure.AzureOpenAIChatGenerator", diff --git a/test/components/generators/chat/test_hugging_face_api.py b/test/components/generators/chat/test_hugging_face_api.py index fd81d574d..13fe00990 100644 --- a/test/components/generators/chat/test_hugging_face_api.py +++ b/test/components/generators/chat/test_hugging_face_api.py @@ -270,6 +270,7 @@ class TestHuggingFaceAPIChatGenerator: assert pipeline_dict == { "metadata": {}, "max_runs_per_component": 100, + "connection_type_validation": True, "components": { "generator": { "type": "haystack.components.generators.chat.hugging_face_api.HuggingFaceAPIChatGenerator", diff --git a/test/components/routers/test_file_router.py b/test/components/routers/test_file_router.py index 28f99b0a5..2c1bb51b3 100644 --- a/test/components/routers/test_file_router.py +++ b/test/components/routers/test_file_router.py @@ -354,6 +354,7 @@ class TestFileTypeRouter: assert pipeline_dict == { "metadata": {}, "max_runs_per_component": 100, + "connection_type_validation": True, "components": { "file_type_router": { "type": "haystack.components.routers.file_type_router.FileTypeRouter", diff --git a/test/components/tools/test_tool_invoker.py b/test/components/tools/test_tool_invoker.py index 6c3aab0a8..4ed4126e6 100644 --- a/test/components/tools/test_tool_invoker.py +++ b/test/components/tools/test_tool_invoker.py @@ -232,6 +232,7 @@ class TestToolInvoker: assert pipeline_dict == { "metadata": {}, "max_runs_per_component": 100, + "connection_type_validation": True, "components": { "invoker": { "type": "haystack.components.tools.tool_invoker.ToolInvoker", diff --git a/test/core/pipeline/test_pipeline_base.py b/test/core/pipeline/test_pipeline_base.py index d0cfbd1fe..3a6910299 100644 --- a/test/core/pipeline/test_pipeline_base.py +++ b/test/core/pipeline/test_pipeline_base.py @@ -308,6 +308,7 @@ class TestPipelineBase: expected = { "metadata": {"test": "test"}, "max_runs_per_component": 42, + "connection_type_validation": True, "components": { "add_two": { "type": "haystack.testing.sample_components.add_value.AddFixedValue", diff --git a/test/core/pipeline/test_type_utils.py b/test/core/pipeline/test_type_utils.py index 6173e4fdf..28e000155 100644 --- a/test/core/pipeline/test_type_utils.py +++ b/test/core/pipeline/test_type_utils.py @@ -76,7 +76,7 @@ def generate_symmetric_cases(): ] -def generate_asymmetric_cases(): +def generate_strict_asymmetric_cases(): """Generate asymmetric test cases with different sender and receiver types.""" cases = [] @@ -194,6 +194,11 @@ def generate_asymmetric_cases(): (Dict[(str, Mapping[(Any, Dict[(Any, Any)])])]), id="nested-mapping-of-classes-to-nested-mapping-of-any-keys-and-values", ), + pytest.param( + (Tuple[Literal["a", "b", "c"], Union[(Path, Dict[(int, Class1)])]]), + (Tuple[Optional[Literal["a", "b", "c"]], Union[(Path, Dict[(int, Class1)])]]), + id="deeply-nested-complex-type", + ), ] ) @@ -202,7 +207,7 @@ def generate_asymmetric_cases(): # Precompute test cases for reuse symmetric_cases = generate_symmetric_cases() -asymmetric_cases = generate_asymmetric_cases() +asymmetric_cases = generate_strict_asymmetric_cases() @pytest.mark.parametrize( @@ -261,25 +266,28 @@ def test_type_name(type_, repr_): @pytest.mark.parametrize("sender_type, receiver_type", symmetric_cases) -def test_same_types_are_compatible(sender_type, receiver_type): - assert _types_are_compatible(sender_type, receiver_type) +def test_same_types_are_compatible_strict(sender_type, receiver_type): + assert _types_are_compatible(sender_type, receiver_type, "strict") @pytest.mark.parametrize("sender_type, receiver_type", asymmetric_cases) -def test_asymmetric_types_are_compatible(sender_type, receiver_type): - assert _types_are_compatible(sender_type, receiver_type) +def test_asymmetric_types_are_compatible_strict(sender_type, receiver_type): + assert _types_are_compatible(sender_type, receiver_type, "strict") @pytest.mark.parametrize("sender_type, receiver_type", asymmetric_cases) -def test_asymmetric_types_are_not_compatible(sender_type, receiver_type): - assert not _types_are_compatible(receiver_type, sender_type) +def test_asymmetric_types_are_not_compatible_strict(sender_type, receiver_type): + assert not _types_are_compatible(receiver_type, sender_type, "strict") incompatible_type_cases = [ + pytest.param(Tuple[int, str], Tuple[Any], id="tuple-of-primitive-to-tuple-of-any-different-lengths"), pytest.param(int, str, id="different-primitives"), pytest.param(Class1, Class2, id="different-classes"), pytest.param((List[int]), (List[str]), id="different-lists-of-primitives"), pytest.param((List[Class1]), (List[Class2]), id="different-lists-of-classes"), + pytest.param((Literal["a", "b", "c"]), (Literal["x", "y"]), id="different-literal-of-same-primitive"), + pytest.param((Literal[Enum1.TEST1]), (Literal[Enum1.TEST2]), id="different-literal-of-same-enum"), pytest.param( (List[Set[Sequence[str]]]), (List[Set[Sequence[bool]]]), id="nested-sequences-of-different-primitives" ), @@ -327,19 +335,11 @@ incompatible_type_cases = [ (Dict[(str, Mapping[(str, Dict[(str, Class2)])])]), id="same-nested-mappings-of-class-to-subclass-values", ), - pytest.param((Literal["a", "b", "c"]), (Literal["x", "y"]), id="different-literal-of-same-primitive"), - pytest.param((Literal[Enum1.TEST1]), (Literal[Enum1.TEST2]), id="different-literal-of-same-enum"), ] @pytest.mark.parametrize("sender_type, receiver_type", incompatible_type_cases) -def test_types_are_always_not_compatible(sender_type, receiver_type): - assert not _types_are_compatible(sender_type, receiver_type) - - -def test_deeply_nested_type_is_compatible_but_cannot_be_checked(): - sender_type = Tuple[Optional[Literal["a", "b", "c"]], Union[(Path, Dict[(int, Class1)])]] - receiver_type = Tuple[Literal["a", "b", "c"], Union[(Path, Dict[(int, Class1)])]] +def test_types_are_always_not_compatible_strict(sender_type, receiver_type): assert not _types_are_compatible(sender_type, receiver_type) @@ -350,7 +350,7 @@ def test_deeply_nested_type_is_compatible_but_cannot_be_checked(): pytest.param((Union[(int, Class1)]), (Union[(int, Class2)]), id="partially-overlapping-unions-with-classes"), ], ) -def test_partially_overlapping_unions(sender_type, receiver_type): +def test_partially_overlapping_unions_are_not_compatible_strict(sender_type, receiver_type): assert not _types_are_compatible(sender_type, receiver_type) @@ -362,4 +362,5 @@ def test_partially_overlapping_unions(sender_type, receiver_type): ], ) def test_list_of_primitive_to_list(sender_type, receiver_type): + """This currently doesn't work because we don't handle bare types without arguments.""" assert not _types_are_compatible(sender_type, receiver_type) diff --git a/test/test_files/yaml/test_pipeline.yaml b/test/test_files/yaml/test_pipeline.yaml index 53c281d30..6afabd45b 100644 --- a/test/test_files/yaml/test_pipeline.yaml +++ b/test/test_files/yaml/test_pipeline.yaml @@ -7,6 +7,7 @@ components: init_parameters: an_init_param: null type: test.core.pipeline.test_pipeline_base.FakeComponent +connection_type_validation: true connections: - receiver: Comp2.input_ sender: Comp1.value