mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-27 01:46:33 +00:00
feat: Change Component's I/O dunder type (#6916)
* Add Pipeline.get_component_name() method * Add utility class to ease discoverability of Component I/O * Move InputOutput in component package * Rename InputOutput to _InputOutput * Raise if inputs or outputs field already exist * Fix tests * Add release notes * Move InputSocket and OutputSocket in types package * Move _InputOutput in socket package * Rename _InputOutput class to Sockets * Simplify Sockets class * Dictch I/O dunder fields in favour of inputs and outputs fields * Update Sockets docstrings * Update release notes * Fix mypy * Remove unnecessary assignment * Remove unused logging * Change SocketsType to SocketsIOType to avoid confusion * Change sockets type and name * Change Sockets.__repr__ to return component instance * Fix linting * Fix sockets tests * Revert to dunder fields for Component IO * Use singular in IO dunder fields * Delete release notes * Update haystack/core/component/types.py Co-authored-by: Massimiliano Pippi <mpippi@gmail.com> --------- Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
This commit is contained in:
parent
3bd6ba93ca
commit
0191b1e6e4
@ -1,7 +1,7 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from haystack.core.component.component import component, Component
|
||||
from haystack.core.component.sockets import InputSocket, OutputSocket
|
||||
from haystack.core.component.component import Component, component
|
||||
from haystack.core.component.types import InputSocket, OutputSocket
|
||||
|
||||
__all__ = ["component", "Component", "InputSocket", "OutputSocket"]
|
||||
|
@ -74,9 +74,11 @@ from copy import deepcopy
|
||||
from types import new_class
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from haystack.core.component.sockets import InputSocket, OutputSocket, _empty
|
||||
from haystack.core.errors import ComponentError
|
||||
|
||||
from .sockets import Sockets
|
||||
from .types import InputSocket, OutputSocket, _empty
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -131,12 +133,14 @@ class ComponentMeta(type):
|
||||
# that stores the output specification.
|
||||
# We deepcopy the content of the cache to transfer ownership from the class method
|
||||
# to the actual instance, so that different instances of the same class won't share this data.
|
||||
instance.__haystack_output__ = deepcopy(getattr(instance.run, "_output_types_cache", {}))
|
||||
instance.__haystack_output__ = Sockets(
|
||||
instance, deepcopy(getattr(instance.run, "_output_types_cache", {})), OutputSocket
|
||||
)
|
||||
|
||||
# Create the sockets if set_input_types() wasn't called in the constructor.
|
||||
# If it was called and there are some parameters also in the `run()` method, these take precedence.
|
||||
if not hasattr(instance, "__haystack_input__"):
|
||||
instance.__haystack_input__ = {}
|
||||
instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
|
||||
run_signature = inspect.signature(getattr(cls, "run"))
|
||||
for param in list(run_signature.parameters)[1:]: # First is 'self' and it doesn't matter.
|
||||
if run_signature.parameters[param].kind not in (
|
||||
@ -185,7 +189,7 @@ class _Component:
|
||||
:param default: default value of the input socket, defaults to _empty
|
||||
"""
|
||||
if not hasattr(instance, "__haystack_input__"):
|
||||
instance.__haystack_input__ = {}
|
||||
instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
|
||||
instance.__haystack_input__[name] = InputSocket(name=name, type=type, default_value=default)
|
||||
|
||||
def set_input_types(self, instance, **types):
|
||||
@ -229,7 +233,9 @@ class _Component:
|
||||
parameter mandatory as specified in `set_input_types`.
|
||||
|
||||
"""
|
||||
instance.__haystack_input__ = {name: InputSocket(name=name, type=type_) for name, type_ in types.items()}
|
||||
instance.__haystack_input__ = Sockets(
|
||||
instance, {name: InputSocket(name=name, type=type_) for name, type_ in types.items()}, InputSocket
|
||||
)
|
||||
|
||||
def set_output_types(self, instance, **types):
|
||||
"""
|
||||
@ -251,7 +257,9 @@ class _Component:
|
||||
return {"output_1": 1, "output_2": "2"}
|
||||
```
|
||||
"""
|
||||
instance.__haystack_output__ = {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()}
|
||||
instance.__haystack_output__ = Sockets(
|
||||
instance, {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()}, OutputSocket
|
||||
)
|
||||
|
||||
def output_types(self, **types):
|
||||
"""
|
||||
|
@ -1,57 +1,107 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, List, Type, get_args
|
||||
|
||||
from haystack.core.component.types import HAYSTACK_VARIADIC_ANNOTATION
|
||||
import logging
|
||||
from typing import Dict, Type, Union
|
||||
|
||||
from haystack.core.type_utils import _type_name
|
||||
|
||||
from .types import InputSocket, OutputSocket
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _empty:
|
||||
"""Custom object for marking InputSocket.default_value as not set."""
|
||||
SocketsDict = Dict[str, Union[InputSocket, OutputSocket]]
|
||||
SocketsIOType = Union[Type[InputSocket], Type[OutputSocket]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputSocket:
|
||||
name: str
|
||||
type: Type
|
||||
default_value: Any = _empty
|
||||
is_variadic: bool = field(init=False)
|
||||
senders: List[str] = field(default_factory=list)
|
||||
class Sockets:
|
||||
"""
|
||||
This class is used to represent the inputs or outputs of a `Component`.
|
||||
Depending on the type passed to the constructor, it will represent either the inputs or the outputs of
|
||||
the `Component`.
|
||||
|
||||
@property
|
||||
def is_mandatory(self):
|
||||
return self.default_value == _empty
|
||||
Usage:
|
||||
```python
|
||||
from haystack.components.builders.prompt_builder import PromptBuilder
|
||||
|
||||
def __post_init__(self):
|
||||
prompt_template = \"""
|
||||
Given these documents, answer the question.\nDocuments:
|
||||
{% for doc in documents %}
|
||||
{{ doc.content }}
|
||||
{% endfor %}
|
||||
|
||||
\nQuestion: {{question}}
|
||||
\nAnswer:
|
||||
\"""
|
||||
|
||||
prompt_builder = PromptBuilder(template=prompt_template)
|
||||
sockets = {"question": InputSocket("question", Any), "documents": InputSocket("documents", Any)}
|
||||
inputs = Sockets(component=prompt_builder, sockets=sockets, sockets_type=InputSocket)
|
||||
inputs
|
||||
>>> PromptBuilder inputs:
|
||||
>>> - question: Any
|
||||
>>> - documents: Any
|
||||
|
||||
inputs.question
|
||||
>>> InputSocket(name='question', type=typing.Any, default_value=<class 'haystack.core.component.types._empty'>, is_variadic=False, senders=[])
|
||||
```
|
||||
"""
|
||||
|
||||
# We're using a forward declaration here to avoid a circular import.
|
||||
def __init__(
|
||||
self,
|
||||
component: "Component", # type: ignore[name-defined] # noqa: F821
|
||||
sockets_dict: SocketsDict,
|
||||
sockets_io_type: SocketsIOType,
|
||||
):
|
||||
"""
|
||||
Create a new Sockets object.
|
||||
We don't do any enforcement on the types of the sockets here, the `sockets_type` is only used for
|
||||
the `__repr__` method.
|
||||
We could do without it and use the type of a random value in the `sockets` dict, but that wouldn't
|
||||
work for components that have no sockets at all. Either input or output.
|
||||
"""
|
||||
self._sockets_io_type = sockets_io_type
|
||||
self._component = component
|
||||
self._sockets_dict = sockets_dict
|
||||
self.__dict__.update(sockets_dict)
|
||||
|
||||
def __setitem__(self, key: str, socket: Union[InputSocket, OutputSocket]):
|
||||
"""
|
||||
Adds a new socket to this Sockets object.
|
||||
This eases a bit updating the list of sockets after Sockets has been created.
|
||||
That should happen only in the `component` decorator.
|
||||
"""
|
||||
self._sockets_dict[key] = socket
|
||||
self.__dict__[key] = socket
|
||||
|
||||
def _component_name(self) -> str:
|
||||
if pipeline := getattr(self._component, "__haystack_added_to_pipeline__"):
|
||||
# This Component has been added in a Pipeline, let's get the name from there.
|
||||
return pipeline.get_component_name(self._component)
|
||||
|
||||
# This Component has not been added to a Pipeline yet, so we can't know its name.
|
||||
# Let's use the class name instead.
|
||||
return str(self._component)
|
||||
|
||||
def __getattribute__(self, name):
|
||||
try:
|
||||
# __metadata__ is a tuple
|
||||
self.is_variadic = self.type.__metadata__[0] == HAYSTACK_VARIADIC_ANNOTATION
|
||||
sockets = object.__getattribute__(self, "_sockets")
|
||||
if name in sockets:
|
||||
return sockets[name]
|
||||
except AttributeError:
|
||||
self.is_variadic = False
|
||||
if self.is_variadic:
|
||||
# We need to "unpack" the type inside the Variadic annotation,
|
||||
# otherwise the pipeline connection api will try to match
|
||||
# `Annotated[type, CANALS_VARIADIC_ANNOTATION]`.
|
||||
#
|
||||
# Note1: Variadic is expressed as an annotation of one single type,
|
||||
# so the return value of get_args will always be a one-item tuple.
|
||||
#
|
||||
# Note2: a pipeline always passes a list of items when a component
|
||||
# input is declared as Variadic, so the type itself always wraps
|
||||
# an iterable of the declared type. For example, Variadic[int]
|
||||
# is eventually an alias for Iterable[int]. Since we're interested
|
||||
# in getting the inner type `int`, we call `get_args` twice: the
|
||||
# first time to get `List[int]` out of `Variadic`, the second time
|
||||
# to get `int` out of `List[int]`.
|
||||
self.type = get_args(get_args(self.type)[0])[0]
|
||||
pass
|
||||
|
||||
return object.__getattribute__(self, name)
|
||||
|
||||
@dataclass
|
||||
class OutputSocket:
|
||||
name: str
|
||||
type: type
|
||||
receivers: List[str] = field(default_factory=list)
|
||||
def __repr__(self) -> str:
|
||||
result = self._component_name()
|
||||
if self._sockets_io_type == InputSocket:
|
||||
result += " inputs:\n"
|
||||
elif self._sockets_io_type == OutputSocket:
|
||||
result += " outputs:\n"
|
||||
|
||||
result += "\n".join([f" - {n}: {_type_name(s.type)}" for n, s in self._sockets_dict.items()])
|
||||
|
||||
return result
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import Iterable, TypeVar
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Iterable, List, Type, TypeVar, get_args
|
||||
|
||||
from typing_extensions import Annotated, TypeAlias # Python 3.8 compatibility
|
||||
|
||||
@ -13,3 +14,50 @@ T = TypeVar("T")
|
||||
# type so it can be used in the `InputSocket` creation where we
|
||||
# check that its annotation equals to CANALS_VARIADIC_ANNOTATION
|
||||
Variadic: TypeAlias = Annotated[Iterable[T], HAYSTACK_VARIADIC_ANNOTATION]
|
||||
|
||||
|
||||
class _empty:
|
||||
"""Custom object for marking InputSocket.default_value as not set."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputSocket:
|
||||
name: str
|
||||
type: Type
|
||||
default_value: Any = _empty
|
||||
is_variadic: bool = field(init=False)
|
||||
senders: List[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def is_mandatory(self):
|
||||
return self.default_value == _empty
|
||||
|
||||
def __post_init__(self):
|
||||
try:
|
||||
# __metadata__ is a tuple
|
||||
self.is_variadic = self.type.__metadata__[0] == HAYSTACK_VARIADIC_ANNOTATION
|
||||
except AttributeError:
|
||||
self.is_variadic = False
|
||||
if self.is_variadic:
|
||||
# We need to "unpack" the type inside the Variadic annotation,
|
||||
# otherwise the pipeline connection api will try to match
|
||||
# `Annotated[type, HAYSTACK_VARIADIC_ANNOTATION]`.
|
||||
#
|
||||
# Note1: Variadic is expressed as an annotation of one single type,
|
||||
# so the return value of get_args will always be a one-item tuple.
|
||||
#
|
||||
# Note2: a pipeline always passes a list of items when a component
|
||||
# input is declared as Variadic, so the type itself always wraps
|
||||
# an iterable of the declared type. For example, Variadic[int]
|
||||
# is eventually an alias for Iterable[int]. Since we're interested
|
||||
# in getting the inner type `int`, we call `get_args` twice: the
|
||||
# first time to get `List[int]` out of `Variadic`, the second time
|
||||
# to get `int` out of `List[int]`.
|
||||
self.type = get_args(get_args(self.type)[0])[0]
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputSocket:
|
||||
name: str
|
||||
type: type
|
||||
receivers: List[str] = field(default_factory=list)
|
||||
|
@ -1,14 +1,13 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import List, Dict
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
|
||||
import networkx # type:ignore
|
||||
|
||||
from haystack.core.component.types import InputSocket, OutputSocket
|
||||
from haystack.core.type_utils import _type_name
|
||||
from haystack.core.component.sockets import InputSocket, OutputSocket
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -197,16 +197,17 @@ class Pipeline:
|
||||
)
|
||||
raise PipelineError(msg)
|
||||
|
||||
# Create the component's input and output sockets
|
||||
input_sockets = getattr(instance, "__haystack_input__", {})
|
||||
output_sockets = getattr(instance, "__haystack_output__", {})
|
||||
|
||||
setattr(instance, "__haystack_added_to_pipeline__", self)
|
||||
|
||||
# Add component to the graph, disconnected
|
||||
logger.debug("Adding component '%s' (%s)", name, instance)
|
||||
# We're completely sure the fields exist so we ignore the type error
|
||||
self.graph.add_node(
|
||||
name, instance=instance, input_sockets=input_sockets, output_sockets=output_sockets, visits=0
|
||||
name,
|
||||
instance=instance,
|
||||
input_sockets=instance.__haystack_input__._sockets_dict, # type: ignore[attr-defined]
|
||||
output_sockets=instance.__haystack_output__._sockets_dict, # type: ignore[attr-defined]
|
||||
visits=0,
|
||||
)
|
||||
|
||||
def connect(self, connect_from: str, connect_to: str) -> None:
|
||||
@ -381,6 +382,16 @@ class Pipeline:
|
||||
except KeyError as exc:
|
||||
raise ValueError(f"Component named {name} not found in the pipeline.") from exc
|
||||
|
||||
def get_component_name(self, instance: Component) -> str:
|
||||
"""
|
||||
Returns the name of a Component instance. If the Component has not been added to this Pipeline,
|
||||
returns an empty string.
|
||||
"""
|
||||
for name, inst in self.graph.nodes(data="instance"):
|
||||
if inst == instance:
|
||||
return name
|
||||
return ""
|
||||
|
||||
def inputs(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Returns a dictionary containing the inputs of a pipeline. Each key in the dictionary
|
||||
@ -465,16 +476,16 @@ class Pipeline:
|
||||
if component_name not in self.graph.nodes:
|
||||
raise ValueError(f"Component named {component_name} not found in the pipeline.")
|
||||
instance = self.graph.nodes[component_name]["instance"]
|
||||
for socket_name, socket in instance.__haystack_input__.items():
|
||||
for socket_name, socket in instance.__haystack_input__._sockets_dict.items():
|
||||
if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs:
|
||||
raise ValueError(f"Missing input for component {component_name}: {socket_name}")
|
||||
for input_name in component_inputs.keys():
|
||||
if input_name not in instance.__haystack_input__:
|
||||
if input_name not in instance.__haystack_input__._sockets_dict:
|
||||
raise ValueError(f"Input {input_name} not found in component {component_name}.")
|
||||
|
||||
for component_name in self.graph.nodes:
|
||||
instance = self.graph.nodes[component_name]["instance"]
|
||||
for socket_name, socket in instance.__haystack_input__.items():
|
||||
for socket_name, socket in instance.__haystack_input__._sockets_dict.items():
|
||||
component_inputs = data.get(component_name, {})
|
||||
if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs:
|
||||
raise ValueError(f"Missing input for component {component_name}: {socket_name}")
|
||||
@ -518,7 +529,7 @@ class Pipeline:
|
||||
for component_input, input_value in component_inputs.items():
|
||||
# Handle mutable input data
|
||||
data[component_name][component_input] = copy(input_value)
|
||||
if instance.__haystack_input__[component_input].is_variadic:
|
||||
if instance.__haystack_input__._sockets_dict[component_input].is_variadic:
|
||||
# Components that have variadic inputs need to receive lists as input.
|
||||
# We don't want to force the user to always pass lists, so we convert single values to lists here.
|
||||
# If it's already a list we assume the component takes a variadic input of lists, so we
|
||||
@ -533,12 +544,12 @@ class Pipeline:
|
||||
for node_name in self.graph.nodes:
|
||||
component = self.graph.nodes[node_name]["instance"]
|
||||
|
||||
if len(component.__haystack_input__) == 0:
|
||||
if len(component.__haystack_input__._sockets_dict) == 0:
|
||||
# Component has no input, can run right away
|
||||
to_run.append((node_name, component))
|
||||
continue
|
||||
|
||||
for socket in component.__haystack_input__.values():
|
||||
for socket in component.__haystack_input__._sockets_dict.values():
|
||||
if not socket.senders or socket.is_variadic:
|
||||
# Component has at least one input not connected or is variadic, can run right away.
|
||||
to_run.append((node_name, component))
|
||||
@ -561,12 +572,12 @@ class Pipeline:
|
||||
while len(to_run) > 0:
|
||||
name, comp = to_run.pop(0)
|
||||
|
||||
if any(socket.is_variadic for socket in comp.__haystack_input__.values()) and not getattr( # type: ignore
|
||||
if any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) and not getattr( # type: ignore
|
||||
comp, "is_greedy", False
|
||||
):
|
||||
there_are_non_variadics = False
|
||||
for _, other_comp in to_run:
|
||||
if not any(socket.is_variadic for socket in other_comp.__haystack_input__.values()): # type: ignore
|
||||
if not any(socket.is_variadic for socket in other_comp.__haystack_input__._sockets_dict.values()): # type: ignore
|
||||
there_are_non_variadics = True
|
||||
break
|
||||
|
||||
@ -575,7 +586,7 @@ class Pipeline:
|
||||
waiting_for_input.append((name, comp))
|
||||
continue
|
||||
|
||||
if name in last_inputs and len(comp.__haystack_input__) == len(last_inputs[name]): # type: ignore
|
||||
if name in last_inputs and len(comp.__haystack_input__._sockets_dict) == len(last_inputs[name]): # type: ignore
|
||||
# This component has all the inputs it needs to run
|
||||
res = comp.run(**last_inputs[name])
|
||||
|
||||
@ -649,7 +660,7 @@ class Pipeline:
|
||||
# This is our last resort, if there's no lazy variadic waiting for input
|
||||
# we're stuck for real and we can't make any progress.
|
||||
for name, comp in waiting_for_input:
|
||||
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__.values()) # type: ignore
|
||||
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore
|
||||
if is_variadic and not getattr(comp, "is_greedy", False):
|
||||
break
|
||||
else:
|
||||
@ -680,14 +691,14 @@ class Pipeline:
|
||||
last_inputs[name] = {}
|
||||
|
||||
# Lazy variadics must be removed only if there's nothing else to run at this stage
|
||||
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__.values()) # type: ignore
|
||||
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore
|
||||
if is_variadic and not getattr(comp, "is_greedy", False):
|
||||
there_are_only_lazy_variadics = True
|
||||
for other_name, other_comp in waiting_for_input:
|
||||
if name == other_name:
|
||||
continue
|
||||
there_are_only_lazy_variadics &= any(
|
||||
socket.is_variadic for socket in other_comp.__haystack_input__.values() # type: ignore
|
||||
socket.is_variadic for socket in other_comp.__haystack_input__._sockets_dict.values() # type: ignore
|
||||
) and not getattr(other_comp, "is_greedy", False)
|
||||
|
||||
if not there_are_only_lazy_variadics:
|
||||
@ -695,7 +706,7 @@ class Pipeline:
|
||||
|
||||
# Find the first component that has all the inputs it needs to run
|
||||
has_enough_inputs = True
|
||||
for input_socket in comp.__haystack_input__.values(): # type: ignore
|
||||
for input_socket in comp.__haystack_input__._sockets_dict.values(): # type: ignore
|
||||
if input_socket.is_mandatory and input_socket.name not in last_inputs[name]:
|
||||
has_enough_inputs = False
|
||||
break
|
||||
|
@ -9,11 +9,11 @@ from haystack.core.component import component
|
||||
@component
|
||||
class Repeat:
|
||||
def __init__(self, outputs: List[str]):
|
||||
self.outputs = outputs
|
||||
self._outputs = outputs
|
||||
component.set_output_types(self, **{k: int for k in outputs})
|
||||
|
||||
def run(self, value: int):
|
||||
"""
|
||||
:param value: the value to repeat.
|
||||
"""
|
||||
return {val: value for val in self.outputs}
|
||||
return {val: value for val in self._outputs}
|
||||
|
@ -15,16 +15,16 @@ class TestDynamicChatPromptBuilder:
|
||||
|
||||
# we have inputs that contain: prompt_source, template_variables + runtime_variables
|
||||
expected_keys = set(runtime_variables + ["prompt_source", "template_variables"])
|
||||
assert set(builder.__haystack_input__.keys()) == expected_keys
|
||||
assert set(builder.__haystack_input__._sockets_dict.keys()) == expected_keys
|
||||
|
||||
# response is always prompt regardless of chat mode
|
||||
assert set(builder.__haystack_output__.keys()) == {"prompt"}
|
||||
assert set(builder.__haystack_output__._sockets_dict.keys()) == {"prompt"}
|
||||
|
||||
# prompt_source is a list of ChatMessage
|
||||
assert builder.__haystack_input__["prompt_source"].type == List[ChatMessage]
|
||||
assert builder.__haystack_input__._sockets_dict["prompt_source"].type == List[ChatMessage]
|
||||
|
||||
# output is always prompt, but the type is different depending on the chat mode
|
||||
assert builder.__haystack_output__["prompt"].type == List[ChatMessage]
|
||||
assert builder.__haystack_output__._sockets_dict["prompt"].type == List[ChatMessage]
|
||||
|
||||
def test_non_empty_chat_messages(self):
|
||||
prompt_builder = DynamicChatPromptBuilder(runtime_variables=["documents"])
|
||||
|
@ -16,16 +16,16 @@ class TestDynamicPromptBuilder:
|
||||
# regardless of the chat mode
|
||||
# we have inputs that contain: prompt_source, template_variables + runtime_variables
|
||||
expected_keys = set(runtime_variables + ["prompt_source", "template_variables"])
|
||||
assert set(builder.__haystack_input__.keys()) == expected_keys
|
||||
assert set(builder.__haystack_input__._sockets_dict.keys()) == expected_keys
|
||||
|
||||
# response is always prompt regardless of chat mode
|
||||
assert set(builder.__haystack_output__.keys()) == {"prompt"}
|
||||
assert set(builder.__haystack_output__._sockets_dict.keys()) == {"prompt"}
|
||||
|
||||
# prompt_source is a list of ChatMessage or a string
|
||||
assert builder.__haystack_input__["prompt_source"].type == str
|
||||
assert builder.__haystack_input__._sockets_dict["prompt_source"].type == str
|
||||
|
||||
# output is always prompt, but the type is different depending on the chat mode
|
||||
assert builder.__haystack_output__["prompt"].type == str
|
||||
assert builder.__haystack_output__._sockets_dict["prompt"].type == str
|
||||
|
||||
def test_processing_a_simple_template_with_provided_variables(self):
|
||||
runtime_variables = ["var1", "var2", "var3"]
|
||||
|
@ -86,8 +86,8 @@ class TestRouter:
|
||||
router = ConditionalRouter(routes)
|
||||
|
||||
assert router.routes == routes
|
||||
assert set(router.__haystack_input__.keys()) == {"query", "streams"}
|
||||
assert set(router.__haystack_output__.keys()) == {"query", "streams"}
|
||||
assert set(router.__haystack_input__._sockets_dict.keys()) == {"query", "streams"}
|
||||
assert set(router.__haystack_output__._sockets_dict.keys()) == {"query", "streams"}
|
||||
|
||||
def test_router_evaluate_condition_expressions(self, router):
|
||||
# first route should be selected
|
||||
|
@ -1,5 +1,4 @@
|
||||
import typing
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
@ -89,6 +88,7 @@ def test_missing_run():
|
||||
|
||||
|
||||
def test_set_input_types():
|
||||
@component
|
||||
class MockComponent:
|
||||
def __init__(self):
|
||||
component.set_input_types(self, value=Any)
|
||||
@ -105,7 +105,7 @@ def test_set_input_types():
|
||||
return {"value": 1}
|
||||
|
||||
comp = MockComponent()
|
||||
assert comp.__haystack_input__ == {"value": InputSocket("value", Any)}
|
||||
assert comp.__haystack_input__._sockets_dict == {"value": InputSocket("value", Any)}
|
||||
assert comp.run() == {"value": 1}
|
||||
|
||||
|
||||
@ -126,7 +126,7 @@ def test_set_output_types():
|
||||
return {"value": 1}
|
||||
|
||||
comp = MockComponent()
|
||||
assert comp.__haystack_output__ == {"value": OutputSocket("value", int)}
|
||||
assert comp.__haystack_output__._sockets_dict == {"value": OutputSocket("value", int)}
|
||||
|
||||
|
||||
def test_output_types_decorator_with_compatible_type():
|
||||
@ -144,7 +144,7 @@ def test_output_types_decorator_with_compatible_type():
|
||||
return cls()
|
||||
|
||||
comp = MockComponent()
|
||||
assert comp.__haystack_output__ == {"value": OutputSocket("value", int)}
|
||||
assert comp.__haystack_output__._sockets_dict == {"value": OutputSocket("value", int)}
|
||||
|
||||
|
||||
def test_component_decorator_set_it_as_component():
|
||||
@ -173,8 +173,8 @@ def test_input_has_default_value():
|
||||
return {"value": value}
|
||||
|
||||
comp = MockComponent()
|
||||
assert comp.__haystack_input__["value"].default_value == 42
|
||||
assert not comp.__haystack_input__["value"].is_mandatory
|
||||
assert comp.__haystack_input__._sockets_dict["value"].default_value == 42
|
||||
assert not comp.__haystack_input__._sockets_dict["value"].is_mandatory
|
||||
|
||||
|
||||
def test_keyword_only_args():
|
||||
@ -187,5 +187,5 @@ def test_keyword_only_args():
|
||||
return {"value": arg}
|
||||
|
||||
comp = MockComponent()
|
||||
component_inputs = {name: {"type": socket.type} for name, socket in comp.__haystack_input__.items()}
|
||||
component_inputs = {name: {"type": socket.type} for name, socket in comp.__haystack_input__._sockets_dict.items()}
|
||||
assert component_inputs == {"arg": {"type": int}}
|
||||
|
57
test/core/component/test_sockets.py
Normal file
57
test/core/component/test_sockets.py
Normal file
@ -0,0 +1,57 @@
|
||||
import pytest
|
||||
|
||||
from haystack.core.component.sockets import InputSocket, Sockets
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.testing.factory import component_class
|
||||
|
||||
|
||||
class TestSockets:
|
||||
def test_init(self):
|
||||
comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})()
|
||||
sockets = {"input_1": InputSocket("input_1", int), "input_2": InputSocket("input_2", int)}
|
||||
io = Sockets(component=comp, sockets_dict=sockets, sockets_io_type=InputSocket)
|
||||
assert io._component == comp
|
||||
assert "input_1" in io.__dict__
|
||||
assert io.__dict__["input_1"] == comp.__haystack_input__._sockets_dict["input_1"]
|
||||
assert "input_2" in io.__dict__
|
||||
assert io.__dict__["input_2"] == comp.__haystack_input__._sockets_dict["input_2"]
|
||||
|
||||
def test_init_with_empty_sockets(self):
|
||||
comp = component_class("SomeComponent")()
|
||||
io = Sockets(component=comp, sockets_dict={}, sockets_io_type=InputSocket)
|
||||
|
||||
assert io._component == comp
|
||||
assert io._sockets_dict == {}
|
||||
|
||||
def test_component_name(self):
|
||||
comp = component_class("SomeComponent")()
|
||||
io = Sockets(component=comp, sockets_dict={}, sockets_io_type=InputSocket)
|
||||
assert io._component_name() == str(comp)
|
||||
|
||||
def test_component_name_added_to_pipeline(self):
|
||||
comp = component_class("SomeComponent")()
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_component("my_component", comp)
|
||||
|
||||
io = Sockets(component=comp, sockets_dict={}, sockets_io_type=InputSocket)
|
||||
assert io._component_name() == "my_component"
|
||||
|
||||
def test_getattribute(self):
|
||||
comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})()
|
||||
io = Sockets(component=comp, sockets_dict=comp.__haystack_input__._sockets_dict, sockets_io_type=InputSocket)
|
||||
|
||||
assert io.input_1 == comp.__haystack_input__._sockets_dict["input_1"]
|
||||
assert io.input_2 == comp.__haystack_input__._sockets_dict["input_2"]
|
||||
|
||||
def test_getattribute_non_existing_socket(self):
|
||||
comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})()
|
||||
io = Sockets(component=comp, sockets_dict=comp.__haystack_input__._sockets_dict, sockets_io_type=InputSocket)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
io.input_3
|
||||
|
||||
def test_repr(self):
|
||||
comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})()
|
||||
io = Sockets(component=comp, sockets_dict=comp.__haystack_input__._sockets_dict, sockets_io_type=InputSocket)
|
||||
res = repr(io)
|
||||
assert res == f"{comp} inputs:\n - input_1: int\n - input_2: int"
|
@ -6,7 +6,7 @@ from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.core.component.sockets import InputSocket, OutputSocket
|
||||
from haystack.core.component.types import InputSocket, OutputSocket
|
||||
from haystack.core.errors import PipelineError, PipelineRuntimeError
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.testing.factory import component_class
|
||||
@ -28,6 +28,21 @@ def test_add_component_to_different_pipelines():
|
||||
second_pipe.add_component("some", some_component)
|
||||
|
||||
|
||||
def test_get_component_name():
|
||||
pipe = Pipeline()
|
||||
some_component = component_class("Some")()
|
||||
pipe.add_component("some", some_component)
|
||||
|
||||
assert pipe.get_component_name(some_component) == "some"
|
||||
|
||||
|
||||
def test_get_component_name_not_added_to_pipeline():
|
||||
pipe = Pipeline()
|
||||
some_component = component_class("Some")()
|
||||
|
||||
assert pipe.get_component_name(some_component) == ""
|
||||
|
||||
|
||||
def test_run_with_component_that_does_not_return_dict():
|
||||
BrokenComponent = component_class(
|
||||
"BrokenComponent", input_types={"a": int}, output_types={"b": int}, output=1 # type:ignore
|
||||
|
@ -5,8 +5,7 @@ from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.core.component.sockets import InputSocket, OutputSocket
|
||||
from haystack.core.component.types import Variadic
|
||||
from haystack.core.component.types import InputSocket, OutputSocket, Variadic
|
||||
from haystack.core.errors import PipelineValidationError
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.core.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs
|
||||
|
Loading…
x
Reference in New Issue
Block a user