mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-06-26 22:00:13 +00:00
feat: Add Type Validation parameter for Pipeline Connections (#8875)
* Starting to refactor type util tests to be more systematic * refactoring * Expand tests * Update to type utils * Add missing subclass check * Expand and refactor tests, introduce type_validation Literal * More test refactoring * Test refactoring, adding type validation variable to pipeline base * Update relaxed version of type checking to pass all newly added tests * trim whitespace * Add tests * cleanup * Updates docstrings * Add reno * docs * Fix mypy and add docstrings * Changes based on advice from Tobi * Remove unused imports * Doc strings * Add connection type validation to to_dict and from_dict * Update tests * Fix test * Also save connection_type_validation at global pipeline level * Fix tests * Remove connection type validation from the connect level, only keep at pipeline level * Formatting * Fix tests * formatting
This commit is contained in:
parent
00fe4d157d
commit
296e31c182
@ -68,7 +68,7 @@ Some examples of what you can do with Haystack:
|
|||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
>
|
>
|
||||||
> Would you like to deploy and serve Haystack pipelines as REST APIs yourself? [Hayhooks](https://github.com/deepset-ai/hayhooks) provides a simple way to wrap your pipelines with custom logic and expose them via HTTP endpoints, including OpenAI-compatible chat completion endpoints and compatibility with fully-featured chat interfaces like [open-webui](https://openwebui.com/).
|
> Would you like to deploy and serve Haystack pipelines as REST APIs yourself? [Hayhooks](https://github.com/deepset-ai/hayhooks) provides a simple way to wrap your pipelines with custom logic and expose them via HTTP endpoints, including OpenAI-compatible chat completion endpoints and compatibility with fully-featured chat interfaces like [open-webui](https://openwebui.com/).
|
||||||
|
|
||||||
## 🆕 deepset Studio: Your Development Environment for Haystack
|
## 🆕 deepset Studio: Your Development Environment for Haystack
|
||||||
|
|
||||||
|
@ -68,7 +68,12 @@ class PipelineBase:
|
|||||||
Builds a graph of components and orchestrates their execution according to the execution graph.
|
Builds a graph of components and orchestrates their execution according to the execution graph.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, metadata: Optional[Dict[str, Any]] = None, max_runs_per_component: int = 100):
|
def __init__(
|
||||||
|
self,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
max_runs_per_component: int = 100,
|
||||||
|
connection_type_validation: bool = True,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Creates the Pipeline.
|
Creates the Pipeline.
|
||||||
|
|
||||||
@ -79,12 +84,15 @@ class PipelineBase:
|
|||||||
How many times the `Pipeline` can run the same Component.
|
How many times the `Pipeline` can run the same Component.
|
||||||
If this limit is reached a `PipelineMaxComponentRuns` exception is raised.
|
If this limit is reached a `PipelineMaxComponentRuns` exception is raised.
|
||||||
If not set defaults to 100 runs per Component.
|
If not set defaults to 100 runs per Component.
|
||||||
|
:param connection_type_validation: Whether the pipeline will validate the types of the connections.
|
||||||
|
Defaults to True.
|
||||||
"""
|
"""
|
||||||
self._telemetry_runs = 0
|
self._telemetry_runs = 0
|
||||||
self._last_telemetry_sent: Optional[datetime] = None
|
self._last_telemetry_sent: Optional[datetime] = None
|
||||||
self.metadata = metadata or {}
|
self.metadata = metadata or {}
|
||||||
self.graph = networkx.MultiDiGraph()
|
self.graph = networkx.MultiDiGraph()
|
||||||
self._max_runs_per_component = max_runs_per_component
|
self._max_runs_per_component = max_runs_per_component
|
||||||
|
self._connection_type_validation = connection_type_validation
|
||||||
|
|
||||||
def __eq__(self, other) -> bool:
|
def __eq__(self, other) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -142,6 +150,7 @@ class PipelineBase:
|
|||||||
"max_runs_per_component": self._max_runs_per_component,
|
"max_runs_per_component": self._max_runs_per_component,
|
||||||
"components": components,
|
"components": components,
|
||||||
"connections": connections,
|
"connections": connections,
|
||||||
|
"connection_type_validation": self._connection_type_validation,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -164,7 +173,12 @@ class PipelineBase:
|
|||||||
data_copy = deepcopy(data) # to prevent modification of original data
|
data_copy = deepcopy(data) # to prevent modification of original data
|
||||||
metadata = data_copy.get("metadata", {})
|
metadata = data_copy.get("metadata", {})
|
||||||
max_runs_per_component = data_copy.get("max_runs_per_component", 100)
|
max_runs_per_component = data_copy.get("max_runs_per_component", 100)
|
||||||
pipe = cls(metadata=metadata, max_runs_per_component=max_runs_per_component)
|
connection_type_validation = data_copy.get("connection_type_validation", True)
|
||||||
|
pipe = cls(
|
||||||
|
metadata=metadata,
|
||||||
|
max_runs_per_component=max_runs_per_component,
|
||||||
|
connection_type_validation=connection_type_validation,
|
||||||
|
)
|
||||||
components_to_reuse = kwargs.get("components", {})
|
components_to_reuse = kwargs.get("components", {})
|
||||||
for name, component_data in data_copy.get("components", {}).items():
|
for name, component_data in data_copy.get("components", {}).items():
|
||||||
if name in components_to_reuse:
|
if name in components_to_reuse:
|
||||||
@ -402,6 +416,8 @@ class PipelineBase:
|
|||||||
:param receiver:
|
:param receiver:
|
||||||
The component that receives the value. This can be either just a component name or can be
|
The component that receives the value. This can be either just a component name or can be
|
||||||
in the format `component_name.connection_name` if the component has multiple inputs.
|
in the format `component_name.connection_name` if the component has multiple inputs.
|
||||||
|
:param connection_type_validation: Whether the pipeline will validate the types of the connections.
|
||||||
|
Defaults to the value set in the pipeline.
|
||||||
:returns:
|
:returns:
|
||||||
The Pipeline instance.
|
The Pipeline instance.
|
||||||
|
|
||||||
@ -418,48 +434,51 @@ class PipelineBase:
|
|||||||
|
|
||||||
# Get the nodes data.
|
# Get the nodes data.
|
||||||
try:
|
try:
|
||||||
from_sockets = self.graph.nodes[sender_component_name]["output_sockets"]
|
sender_sockets = self.graph.nodes[sender_component_name]["output_sockets"]
|
||||||
except KeyError as exc:
|
except KeyError as exc:
|
||||||
raise ValueError(f"Component named {sender_component_name} not found in the pipeline.") from exc
|
raise ValueError(f"Component named {sender_component_name} not found in the pipeline.") from exc
|
||||||
try:
|
try:
|
||||||
to_sockets = self.graph.nodes[receiver_component_name]["input_sockets"]
|
receiver_sockets = self.graph.nodes[receiver_component_name]["input_sockets"]
|
||||||
except KeyError as exc:
|
except KeyError as exc:
|
||||||
raise ValueError(f"Component named {receiver_component_name} not found in the pipeline.") from exc
|
raise ValueError(f"Component named {receiver_component_name} not found in the pipeline.") from exc
|
||||||
|
|
||||||
# If the name of either socket is given, get the socket
|
# If the name of either socket is given, get the socket
|
||||||
sender_socket: Optional[OutputSocket] = None
|
sender_socket: Optional[OutputSocket] = None
|
||||||
if sender_socket_name:
|
if sender_socket_name:
|
||||||
sender_socket = from_sockets.get(sender_socket_name)
|
sender_socket = sender_sockets.get(sender_socket_name)
|
||||||
if not sender_socket:
|
if not sender_socket:
|
||||||
raise PipelineConnectError(
|
raise PipelineConnectError(
|
||||||
f"'{sender} does not exist. "
|
f"'{sender} does not exist. "
|
||||||
f"Output connections of {sender_component_name} are: "
|
f"Output connections of {sender_component_name} are: "
|
||||||
+ ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in from_sockets.items()])
|
+ ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in sender_sockets.items()])
|
||||||
)
|
)
|
||||||
|
|
||||||
receiver_socket: Optional[InputSocket] = None
|
receiver_socket: Optional[InputSocket] = None
|
||||||
if receiver_socket_name:
|
if receiver_socket_name:
|
||||||
receiver_socket = to_sockets.get(receiver_socket_name)
|
receiver_socket = receiver_sockets.get(receiver_socket_name)
|
||||||
if not receiver_socket:
|
if not receiver_socket:
|
||||||
raise PipelineConnectError(
|
raise PipelineConnectError(
|
||||||
f"'{receiver} does not exist. "
|
f"'{receiver} does not exist. "
|
||||||
f"Input connections of {receiver_component_name} are: "
|
f"Input connections of {receiver_component_name} are: "
|
||||||
+ ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in to_sockets.items()])
|
+ ", ".join(
|
||||||
|
[f"{name} (type {_type_name(socket.type)})" for name, socket in receiver_sockets.items()]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Look for a matching connection among the possible ones.
|
# Look for a matching connection among the possible ones.
|
||||||
# Note that if there is more than one possible connection but two sockets match by name, they're paired.
|
# Note that if there is more than one possible connection but two sockets match by name, they're paired.
|
||||||
sender_socket_candidates: List[OutputSocket] = [sender_socket] if sender_socket else list(from_sockets.values())
|
sender_socket_candidates: List[OutputSocket] = (
|
||||||
|
[sender_socket] if sender_socket else list(sender_sockets.values())
|
||||||
|
)
|
||||||
receiver_socket_candidates: List[InputSocket] = (
|
receiver_socket_candidates: List[InputSocket] = (
|
||||||
[receiver_socket] if receiver_socket else list(to_sockets.values())
|
[receiver_socket] if receiver_socket else list(receiver_sockets.values())
|
||||||
)
|
)
|
||||||
|
|
||||||
# Find all possible connections between these two components
|
# Find all possible connections between these two components
|
||||||
possible_connections = [
|
possible_connections = []
|
||||||
(sender_sock, receiver_sock)
|
for sender_sock, receiver_sock in itertools.product(sender_socket_candidates, receiver_socket_candidates):
|
||||||
for sender_sock, receiver_sock in itertools.product(sender_socket_candidates, receiver_socket_candidates)
|
if _types_are_compatible(sender_sock.type, receiver_sock.type, self._connection_type_validation):
|
||||||
if _types_are_compatible(sender_sock.type, receiver_sock.type)
|
possible_connections.append((sender_sock, receiver_sock))
|
||||||
]
|
|
||||||
|
|
||||||
# We need this status for error messages, since we might need it in multiple places we calculate it here
|
# We need this status for error messages, since we might need it in multiple places we calculate it here
|
||||||
status = _connections_status(
|
status = _connections_status(
|
||||||
@ -860,7 +879,7 @@ class PipelineBase:
|
|||||||
|
|
||||||
def _find_receivers_from(self, component_name: str) -> List[Tuple[str, OutputSocket, InputSocket]]:
|
def _find_receivers_from(self, component_name: str) -> List[Tuple[str, OutputSocket, InputSocket]]:
|
||||||
"""
|
"""
|
||||||
Utility function to find all Components that receive input form `component_name`.
|
Utility function to find all Components that receive input from `component_name`.
|
||||||
|
|
||||||
:param component_name:
|
:param component_name:
|
||||||
Name of the sender Component
|
Name of the sender Component
|
||||||
@ -1179,7 +1198,7 @@ class PipelineBase:
|
|||||||
|
|
||||||
def _connections_status(
|
def _connections_status(
|
||||||
sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]
|
sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]
|
||||||
):
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Lists the status of the sockets, for error messages.
|
Lists the status of the sockets, for error messages.
|
||||||
"""
|
"""
|
||||||
|
@ -2,29 +2,42 @@
|
|||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Any, Union, get_args, get_origin
|
from typing import Any, TypeVar, Union, get_args, get_origin
|
||||||
|
|
||||||
from haystack import logging
|
from haystack import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
def _is_optional(type_: type) -> bool:
|
|
||||||
"""
|
|
||||||
Utility method that returns whether a type is Optional.
|
|
||||||
"""
|
|
||||||
return get_origin(type_) is Union and type(None) in get_args(type_)
|
|
||||||
|
|
||||||
|
|
||||||
def _types_are_compatible(sender, receiver): # pylint: disable=too-many-return-statements
|
def _types_are_compatible(sender, receiver, type_validation: bool = True) -> bool:
|
||||||
"""
|
"""
|
||||||
Checks whether the source type is equal or a subtype of the destination type. Used to validate pipeline connections.
|
Determines if two types are compatible based on the specified validation mode.
|
||||||
|
|
||||||
|
:param sender: The sender type.
|
||||||
|
:param receiver: The receiver type.
|
||||||
|
:param type_validation: Whether to perform strict type validation.
|
||||||
|
:return: True if the types are compatible, False otherwise.
|
||||||
|
"""
|
||||||
|
if type_validation:
|
||||||
|
return _strict_types_are_compatible(sender, receiver)
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _strict_types_are_compatible(sender, receiver): # pylint: disable=too-many-return-statements
|
||||||
|
"""
|
||||||
|
Checks whether the sender type is equal to or a subtype of the receiver type under strict validation.
|
||||||
|
|
||||||
Note: this method has no pretense to perform proper type matching. It especially does not deal with aliasing of
|
Note: this method has no pretense to perform proper type matching. It especially does not deal with aliasing of
|
||||||
typing classes such as `List` or `Dict` to their runtime counterparts `list` and `dict`. It also does not deal well
|
typing classes such as `List` or `Dict` to their runtime counterparts `list` and `dict`. It also does not deal well
|
||||||
with "bare" types, so `List` is treated differently from `List[Any]`, even though they should be the same.
|
with "bare" types, so `List` is treated differently from `List[Any]`, even though they should be the same.
|
||||||
|
|
||||||
Consider simplifying the typing of your components if you observe unexpected errors during component connection.
|
Consider simplifying the typing of your components if you observe unexpected errors during component connection.
|
||||||
|
|
||||||
|
:param sender: The sender type.
|
||||||
|
:param receiver: The receiver type.
|
||||||
|
:return: True if the sender type is strictly compatible with the receiver type, False otherwise.
|
||||||
"""
|
"""
|
||||||
if sender == receiver or receiver is Any:
|
if sender == receiver or receiver is Any:
|
||||||
return True
|
return True
|
||||||
@ -42,17 +55,19 @@ def _types_are_compatible(sender, receiver): # pylint: disable=too-many-return-
|
|||||||
receiver_origin = get_origin(receiver)
|
receiver_origin = get_origin(receiver)
|
||||||
|
|
||||||
if sender_origin is not Union and receiver_origin is Union:
|
if sender_origin is not Union and receiver_origin is Union:
|
||||||
return any(_types_are_compatible(sender, union_arg) for union_arg in get_args(receiver))
|
return any(_strict_types_are_compatible(sender, union_arg) for union_arg in get_args(receiver))
|
||||||
|
|
||||||
if not sender_origin or not receiver_origin or sender_origin != receiver_origin:
|
# Both must have origins and they must be equal
|
||||||
|
if not (sender_origin and receiver_origin and sender_origin == receiver_origin):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# Compare generic type arguments
|
||||||
sender_args = get_args(sender)
|
sender_args = get_args(sender)
|
||||||
receiver_args = get_args(receiver)
|
receiver_args = get_args(receiver)
|
||||||
if len(sender_args) > len(receiver_args):
|
if len(sender_args) > len(receiver_args):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return all(_types_are_compatible(*args) for args in zip(sender_args, receiver_args))
|
return all(_strict_types_are_compatible(*args) for args in zip(sender_args, receiver_args))
|
||||||
|
|
||||||
|
|
||||||
def _type_name(type_):
|
def _type_name(type_):
|
||||||
|
@ -0,0 +1,5 @@
|
|||||||
|
---
|
||||||
|
features:
|
||||||
|
- |
|
||||||
|
We've introduced a new type_validation parameter to control type compatibility checks in pipeline connections.
|
||||||
|
It can be set to True (default) or False which means no type checks will be done and everything is allowed.
|
@ -151,6 +151,7 @@ class TestOpenAPIConnector:
|
|||||||
assert pipeline_dict == {
|
assert pipeline_dict == {
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"max_runs_per_component": 100,
|
"max_runs_per_component": 100,
|
||||||
|
"connection_type_validation": True,
|
||||||
"components": {
|
"components": {
|
||||||
"api": {
|
"api": {
|
||||||
"type": "haystack.components.connectors.openapi.OpenAPIConnector",
|
"type": "haystack.components.connectors.openapi.OpenAPIConnector",
|
||||||
|
@ -218,6 +218,7 @@ class TestOpenAPIServiceConnector:
|
|||||||
assert pipeline_dict == {
|
assert pipeline_dict == {
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"max_runs_per_component": 100,
|
"max_runs_per_component": 100,
|
||||||
|
"connection_type_validation": True,
|
||||||
"components": {
|
"components": {
|
||||||
"connector": {
|
"connector": {
|
||||||
"type": "haystack.components.connectors.openapi_service.OpenAPIServiceConnector",
|
"type": "haystack.components.connectors.openapi_service.OpenAPIServiceConnector",
|
||||||
|
@ -171,6 +171,7 @@ class TestAzureOpenAIChatGenerator:
|
|||||||
assert p.to_dict() == {
|
assert p.to_dict() == {
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"max_runs_per_component": 100,
|
"max_runs_per_component": 100,
|
||||||
|
"connection_type_validation": True,
|
||||||
"components": {
|
"components": {
|
||||||
"generator": {
|
"generator": {
|
||||||
"type": "haystack.components.generators.chat.azure.AzureOpenAIChatGenerator",
|
"type": "haystack.components.generators.chat.azure.AzureOpenAIChatGenerator",
|
||||||
|
@ -270,6 +270,7 @@ class TestHuggingFaceAPIChatGenerator:
|
|||||||
assert pipeline_dict == {
|
assert pipeline_dict == {
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"max_runs_per_component": 100,
|
"max_runs_per_component": 100,
|
||||||
|
"connection_type_validation": True,
|
||||||
"components": {
|
"components": {
|
||||||
"generator": {
|
"generator": {
|
||||||
"type": "haystack.components.generators.chat.hugging_face_api.HuggingFaceAPIChatGenerator",
|
"type": "haystack.components.generators.chat.hugging_face_api.HuggingFaceAPIChatGenerator",
|
||||||
|
@ -354,6 +354,7 @@ class TestFileTypeRouter:
|
|||||||
assert pipeline_dict == {
|
assert pipeline_dict == {
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"max_runs_per_component": 100,
|
"max_runs_per_component": 100,
|
||||||
|
"connection_type_validation": True,
|
||||||
"components": {
|
"components": {
|
||||||
"file_type_router": {
|
"file_type_router": {
|
||||||
"type": "haystack.components.routers.file_type_router.FileTypeRouter",
|
"type": "haystack.components.routers.file_type_router.FileTypeRouter",
|
||||||
|
@ -232,6 +232,7 @@ class TestToolInvoker:
|
|||||||
assert pipeline_dict == {
|
assert pipeline_dict == {
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"max_runs_per_component": 100,
|
"max_runs_per_component": 100,
|
||||||
|
"connection_type_validation": True,
|
||||||
"components": {
|
"components": {
|
||||||
"invoker": {
|
"invoker": {
|
||||||
"type": "haystack.components.tools.tool_invoker.ToolInvoker",
|
"type": "haystack.components.tools.tool_invoker.ToolInvoker",
|
||||||
|
@ -308,6 +308,7 @@ class TestPipelineBase:
|
|||||||
expected = {
|
expected = {
|
||||||
"metadata": {"test": "test"},
|
"metadata": {"test": "test"},
|
||||||
"max_runs_per_component": 42,
|
"max_runs_per_component": 42,
|
||||||
|
"connection_type_validation": True,
|
||||||
"components": {
|
"components": {
|
||||||
"add_two": {
|
"add_two": {
|
||||||
"type": "haystack.testing.sample_components.add_value.AddFixedValue",
|
"type": "haystack.testing.sample_components.add_value.AddFixedValue",
|
||||||
|
@ -76,7 +76,7 @@ def generate_symmetric_cases():
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def generate_asymmetric_cases():
|
def generate_strict_asymmetric_cases():
|
||||||
"""Generate asymmetric test cases with different sender and receiver types."""
|
"""Generate asymmetric test cases with different sender and receiver types."""
|
||||||
cases = []
|
cases = []
|
||||||
|
|
||||||
@ -194,6 +194,11 @@ def generate_asymmetric_cases():
|
|||||||
(Dict[(str, Mapping[(Any, Dict[(Any, Any)])])]),
|
(Dict[(str, Mapping[(Any, Dict[(Any, Any)])])]),
|
||||||
id="nested-mapping-of-classes-to-nested-mapping-of-any-keys-and-values",
|
id="nested-mapping-of-classes-to-nested-mapping-of-any-keys-and-values",
|
||||||
),
|
),
|
||||||
|
pytest.param(
|
||||||
|
(Tuple[Literal["a", "b", "c"], Union[(Path, Dict[(int, Class1)])]]),
|
||||||
|
(Tuple[Optional[Literal["a", "b", "c"]], Union[(Path, Dict[(int, Class1)])]]),
|
||||||
|
id="deeply-nested-complex-type",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -202,7 +207,7 @@ def generate_asymmetric_cases():
|
|||||||
|
|
||||||
# Precompute test cases for reuse
|
# Precompute test cases for reuse
|
||||||
symmetric_cases = generate_symmetric_cases()
|
symmetric_cases = generate_symmetric_cases()
|
||||||
asymmetric_cases = generate_asymmetric_cases()
|
asymmetric_cases = generate_strict_asymmetric_cases()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -261,25 +266,28 @@ def test_type_name(type_, repr_):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sender_type, receiver_type", symmetric_cases)
|
@pytest.mark.parametrize("sender_type, receiver_type", symmetric_cases)
|
||||||
def test_same_types_are_compatible(sender_type, receiver_type):
|
def test_same_types_are_compatible_strict(sender_type, receiver_type):
|
||||||
assert _types_are_compatible(sender_type, receiver_type)
|
assert _types_are_compatible(sender_type, receiver_type, "strict")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sender_type, receiver_type", asymmetric_cases)
|
@pytest.mark.parametrize("sender_type, receiver_type", asymmetric_cases)
|
||||||
def test_asymmetric_types_are_compatible(sender_type, receiver_type):
|
def test_asymmetric_types_are_compatible_strict(sender_type, receiver_type):
|
||||||
assert _types_are_compatible(sender_type, receiver_type)
|
assert _types_are_compatible(sender_type, receiver_type, "strict")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sender_type, receiver_type", asymmetric_cases)
|
@pytest.mark.parametrize("sender_type, receiver_type", asymmetric_cases)
|
||||||
def test_asymmetric_types_are_not_compatible(sender_type, receiver_type):
|
def test_asymmetric_types_are_not_compatible_strict(sender_type, receiver_type):
|
||||||
assert not _types_are_compatible(receiver_type, sender_type)
|
assert not _types_are_compatible(receiver_type, sender_type, "strict")
|
||||||
|
|
||||||
|
|
||||||
incompatible_type_cases = [
|
incompatible_type_cases = [
|
||||||
|
pytest.param(Tuple[int, str], Tuple[Any], id="tuple-of-primitive-to-tuple-of-any-different-lengths"),
|
||||||
pytest.param(int, str, id="different-primitives"),
|
pytest.param(int, str, id="different-primitives"),
|
||||||
pytest.param(Class1, Class2, id="different-classes"),
|
pytest.param(Class1, Class2, id="different-classes"),
|
||||||
pytest.param((List[int]), (List[str]), id="different-lists-of-primitives"),
|
pytest.param((List[int]), (List[str]), id="different-lists-of-primitives"),
|
||||||
pytest.param((List[Class1]), (List[Class2]), id="different-lists-of-classes"),
|
pytest.param((List[Class1]), (List[Class2]), id="different-lists-of-classes"),
|
||||||
|
pytest.param((Literal["a", "b", "c"]), (Literal["x", "y"]), id="different-literal-of-same-primitive"),
|
||||||
|
pytest.param((Literal[Enum1.TEST1]), (Literal[Enum1.TEST2]), id="different-literal-of-same-enum"),
|
||||||
pytest.param(
|
pytest.param(
|
||||||
(List[Set[Sequence[str]]]), (List[Set[Sequence[bool]]]), id="nested-sequences-of-different-primitives"
|
(List[Set[Sequence[str]]]), (List[Set[Sequence[bool]]]), id="nested-sequences-of-different-primitives"
|
||||||
),
|
),
|
||||||
@ -327,19 +335,11 @@ incompatible_type_cases = [
|
|||||||
(Dict[(str, Mapping[(str, Dict[(str, Class2)])])]),
|
(Dict[(str, Mapping[(str, Dict[(str, Class2)])])]),
|
||||||
id="same-nested-mappings-of-class-to-subclass-values",
|
id="same-nested-mappings-of-class-to-subclass-values",
|
||||||
),
|
),
|
||||||
pytest.param((Literal["a", "b", "c"]), (Literal["x", "y"]), id="different-literal-of-same-primitive"),
|
|
||||||
pytest.param((Literal[Enum1.TEST1]), (Literal[Enum1.TEST2]), id="different-literal-of-same-enum"),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sender_type, receiver_type", incompatible_type_cases)
|
@pytest.mark.parametrize("sender_type, receiver_type", incompatible_type_cases)
|
||||||
def test_types_are_always_not_compatible(sender_type, receiver_type):
|
def test_types_are_always_not_compatible_strict(sender_type, receiver_type):
|
||||||
assert not _types_are_compatible(sender_type, receiver_type)
|
|
||||||
|
|
||||||
|
|
||||||
def test_deeply_nested_type_is_compatible_but_cannot_be_checked():
|
|
||||||
sender_type = Tuple[Optional[Literal["a", "b", "c"]], Union[(Path, Dict[(int, Class1)])]]
|
|
||||||
receiver_type = Tuple[Literal["a", "b", "c"], Union[(Path, Dict[(int, Class1)])]]
|
|
||||||
assert not _types_are_compatible(sender_type, receiver_type)
|
assert not _types_are_compatible(sender_type, receiver_type)
|
||||||
|
|
||||||
|
|
||||||
@ -350,7 +350,7 @@ def test_deeply_nested_type_is_compatible_but_cannot_be_checked():
|
|||||||
pytest.param((Union[(int, Class1)]), (Union[(int, Class2)]), id="partially-overlapping-unions-with-classes"),
|
pytest.param((Union[(int, Class1)]), (Union[(int, Class2)]), id="partially-overlapping-unions-with-classes"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_partially_overlapping_unions(sender_type, receiver_type):
|
def test_partially_overlapping_unions_are_not_compatible_strict(sender_type, receiver_type):
|
||||||
assert not _types_are_compatible(sender_type, receiver_type)
|
assert not _types_are_compatible(sender_type, receiver_type)
|
||||||
|
|
||||||
|
|
||||||
@ -362,4 +362,5 @@ def test_partially_overlapping_unions(sender_type, receiver_type):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_list_of_primitive_to_list(sender_type, receiver_type):
|
def test_list_of_primitive_to_list(sender_type, receiver_type):
|
||||||
|
"""This currently doesn't work because we don't handle bare types without arguments."""
|
||||||
assert not _types_are_compatible(sender_type, receiver_type)
|
assert not _types_are_compatible(sender_type, receiver_type)
|
||||||
|
@ -7,6 +7,7 @@ components:
|
|||||||
init_parameters:
|
init_parameters:
|
||||||
an_init_param: null
|
an_init_param: null
|
||||||
type: test.core.pipeline.test_pipeline_base.FakeComponent
|
type: test.core.pipeline.test_pipeline_base.FakeComponent
|
||||||
|
connection_type_validation: true
|
||||||
connections:
|
connections:
|
||||||
- receiver: Comp2.input_
|
- receiver: Comp2.input_
|
||||||
sender: Comp1.value
|
sender: Comp1.value
|
||||||
|
Loading…
x
Reference in New Issue
Block a user