feat: Change Component's I/O dunder type (#6916)

* Add Pipeline.get_component_name() method

* Add utility class to ease discoverability of Component I/O

* Move InputOutput in component package

* Rename InputOutput to _InputOutput

* Raise if inputs or outputs field already exist

* Fix tests

* Add release notes

* Move InputSocket and OutputSocket in types package

* Move _InputOutput in socket package

* Rename _InputOutput class to Sockets

* Simplify Sockets class

* Dictch I/O dunder fields in favour of inputs and outputs fields

* Update Sockets docstrings

* Update release notes

* Fix mypy

* Remove unnecessary assignment

* Remove unused logging

* Change SocketsType to SocketsIOType to avoid confusion

* Change sockets type and name

* Change Sockets.__repr__ to return component instance

* Fix linting

* Fix sockets tests

* Revert to dunder fields for Component IO

* Use singular in IO dunder fields

* Delete release notes

* Update haystack/core/component/types.py

Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>

---------

Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
This commit is contained in:
Silvano Cerza 2024-02-05 17:46:45 +01:00 committed by GitHub
parent 3bd6ba93ca
commit 0191b1e6e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 282 additions and 95 deletions

View File

@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from haystack.core.component.component import component, Component
from haystack.core.component.sockets import InputSocket, OutputSocket
from haystack.core.component.component import Component, component
from haystack.core.component.types import InputSocket, OutputSocket
__all__ = ["component", "Component", "InputSocket", "OutputSocket"]

View File

@ -74,9 +74,11 @@ from copy import deepcopy
from types import new_class
from typing import Any, Protocol, runtime_checkable
from haystack.core.component.sockets import InputSocket, OutputSocket, _empty
from haystack.core.errors import ComponentError
from .sockets import Sockets
from .types import InputSocket, OutputSocket, _empty
logger = logging.getLogger(__name__)
@ -131,12 +133,14 @@ class ComponentMeta(type):
# that stores the output specification.
# We deepcopy the content of the cache to transfer ownership from the class method
# to the actual instance, so that different instances of the same class won't share this data.
instance.__haystack_output__ = deepcopy(getattr(instance.run, "_output_types_cache", {}))
instance.__haystack_output__ = Sockets(
instance, deepcopy(getattr(instance.run, "_output_types_cache", {})), OutputSocket
)
# Create the sockets if set_input_types() wasn't called in the constructor.
# If it was called and there are some parameters also in the `run()` method, these take precedence.
if not hasattr(instance, "__haystack_input__"):
instance.__haystack_input__ = {}
instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
run_signature = inspect.signature(getattr(cls, "run"))
for param in list(run_signature.parameters)[1:]: # First is 'self' and it doesn't matter.
if run_signature.parameters[param].kind not in (
@ -185,7 +189,7 @@ class _Component:
:param default: default value of the input socket, defaults to _empty
"""
if not hasattr(instance, "__haystack_input__"):
instance.__haystack_input__ = {}
instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
instance.__haystack_input__[name] = InputSocket(name=name, type=type, default_value=default)
def set_input_types(self, instance, **types):
@ -229,7 +233,9 @@ class _Component:
parameter mandatory as specified in `set_input_types`.
"""
instance.__haystack_input__ = {name: InputSocket(name=name, type=type_) for name, type_ in types.items()}
instance.__haystack_input__ = Sockets(
instance, {name: InputSocket(name=name, type=type_) for name, type_ in types.items()}, InputSocket
)
def set_output_types(self, instance, **types):
"""
@ -251,7 +257,9 @@ class _Component:
return {"output_1": 1, "output_2": "2"}
```
"""
instance.__haystack_output__ = {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()}
instance.__haystack_output__ = Sockets(
instance, {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()}, OutputSocket
)
def output_types(self, **types):
"""

View File

@ -1,57 +1,107 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import logging
from dataclasses import dataclass, field
from typing import Any, List, Type, get_args
from haystack.core.component.types import HAYSTACK_VARIADIC_ANNOTATION
import logging
from typing import Dict, Type, Union
from haystack.core.type_utils import _type_name
from .types import InputSocket, OutputSocket
logger = logging.getLogger(__name__)
class _empty:
"""Custom object for marking InputSocket.default_value as not set."""
SocketsDict = Dict[str, Union[InputSocket, OutputSocket]]
SocketsIOType = Union[Type[InputSocket], Type[OutputSocket]]
@dataclass
class InputSocket:
name: str
type: Type
default_value: Any = _empty
is_variadic: bool = field(init=False)
senders: List[str] = field(default_factory=list)
class Sockets:
"""
This class is used to represent the inputs or outputs of a `Component`.
Depending on the type passed to the constructor, it will represent either the inputs or the outputs of
the `Component`.
@property
def is_mandatory(self):
return self.default_value == _empty
Usage:
```python
from haystack.components.builders.prompt_builder import PromptBuilder
def __post_init__(self):
prompt_template = \"""
Given these documents, answer the question.\nDocuments:
{% for doc in documents %}
{{ doc.content }}
{% endfor %}
\nQuestion: {{question}}
\nAnswer:
\"""
prompt_builder = PromptBuilder(template=prompt_template)
sockets = {"question": InputSocket("question", Any), "documents": InputSocket("documents", Any)}
inputs = Sockets(component=prompt_builder, sockets=sockets, sockets_type=InputSocket)
inputs
>>> PromptBuilder inputs:
>>> - question: Any
>>> - documents: Any
inputs.question
>>> InputSocket(name='question', type=typing.Any, default_value=<class 'haystack.core.component.types._empty'>, is_variadic=False, senders=[])
```
"""
# We're using a forward declaration here to avoid a circular import.
def __init__(
self,
component: "Component", # type: ignore[name-defined] # noqa: F821
sockets_dict: SocketsDict,
sockets_io_type: SocketsIOType,
):
"""
Create a new Sockets object.
We don't do any enforcement on the types of the sockets here, the `sockets_type` is only used for
the `__repr__` method.
We could do without it and use the type of a random value in the `sockets` dict, but that wouldn't
work for components that have no sockets at all. Either input or output.
"""
self._sockets_io_type = sockets_io_type
self._component = component
self._sockets_dict = sockets_dict
self.__dict__.update(sockets_dict)
def __setitem__(self, key: str, socket: Union[InputSocket, OutputSocket]):
"""
Adds a new socket to this Sockets object.
This eases a bit updating the list of sockets after Sockets has been created.
That should happen only in the `component` decorator.
"""
self._sockets_dict[key] = socket
self.__dict__[key] = socket
def _component_name(self) -> str:
if pipeline := getattr(self._component, "__haystack_added_to_pipeline__"):
# This Component has been added in a Pipeline, let's get the name from there.
return pipeline.get_component_name(self._component)
# This Component has not been added to a Pipeline yet, so we can't know its name.
# Let's use the class name instead.
return str(self._component)
def __getattribute__(self, name):
try:
# __metadata__ is a tuple
self.is_variadic = self.type.__metadata__[0] == HAYSTACK_VARIADIC_ANNOTATION
sockets = object.__getattribute__(self, "_sockets")
if name in sockets:
return sockets[name]
except AttributeError:
self.is_variadic = False
if self.is_variadic:
# We need to "unpack" the type inside the Variadic annotation,
# otherwise the pipeline connection api will try to match
# `Annotated[type, CANALS_VARIADIC_ANNOTATION]`.
#
# Note1: Variadic is expressed as an annotation of one single type,
# so the return value of get_args will always be a one-item tuple.
#
# Note2: a pipeline always passes a list of items when a component
# input is declared as Variadic, so the type itself always wraps
# an iterable of the declared type. For example, Variadic[int]
# is eventually an alias for Iterable[int]. Since we're interested
# in getting the inner type `int`, we call `get_args` twice: the
# first time to get `List[int]` out of `Variadic`, the second time
# to get `int` out of `List[int]`.
self.type = get_args(get_args(self.type)[0])[0]
pass
return object.__getattribute__(self, name)
@dataclass
class OutputSocket:
name: str
type: type
receivers: List[str] = field(default_factory=list)
def __repr__(self) -> str:
result = self._component_name()
if self._sockets_io_type == InputSocket:
result += " inputs:\n"
elif self._sockets_io_type == OutputSocket:
result += " outputs:\n"
result += "\n".join([f" - {n}: {_type_name(s.type)}" for n, s in self._sockets_dict.items()])
return result

View File

@ -1,4 +1,5 @@
from typing import Iterable, TypeVar
from dataclasses import dataclass, field
from typing import Any, Iterable, List, Type, TypeVar, get_args
from typing_extensions import Annotated, TypeAlias # Python 3.8 compatibility
@ -13,3 +14,50 @@ T = TypeVar("T")
# type so it can be used in the `InputSocket` creation where we
# check that its annotation equals to CANALS_VARIADIC_ANNOTATION
Variadic: TypeAlias = Annotated[Iterable[T], HAYSTACK_VARIADIC_ANNOTATION]
class _empty:
"""Custom object for marking InputSocket.default_value as not set."""
@dataclass
class InputSocket:
name: str
type: Type
default_value: Any = _empty
is_variadic: bool = field(init=False)
senders: List[str] = field(default_factory=list)
@property
def is_mandatory(self):
return self.default_value == _empty
def __post_init__(self):
try:
# __metadata__ is a tuple
self.is_variadic = self.type.__metadata__[0] == HAYSTACK_VARIADIC_ANNOTATION
except AttributeError:
self.is_variadic = False
if self.is_variadic:
# We need to "unpack" the type inside the Variadic annotation,
# otherwise the pipeline connection api will try to match
# `Annotated[type, HAYSTACK_VARIADIC_ANNOTATION]`.
#
# Note1: Variadic is expressed as an annotation of one single type,
# so the return value of get_args will always be a one-item tuple.
#
# Note2: a pipeline always passes a list of items when a component
# input is declared as Variadic, so the type itself always wraps
# an iterable of the declared type. For example, Variadic[int]
# is eventually an alias for Iterable[int]. Since we're interested
# in getting the inner type `int`, we call `get_args` twice: the
# first time to get `List[int]` out of `Variadic`, the second time
# to get `int` out of `List[int]`.
self.type = get_args(get_args(self.type)[0])[0]
@dataclass
class OutputSocket:
name: str
type: type
receivers: List[str] = field(default_factory=list)

View File

@ -1,14 +1,13 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import List, Dict
import logging
from typing import Dict, List
import networkx # type:ignore
from haystack.core.component.types import InputSocket, OutputSocket
from haystack.core.type_utils import _type_name
from haystack.core.component.sockets import InputSocket, OutputSocket
logger = logging.getLogger(__name__)

View File

@ -197,16 +197,17 @@ class Pipeline:
)
raise PipelineError(msg)
# Create the component's input and output sockets
input_sockets = getattr(instance, "__haystack_input__", {})
output_sockets = getattr(instance, "__haystack_output__", {})
setattr(instance, "__haystack_added_to_pipeline__", self)
# Add component to the graph, disconnected
logger.debug("Adding component '%s' (%s)", name, instance)
# We're completely sure the fields exist so we ignore the type error
self.graph.add_node(
name, instance=instance, input_sockets=input_sockets, output_sockets=output_sockets, visits=0
name,
instance=instance,
input_sockets=instance.__haystack_input__._sockets_dict, # type: ignore[attr-defined]
output_sockets=instance.__haystack_output__._sockets_dict, # type: ignore[attr-defined]
visits=0,
)
def connect(self, connect_from: str, connect_to: str) -> None:
@ -381,6 +382,16 @@ class Pipeline:
except KeyError as exc:
raise ValueError(f"Component named {name} not found in the pipeline.") from exc
def get_component_name(self, instance: Component) -> str:
"""
Returns the name of a Component instance. If the Component has not been added to this Pipeline,
returns an empty string.
"""
for name, inst in self.graph.nodes(data="instance"):
if inst == instance:
return name
return ""
def inputs(self) -> Dict[str, Dict[str, Any]]:
"""
Returns a dictionary containing the inputs of a pipeline. Each key in the dictionary
@ -465,16 +476,16 @@ class Pipeline:
if component_name not in self.graph.nodes:
raise ValueError(f"Component named {component_name} not found in the pipeline.")
instance = self.graph.nodes[component_name]["instance"]
for socket_name, socket in instance.__haystack_input__.items():
for socket_name, socket in instance.__haystack_input__._sockets_dict.items():
if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs:
raise ValueError(f"Missing input for component {component_name}: {socket_name}")
for input_name in component_inputs.keys():
if input_name not in instance.__haystack_input__:
if input_name not in instance.__haystack_input__._sockets_dict:
raise ValueError(f"Input {input_name} not found in component {component_name}.")
for component_name in self.graph.nodes:
instance = self.graph.nodes[component_name]["instance"]
for socket_name, socket in instance.__haystack_input__.items():
for socket_name, socket in instance.__haystack_input__._sockets_dict.items():
component_inputs = data.get(component_name, {})
if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs:
raise ValueError(f"Missing input for component {component_name}: {socket_name}")
@ -518,7 +529,7 @@ class Pipeline:
for component_input, input_value in component_inputs.items():
# Handle mutable input data
data[component_name][component_input] = copy(input_value)
if instance.__haystack_input__[component_input].is_variadic:
if instance.__haystack_input__._sockets_dict[component_input].is_variadic:
# Components that have variadic inputs need to receive lists as input.
# We don't want to force the user to always pass lists, so we convert single values to lists here.
# If it's already a list we assume the component takes a variadic input of lists, so we
@ -533,12 +544,12 @@ class Pipeline:
for node_name in self.graph.nodes:
component = self.graph.nodes[node_name]["instance"]
if len(component.__haystack_input__) == 0:
if len(component.__haystack_input__._sockets_dict) == 0:
# Component has no input, can run right away
to_run.append((node_name, component))
continue
for socket in component.__haystack_input__.values():
for socket in component.__haystack_input__._sockets_dict.values():
if not socket.senders or socket.is_variadic:
# Component has at least one input not connected or is variadic, can run right away.
to_run.append((node_name, component))
@ -561,12 +572,12 @@ class Pipeline:
while len(to_run) > 0:
name, comp = to_run.pop(0)
if any(socket.is_variadic for socket in comp.__haystack_input__.values()) and not getattr( # type: ignore
if any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) and not getattr( # type: ignore
comp, "is_greedy", False
):
there_are_non_variadics = False
for _, other_comp in to_run:
if not any(socket.is_variadic for socket in other_comp.__haystack_input__.values()): # type: ignore
if not any(socket.is_variadic for socket in other_comp.__haystack_input__._sockets_dict.values()): # type: ignore
there_are_non_variadics = True
break
@ -575,7 +586,7 @@ class Pipeline:
waiting_for_input.append((name, comp))
continue
if name in last_inputs and len(comp.__haystack_input__) == len(last_inputs[name]): # type: ignore
if name in last_inputs and len(comp.__haystack_input__._sockets_dict) == len(last_inputs[name]): # type: ignore
# This component has all the inputs it needs to run
res = comp.run(**last_inputs[name])
@ -649,7 +660,7 @@ class Pipeline:
# This is our last resort, if there's no lazy variadic waiting for input
# we're stuck for real and we can't make any progress.
for name, comp in waiting_for_input:
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__.values()) # type: ignore
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore
if is_variadic and not getattr(comp, "is_greedy", False):
break
else:
@ -680,14 +691,14 @@ class Pipeline:
last_inputs[name] = {}
# Lazy variadics must be removed only if there's nothing else to run at this stage
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__.values()) # type: ignore
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore
if is_variadic and not getattr(comp, "is_greedy", False):
there_are_only_lazy_variadics = True
for other_name, other_comp in waiting_for_input:
if name == other_name:
continue
there_are_only_lazy_variadics &= any(
socket.is_variadic for socket in other_comp.__haystack_input__.values() # type: ignore
socket.is_variadic for socket in other_comp.__haystack_input__._sockets_dict.values() # type: ignore
) and not getattr(other_comp, "is_greedy", False)
if not there_are_only_lazy_variadics:
@ -695,7 +706,7 @@ class Pipeline:
# Find the first component that has all the inputs it needs to run
has_enough_inputs = True
for input_socket in comp.__haystack_input__.values(): # type: ignore
for input_socket in comp.__haystack_input__._sockets_dict.values(): # type: ignore
if input_socket.is_mandatory and input_socket.name not in last_inputs[name]:
has_enough_inputs = False
break

View File

@ -9,11 +9,11 @@ from haystack.core.component import component
@component
class Repeat:
def __init__(self, outputs: List[str]):
self.outputs = outputs
self._outputs = outputs
component.set_output_types(self, **{k: int for k in outputs})
def run(self, value: int):
"""
:param value: the value to repeat.
"""
return {val: value for val in self.outputs}
return {val: value for val in self._outputs}

View File

@ -15,16 +15,16 @@ class TestDynamicChatPromptBuilder:
# we have inputs that contain: prompt_source, template_variables + runtime_variables
expected_keys = set(runtime_variables + ["prompt_source", "template_variables"])
assert set(builder.__haystack_input__.keys()) == expected_keys
assert set(builder.__haystack_input__._sockets_dict.keys()) == expected_keys
# response is always prompt regardless of chat mode
assert set(builder.__haystack_output__.keys()) == {"prompt"}
assert set(builder.__haystack_output__._sockets_dict.keys()) == {"prompt"}
# prompt_source is a list of ChatMessage
assert builder.__haystack_input__["prompt_source"].type == List[ChatMessage]
assert builder.__haystack_input__._sockets_dict["prompt_source"].type == List[ChatMessage]
# output is always prompt, but the type is different depending on the chat mode
assert builder.__haystack_output__["prompt"].type == List[ChatMessage]
assert builder.__haystack_output__._sockets_dict["prompt"].type == List[ChatMessage]
def test_non_empty_chat_messages(self):
prompt_builder = DynamicChatPromptBuilder(runtime_variables=["documents"])

View File

@ -16,16 +16,16 @@ class TestDynamicPromptBuilder:
# regardless of the chat mode
# we have inputs that contain: prompt_source, template_variables + runtime_variables
expected_keys = set(runtime_variables + ["prompt_source", "template_variables"])
assert set(builder.__haystack_input__.keys()) == expected_keys
assert set(builder.__haystack_input__._sockets_dict.keys()) == expected_keys
# response is always prompt regardless of chat mode
assert set(builder.__haystack_output__.keys()) == {"prompt"}
assert set(builder.__haystack_output__._sockets_dict.keys()) == {"prompt"}
# prompt_source is a list of ChatMessage or a string
assert builder.__haystack_input__["prompt_source"].type == str
assert builder.__haystack_input__._sockets_dict["prompt_source"].type == str
# output is always prompt, but the type is different depending on the chat mode
assert builder.__haystack_output__["prompt"].type == str
assert builder.__haystack_output__._sockets_dict["prompt"].type == str
def test_processing_a_simple_template_with_provided_variables(self):
runtime_variables = ["var1", "var2", "var3"]

View File

@ -86,8 +86,8 @@ class TestRouter:
router = ConditionalRouter(routes)
assert router.routes == routes
assert set(router.__haystack_input__.keys()) == {"query", "streams"}
assert set(router.__haystack_output__.keys()) == {"query", "streams"}
assert set(router.__haystack_input__._sockets_dict.keys()) == {"query", "streams"}
assert set(router.__haystack_output__._sockets_dict.keys()) == {"query", "streams"}
def test_router_evaluate_condition_expressions(self, router):
# first route should be selected

View File

@ -1,5 +1,4 @@
import typing
from typing import Any, Optional
from typing import Any
import pytest
@ -89,6 +88,7 @@ def test_missing_run():
def test_set_input_types():
@component
class MockComponent:
def __init__(self):
component.set_input_types(self, value=Any)
@ -105,7 +105,7 @@ def test_set_input_types():
return {"value": 1}
comp = MockComponent()
assert comp.__haystack_input__ == {"value": InputSocket("value", Any)}
assert comp.__haystack_input__._sockets_dict == {"value": InputSocket("value", Any)}
assert comp.run() == {"value": 1}
@ -126,7 +126,7 @@ def test_set_output_types():
return {"value": 1}
comp = MockComponent()
assert comp.__haystack_output__ == {"value": OutputSocket("value", int)}
assert comp.__haystack_output__._sockets_dict == {"value": OutputSocket("value", int)}
def test_output_types_decorator_with_compatible_type():
@ -144,7 +144,7 @@ def test_output_types_decorator_with_compatible_type():
return cls()
comp = MockComponent()
assert comp.__haystack_output__ == {"value": OutputSocket("value", int)}
assert comp.__haystack_output__._sockets_dict == {"value": OutputSocket("value", int)}
def test_component_decorator_set_it_as_component():
@ -173,8 +173,8 @@ def test_input_has_default_value():
return {"value": value}
comp = MockComponent()
assert comp.__haystack_input__["value"].default_value == 42
assert not comp.__haystack_input__["value"].is_mandatory
assert comp.__haystack_input__._sockets_dict["value"].default_value == 42
assert not comp.__haystack_input__._sockets_dict["value"].is_mandatory
def test_keyword_only_args():
@ -187,5 +187,5 @@ def test_keyword_only_args():
return {"value": arg}
comp = MockComponent()
component_inputs = {name: {"type": socket.type} for name, socket in comp.__haystack_input__.items()}
component_inputs = {name: {"type": socket.type} for name, socket in comp.__haystack_input__._sockets_dict.items()}
assert component_inputs == {"arg": {"type": int}}

View File

@ -0,0 +1,57 @@
import pytest
from haystack.core.component.sockets import InputSocket, Sockets
from haystack.core.pipeline import Pipeline
from haystack.testing.factory import component_class
class TestSockets:
def test_init(self):
comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})()
sockets = {"input_1": InputSocket("input_1", int), "input_2": InputSocket("input_2", int)}
io = Sockets(component=comp, sockets_dict=sockets, sockets_io_type=InputSocket)
assert io._component == comp
assert "input_1" in io.__dict__
assert io.__dict__["input_1"] == comp.__haystack_input__._sockets_dict["input_1"]
assert "input_2" in io.__dict__
assert io.__dict__["input_2"] == comp.__haystack_input__._sockets_dict["input_2"]
def test_init_with_empty_sockets(self):
comp = component_class("SomeComponent")()
io = Sockets(component=comp, sockets_dict={}, sockets_io_type=InputSocket)
assert io._component == comp
assert io._sockets_dict == {}
def test_component_name(self):
comp = component_class("SomeComponent")()
io = Sockets(component=comp, sockets_dict={}, sockets_io_type=InputSocket)
assert io._component_name() == str(comp)
def test_component_name_added_to_pipeline(self):
comp = component_class("SomeComponent")()
pipeline = Pipeline()
pipeline.add_component("my_component", comp)
io = Sockets(component=comp, sockets_dict={}, sockets_io_type=InputSocket)
assert io._component_name() == "my_component"
def test_getattribute(self):
comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})()
io = Sockets(component=comp, sockets_dict=comp.__haystack_input__._sockets_dict, sockets_io_type=InputSocket)
assert io.input_1 == comp.__haystack_input__._sockets_dict["input_1"]
assert io.input_2 == comp.__haystack_input__._sockets_dict["input_2"]
def test_getattribute_non_existing_socket(self):
comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})()
io = Sockets(component=comp, sockets_dict=comp.__haystack_input__._sockets_dict, sockets_io_type=InputSocket)
with pytest.raises(AttributeError):
io.input_3
def test_repr(self):
comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})()
io = Sockets(component=comp, sockets_dict=comp.__haystack_input__._sockets_dict, sockets_io_type=InputSocket)
res = repr(io)
assert res == f"{comp} inputs:\n - input_1: int\n - input_2: int"

View File

@ -6,7 +6,7 @@ from typing import Optional
import pytest
from haystack.core.component.sockets import InputSocket, OutputSocket
from haystack.core.component.types import InputSocket, OutputSocket
from haystack.core.errors import PipelineError, PipelineRuntimeError
from haystack.core.pipeline import Pipeline
from haystack.testing.factory import component_class
@ -28,6 +28,21 @@ def test_add_component_to_different_pipelines():
second_pipe.add_component("some", some_component)
def test_get_component_name():
pipe = Pipeline()
some_component = component_class("Some")()
pipe.add_component("some", some_component)
assert pipe.get_component_name(some_component) == "some"
def test_get_component_name_not_added_to_pipeline():
pipe = Pipeline()
some_component = component_class("Some")()
assert pipe.get_component_name(some_component) == ""
def test_run_with_component_that_does_not_return_dict():
BrokenComponent = component_class(
"BrokenComponent", input_types={"a": int}, output_types={"b": int}, output=1 # type:ignore

View File

@ -5,8 +5,7 @@ from typing import Optional
import pytest
from haystack.core.component.sockets import InputSocket, OutputSocket
from haystack.core.component.types import Variadic
from haystack.core.component.types import InputSocket, OutputSocket, Variadic
from haystack.core.errors import PipelineValidationError
from haystack.core.pipeline import Pipeline
from haystack.core.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs