feat: Change Component's I/O dunder type (#6916)

* Add Pipeline.get_component_name() method

* Add utility class to ease discoverability of Component I/O

* Move InputOutput in component package

* Rename InputOutput to _InputOutput

* Raise if inputs or outputs field already exist

* Fix tests

* Add release notes

* Move InputSocket and OutputSocket in types package

* Move _InputOutput in socket package

* Rename _InputOutput class to Sockets

* Simplify Sockets class

* Dictch I/O dunder fields in favour of inputs and outputs fields

* Update Sockets docstrings

* Update release notes

* Fix mypy

* Remove unnecessary assignment

* Remove unused logging

* Change SocketsType to SocketsIOType to avoid confusion

* Change sockets type and name

* Change Sockets.__repr__ to return component instance

* Fix linting

* Fix sockets tests

* Revert to dunder fields for Component IO

* Use singular in IO dunder fields

* Delete release notes

* Update haystack/core/component/types.py

Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>

---------

Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
This commit is contained in:
Silvano Cerza 2024-02-05 17:46:45 +01:00 committed by GitHub
parent 3bd6ba93ca
commit 0191b1e6e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 282 additions and 95 deletions

View File

@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from haystack.core.component.component import component, Component from haystack.core.component.component import Component, component
from haystack.core.component.sockets import InputSocket, OutputSocket from haystack.core.component.types import InputSocket, OutputSocket
__all__ = ["component", "Component", "InputSocket", "OutputSocket"] __all__ = ["component", "Component", "InputSocket", "OutputSocket"]

View File

@ -74,9 +74,11 @@ from copy import deepcopy
from types import new_class from types import new_class
from typing import Any, Protocol, runtime_checkable from typing import Any, Protocol, runtime_checkable
from haystack.core.component.sockets import InputSocket, OutputSocket, _empty
from haystack.core.errors import ComponentError from haystack.core.errors import ComponentError
from .sockets import Sockets
from .types import InputSocket, OutputSocket, _empty
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -131,12 +133,14 @@ class ComponentMeta(type):
# that stores the output specification. # that stores the output specification.
# We deepcopy the content of the cache to transfer ownership from the class method # We deepcopy the content of the cache to transfer ownership from the class method
# to the actual instance, so that different instances of the same class won't share this data. # to the actual instance, so that different instances of the same class won't share this data.
instance.__haystack_output__ = deepcopy(getattr(instance.run, "_output_types_cache", {})) instance.__haystack_output__ = Sockets(
instance, deepcopy(getattr(instance.run, "_output_types_cache", {})), OutputSocket
)
# Create the sockets if set_input_types() wasn't called in the constructor. # Create the sockets if set_input_types() wasn't called in the constructor.
# If it was called and there are some parameters also in the `run()` method, these take precedence. # If it was called and there are some parameters also in the `run()` method, these take precedence.
if not hasattr(instance, "__haystack_input__"): if not hasattr(instance, "__haystack_input__"):
instance.__haystack_input__ = {} instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
run_signature = inspect.signature(getattr(cls, "run")) run_signature = inspect.signature(getattr(cls, "run"))
for param in list(run_signature.parameters)[1:]: # First is 'self' and it doesn't matter. for param in list(run_signature.parameters)[1:]: # First is 'self' and it doesn't matter.
if run_signature.parameters[param].kind not in ( if run_signature.parameters[param].kind not in (
@ -185,7 +189,7 @@ class _Component:
:param default: default value of the input socket, defaults to _empty :param default: default value of the input socket, defaults to _empty
""" """
if not hasattr(instance, "__haystack_input__"): if not hasattr(instance, "__haystack_input__"):
instance.__haystack_input__ = {} instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
instance.__haystack_input__[name] = InputSocket(name=name, type=type, default_value=default) instance.__haystack_input__[name] = InputSocket(name=name, type=type, default_value=default)
def set_input_types(self, instance, **types): def set_input_types(self, instance, **types):
@ -229,7 +233,9 @@ class _Component:
parameter mandatory as specified in `set_input_types`. parameter mandatory as specified in `set_input_types`.
""" """
instance.__haystack_input__ = {name: InputSocket(name=name, type=type_) for name, type_ in types.items()} instance.__haystack_input__ = Sockets(
instance, {name: InputSocket(name=name, type=type_) for name, type_ in types.items()}, InputSocket
)
def set_output_types(self, instance, **types): def set_output_types(self, instance, **types):
""" """
@ -251,7 +257,9 @@ class _Component:
return {"output_1": 1, "output_2": "2"} return {"output_1": 1, "output_2": "2"}
``` ```
""" """
instance.__haystack_output__ = {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()} instance.__haystack_output__ = Sockets(
instance, {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()}, OutputSocket
)
def output_types(self, **types): def output_types(self, **types):
""" """

View File

@ -1,57 +1,107 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import logging
from dataclasses import dataclass, field
from typing import Any, List, Type, get_args
from haystack.core.component.types import HAYSTACK_VARIADIC_ANNOTATION import logging
from typing import Dict, Type, Union
from haystack.core.type_utils import _type_name
from .types import InputSocket, OutputSocket
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SocketsDict = Dict[str, Union[InputSocket, OutputSocket]]
class _empty: SocketsIOType = Union[Type[InputSocket], Type[OutputSocket]]
"""Custom object for marking InputSocket.default_value as not set."""
@dataclass class Sockets:
class InputSocket: """
name: str This class is used to represent the inputs or outputs of a `Component`.
type: Type Depending on the type passed to the constructor, it will represent either the inputs or the outputs of
default_value: Any = _empty the `Component`.
is_variadic: bool = field(init=False)
senders: List[str] = field(default_factory=list)
@property Usage:
def is_mandatory(self): ```python
return self.default_value == _empty from haystack.components.builders.prompt_builder import PromptBuilder
def __post_init__(self): prompt_template = \"""
Given these documents, answer the question.\nDocuments:
{% for doc in documents %}
{{ doc.content }}
{% endfor %}
\nQuestion: {{question}}
\nAnswer:
\"""
prompt_builder = PromptBuilder(template=prompt_template)
sockets = {"question": InputSocket("question", Any), "documents": InputSocket("documents", Any)}
inputs = Sockets(component=prompt_builder, sockets=sockets, sockets_type=InputSocket)
inputs
>>> PromptBuilder inputs:
>>> - question: Any
>>> - documents: Any
inputs.question
>>> InputSocket(name='question', type=typing.Any, default_value=<class 'haystack.core.component.types._empty'>, is_variadic=False, senders=[])
```
"""
# We're using a forward declaration here to avoid a circular import.
def __init__(
self,
component: "Component", # type: ignore[name-defined] # noqa: F821
sockets_dict: SocketsDict,
sockets_io_type: SocketsIOType,
):
"""
Create a new Sockets object.
We don't do any enforcement on the types of the sockets here, the `sockets_type` is only used for
the `__repr__` method.
We could do without it and use the type of a random value in the `sockets` dict, but that wouldn't
work for components that have no sockets at all. Either input or output.
"""
self._sockets_io_type = sockets_io_type
self._component = component
self._sockets_dict = sockets_dict
self.__dict__.update(sockets_dict)
def __setitem__(self, key: str, socket: Union[InputSocket, OutputSocket]):
"""
Adds a new socket to this Sockets object.
This eases a bit updating the list of sockets after Sockets has been created.
That should happen only in the `component` decorator.
"""
self._sockets_dict[key] = socket
self.__dict__[key] = socket
def _component_name(self) -> str:
if pipeline := getattr(self._component, "__haystack_added_to_pipeline__"):
# This Component has been added in a Pipeline, let's get the name from there.
return pipeline.get_component_name(self._component)
# This Component has not been added to a Pipeline yet, so we can't know its name.
# Let's use the class name instead.
return str(self._component)
def __getattribute__(self, name):
try: try:
# __metadata__ is a tuple sockets = object.__getattribute__(self, "_sockets")
self.is_variadic = self.type.__metadata__[0] == HAYSTACK_VARIADIC_ANNOTATION if name in sockets:
return sockets[name]
except AttributeError: except AttributeError:
self.is_variadic = False pass
if self.is_variadic:
# We need to "unpack" the type inside the Variadic annotation,
# otherwise the pipeline connection api will try to match
# `Annotated[type, CANALS_VARIADIC_ANNOTATION]`.
#
# Note1: Variadic is expressed as an annotation of one single type,
# so the return value of get_args will always be a one-item tuple.
#
# Note2: a pipeline always passes a list of items when a component
# input is declared as Variadic, so the type itself always wraps
# an iterable of the declared type. For example, Variadic[int]
# is eventually an alias for Iterable[int]. Since we're interested
# in getting the inner type `int`, we call `get_args` twice: the
# first time to get `List[int]` out of `Variadic`, the second time
# to get `int` out of `List[int]`.
self.type = get_args(get_args(self.type)[0])[0]
return object.__getattribute__(self, name)
@dataclass def __repr__(self) -> str:
class OutputSocket: result = self._component_name()
name: str if self._sockets_io_type == InputSocket:
type: type result += " inputs:\n"
receivers: List[str] = field(default_factory=list) elif self._sockets_io_type == OutputSocket:
result += " outputs:\n"
result += "\n".join([f" - {n}: {_type_name(s.type)}" for n, s in self._sockets_dict.items()])
return result

View File

@ -1,4 +1,5 @@
from typing import Iterable, TypeVar from dataclasses import dataclass, field
from typing import Any, Iterable, List, Type, TypeVar, get_args
from typing_extensions import Annotated, TypeAlias # Python 3.8 compatibility from typing_extensions import Annotated, TypeAlias # Python 3.8 compatibility
@ -13,3 +14,50 @@ T = TypeVar("T")
# type so it can be used in the `InputSocket` creation where we # type so it can be used in the `InputSocket` creation where we
# check that its annotation equals to CANALS_VARIADIC_ANNOTATION # check that its annotation equals to CANALS_VARIADIC_ANNOTATION
Variadic: TypeAlias = Annotated[Iterable[T], HAYSTACK_VARIADIC_ANNOTATION] Variadic: TypeAlias = Annotated[Iterable[T], HAYSTACK_VARIADIC_ANNOTATION]
class _empty:
"""Custom object for marking InputSocket.default_value as not set."""
@dataclass
class InputSocket:
name: str
type: Type
default_value: Any = _empty
is_variadic: bool = field(init=False)
senders: List[str] = field(default_factory=list)
@property
def is_mandatory(self):
return self.default_value == _empty
def __post_init__(self):
try:
# __metadata__ is a tuple
self.is_variadic = self.type.__metadata__[0] == HAYSTACK_VARIADIC_ANNOTATION
except AttributeError:
self.is_variadic = False
if self.is_variadic:
# We need to "unpack" the type inside the Variadic annotation,
# otherwise the pipeline connection api will try to match
# `Annotated[type, HAYSTACK_VARIADIC_ANNOTATION]`.
#
# Note1: Variadic is expressed as an annotation of one single type,
# so the return value of get_args will always be a one-item tuple.
#
# Note2: a pipeline always passes a list of items when a component
# input is declared as Variadic, so the type itself always wraps
# an iterable of the declared type. For example, Variadic[int]
# is eventually an alias for Iterable[int]. Since we're interested
# in getting the inner type `int`, we call `get_args` twice: the
# first time to get `List[int]` out of `Variadic`, the second time
# to get `int` out of `List[int]`.
self.type = get_args(get_args(self.type)[0])[0]
@dataclass
class OutputSocket:
name: str
type: type
receivers: List[str] = field(default_factory=list)

View File

@ -1,14 +1,13 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List, Dict
import logging import logging
from typing import Dict, List
import networkx # type:ignore import networkx # type:ignore
from haystack.core.component.types import InputSocket, OutputSocket
from haystack.core.type_utils import _type_name from haystack.core.type_utils import _type_name
from haystack.core.component.sockets import InputSocket, OutputSocket
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -197,16 +197,17 @@ class Pipeline:
) )
raise PipelineError(msg) raise PipelineError(msg)
# Create the component's input and output sockets
input_sockets = getattr(instance, "__haystack_input__", {})
output_sockets = getattr(instance, "__haystack_output__", {})
setattr(instance, "__haystack_added_to_pipeline__", self) setattr(instance, "__haystack_added_to_pipeline__", self)
# Add component to the graph, disconnected # Add component to the graph, disconnected
logger.debug("Adding component '%s' (%s)", name, instance) logger.debug("Adding component '%s' (%s)", name, instance)
# We're completely sure the fields exist so we ignore the type error
self.graph.add_node( self.graph.add_node(
name, instance=instance, input_sockets=input_sockets, output_sockets=output_sockets, visits=0 name,
instance=instance,
input_sockets=instance.__haystack_input__._sockets_dict, # type: ignore[attr-defined]
output_sockets=instance.__haystack_output__._sockets_dict, # type: ignore[attr-defined]
visits=0,
) )
def connect(self, connect_from: str, connect_to: str) -> None: def connect(self, connect_from: str, connect_to: str) -> None:
@ -381,6 +382,16 @@ class Pipeline:
except KeyError as exc: except KeyError as exc:
raise ValueError(f"Component named {name} not found in the pipeline.") from exc raise ValueError(f"Component named {name} not found in the pipeline.") from exc
def get_component_name(self, instance: Component) -> str:
"""
Returns the name of a Component instance. If the Component has not been added to this Pipeline,
returns an empty string.
"""
for name, inst in self.graph.nodes(data="instance"):
if inst == instance:
return name
return ""
def inputs(self) -> Dict[str, Dict[str, Any]]: def inputs(self) -> Dict[str, Dict[str, Any]]:
""" """
Returns a dictionary containing the inputs of a pipeline. Each key in the dictionary Returns a dictionary containing the inputs of a pipeline. Each key in the dictionary
@ -465,16 +476,16 @@ class Pipeline:
if component_name not in self.graph.nodes: if component_name not in self.graph.nodes:
raise ValueError(f"Component named {component_name} not found in the pipeline.") raise ValueError(f"Component named {component_name} not found in the pipeline.")
instance = self.graph.nodes[component_name]["instance"] instance = self.graph.nodes[component_name]["instance"]
for socket_name, socket in instance.__haystack_input__.items(): for socket_name, socket in instance.__haystack_input__._sockets_dict.items():
if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs: if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs:
raise ValueError(f"Missing input for component {component_name}: {socket_name}") raise ValueError(f"Missing input for component {component_name}: {socket_name}")
for input_name in component_inputs.keys(): for input_name in component_inputs.keys():
if input_name not in instance.__haystack_input__: if input_name not in instance.__haystack_input__._sockets_dict:
raise ValueError(f"Input {input_name} not found in component {component_name}.") raise ValueError(f"Input {input_name} not found in component {component_name}.")
for component_name in self.graph.nodes: for component_name in self.graph.nodes:
instance = self.graph.nodes[component_name]["instance"] instance = self.graph.nodes[component_name]["instance"]
for socket_name, socket in instance.__haystack_input__.items(): for socket_name, socket in instance.__haystack_input__._sockets_dict.items():
component_inputs = data.get(component_name, {}) component_inputs = data.get(component_name, {})
if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs: if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs:
raise ValueError(f"Missing input for component {component_name}: {socket_name}") raise ValueError(f"Missing input for component {component_name}: {socket_name}")
@ -518,7 +529,7 @@ class Pipeline:
for component_input, input_value in component_inputs.items(): for component_input, input_value in component_inputs.items():
# Handle mutable input data # Handle mutable input data
data[component_name][component_input] = copy(input_value) data[component_name][component_input] = copy(input_value)
if instance.__haystack_input__[component_input].is_variadic: if instance.__haystack_input__._sockets_dict[component_input].is_variadic:
# Components that have variadic inputs need to receive lists as input. # Components that have variadic inputs need to receive lists as input.
# We don't want to force the user to always pass lists, so we convert single values to lists here. # We don't want to force the user to always pass lists, so we convert single values to lists here.
# If it's already a list we assume the component takes a variadic input of lists, so we # If it's already a list we assume the component takes a variadic input of lists, so we
@ -533,12 +544,12 @@ class Pipeline:
for node_name in self.graph.nodes: for node_name in self.graph.nodes:
component = self.graph.nodes[node_name]["instance"] component = self.graph.nodes[node_name]["instance"]
if len(component.__haystack_input__) == 0: if len(component.__haystack_input__._sockets_dict) == 0:
# Component has no input, can run right away # Component has no input, can run right away
to_run.append((node_name, component)) to_run.append((node_name, component))
continue continue
for socket in component.__haystack_input__.values(): for socket in component.__haystack_input__._sockets_dict.values():
if not socket.senders or socket.is_variadic: if not socket.senders or socket.is_variadic:
# Component has at least one input not connected or is variadic, can run right away. # Component has at least one input not connected or is variadic, can run right away.
to_run.append((node_name, component)) to_run.append((node_name, component))
@ -561,12 +572,12 @@ class Pipeline:
while len(to_run) > 0: while len(to_run) > 0:
name, comp = to_run.pop(0) name, comp = to_run.pop(0)
if any(socket.is_variadic for socket in comp.__haystack_input__.values()) and not getattr( # type: ignore if any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) and not getattr( # type: ignore
comp, "is_greedy", False comp, "is_greedy", False
): ):
there_are_non_variadics = False there_are_non_variadics = False
for _, other_comp in to_run: for _, other_comp in to_run:
if not any(socket.is_variadic for socket in other_comp.__haystack_input__.values()): # type: ignore if not any(socket.is_variadic for socket in other_comp.__haystack_input__._sockets_dict.values()): # type: ignore
there_are_non_variadics = True there_are_non_variadics = True
break break
@ -575,7 +586,7 @@ class Pipeline:
waiting_for_input.append((name, comp)) waiting_for_input.append((name, comp))
continue continue
if name in last_inputs and len(comp.__haystack_input__) == len(last_inputs[name]): # type: ignore if name in last_inputs and len(comp.__haystack_input__._sockets_dict) == len(last_inputs[name]): # type: ignore
# This component has all the inputs it needs to run # This component has all the inputs it needs to run
res = comp.run(**last_inputs[name]) res = comp.run(**last_inputs[name])
@ -649,7 +660,7 @@ class Pipeline:
# This is our last resort, if there's no lazy variadic waiting for input # This is our last resort, if there's no lazy variadic waiting for input
# we're stuck for real and we can't make any progress. # we're stuck for real and we can't make any progress.
for name, comp in waiting_for_input: for name, comp in waiting_for_input:
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__.values()) # type: ignore is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore
if is_variadic and not getattr(comp, "is_greedy", False): if is_variadic and not getattr(comp, "is_greedy", False):
break break
else: else:
@ -680,14 +691,14 @@ class Pipeline:
last_inputs[name] = {} last_inputs[name] = {}
# Lazy variadics must be removed only if there's nothing else to run at this stage # Lazy variadics must be removed only if there's nothing else to run at this stage
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__.values()) # type: ignore is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore
if is_variadic and not getattr(comp, "is_greedy", False): if is_variadic and not getattr(comp, "is_greedy", False):
there_are_only_lazy_variadics = True there_are_only_lazy_variadics = True
for other_name, other_comp in waiting_for_input: for other_name, other_comp in waiting_for_input:
if name == other_name: if name == other_name:
continue continue
there_are_only_lazy_variadics &= any( there_are_only_lazy_variadics &= any(
socket.is_variadic for socket in other_comp.__haystack_input__.values() # type: ignore socket.is_variadic for socket in other_comp.__haystack_input__._sockets_dict.values() # type: ignore
) and not getattr(other_comp, "is_greedy", False) ) and not getattr(other_comp, "is_greedy", False)
if not there_are_only_lazy_variadics: if not there_are_only_lazy_variadics:
@ -695,7 +706,7 @@ class Pipeline:
# Find the first component that has all the inputs it needs to run # Find the first component that has all the inputs it needs to run
has_enough_inputs = True has_enough_inputs = True
for input_socket in comp.__haystack_input__.values(): # type: ignore for input_socket in comp.__haystack_input__._sockets_dict.values(): # type: ignore
if input_socket.is_mandatory and input_socket.name not in last_inputs[name]: if input_socket.is_mandatory and input_socket.name not in last_inputs[name]:
has_enough_inputs = False has_enough_inputs = False
break break

View File

@ -9,11 +9,11 @@ from haystack.core.component import component
@component @component
class Repeat: class Repeat:
def __init__(self, outputs: List[str]): def __init__(self, outputs: List[str]):
self.outputs = outputs self._outputs = outputs
component.set_output_types(self, **{k: int for k in outputs}) component.set_output_types(self, **{k: int for k in outputs})
def run(self, value: int): def run(self, value: int):
""" """
:param value: the value to repeat. :param value: the value to repeat.
""" """
return {val: value for val in self.outputs} return {val: value for val in self._outputs}

View File

@ -15,16 +15,16 @@ class TestDynamicChatPromptBuilder:
# we have inputs that contain: prompt_source, template_variables + runtime_variables # we have inputs that contain: prompt_source, template_variables + runtime_variables
expected_keys = set(runtime_variables + ["prompt_source", "template_variables"]) expected_keys = set(runtime_variables + ["prompt_source", "template_variables"])
assert set(builder.__haystack_input__.keys()) == expected_keys assert set(builder.__haystack_input__._sockets_dict.keys()) == expected_keys
# response is always prompt regardless of chat mode # response is always prompt regardless of chat mode
assert set(builder.__haystack_output__.keys()) == {"prompt"} assert set(builder.__haystack_output__._sockets_dict.keys()) == {"prompt"}
# prompt_source is a list of ChatMessage # prompt_source is a list of ChatMessage
assert builder.__haystack_input__["prompt_source"].type == List[ChatMessage] assert builder.__haystack_input__._sockets_dict["prompt_source"].type == List[ChatMessage]
# output is always prompt, but the type is different depending on the chat mode # output is always prompt, but the type is different depending on the chat mode
assert builder.__haystack_output__["prompt"].type == List[ChatMessage] assert builder.__haystack_output__._sockets_dict["prompt"].type == List[ChatMessage]
def test_non_empty_chat_messages(self): def test_non_empty_chat_messages(self):
prompt_builder = DynamicChatPromptBuilder(runtime_variables=["documents"]) prompt_builder = DynamicChatPromptBuilder(runtime_variables=["documents"])

View File

@ -16,16 +16,16 @@ class TestDynamicPromptBuilder:
# regardless of the chat mode # regardless of the chat mode
# we have inputs that contain: prompt_source, template_variables + runtime_variables # we have inputs that contain: prompt_source, template_variables + runtime_variables
expected_keys = set(runtime_variables + ["prompt_source", "template_variables"]) expected_keys = set(runtime_variables + ["prompt_source", "template_variables"])
assert set(builder.__haystack_input__.keys()) == expected_keys assert set(builder.__haystack_input__._sockets_dict.keys()) == expected_keys
# response is always prompt regardless of chat mode # response is always prompt regardless of chat mode
assert set(builder.__haystack_output__.keys()) == {"prompt"} assert set(builder.__haystack_output__._sockets_dict.keys()) == {"prompt"}
# prompt_source is a list of ChatMessage or a string # prompt_source is a list of ChatMessage or a string
assert builder.__haystack_input__["prompt_source"].type == str assert builder.__haystack_input__._sockets_dict["prompt_source"].type == str
# output is always prompt, but the type is different depending on the chat mode # output is always prompt, but the type is different depending on the chat mode
assert builder.__haystack_output__["prompt"].type == str assert builder.__haystack_output__._sockets_dict["prompt"].type == str
def test_processing_a_simple_template_with_provided_variables(self): def test_processing_a_simple_template_with_provided_variables(self):
runtime_variables = ["var1", "var2", "var3"] runtime_variables = ["var1", "var2", "var3"]

View File

@ -86,8 +86,8 @@ class TestRouter:
router = ConditionalRouter(routes) router = ConditionalRouter(routes)
assert router.routes == routes assert router.routes == routes
assert set(router.__haystack_input__.keys()) == {"query", "streams"} assert set(router.__haystack_input__._sockets_dict.keys()) == {"query", "streams"}
assert set(router.__haystack_output__.keys()) == {"query", "streams"} assert set(router.__haystack_output__._sockets_dict.keys()) == {"query", "streams"}
def test_router_evaluate_condition_expressions(self, router): def test_router_evaluate_condition_expressions(self, router):
# first route should be selected # first route should be selected

View File

@ -1,5 +1,4 @@
import typing from typing import Any
from typing import Any, Optional
import pytest import pytest
@ -89,6 +88,7 @@ def test_missing_run():
def test_set_input_types(): def test_set_input_types():
@component
class MockComponent: class MockComponent:
def __init__(self): def __init__(self):
component.set_input_types(self, value=Any) component.set_input_types(self, value=Any)
@ -105,7 +105,7 @@ def test_set_input_types():
return {"value": 1} return {"value": 1}
comp = MockComponent() comp = MockComponent()
assert comp.__haystack_input__ == {"value": InputSocket("value", Any)} assert comp.__haystack_input__._sockets_dict == {"value": InputSocket("value", Any)}
assert comp.run() == {"value": 1} assert comp.run() == {"value": 1}
@ -126,7 +126,7 @@ def test_set_output_types():
return {"value": 1} return {"value": 1}
comp = MockComponent() comp = MockComponent()
assert comp.__haystack_output__ == {"value": OutputSocket("value", int)} assert comp.__haystack_output__._sockets_dict == {"value": OutputSocket("value", int)}
def test_output_types_decorator_with_compatible_type(): def test_output_types_decorator_with_compatible_type():
@ -144,7 +144,7 @@ def test_output_types_decorator_with_compatible_type():
return cls() return cls()
comp = MockComponent() comp = MockComponent()
assert comp.__haystack_output__ == {"value": OutputSocket("value", int)} assert comp.__haystack_output__._sockets_dict == {"value": OutputSocket("value", int)}
def test_component_decorator_set_it_as_component(): def test_component_decorator_set_it_as_component():
@ -173,8 +173,8 @@ def test_input_has_default_value():
return {"value": value} return {"value": value}
comp = MockComponent() comp = MockComponent()
assert comp.__haystack_input__["value"].default_value == 42 assert comp.__haystack_input__._sockets_dict["value"].default_value == 42
assert not comp.__haystack_input__["value"].is_mandatory assert not comp.__haystack_input__._sockets_dict["value"].is_mandatory
def test_keyword_only_args(): def test_keyword_only_args():
@ -187,5 +187,5 @@ def test_keyword_only_args():
return {"value": arg} return {"value": arg}
comp = MockComponent() comp = MockComponent()
component_inputs = {name: {"type": socket.type} for name, socket in comp.__haystack_input__.items()} component_inputs = {name: {"type": socket.type} for name, socket in comp.__haystack_input__._sockets_dict.items()}
assert component_inputs == {"arg": {"type": int}} assert component_inputs == {"arg": {"type": int}}

View File

@ -0,0 +1,57 @@
import pytest
from haystack.core.component.sockets import InputSocket, Sockets
from haystack.core.pipeline import Pipeline
from haystack.testing.factory import component_class
class TestSockets:
def test_init(self):
comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})()
sockets = {"input_1": InputSocket("input_1", int), "input_2": InputSocket("input_2", int)}
io = Sockets(component=comp, sockets_dict=sockets, sockets_io_type=InputSocket)
assert io._component == comp
assert "input_1" in io.__dict__
assert io.__dict__["input_1"] == comp.__haystack_input__._sockets_dict["input_1"]
assert "input_2" in io.__dict__
assert io.__dict__["input_2"] == comp.__haystack_input__._sockets_dict["input_2"]
def test_init_with_empty_sockets(self):
comp = component_class("SomeComponent")()
io = Sockets(component=comp, sockets_dict={}, sockets_io_type=InputSocket)
assert io._component == comp
assert io._sockets_dict == {}
def test_component_name(self):
comp = component_class("SomeComponent")()
io = Sockets(component=comp, sockets_dict={}, sockets_io_type=InputSocket)
assert io._component_name() == str(comp)
def test_component_name_added_to_pipeline(self):
comp = component_class("SomeComponent")()
pipeline = Pipeline()
pipeline.add_component("my_component", comp)
io = Sockets(component=comp, sockets_dict={}, sockets_io_type=InputSocket)
assert io._component_name() == "my_component"
def test_getattribute(self):
comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})()
io = Sockets(component=comp, sockets_dict=comp.__haystack_input__._sockets_dict, sockets_io_type=InputSocket)
assert io.input_1 == comp.__haystack_input__._sockets_dict["input_1"]
assert io.input_2 == comp.__haystack_input__._sockets_dict["input_2"]
def test_getattribute_non_existing_socket(self):
comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})()
io = Sockets(component=comp, sockets_dict=comp.__haystack_input__._sockets_dict, sockets_io_type=InputSocket)
with pytest.raises(AttributeError):
io.input_3
def test_repr(self):
comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})()
io = Sockets(component=comp, sockets_dict=comp.__haystack_input__._sockets_dict, sockets_io_type=InputSocket)
res = repr(io)
assert res == f"{comp} inputs:\n - input_1: int\n - input_2: int"

View File

@ -6,7 +6,7 @@ from typing import Optional
import pytest import pytest
from haystack.core.component.sockets import InputSocket, OutputSocket from haystack.core.component.types import InputSocket, OutputSocket
from haystack.core.errors import PipelineError, PipelineRuntimeError from haystack.core.errors import PipelineError, PipelineRuntimeError
from haystack.core.pipeline import Pipeline from haystack.core.pipeline import Pipeline
from haystack.testing.factory import component_class from haystack.testing.factory import component_class
@ -28,6 +28,21 @@ def test_add_component_to_different_pipelines():
second_pipe.add_component("some", some_component) second_pipe.add_component("some", some_component)
def test_get_component_name():
pipe = Pipeline()
some_component = component_class("Some")()
pipe.add_component("some", some_component)
assert pipe.get_component_name(some_component) == "some"
def test_get_component_name_not_added_to_pipeline():
pipe = Pipeline()
some_component = component_class("Some")()
assert pipe.get_component_name(some_component) == ""
def test_run_with_component_that_does_not_return_dict(): def test_run_with_component_that_does_not_return_dict():
BrokenComponent = component_class( BrokenComponent = component_class(
"BrokenComponent", input_types={"a": int}, output_types={"b": int}, output=1 # type:ignore "BrokenComponent", input_types={"a": int}, output_types={"b": int}, output=1 # type:ignore

View File

@ -5,8 +5,7 @@ from typing import Optional
import pytest import pytest
from haystack.core.component.sockets import InputSocket, OutputSocket from haystack.core.component.types import InputSocket, OutputSocket, Variadic
from haystack.core.component.types import Variadic
from haystack.core.errors import PipelineValidationError from haystack.core.errors import PipelineValidationError
from haystack.core.pipeline import Pipeline from haystack.core.pipeline import Pipeline
from haystack.core.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs from haystack.core.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs