mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-27 18:06:17 +00:00
feat: Change Component's I/O dunder type (#6916)
* Add Pipeline.get_component_name() method * Add utility class to ease discoverability of Component I/O * Move InputOutput in component package * Rename InputOutput to _InputOutput * Raise if inputs or outputs field already exist * Fix tests * Add release notes * Move InputSocket and OutputSocket in types package * Move _InputOutput in socket package * Rename _InputOutput class to Sockets * Simplify Sockets class * Dictch I/O dunder fields in favour of inputs and outputs fields * Update Sockets docstrings * Update release notes * Fix mypy * Remove unnecessary assignment * Remove unused logging * Change SocketsType to SocketsIOType to avoid confusion * Change sockets type and name * Change Sockets.__repr__ to return component instance * Fix linting * Fix sockets tests * Revert to dunder fields for Component IO * Use singular in IO dunder fields * Delete release notes * Update haystack/core/component/types.py Co-authored-by: Massimiliano Pippi <mpippi@gmail.com> --------- Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
This commit is contained in:
parent
3bd6ba93ca
commit
0191b1e6e4
@ -1,7 +1,7 @@
|
|||||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
from haystack.core.component.component import component, Component
|
from haystack.core.component.component import Component, component
|
||||||
from haystack.core.component.sockets import InputSocket, OutputSocket
|
from haystack.core.component.types import InputSocket, OutputSocket
|
||||||
|
|
||||||
__all__ = ["component", "Component", "InputSocket", "OutputSocket"]
|
__all__ = ["component", "Component", "InputSocket", "OutputSocket"]
|
||||||
|
@ -74,9 +74,11 @@ from copy import deepcopy
|
|||||||
from types import new_class
|
from types import new_class
|
||||||
from typing import Any, Protocol, runtime_checkable
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
from haystack.core.component.sockets import InputSocket, OutputSocket, _empty
|
|
||||||
from haystack.core.errors import ComponentError
|
from haystack.core.errors import ComponentError
|
||||||
|
|
||||||
|
from .sockets import Sockets
|
||||||
|
from .types import InputSocket, OutputSocket, _empty
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -131,12 +133,14 @@ class ComponentMeta(type):
|
|||||||
# that stores the output specification.
|
# that stores the output specification.
|
||||||
# We deepcopy the content of the cache to transfer ownership from the class method
|
# We deepcopy the content of the cache to transfer ownership from the class method
|
||||||
# to the actual instance, so that different instances of the same class won't share this data.
|
# to the actual instance, so that different instances of the same class won't share this data.
|
||||||
instance.__haystack_output__ = deepcopy(getattr(instance.run, "_output_types_cache", {}))
|
instance.__haystack_output__ = Sockets(
|
||||||
|
instance, deepcopy(getattr(instance.run, "_output_types_cache", {})), OutputSocket
|
||||||
|
)
|
||||||
|
|
||||||
# Create the sockets if set_input_types() wasn't called in the constructor.
|
# Create the sockets if set_input_types() wasn't called in the constructor.
|
||||||
# If it was called and there are some parameters also in the `run()` method, these take precedence.
|
# If it was called and there are some parameters also in the `run()` method, these take precedence.
|
||||||
if not hasattr(instance, "__haystack_input__"):
|
if not hasattr(instance, "__haystack_input__"):
|
||||||
instance.__haystack_input__ = {}
|
instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
|
||||||
run_signature = inspect.signature(getattr(cls, "run"))
|
run_signature = inspect.signature(getattr(cls, "run"))
|
||||||
for param in list(run_signature.parameters)[1:]: # First is 'self' and it doesn't matter.
|
for param in list(run_signature.parameters)[1:]: # First is 'self' and it doesn't matter.
|
||||||
if run_signature.parameters[param].kind not in (
|
if run_signature.parameters[param].kind not in (
|
||||||
@ -185,7 +189,7 @@ class _Component:
|
|||||||
:param default: default value of the input socket, defaults to _empty
|
:param default: default value of the input socket, defaults to _empty
|
||||||
"""
|
"""
|
||||||
if not hasattr(instance, "__haystack_input__"):
|
if not hasattr(instance, "__haystack_input__"):
|
||||||
instance.__haystack_input__ = {}
|
instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
|
||||||
instance.__haystack_input__[name] = InputSocket(name=name, type=type, default_value=default)
|
instance.__haystack_input__[name] = InputSocket(name=name, type=type, default_value=default)
|
||||||
|
|
||||||
def set_input_types(self, instance, **types):
|
def set_input_types(self, instance, **types):
|
||||||
@ -229,7 +233,9 @@ class _Component:
|
|||||||
parameter mandatory as specified in `set_input_types`.
|
parameter mandatory as specified in `set_input_types`.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
instance.__haystack_input__ = {name: InputSocket(name=name, type=type_) for name, type_ in types.items()}
|
instance.__haystack_input__ = Sockets(
|
||||||
|
instance, {name: InputSocket(name=name, type=type_) for name, type_ in types.items()}, InputSocket
|
||||||
|
)
|
||||||
|
|
||||||
def set_output_types(self, instance, **types):
|
def set_output_types(self, instance, **types):
|
||||||
"""
|
"""
|
||||||
@ -251,7 +257,9 @@ class _Component:
|
|||||||
return {"output_1": 1, "output_2": "2"}
|
return {"output_1": 1, "output_2": "2"}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
instance.__haystack_output__ = {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()}
|
instance.__haystack_output__ = Sockets(
|
||||||
|
instance, {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()}, OutputSocket
|
||||||
|
)
|
||||||
|
|
||||||
def output_types(self, **types):
|
def output_types(self, **types):
|
||||||
"""
|
"""
|
||||||
|
@ -1,57 +1,107 @@
|
|||||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any, List, Type, get_args
|
|
||||||
|
|
||||||
from haystack.core.component.types import HAYSTACK_VARIADIC_ANNOTATION
|
import logging
|
||||||
|
from typing import Dict, Type, Union
|
||||||
|
|
||||||
|
from haystack.core.type_utils import _type_name
|
||||||
|
|
||||||
|
from .types import InputSocket, OutputSocket
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SocketsDict = Dict[str, Union[InputSocket, OutputSocket]]
|
||||||
class _empty:
|
SocketsIOType = Union[Type[InputSocket], Type[OutputSocket]]
|
||||||
"""Custom object for marking InputSocket.default_value as not set."""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class Sockets:
|
||||||
class InputSocket:
|
"""
|
||||||
name: str
|
This class is used to represent the inputs or outputs of a `Component`.
|
||||||
type: Type
|
Depending on the type passed to the constructor, it will represent either the inputs or the outputs of
|
||||||
default_value: Any = _empty
|
the `Component`.
|
||||||
is_variadic: bool = field(init=False)
|
|
||||||
senders: List[str] = field(default_factory=list)
|
|
||||||
|
|
||||||
@property
|
Usage:
|
||||||
def is_mandatory(self):
|
```python
|
||||||
return self.default_value == _empty
|
from haystack.components.builders.prompt_builder import PromptBuilder
|
||||||
|
|
||||||
def __post_init__(self):
|
prompt_template = \"""
|
||||||
|
Given these documents, answer the question.\nDocuments:
|
||||||
|
{% for doc in documents %}
|
||||||
|
{{ doc.content }}
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
\nQuestion: {{question}}
|
||||||
|
\nAnswer:
|
||||||
|
\"""
|
||||||
|
|
||||||
|
prompt_builder = PromptBuilder(template=prompt_template)
|
||||||
|
sockets = {"question": InputSocket("question", Any), "documents": InputSocket("documents", Any)}
|
||||||
|
inputs = Sockets(component=prompt_builder, sockets=sockets, sockets_type=InputSocket)
|
||||||
|
inputs
|
||||||
|
>>> PromptBuilder inputs:
|
||||||
|
>>> - question: Any
|
||||||
|
>>> - documents: Any
|
||||||
|
|
||||||
|
inputs.question
|
||||||
|
>>> InputSocket(name='question', type=typing.Any, default_value=<class 'haystack.core.component.types._empty'>, is_variadic=False, senders=[])
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
# We're using a forward declaration here to avoid a circular import.
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
component: "Component", # type: ignore[name-defined] # noqa: F821
|
||||||
|
sockets_dict: SocketsDict,
|
||||||
|
sockets_io_type: SocketsIOType,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a new Sockets object.
|
||||||
|
We don't do any enforcement on the types of the sockets here, the `sockets_type` is only used for
|
||||||
|
the `__repr__` method.
|
||||||
|
We could do without it and use the type of a random value in the `sockets` dict, but that wouldn't
|
||||||
|
work for components that have no sockets at all. Either input or output.
|
||||||
|
"""
|
||||||
|
self._sockets_io_type = sockets_io_type
|
||||||
|
self._component = component
|
||||||
|
self._sockets_dict = sockets_dict
|
||||||
|
self.__dict__.update(sockets_dict)
|
||||||
|
|
||||||
|
def __setitem__(self, key: str, socket: Union[InputSocket, OutputSocket]):
|
||||||
|
"""
|
||||||
|
Adds a new socket to this Sockets object.
|
||||||
|
This eases a bit updating the list of sockets after Sockets has been created.
|
||||||
|
That should happen only in the `component` decorator.
|
||||||
|
"""
|
||||||
|
self._sockets_dict[key] = socket
|
||||||
|
self.__dict__[key] = socket
|
||||||
|
|
||||||
|
def _component_name(self) -> str:
|
||||||
|
if pipeline := getattr(self._component, "__haystack_added_to_pipeline__"):
|
||||||
|
# This Component has been added in a Pipeline, let's get the name from there.
|
||||||
|
return pipeline.get_component_name(self._component)
|
||||||
|
|
||||||
|
# This Component has not been added to a Pipeline yet, so we can't know its name.
|
||||||
|
# Let's use the class name instead.
|
||||||
|
return str(self._component)
|
||||||
|
|
||||||
|
def __getattribute__(self, name):
|
||||||
try:
|
try:
|
||||||
# __metadata__ is a tuple
|
sockets = object.__getattribute__(self, "_sockets")
|
||||||
self.is_variadic = self.type.__metadata__[0] == HAYSTACK_VARIADIC_ANNOTATION
|
if name in sockets:
|
||||||
|
return sockets[name]
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
self.is_variadic = False
|
pass
|
||||||
if self.is_variadic:
|
|
||||||
# We need to "unpack" the type inside the Variadic annotation,
|
|
||||||
# otherwise the pipeline connection api will try to match
|
|
||||||
# `Annotated[type, CANALS_VARIADIC_ANNOTATION]`.
|
|
||||||
#
|
|
||||||
# Note1: Variadic is expressed as an annotation of one single type,
|
|
||||||
# so the return value of get_args will always be a one-item tuple.
|
|
||||||
#
|
|
||||||
# Note2: a pipeline always passes a list of items when a component
|
|
||||||
# input is declared as Variadic, so the type itself always wraps
|
|
||||||
# an iterable of the declared type. For example, Variadic[int]
|
|
||||||
# is eventually an alias for Iterable[int]. Since we're interested
|
|
||||||
# in getting the inner type `int`, we call `get_args` twice: the
|
|
||||||
# first time to get `List[int]` out of `Variadic`, the second time
|
|
||||||
# to get `int` out of `List[int]`.
|
|
||||||
self.type = get_args(get_args(self.type)[0])[0]
|
|
||||||
|
|
||||||
|
return object.__getattribute__(self, name)
|
||||||
|
|
||||||
@dataclass
|
def __repr__(self) -> str:
|
||||||
class OutputSocket:
|
result = self._component_name()
|
||||||
name: str
|
if self._sockets_io_type == InputSocket:
|
||||||
type: type
|
result += " inputs:\n"
|
||||||
receivers: List[str] = field(default_factory=list)
|
elif self._sockets_io_type == OutputSocket:
|
||||||
|
result += " outputs:\n"
|
||||||
|
|
||||||
|
result += "\n".join([f" - {n}: {_type_name(s.type)}" for n, s in self._sockets_dict.items()])
|
||||||
|
|
||||||
|
return result
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from typing import Iterable, TypeVar
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Iterable, List, Type, TypeVar, get_args
|
||||||
|
|
||||||
from typing_extensions import Annotated, TypeAlias # Python 3.8 compatibility
|
from typing_extensions import Annotated, TypeAlias # Python 3.8 compatibility
|
||||||
|
|
||||||
@ -13,3 +14,50 @@ T = TypeVar("T")
|
|||||||
# type so it can be used in the `InputSocket` creation where we
|
# type so it can be used in the `InputSocket` creation where we
|
||||||
# check that its annotation equals to CANALS_VARIADIC_ANNOTATION
|
# check that its annotation equals to CANALS_VARIADIC_ANNOTATION
|
||||||
Variadic: TypeAlias = Annotated[Iterable[T], HAYSTACK_VARIADIC_ANNOTATION]
|
Variadic: TypeAlias = Annotated[Iterable[T], HAYSTACK_VARIADIC_ANNOTATION]
|
||||||
|
|
||||||
|
|
||||||
|
class _empty:
|
||||||
|
"""Custom object for marking InputSocket.default_value as not set."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InputSocket:
|
||||||
|
name: str
|
||||||
|
type: Type
|
||||||
|
default_value: Any = _empty
|
||||||
|
is_variadic: bool = field(init=False)
|
||||||
|
senders: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_mandatory(self):
|
||||||
|
return self.default_value == _empty
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
try:
|
||||||
|
# __metadata__ is a tuple
|
||||||
|
self.is_variadic = self.type.__metadata__[0] == HAYSTACK_VARIADIC_ANNOTATION
|
||||||
|
except AttributeError:
|
||||||
|
self.is_variadic = False
|
||||||
|
if self.is_variadic:
|
||||||
|
# We need to "unpack" the type inside the Variadic annotation,
|
||||||
|
# otherwise the pipeline connection api will try to match
|
||||||
|
# `Annotated[type, HAYSTACK_VARIADIC_ANNOTATION]`.
|
||||||
|
#
|
||||||
|
# Note1: Variadic is expressed as an annotation of one single type,
|
||||||
|
# so the return value of get_args will always be a one-item tuple.
|
||||||
|
#
|
||||||
|
# Note2: a pipeline always passes a list of items when a component
|
||||||
|
# input is declared as Variadic, so the type itself always wraps
|
||||||
|
# an iterable of the declared type. For example, Variadic[int]
|
||||||
|
# is eventually an alias for Iterable[int]. Since we're interested
|
||||||
|
# in getting the inner type `int`, we call `get_args` twice: the
|
||||||
|
# first time to get `List[int]` out of `Variadic`, the second time
|
||||||
|
# to get `int` out of `List[int]`.
|
||||||
|
self.type = get_args(get_args(self.type)[0])[0]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OutputSocket:
|
||||||
|
name: str
|
||||||
|
type: type
|
||||||
|
receivers: List[str] = field(default_factory=list)
|
||||||
|
@ -1,14 +1,13 @@
|
|||||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
from typing import List, Dict
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
import networkx # type:ignore
|
import networkx # type:ignore
|
||||||
|
|
||||||
|
from haystack.core.component.types import InputSocket, OutputSocket
|
||||||
from haystack.core.type_utils import _type_name
|
from haystack.core.type_utils import _type_name
|
||||||
from haystack.core.component.sockets import InputSocket, OutputSocket
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -197,16 +197,17 @@ class Pipeline:
|
|||||||
)
|
)
|
||||||
raise PipelineError(msg)
|
raise PipelineError(msg)
|
||||||
|
|
||||||
# Create the component's input and output sockets
|
|
||||||
input_sockets = getattr(instance, "__haystack_input__", {})
|
|
||||||
output_sockets = getattr(instance, "__haystack_output__", {})
|
|
||||||
|
|
||||||
setattr(instance, "__haystack_added_to_pipeline__", self)
|
setattr(instance, "__haystack_added_to_pipeline__", self)
|
||||||
|
|
||||||
# Add component to the graph, disconnected
|
# Add component to the graph, disconnected
|
||||||
logger.debug("Adding component '%s' (%s)", name, instance)
|
logger.debug("Adding component '%s' (%s)", name, instance)
|
||||||
|
# We're completely sure the fields exist so we ignore the type error
|
||||||
self.graph.add_node(
|
self.graph.add_node(
|
||||||
name, instance=instance, input_sockets=input_sockets, output_sockets=output_sockets, visits=0
|
name,
|
||||||
|
instance=instance,
|
||||||
|
input_sockets=instance.__haystack_input__._sockets_dict, # type: ignore[attr-defined]
|
||||||
|
output_sockets=instance.__haystack_output__._sockets_dict, # type: ignore[attr-defined]
|
||||||
|
visits=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
def connect(self, connect_from: str, connect_to: str) -> None:
|
def connect(self, connect_from: str, connect_to: str) -> None:
|
||||||
@ -381,6 +382,16 @@ class Pipeline:
|
|||||||
except KeyError as exc:
|
except KeyError as exc:
|
||||||
raise ValueError(f"Component named {name} not found in the pipeline.") from exc
|
raise ValueError(f"Component named {name} not found in the pipeline.") from exc
|
||||||
|
|
||||||
|
def get_component_name(self, instance: Component) -> str:
|
||||||
|
"""
|
||||||
|
Returns the name of a Component instance. If the Component has not been added to this Pipeline,
|
||||||
|
returns an empty string.
|
||||||
|
"""
|
||||||
|
for name, inst in self.graph.nodes(data="instance"):
|
||||||
|
if inst == instance:
|
||||||
|
return name
|
||||||
|
return ""
|
||||||
|
|
||||||
def inputs(self) -> Dict[str, Dict[str, Any]]:
|
def inputs(self) -> Dict[str, Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Returns a dictionary containing the inputs of a pipeline. Each key in the dictionary
|
Returns a dictionary containing the inputs of a pipeline. Each key in the dictionary
|
||||||
@ -465,16 +476,16 @@ class Pipeline:
|
|||||||
if component_name not in self.graph.nodes:
|
if component_name not in self.graph.nodes:
|
||||||
raise ValueError(f"Component named {component_name} not found in the pipeline.")
|
raise ValueError(f"Component named {component_name} not found in the pipeline.")
|
||||||
instance = self.graph.nodes[component_name]["instance"]
|
instance = self.graph.nodes[component_name]["instance"]
|
||||||
for socket_name, socket in instance.__haystack_input__.items():
|
for socket_name, socket in instance.__haystack_input__._sockets_dict.items():
|
||||||
if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs:
|
if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs:
|
||||||
raise ValueError(f"Missing input for component {component_name}: {socket_name}")
|
raise ValueError(f"Missing input for component {component_name}: {socket_name}")
|
||||||
for input_name in component_inputs.keys():
|
for input_name in component_inputs.keys():
|
||||||
if input_name not in instance.__haystack_input__:
|
if input_name not in instance.__haystack_input__._sockets_dict:
|
||||||
raise ValueError(f"Input {input_name} not found in component {component_name}.")
|
raise ValueError(f"Input {input_name} not found in component {component_name}.")
|
||||||
|
|
||||||
for component_name in self.graph.nodes:
|
for component_name in self.graph.nodes:
|
||||||
instance = self.graph.nodes[component_name]["instance"]
|
instance = self.graph.nodes[component_name]["instance"]
|
||||||
for socket_name, socket in instance.__haystack_input__.items():
|
for socket_name, socket in instance.__haystack_input__._sockets_dict.items():
|
||||||
component_inputs = data.get(component_name, {})
|
component_inputs = data.get(component_name, {})
|
||||||
if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs:
|
if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs:
|
||||||
raise ValueError(f"Missing input for component {component_name}: {socket_name}")
|
raise ValueError(f"Missing input for component {component_name}: {socket_name}")
|
||||||
@ -518,7 +529,7 @@ class Pipeline:
|
|||||||
for component_input, input_value in component_inputs.items():
|
for component_input, input_value in component_inputs.items():
|
||||||
# Handle mutable input data
|
# Handle mutable input data
|
||||||
data[component_name][component_input] = copy(input_value)
|
data[component_name][component_input] = copy(input_value)
|
||||||
if instance.__haystack_input__[component_input].is_variadic:
|
if instance.__haystack_input__._sockets_dict[component_input].is_variadic:
|
||||||
# Components that have variadic inputs need to receive lists as input.
|
# Components that have variadic inputs need to receive lists as input.
|
||||||
# We don't want to force the user to always pass lists, so we convert single values to lists here.
|
# We don't want to force the user to always pass lists, so we convert single values to lists here.
|
||||||
# If it's already a list we assume the component takes a variadic input of lists, so we
|
# If it's already a list we assume the component takes a variadic input of lists, so we
|
||||||
@ -533,12 +544,12 @@ class Pipeline:
|
|||||||
for node_name in self.graph.nodes:
|
for node_name in self.graph.nodes:
|
||||||
component = self.graph.nodes[node_name]["instance"]
|
component = self.graph.nodes[node_name]["instance"]
|
||||||
|
|
||||||
if len(component.__haystack_input__) == 0:
|
if len(component.__haystack_input__._sockets_dict) == 0:
|
||||||
# Component has no input, can run right away
|
# Component has no input, can run right away
|
||||||
to_run.append((node_name, component))
|
to_run.append((node_name, component))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for socket in component.__haystack_input__.values():
|
for socket in component.__haystack_input__._sockets_dict.values():
|
||||||
if not socket.senders or socket.is_variadic:
|
if not socket.senders or socket.is_variadic:
|
||||||
# Component has at least one input not connected or is variadic, can run right away.
|
# Component has at least one input not connected or is variadic, can run right away.
|
||||||
to_run.append((node_name, component))
|
to_run.append((node_name, component))
|
||||||
@ -561,12 +572,12 @@ class Pipeline:
|
|||||||
while len(to_run) > 0:
|
while len(to_run) > 0:
|
||||||
name, comp = to_run.pop(0)
|
name, comp = to_run.pop(0)
|
||||||
|
|
||||||
if any(socket.is_variadic for socket in comp.__haystack_input__.values()) and not getattr( # type: ignore
|
if any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) and not getattr( # type: ignore
|
||||||
comp, "is_greedy", False
|
comp, "is_greedy", False
|
||||||
):
|
):
|
||||||
there_are_non_variadics = False
|
there_are_non_variadics = False
|
||||||
for _, other_comp in to_run:
|
for _, other_comp in to_run:
|
||||||
if not any(socket.is_variadic for socket in other_comp.__haystack_input__.values()): # type: ignore
|
if not any(socket.is_variadic for socket in other_comp.__haystack_input__._sockets_dict.values()): # type: ignore
|
||||||
there_are_non_variadics = True
|
there_are_non_variadics = True
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -575,7 +586,7 @@ class Pipeline:
|
|||||||
waiting_for_input.append((name, comp))
|
waiting_for_input.append((name, comp))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if name in last_inputs and len(comp.__haystack_input__) == len(last_inputs[name]): # type: ignore
|
if name in last_inputs and len(comp.__haystack_input__._sockets_dict) == len(last_inputs[name]): # type: ignore
|
||||||
# This component has all the inputs it needs to run
|
# This component has all the inputs it needs to run
|
||||||
res = comp.run(**last_inputs[name])
|
res = comp.run(**last_inputs[name])
|
||||||
|
|
||||||
@ -649,7 +660,7 @@ class Pipeline:
|
|||||||
# This is our last resort, if there's no lazy variadic waiting for input
|
# This is our last resort, if there's no lazy variadic waiting for input
|
||||||
# we're stuck for real and we can't make any progress.
|
# we're stuck for real and we can't make any progress.
|
||||||
for name, comp in waiting_for_input:
|
for name, comp in waiting_for_input:
|
||||||
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__.values()) # type: ignore
|
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore
|
||||||
if is_variadic and not getattr(comp, "is_greedy", False):
|
if is_variadic and not getattr(comp, "is_greedy", False):
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
@ -680,14 +691,14 @@ class Pipeline:
|
|||||||
last_inputs[name] = {}
|
last_inputs[name] = {}
|
||||||
|
|
||||||
# Lazy variadics must be removed only if there's nothing else to run at this stage
|
# Lazy variadics must be removed only if there's nothing else to run at this stage
|
||||||
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__.values()) # type: ignore
|
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore
|
||||||
if is_variadic and not getattr(comp, "is_greedy", False):
|
if is_variadic and not getattr(comp, "is_greedy", False):
|
||||||
there_are_only_lazy_variadics = True
|
there_are_only_lazy_variadics = True
|
||||||
for other_name, other_comp in waiting_for_input:
|
for other_name, other_comp in waiting_for_input:
|
||||||
if name == other_name:
|
if name == other_name:
|
||||||
continue
|
continue
|
||||||
there_are_only_lazy_variadics &= any(
|
there_are_only_lazy_variadics &= any(
|
||||||
socket.is_variadic for socket in other_comp.__haystack_input__.values() # type: ignore
|
socket.is_variadic for socket in other_comp.__haystack_input__._sockets_dict.values() # type: ignore
|
||||||
) and not getattr(other_comp, "is_greedy", False)
|
) and not getattr(other_comp, "is_greedy", False)
|
||||||
|
|
||||||
if not there_are_only_lazy_variadics:
|
if not there_are_only_lazy_variadics:
|
||||||
@ -695,7 +706,7 @@ class Pipeline:
|
|||||||
|
|
||||||
# Find the first component that has all the inputs it needs to run
|
# Find the first component that has all the inputs it needs to run
|
||||||
has_enough_inputs = True
|
has_enough_inputs = True
|
||||||
for input_socket in comp.__haystack_input__.values(): # type: ignore
|
for input_socket in comp.__haystack_input__._sockets_dict.values(): # type: ignore
|
||||||
if input_socket.is_mandatory and input_socket.name not in last_inputs[name]:
|
if input_socket.is_mandatory and input_socket.name not in last_inputs[name]:
|
||||||
has_enough_inputs = False
|
has_enough_inputs = False
|
||||||
break
|
break
|
||||||
|
@ -9,11 +9,11 @@ from haystack.core.component import component
|
|||||||
@component
|
@component
|
||||||
class Repeat:
|
class Repeat:
|
||||||
def __init__(self, outputs: List[str]):
|
def __init__(self, outputs: List[str]):
|
||||||
self.outputs = outputs
|
self._outputs = outputs
|
||||||
component.set_output_types(self, **{k: int for k in outputs})
|
component.set_output_types(self, **{k: int for k in outputs})
|
||||||
|
|
||||||
def run(self, value: int):
|
def run(self, value: int):
|
||||||
"""
|
"""
|
||||||
:param value: the value to repeat.
|
:param value: the value to repeat.
|
||||||
"""
|
"""
|
||||||
return {val: value for val in self.outputs}
|
return {val: value for val in self._outputs}
|
||||||
|
@ -15,16 +15,16 @@ class TestDynamicChatPromptBuilder:
|
|||||||
|
|
||||||
# we have inputs that contain: prompt_source, template_variables + runtime_variables
|
# we have inputs that contain: prompt_source, template_variables + runtime_variables
|
||||||
expected_keys = set(runtime_variables + ["prompt_source", "template_variables"])
|
expected_keys = set(runtime_variables + ["prompt_source", "template_variables"])
|
||||||
assert set(builder.__haystack_input__.keys()) == expected_keys
|
assert set(builder.__haystack_input__._sockets_dict.keys()) == expected_keys
|
||||||
|
|
||||||
# response is always prompt regardless of chat mode
|
# response is always prompt regardless of chat mode
|
||||||
assert set(builder.__haystack_output__.keys()) == {"prompt"}
|
assert set(builder.__haystack_output__._sockets_dict.keys()) == {"prompt"}
|
||||||
|
|
||||||
# prompt_source is a list of ChatMessage
|
# prompt_source is a list of ChatMessage
|
||||||
assert builder.__haystack_input__["prompt_source"].type == List[ChatMessage]
|
assert builder.__haystack_input__._sockets_dict["prompt_source"].type == List[ChatMessage]
|
||||||
|
|
||||||
# output is always prompt, but the type is different depending on the chat mode
|
# output is always prompt, but the type is different depending on the chat mode
|
||||||
assert builder.__haystack_output__["prompt"].type == List[ChatMessage]
|
assert builder.__haystack_output__._sockets_dict["prompt"].type == List[ChatMessage]
|
||||||
|
|
||||||
def test_non_empty_chat_messages(self):
|
def test_non_empty_chat_messages(self):
|
||||||
prompt_builder = DynamicChatPromptBuilder(runtime_variables=["documents"])
|
prompt_builder = DynamicChatPromptBuilder(runtime_variables=["documents"])
|
||||||
|
@ -16,16 +16,16 @@ class TestDynamicPromptBuilder:
|
|||||||
# regardless of the chat mode
|
# regardless of the chat mode
|
||||||
# we have inputs that contain: prompt_source, template_variables + runtime_variables
|
# we have inputs that contain: prompt_source, template_variables + runtime_variables
|
||||||
expected_keys = set(runtime_variables + ["prompt_source", "template_variables"])
|
expected_keys = set(runtime_variables + ["prompt_source", "template_variables"])
|
||||||
assert set(builder.__haystack_input__.keys()) == expected_keys
|
assert set(builder.__haystack_input__._sockets_dict.keys()) == expected_keys
|
||||||
|
|
||||||
# response is always prompt regardless of chat mode
|
# response is always prompt regardless of chat mode
|
||||||
assert set(builder.__haystack_output__.keys()) == {"prompt"}
|
assert set(builder.__haystack_output__._sockets_dict.keys()) == {"prompt"}
|
||||||
|
|
||||||
# prompt_source is a list of ChatMessage or a string
|
# prompt_source is a list of ChatMessage or a string
|
||||||
assert builder.__haystack_input__["prompt_source"].type == str
|
assert builder.__haystack_input__._sockets_dict["prompt_source"].type == str
|
||||||
|
|
||||||
# output is always prompt, but the type is different depending on the chat mode
|
# output is always prompt, but the type is different depending on the chat mode
|
||||||
assert builder.__haystack_output__["prompt"].type == str
|
assert builder.__haystack_output__._sockets_dict["prompt"].type == str
|
||||||
|
|
||||||
def test_processing_a_simple_template_with_provided_variables(self):
|
def test_processing_a_simple_template_with_provided_variables(self):
|
||||||
runtime_variables = ["var1", "var2", "var3"]
|
runtime_variables = ["var1", "var2", "var3"]
|
||||||
|
@ -86,8 +86,8 @@ class TestRouter:
|
|||||||
router = ConditionalRouter(routes)
|
router = ConditionalRouter(routes)
|
||||||
|
|
||||||
assert router.routes == routes
|
assert router.routes == routes
|
||||||
assert set(router.__haystack_input__.keys()) == {"query", "streams"}
|
assert set(router.__haystack_input__._sockets_dict.keys()) == {"query", "streams"}
|
||||||
assert set(router.__haystack_output__.keys()) == {"query", "streams"}
|
assert set(router.__haystack_output__._sockets_dict.keys()) == {"query", "streams"}
|
||||||
|
|
||||||
def test_router_evaluate_condition_expressions(self, router):
|
def test_router_evaluate_condition_expressions(self, router):
|
||||||
# first route should be selected
|
# first route should be selected
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import typing
|
from typing import Any
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -89,6 +88,7 @@ def test_missing_run():
|
|||||||
|
|
||||||
|
|
||||||
def test_set_input_types():
|
def test_set_input_types():
|
||||||
|
@component
|
||||||
class MockComponent:
|
class MockComponent:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
component.set_input_types(self, value=Any)
|
component.set_input_types(self, value=Any)
|
||||||
@ -105,7 +105,7 @@ def test_set_input_types():
|
|||||||
return {"value": 1}
|
return {"value": 1}
|
||||||
|
|
||||||
comp = MockComponent()
|
comp = MockComponent()
|
||||||
assert comp.__haystack_input__ == {"value": InputSocket("value", Any)}
|
assert comp.__haystack_input__._sockets_dict == {"value": InputSocket("value", Any)}
|
||||||
assert comp.run() == {"value": 1}
|
assert comp.run() == {"value": 1}
|
||||||
|
|
||||||
|
|
||||||
@ -126,7 +126,7 @@ def test_set_output_types():
|
|||||||
return {"value": 1}
|
return {"value": 1}
|
||||||
|
|
||||||
comp = MockComponent()
|
comp = MockComponent()
|
||||||
assert comp.__haystack_output__ == {"value": OutputSocket("value", int)}
|
assert comp.__haystack_output__._sockets_dict == {"value": OutputSocket("value", int)}
|
||||||
|
|
||||||
|
|
||||||
def test_output_types_decorator_with_compatible_type():
|
def test_output_types_decorator_with_compatible_type():
|
||||||
@ -144,7 +144,7 @@ def test_output_types_decorator_with_compatible_type():
|
|||||||
return cls()
|
return cls()
|
||||||
|
|
||||||
comp = MockComponent()
|
comp = MockComponent()
|
||||||
assert comp.__haystack_output__ == {"value": OutputSocket("value", int)}
|
assert comp.__haystack_output__._sockets_dict == {"value": OutputSocket("value", int)}
|
||||||
|
|
||||||
|
|
||||||
def test_component_decorator_set_it_as_component():
|
def test_component_decorator_set_it_as_component():
|
||||||
@ -173,8 +173,8 @@ def test_input_has_default_value():
|
|||||||
return {"value": value}
|
return {"value": value}
|
||||||
|
|
||||||
comp = MockComponent()
|
comp = MockComponent()
|
||||||
assert comp.__haystack_input__["value"].default_value == 42
|
assert comp.__haystack_input__._sockets_dict["value"].default_value == 42
|
||||||
assert not comp.__haystack_input__["value"].is_mandatory
|
assert not comp.__haystack_input__._sockets_dict["value"].is_mandatory
|
||||||
|
|
||||||
|
|
||||||
def test_keyword_only_args():
|
def test_keyword_only_args():
|
||||||
@ -187,5 +187,5 @@ def test_keyword_only_args():
|
|||||||
return {"value": arg}
|
return {"value": arg}
|
||||||
|
|
||||||
comp = MockComponent()
|
comp = MockComponent()
|
||||||
component_inputs = {name: {"type": socket.type} for name, socket in comp.__haystack_input__.items()}
|
component_inputs = {name: {"type": socket.type} for name, socket in comp.__haystack_input__._sockets_dict.items()}
|
||||||
assert component_inputs == {"arg": {"type": int}}
|
assert component_inputs == {"arg": {"type": int}}
|
||||||
|
57
test/core/component/test_sockets.py
Normal file
57
test/core/component/test_sockets.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from haystack.core.component.sockets import InputSocket, Sockets
|
||||||
|
from haystack.core.pipeline import Pipeline
|
||||||
|
from haystack.testing.factory import component_class
|
||||||
|
|
||||||
|
|
||||||
|
class TestSockets:
|
||||||
|
def test_init(self):
|
||||||
|
comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})()
|
||||||
|
sockets = {"input_1": InputSocket("input_1", int), "input_2": InputSocket("input_2", int)}
|
||||||
|
io = Sockets(component=comp, sockets_dict=sockets, sockets_io_type=InputSocket)
|
||||||
|
assert io._component == comp
|
||||||
|
assert "input_1" in io.__dict__
|
||||||
|
assert io.__dict__["input_1"] == comp.__haystack_input__._sockets_dict["input_1"]
|
||||||
|
assert "input_2" in io.__dict__
|
||||||
|
assert io.__dict__["input_2"] == comp.__haystack_input__._sockets_dict["input_2"]
|
||||||
|
|
||||||
|
def test_init_with_empty_sockets(self):
|
||||||
|
comp = component_class("SomeComponent")()
|
||||||
|
io = Sockets(component=comp, sockets_dict={}, sockets_io_type=InputSocket)
|
||||||
|
|
||||||
|
assert io._component == comp
|
||||||
|
assert io._sockets_dict == {}
|
||||||
|
|
||||||
|
def test_component_name(self):
|
||||||
|
comp = component_class("SomeComponent")()
|
||||||
|
io = Sockets(component=comp, sockets_dict={}, sockets_io_type=InputSocket)
|
||||||
|
assert io._component_name() == str(comp)
|
||||||
|
|
||||||
|
def test_component_name_added_to_pipeline(self):
|
||||||
|
comp = component_class("SomeComponent")()
|
||||||
|
pipeline = Pipeline()
|
||||||
|
pipeline.add_component("my_component", comp)
|
||||||
|
|
||||||
|
io = Sockets(component=comp, sockets_dict={}, sockets_io_type=InputSocket)
|
||||||
|
assert io._component_name() == "my_component"
|
||||||
|
|
||||||
|
def test_getattribute(self):
|
||||||
|
comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})()
|
||||||
|
io = Sockets(component=comp, sockets_dict=comp.__haystack_input__._sockets_dict, sockets_io_type=InputSocket)
|
||||||
|
|
||||||
|
assert io.input_1 == comp.__haystack_input__._sockets_dict["input_1"]
|
||||||
|
assert io.input_2 == comp.__haystack_input__._sockets_dict["input_2"]
|
||||||
|
|
||||||
|
def test_getattribute_non_existing_socket(self):
|
||||||
|
comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})()
|
||||||
|
io = Sockets(component=comp, sockets_dict=comp.__haystack_input__._sockets_dict, sockets_io_type=InputSocket)
|
||||||
|
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
io.input_3
|
||||||
|
|
||||||
|
def test_repr(self):
|
||||||
|
comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})()
|
||||||
|
io = Sockets(component=comp, sockets_dict=comp.__haystack_input__._sockets_dict, sockets_io_type=InputSocket)
|
||||||
|
res = repr(io)
|
||||||
|
assert res == f"{comp} inputs:\n - input_1: int\n - input_2: int"
|
@ -6,7 +6,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack.core.component.sockets import InputSocket, OutputSocket
|
from haystack.core.component.types import InputSocket, OutputSocket
|
||||||
from haystack.core.errors import PipelineError, PipelineRuntimeError
|
from haystack.core.errors import PipelineError, PipelineRuntimeError
|
||||||
from haystack.core.pipeline import Pipeline
|
from haystack.core.pipeline import Pipeline
|
||||||
from haystack.testing.factory import component_class
|
from haystack.testing.factory import component_class
|
||||||
@ -28,6 +28,21 @@ def test_add_component_to_different_pipelines():
|
|||||||
second_pipe.add_component("some", some_component)
|
second_pipe.add_component("some", some_component)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_component_name():
|
||||||
|
pipe = Pipeline()
|
||||||
|
some_component = component_class("Some")()
|
||||||
|
pipe.add_component("some", some_component)
|
||||||
|
|
||||||
|
assert pipe.get_component_name(some_component) == "some"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_component_name_not_added_to_pipeline():
|
||||||
|
pipe = Pipeline()
|
||||||
|
some_component = component_class("Some")()
|
||||||
|
|
||||||
|
assert pipe.get_component_name(some_component) == ""
|
||||||
|
|
||||||
|
|
||||||
def test_run_with_component_that_does_not_return_dict():
|
def test_run_with_component_that_does_not_return_dict():
|
||||||
BrokenComponent = component_class(
|
BrokenComponent = component_class(
|
||||||
"BrokenComponent", input_types={"a": int}, output_types={"b": int}, output=1 # type:ignore
|
"BrokenComponent", input_types={"a": int}, output_types={"b": int}, output=1 # type:ignore
|
||||||
|
@ -5,8 +5,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack.core.component.sockets import InputSocket, OutputSocket
|
from haystack.core.component.types import InputSocket, OutputSocket, Variadic
|
||||||
from haystack.core.component.types import Variadic
|
|
||||||
from haystack.core.errors import PipelineValidationError
|
from haystack.core.errors import PipelineValidationError
|
||||||
from haystack.core.pipeline import Pipeline
|
from haystack.core.pipeline import Pipeline
|
||||||
from haystack.core.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs
|
from haystack.core.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs
|
||||||
|
Loading…
x
Reference in New Issue
Block a user