mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-06-26 22:00:13 +00:00
feat: Add Type Validation parameter for Pipeline Connections (#8875)
* Starting to refactor type util tests to be more systematic * refactoring * Expand tests * Update to type utils * Add missing subclass check * Expand and refactor tests, introduce type_validation Literal * More test refactoring * Test refactoring, adding type validation variable to pipeline base * Update relaxed version of type checking to pass all newly added tests * trim whitespace * Add tests * cleanup * Updates docstrings * Add reno * docs * Fix mypy and add docstrings * Changes based on advice from Tobi * Remove unused imports * Doc strings * Add connection type validation to to_dict and from_dict * Update tests * Fix test * Also save connection_type_validation at global pipeline level * Fix tests * Remove connection type validation from the connect level, only keep at pipeline level * Formatting * Fix tests * formatting
This commit is contained in:
parent
00fe4d157d
commit
296e31c182
@ -68,7 +68,12 @@ class PipelineBase:
|
||||
Builds a graph of components and orchestrates their execution according to the execution graph.
|
||||
"""
|
||||
|
||||
def __init__(self, metadata: Optional[Dict[str, Any]] = None, max_runs_per_component: int = 100):
|
||||
def __init__(
|
||||
self,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
max_runs_per_component: int = 100,
|
||||
connection_type_validation: bool = True,
|
||||
):
|
||||
"""
|
||||
Creates the Pipeline.
|
||||
|
||||
@ -79,12 +84,15 @@ class PipelineBase:
|
||||
How many times the `Pipeline` can run the same Component.
|
||||
If this limit is reached a `PipelineMaxComponentRuns` exception is raised.
|
||||
If not set defaults to 100 runs per Component.
|
||||
:param connection_type_validation: Whether the pipeline will validate the types of the connections.
|
||||
Defaults to True.
|
||||
"""
|
||||
self._telemetry_runs = 0
|
||||
self._last_telemetry_sent: Optional[datetime] = None
|
||||
self.metadata = metadata or {}
|
||||
self.graph = networkx.MultiDiGraph()
|
||||
self._max_runs_per_component = max_runs_per_component
|
||||
self._connection_type_validation = connection_type_validation
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
"""
|
||||
@ -142,6 +150,7 @@ class PipelineBase:
|
||||
"max_runs_per_component": self._max_runs_per_component,
|
||||
"components": components,
|
||||
"connections": connections,
|
||||
"connection_type_validation": self._connection_type_validation,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@ -164,7 +173,12 @@ class PipelineBase:
|
||||
data_copy = deepcopy(data) # to prevent modification of original data
|
||||
metadata = data_copy.get("metadata", {})
|
||||
max_runs_per_component = data_copy.get("max_runs_per_component", 100)
|
||||
pipe = cls(metadata=metadata, max_runs_per_component=max_runs_per_component)
|
||||
connection_type_validation = data_copy.get("connection_type_validation", True)
|
||||
pipe = cls(
|
||||
metadata=metadata,
|
||||
max_runs_per_component=max_runs_per_component,
|
||||
connection_type_validation=connection_type_validation,
|
||||
)
|
||||
components_to_reuse = kwargs.get("components", {})
|
||||
for name, component_data in data_copy.get("components", {}).items():
|
||||
if name in components_to_reuse:
|
||||
@ -402,6 +416,8 @@ class PipelineBase:
|
||||
:param receiver:
|
||||
The component that receives the value. This can be either just a component name or can be
|
||||
in the format `component_name.connection_name` if the component has multiple inputs.
|
||||
:param connection_type_validation: Whether the pipeline will validate the types of the connections.
|
||||
Defaults to the value set in the pipeline.
|
||||
:returns:
|
||||
The Pipeline instance.
|
||||
|
||||
@ -418,48 +434,51 @@ class PipelineBase:
|
||||
|
||||
# Get the nodes data.
|
||||
try:
|
||||
from_sockets = self.graph.nodes[sender_component_name]["output_sockets"]
|
||||
sender_sockets = self.graph.nodes[sender_component_name]["output_sockets"]
|
||||
except KeyError as exc:
|
||||
raise ValueError(f"Component named {sender_component_name} not found in the pipeline.") from exc
|
||||
try:
|
||||
to_sockets = self.graph.nodes[receiver_component_name]["input_sockets"]
|
||||
receiver_sockets = self.graph.nodes[receiver_component_name]["input_sockets"]
|
||||
except KeyError as exc:
|
||||
raise ValueError(f"Component named {receiver_component_name} not found in the pipeline.") from exc
|
||||
|
||||
# If the name of either socket is given, get the socket
|
||||
sender_socket: Optional[OutputSocket] = None
|
||||
if sender_socket_name:
|
||||
sender_socket = from_sockets.get(sender_socket_name)
|
||||
sender_socket = sender_sockets.get(sender_socket_name)
|
||||
if not sender_socket:
|
||||
raise PipelineConnectError(
|
||||
f"'{sender} does not exist. "
|
||||
f"Output connections of {sender_component_name} are: "
|
||||
+ ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in from_sockets.items()])
|
||||
+ ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in sender_sockets.items()])
|
||||
)
|
||||
|
||||
receiver_socket: Optional[InputSocket] = None
|
||||
if receiver_socket_name:
|
||||
receiver_socket = to_sockets.get(receiver_socket_name)
|
||||
receiver_socket = receiver_sockets.get(receiver_socket_name)
|
||||
if not receiver_socket:
|
||||
raise PipelineConnectError(
|
||||
f"'{receiver} does not exist. "
|
||||
f"Input connections of {receiver_component_name} are: "
|
||||
+ ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in to_sockets.items()])
|
||||
+ ", ".join(
|
||||
[f"{name} (type {_type_name(socket.type)})" for name, socket in receiver_sockets.items()]
|
||||
)
|
||||
)
|
||||
|
||||
# Look for a matching connection among the possible ones.
|
||||
# Note that if there is more than one possible connection but two sockets match by name, they're paired.
|
||||
sender_socket_candidates: List[OutputSocket] = [sender_socket] if sender_socket else list(from_sockets.values())
|
||||
sender_socket_candidates: List[OutputSocket] = (
|
||||
[sender_socket] if sender_socket else list(sender_sockets.values())
|
||||
)
|
||||
receiver_socket_candidates: List[InputSocket] = (
|
||||
[receiver_socket] if receiver_socket else list(to_sockets.values())
|
||||
[receiver_socket] if receiver_socket else list(receiver_sockets.values())
|
||||
)
|
||||
|
||||
# Find all possible connections between these two components
|
||||
possible_connections = [
|
||||
(sender_sock, receiver_sock)
|
||||
for sender_sock, receiver_sock in itertools.product(sender_socket_candidates, receiver_socket_candidates)
|
||||
if _types_are_compatible(sender_sock.type, receiver_sock.type)
|
||||
]
|
||||
possible_connections = []
|
||||
for sender_sock, receiver_sock in itertools.product(sender_socket_candidates, receiver_socket_candidates):
|
||||
if _types_are_compatible(sender_sock.type, receiver_sock.type, self._connection_type_validation):
|
||||
possible_connections.append((sender_sock, receiver_sock))
|
||||
|
||||
# We need this status for error messages, since we might need it in multiple places we calculate it here
|
||||
status = _connections_status(
|
||||
@ -860,7 +879,7 @@ class PipelineBase:
|
||||
|
||||
def _find_receivers_from(self, component_name: str) -> List[Tuple[str, OutputSocket, InputSocket]]:
|
||||
"""
|
||||
Utility function to find all Components that receive input form `component_name`.
|
||||
Utility function to find all Components that receive input from `component_name`.
|
||||
|
||||
:param component_name:
|
||||
Name of the sender Component
|
||||
@ -1179,7 +1198,7 @@ class PipelineBase:
|
||||
|
||||
def _connections_status(
|
||||
sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]
|
||||
):
|
||||
) -> str:
|
||||
"""
|
||||
Lists the status of the sockets, for error messages.
|
||||
"""
|
||||
|
@ -2,29 +2,42 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Union, get_args, get_origin
|
||||
from typing import Any, TypeVar, Union, get_args, get_origin
|
||||
|
||||
from haystack import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_optional(type_: type) -> bool:
|
||||
"""
|
||||
Utility method that returns whether a type is Optional.
|
||||
"""
|
||||
return get_origin(type_) is Union and type(None) in get_args(type_)
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _types_are_compatible(sender, receiver): # pylint: disable=too-many-return-statements
|
||||
def _types_are_compatible(sender, receiver, type_validation: bool = True) -> bool:
|
||||
"""
|
||||
Checks whether the source type is equal or a subtype of the destination type. Used to validate pipeline connections.
|
||||
Determines if two types are compatible based on the specified validation mode.
|
||||
|
||||
:param sender: The sender type.
|
||||
:param receiver: The receiver type.
|
||||
:param type_validation: Whether to perform strict type validation.
|
||||
:return: True if the types are compatible, False otherwise.
|
||||
"""
|
||||
if type_validation:
|
||||
return _strict_types_are_compatible(sender, receiver)
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def _strict_types_are_compatible(sender, receiver): # pylint: disable=too-many-return-statements
|
||||
"""
|
||||
Checks whether the sender type is equal to or a subtype of the receiver type under strict validation.
|
||||
|
||||
Note: this method has no pretense to perform proper type matching. It especially does not deal with aliasing of
|
||||
typing classes such as `List` or `Dict` to their runtime counterparts `list` and `dict`. It also does not deal well
|
||||
with "bare" types, so `List` is treated differently from `List[Any]`, even though they should be the same.
|
||||
|
||||
Consider simplifying the typing of your components if you observe unexpected errors during component connection.
|
||||
|
||||
:param sender: The sender type.
|
||||
:param receiver: The receiver type.
|
||||
:return: True if the sender type is strictly compatible with the receiver type, False otherwise.
|
||||
"""
|
||||
if sender == receiver or receiver is Any:
|
||||
return True
|
||||
@ -42,17 +55,19 @@ def _types_are_compatible(sender, receiver): # pylint: disable=too-many-return-
|
||||
receiver_origin = get_origin(receiver)
|
||||
|
||||
if sender_origin is not Union and receiver_origin is Union:
|
||||
return any(_types_are_compatible(sender, union_arg) for union_arg in get_args(receiver))
|
||||
return any(_strict_types_are_compatible(sender, union_arg) for union_arg in get_args(receiver))
|
||||
|
||||
if not sender_origin or not receiver_origin or sender_origin != receiver_origin:
|
||||
# Both must have origins and they must be equal
|
||||
if not (sender_origin and receiver_origin and sender_origin == receiver_origin):
|
||||
return False
|
||||
|
||||
# Compare generic type arguments
|
||||
sender_args = get_args(sender)
|
||||
receiver_args = get_args(receiver)
|
||||
if len(sender_args) > len(receiver_args):
|
||||
return False
|
||||
|
||||
return all(_types_are_compatible(*args) for args in zip(sender_args, receiver_args))
|
||||
return all(_strict_types_are_compatible(*args) for args in zip(sender_args, receiver_args))
|
||||
|
||||
|
||||
def _type_name(type_):
|
||||
|
@ -0,0 +1,5 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
We've introduced a new type_validation parameter to control type compatibility checks in pipeline connections.
|
||||
It can be set to True (default) or False which means no type checks will be done and everything is allowed.
|
@ -151,6 +151,7 @@ class TestOpenAPIConnector:
|
||||
assert pipeline_dict == {
|
||||
"metadata": {},
|
||||
"max_runs_per_component": 100,
|
||||
"connection_type_validation": True,
|
||||
"components": {
|
||||
"api": {
|
||||
"type": "haystack.components.connectors.openapi.OpenAPIConnector",
|
||||
|
@ -218,6 +218,7 @@ class TestOpenAPIServiceConnector:
|
||||
assert pipeline_dict == {
|
||||
"metadata": {},
|
||||
"max_runs_per_component": 100,
|
||||
"connection_type_validation": True,
|
||||
"components": {
|
||||
"connector": {
|
||||
"type": "haystack.components.connectors.openapi_service.OpenAPIServiceConnector",
|
||||
|
@ -171,6 +171,7 @@ class TestAzureOpenAIChatGenerator:
|
||||
assert p.to_dict() == {
|
||||
"metadata": {},
|
||||
"max_runs_per_component": 100,
|
||||
"connection_type_validation": True,
|
||||
"components": {
|
||||
"generator": {
|
||||
"type": "haystack.components.generators.chat.azure.AzureOpenAIChatGenerator",
|
||||
|
@ -270,6 +270,7 @@ class TestHuggingFaceAPIChatGenerator:
|
||||
assert pipeline_dict == {
|
||||
"metadata": {},
|
||||
"max_runs_per_component": 100,
|
||||
"connection_type_validation": True,
|
||||
"components": {
|
||||
"generator": {
|
||||
"type": "haystack.components.generators.chat.hugging_face_api.HuggingFaceAPIChatGenerator",
|
||||
|
@ -354,6 +354,7 @@ class TestFileTypeRouter:
|
||||
assert pipeline_dict == {
|
||||
"metadata": {},
|
||||
"max_runs_per_component": 100,
|
||||
"connection_type_validation": True,
|
||||
"components": {
|
||||
"file_type_router": {
|
||||
"type": "haystack.components.routers.file_type_router.FileTypeRouter",
|
||||
|
@ -232,6 +232,7 @@ class TestToolInvoker:
|
||||
assert pipeline_dict == {
|
||||
"metadata": {},
|
||||
"max_runs_per_component": 100,
|
||||
"connection_type_validation": True,
|
||||
"components": {
|
||||
"invoker": {
|
||||
"type": "haystack.components.tools.tool_invoker.ToolInvoker",
|
||||
|
@ -308,6 +308,7 @@ class TestPipelineBase:
|
||||
expected = {
|
||||
"metadata": {"test": "test"},
|
||||
"max_runs_per_component": 42,
|
||||
"connection_type_validation": True,
|
||||
"components": {
|
||||
"add_two": {
|
||||
"type": "haystack.testing.sample_components.add_value.AddFixedValue",
|
||||
|
@ -76,7 +76,7 @@ def generate_symmetric_cases():
|
||||
]
|
||||
|
||||
|
||||
def generate_asymmetric_cases():
|
||||
def generate_strict_asymmetric_cases():
|
||||
"""Generate asymmetric test cases with different sender and receiver types."""
|
||||
cases = []
|
||||
|
||||
@ -194,6 +194,11 @@ def generate_asymmetric_cases():
|
||||
(Dict[(str, Mapping[(Any, Dict[(Any, Any)])])]),
|
||||
id="nested-mapping-of-classes-to-nested-mapping-of-any-keys-and-values",
|
||||
),
|
||||
pytest.param(
|
||||
(Tuple[Literal["a", "b", "c"], Union[(Path, Dict[(int, Class1)])]]),
|
||||
(Tuple[Optional[Literal["a", "b", "c"]], Union[(Path, Dict[(int, Class1)])]]),
|
||||
id="deeply-nested-complex-type",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@ -202,7 +207,7 @@ def generate_asymmetric_cases():
|
||||
|
||||
# Precompute test cases for reuse
|
||||
symmetric_cases = generate_symmetric_cases()
|
||||
asymmetric_cases = generate_asymmetric_cases()
|
||||
asymmetric_cases = generate_strict_asymmetric_cases()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -261,25 +266,28 @@ def test_type_name(type_, repr_):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sender_type, receiver_type", symmetric_cases)
|
||||
def test_same_types_are_compatible(sender_type, receiver_type):
|
||||
assert _types_are_compatible(sender_type, receiver_type)
|
||||
def test_same_types_are_compatible_strict(sender_type, receiver_type):
|
||||
assert _types_are_compatible(sender_type, receiver_type, "strict")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sender_type, receiver_type", asymmetric_cases)
|
||||
def test_asymmetric_types_are_compatible(sender_type, receiver_type):
|
||||
assert _types_are_compatible(sender_type, receiver_type)
|
||||
def test_asymmetric_types_are_compatible_strict(sender_type, receiver_type):
|
||||
assert _types_are_compatible(sender_type, receiver_type, "strict")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sender_type, receiver_type", asymmetric_cases)
|
||||
def test_asymmetric_types_are_not_compatible(sender_type, receiver_type):
|
||||
assert not _types_are_compatible(receiver_type, sender_type)
|
||||
def test_asymmetric_types_are_not_compatible_strict(sender_type, receiver_type):
|
||||
assert not _types_are_compatible(receiver_type, sender_type, "strict")
|
||||
|
||||
|
||||
incompatible_type_cases = [
|
||||
pytest.param(Tuple[int, str], Tuple[Any], id="tuple-of-primitive-to-tuple-of-any-different-lengths"),
|
||||
pytest.param(int, str, id="different-primitives"),
|
||||
pytest.param(Class1, Class2, id="different-classes"),
|
||||
pytest.param((List[int]), (List[str]), id="different-lists-of-primitives"),
|
||||
pytest.param((List[Class1]), (List[Class2]), id="different-lists-of-classes"),
|
||||
pytest.param((Literal["a", "b", "c"]), (Literal["x", "y"]), id="different-literal-of-same-primitive"),
|
||||
pytest.param((Literal[Enum1.TEST1]), (Literal[Enum1.TEST2]), id="different-literal-of-same-enum"),
|
||||
pytest.param(
|
||||
(List[Set[Sequence[str]]]), (List[Set[Sequence[bool]]]), id="nested-sequences-of-different-primitives"
|
||||
),
|
||||
@ -327,19 +335,11 @@ incompatible_type_cases = [
|
||||
(Dict[(str, Mapping[(str, Dict[(str, Class2)])])]),
|
||||
id="same-nested-mappings-of-class-to-subclass-values",
|
||||
),
|
||||
pytest.param((Literal["a", "b", "c"]), (Literal["x", "y"]), id="different-literal-of-same-primitive"),
|
||||
pytest.param((Literal[Enum1.TEST1]), (Literal[Enum1.TEST2]), id="different-literal-of-same-enum"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sender_type, receiver_type", incompatible_type_cases)
|
||||
def test_types_are_always_not_compatible(sender_type, receiver_type):
|
||||
assert not _types_are_compatible(sender_type, receiver_type)
|
||||
|
||||
|
||||
def test_deeply_nested_type_is_compatible_but_cannot_be_checked():
|
||||
sender_type = Tuple[Optional[Literal["a", "b", "c"]], Union[(Path, Dict[(int, Class1)])]]
|
||||
receiver_type = Tuple[Literal["a", "b", "c"], Union[(Path, Dict[(int, Class1)])]]
|
||||
def test_types_are_always_not_compatible_strict(sender_type, receiver_type):
|
||||
assert not _types_are_compatible(sender_type, receiver_type)
|
||||
|
||||
|
||||
@ -350,7 +350,7 @@ def test_deeply_nested_type_is_compatible_but_cannot_be_checked():
|
||||
pytest.param((Union[(int, Class1)]), (Union[(int, Class2)]), id="partially-overlapping-unions-with-classes"),
|
||||
],
|
||||
)
|
||||
def test_partially_overlapping_unions(sender_type, receiver_type):
|
||||
def test_partially_overlapping_unions_are_not_compatible_strict(sender_type, receiver_type):
|
||||
assert not _types_are_compatible(sender_type, receiver_type)
|
||||
|
||||
|
||||
@ -362,4 +362,5 @@ def test_partially_overlapping_unions(sender_type, receiver_type):
|
||||
],
|
||||
)
|
||||
def test_list_of_primitive_to_list(sender_type, receiver_type):
|
||||
"""This currently doesn't work because we don't handle bare types without arguments."""
|
||||
assert not _types_are_compatible(sender_type, receiver_type)
|
||||
|
@ -7,6 +7,7 @@ components:
|
||||
init_parameters:
|
||||
an_init_param: null
|
||||
type: test.core.pipeline.test_pipeline_base.FakeComponent
|
||||
connection_type_validation: true
|
||||
connections:
|
||||
- receiver: Comp2.input_
|
||||
sender: Comp1.value
|
||||
|
Loading…
x
Reference in New Issue
Block a user