diff --git a/haystack/core/component/__init__.py b/haystack/core/component/__init__.py index 3a292edaf..dea761412 100644 --- a/haystack/core/component/__init__.py +++ b/haystack/core/component/__init__.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from haystack.core.component.component import component, Component -from haystack.core.component.sockets import InputSocket, OutputSocket +from haystack.core.component.component import Component, component +from haystack.core.component.types import InputSocket, OutputSocket __all__ = ["component", "Component", "InputSocket", "OutputSocket"] diff --git a/haystack/core/component/component.py b/haystack/core/component/component.py index 5fab8455f..95ae87d73 100644 --- a/haystack/core/component/component.py +++ b/haystack/core/component/component.py @@ -74,9 +74,11 @@ from copy import deepcopy from types import new_class from typing import Any, Protocol, runtime_checkable -from haystack.core.component.sockets import InputSocket, OutputSocket, _empty from haystack.core.errors import ComponentError +from .sockets import Sockets +from .types import InputSocket, OutputSocket, _empty + logger = logging.getLogger(__name__) @@ -131,12 +133,14 @@ class ComponentMeta(type): # that stores the output specification. # We deepcopy the content of the cache to transfer ownership from the class method # to the actual instance, so that different instances of the same class won't share this data. - instance.__haystack_output__ = deepcopy(getattr(instance.run, "_output_types_cache", {})) + instance.__haystack_output__ = Sockets( + instance, deepcopy(getattr(instance.run, "_output_types_cache", {})), OutputSocket + ) # Create the sockets if set_input_types() wasn't called in the constructor. # If it was called and there are some parameters also in the `run()` method, these take precedence. if not hasattr(instance, "__haystack_input__"): - instance.__haystack_input__ = {} + instance.__haystack_input__ = Sockets(instance, {}, InputSocket) run_signature = inspect.signature(getattr(cls, "run")) for param in list(run_signature.parameters)[1:]: # First is 'self' and it doesn't matter. if run_signature.parameters[param].kind not in ( @@ -185,7 +189,7 @@ class _Component: :param default: default value of the input socket, defaults to _empty """ if not hasattr(instance, "__haystack_input__"): - instance.__haystack_input__ = {} + instance.__haystack_input__ = Sockets(instance, {}, InputSocket) instance.__haystack_input__[name] = InputSocket(name=name, type=type, default_value=default) def set_input_types(self, instance, **types): @@ -229,7 +233,9 @@ class _Component: parameter mandatory as specified in `set_input_types`. """ - instance.__haystack_input__ = {name: InputSocket(name=name, type=type_) for name, type_ in types.items()} + instance.__haystack_input__ = Sockets( + instance, {name: InputSocket(name=name, type=type_) for name, type_ in types.items()}, InputSocket + ) def set_output_types(self, instance, **types): """ @@ -251,7 +257,9 @@ class _Component: return {"output_1": 1, "output_2": "2"} ``` """ - instance.__haystack_output__ = {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()} + instance.__haystack_output__ = Sockets( + instance, {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()}, OutputSocket + ) def output_types(self, **types): """ diff --git a/haystack/core/component/sockets.py b/haystack/core/component/sockets.py index ff3080dcf..25bf4fdc8 100644 --- a/haystack/core/component/sockets.py +++ b/haystack/core/component/sockets.py @@ -1,57 +1,107 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -import logging -from dataclasses import dataclass, field -from typing import Any, List, Type, get_args -from haystack.core.component.types import HAYSTACK_VARIADIC_ANNOTATION +import logging +from typing import Dict, Type, Union + +from haystack.core.type_utils import _type_name + +from .types import InputSocket, OutputSocket logger = logging.getLogger(__name__) - -class _empty: - """Custom object for marking InputSocket.default_value as not set.""" +SocketsDict = Dict[str, Union[InputSocket, OutputSocket]] +SocketsIOType = Union[Type[InputSocket], Type[OutputSocket]] -@dataclass -class InputSocket: - name: str - type: Type - default_value: Any = _empty - is_variadic: bool = field(init=False) - senders: List[str] = field(default_factory=list) +class Sockets: + """ + This class is used to represent the inputs or outputs of a `Component`. + Depending on the type passed to the constructor, it will represent either the inputs or the outputs of + the `Component`. - @property - def is_mandatory(self): - return self.default_value == _empty + Usage: + ```python + from haystack.components.builders.prompt_builder import PromptBuilder - def __post_init__(self): + prompt_template = \""" + Given these documents, answer the question.\nDocuments: + {% for doc in documents %} + {{ doc.content }} + {% endfor %} + + \nQuestion: {{question}} + \nAnswer: + \""" + + prompt_builder = PromptBuilder(template=prompt_template) + sockets = {"question": InputSocket("question", Any), "documents": InputSocket("documents", Any)} + inputs = Sockets(component=prompt_builder, sockets=sockets, sockets_type=InputSocket) + inputs + >>> PromptBuilder inputs: + >>> - question: Any + >>> - documents: Any + + inputs.question + >>> InputSocket(name='question', type=typing.Any, default_value=, is_variadic=False, senders=[]) + ``` + """ + + # We're using a forward declaration here to avoid a circular import. + def __init__( + self, + component: "Component", # type: ignore[name-defined] # noqa: F821 + sockets_dict: SocketsDict, + sockets_io_type: SocketsIOType, + ): + """ + Create a new Sockets object. + We don't do any enforcement on the types of the sockets here, the `sockets_type` is only used for + the `__repr__` method. + We could do without it and use the type of a random value in the `sockets` dict, but that wouldn't + work for components that have no sockets at all. Either input or output. + """ + self._sockets_io_type = sockets_io_type + self._component = component + self._sockets_dict = sockets_dict + self.__dict__.update(sockets_dict) + + def __setitem__(self, key: str, socket: Union[InputSocket, OutputSocket]): + """ + Adds a new socket to this Sockets object. + This eases a bit updating the list of sockets after Sockets has been created. + That should happen only in the `component` decorator. + """ + self._sockets_dict[key] = socket + self.__dict__[key] = socket + + def _component_name(self) -> str: + if pipeline := getattr(self._component, "__haystack_added_to_pipeline__"): + # This Component has been added in a Pipeline, let's get the name from there. + return pipeline.get_component_name(self._component) + + # This Component has not been added to a Pipeline yet, so we can't know its name. + # Let's use the class name instead. + return str(self._component) + + def __getattribute__(self, name): try: - # __metadata__ is a tuple - self.is_variadic = self.type.__metadata__[0] == HAYSTACK_VARIADIC_ANNOTATION + sockets = object.__getattribute__(self, "_sockets") + if name in sockets: + return sockets[name] except AttributeError: - self.is_variadic = False - if self.is_variadic: - # We need to "unpack" the type inside the Variadic annotation, - # otherwise the pipeline connection api will try to match - # `Annotated[type, CANALS_VARIADIC_ANNOTATION]`. - # - # Note1: Variadic is expressed as an annotation of one single type, - # so the return value of get_args will always be a one-item tuple. - # - # Note2: a pipeline always passes a list of items when a component - # input is declared as Variadic, so the type itself always wraps - # an iterable of the declared type. For example, Variadic[int] - # is eventually an alias for Iterable[int]. Since we're interested - # in getting the inner type `int`, we call `get_args` twice: the - # first time to get `List[int]` out of `Variadic`, the second time - # to get `int` out of `List[int]`. - self.type = get_args(get_args(self.type)[0])[0] + pass + return object.__getattribute__(self, name) -@dataclass -class OutputSocket: - name: str - type: type - receivers: List[str] = field(default_factory=list) + def __repr__(self) -> str: + result = self._component_name() + if self._sockets_io_type == InputSocket: + result += " inputs:\n" + elif self._sockets_io_type == OutputSocket: + result += " outputs:\n" + + result += "\n".join([f" - {n}: {_type_name(s.type)}" for n, s in self._sockets_dict.items()]) + + return result diff --git a/haystack/core/component/types.py b/haystack/core/component/types.py index 252787902..663bccf4f 100644 --- a/haystack/core/component/types.py +++ b/haystack/core/component/types.py @@ -1,4 +1,5 @@ -from typing import Iterable, TypeVar +from dataclasses import dataclass, field +from typing import Any, Iterable, List, Type, TypeVar, get_args from typing_extensions import Annotated, TypeAlias # Python 3.8 compatibility @@ -13,3 +14,50 @@ T = TypeVar("T") # type so it can be used in the `InputSocket` creation where we # check that its annotation equals to CANALS_VARIADIC_ANNOTATION Variadic: TypeAlias = Annotated[Iterable[T], HAYSTACK_VARIADIC_ANNOTATION] + + +class _empty: + """Custom object for marking InputSocket.default_value as not set.""" + + +@dataclass +class InputSocket: + name: str + type: Type + default_value: Any = _empty + is_variadic: bool = field(init=False) + senders: List[str] = field(default_factory=list) + + @property + def is_mandatory(self): + return self.default_value == _empty + + def __post_init__(self): + try: + # __metadata__ is a tuple + self.is_variadic = self.type.__metadata__[0] == HAYSTACK_VARIADIC_ANNOTATION + except AttributeError: + self.is_variadic = False + if self.is_variadic: + # We need to "unpack" the type inside the Variadic annotation, + # otherwise the pipeline connection api will try to match + # `Annotated[type, HAYSTACK_VARIADIC_ANNOTATION]`. + # + # Note1: Variadic is expressed as an annotation of one single type, + # so the return value of get_args will always be a one-item tuple. + # + # Note2: a pipeline always passes a list of items when a component + # input is declared as Variadic, so the type itself always wraps + # an iterable of the declared type. For example, Variadic[int] + # is eventually an alias for Iterable[int]. Since we're interested + # in getting the inner type `int`, we call `get_args` twice: the + # first time to get `List[int]` out of `Variadic`, the second time + # to get `int` out of `List[int]`. + self.type = get_args(get_args(self.type)[0])[0] + + +@dataclass +class OutputSocket: + name: str + type: type + receivers: List[str] = field(default_factory=list) diff --git a/haystack/core/pipeline/descriptions.py b/haystack/core/pipeline/descriptions.py index 406bf1967..0e7c04209 100644 --- a/haystack/core/pipeline/descriptions.py +++ b/haystack/core/pipeline/descriptions.py @@ -1,14 +1,13 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from typing import List, Dict import logging +from typing import Dict, List import networkx # type:ignore +from haystack.core.component.types import InputSocket, OutputSocket from haystack.core.type_utils import _type_name -from haystack.core.component.sockets import InputSocket, OutputSocket - logger = logging.getLogger(__name__) diff --git a/haystack/core/pipeline/pipeline.py b/haystack/core/pipeline/pipeline.py index eeecdcea4..a435f5202 100644 --- a/haystack/core/pipeline/pipeline.py +++ b/haystack/core/pipeline/pipeline.py @@ -197,16 +197,17 @@ class Pipeline: ) raise PipelineError(msg) - # Create the component's input and output sockets - input_sockets = getattr(instance, "__haystack_input__", {}) - output_sockets = getattr(instance, "__haystack_output__", {}) - setattr(instance, "__haystack_added_to_pipeline__", self) # Add component to the graph, disconnected logger.debug("Adding component '%s' (%s)", name, instance) + # We're completely sure the fields exist so we ignore the type error self.graph.add_node( - name, instance=instance, input_sockets=input_sockets, output_sockets=output_sockets, visits=0 + name, + instance=instance, + input_sockets=instance.__haystack_input__._sockets_dict, # type: ignore[attr-defined] + output_sockets=instance.__haystack_output__._sockets_dict, # type: ignore[attr-defined] + visits=0, ) def connect(self, connect_from: str, connect_to: str) -> None: @@ -381,6 +382,16 @@ class Pipeline: except KeyError as exc: raise ValueError(f"Component named {name} not found in the pipeline.") from exc + def get_component_name(self, instance: Component) -> str: + """ + Returns the name of a Component instance. If the Component has not been added to this Pipeline, + returns an empty string. + """ + for name, inst in self.graph.nodes(data="instance"): + if inst == instance: + return name + return "" + def inputs(self) -> Dict[str, Dict[str, Any]]: """ Returns a dictionary containing the inputs of a pipeline. Each key in the dictionary @@ -465,16 +476,16 @@ class Pipeline: if component_name not in self.graph.nodes: raise ValueError(f"Component named {component_name} not found in the pipeline.") instance = self.graph.nodes[component_name]["instance"] - for socket_name, socket in instance.__haystack_input__.items(): + for socket_name, socket in instance.__haystack_input__._sockets_dict.items(): if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs: raise ValueError(f"Missing input for component {component_name}: {socket_name}") for input_name in component_inputs.keys(): - if input_name not in instance.__haystack_input__: + if input_name not in instance.__haystack_input__._sockets_dict: raise ValueError(f"Input {input_name} not found in component {component_name}.") for component_name in self.graph.nodes: instance = self.graph.nodes[component_name]["instance"] - for socket_name, socket in instance.__haystack_input__.items(): + for socket_name, socket in instance.__haystack_input__._sockets_dict.items(): component_inputs = data.get(component_name, {}) if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs: raise ValueError(f"Missing input for component {component_name}: {socket_name}") @@ -518,7 +529,7 @@ class Pipeline: for component_input, input_value in component_inputs.items(): # Handle mutable input data data[component_name][component_input] = copy(input_value) - if instance.__haystack_input__[component_input].is_variadic: + if instance.__haystack_input__._sockets_dict[component_input].is_variadic: # Components that have variadic inputs need to receive lists as input. # We don't want to force the user to always pass lists, so we convert single values to lists here. # If it's already a list we assume the component takes a variadic input of lists, so we @@ -533,12 +544,12 @@ class Pipeline: for node_name in self.graph.nodes: component = self.graph.nodes[node_name]["instance"] - if len(component.__haystack_input__) == 0: + if len(component.__haystack_input__._sockets_dict) == 0: # Component has no input, can run right away to_run.append((node_name, component)) continue - for socket in component.__haystack_input__.values(): + for socket in component.__haystack_input__._sockets_dict.values(): if not socket.senders or socket.is_variadic: # Component has at least one input not connected or is variadic, can run right away. to_run.append((node_name, component)) @@ -561,12 +572,12 @@ class Pipeline: while len(to_run) > 0: name, comp = to_run.pop(0) - if any(socket.is_variadic for socket in comp.__haystack_input__.values()) and not getattr( # type: ignore + if any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) and not getattr( # type: ignore comp, "is_greedy", False ): there_are_non_variadics = False for _, other_comp in to_run: - if not any(socket.is_variadic for socket in other_comp.__haystack_input__.values()): # type: ignore + if not any(socket.is_variadic for socket in other_comp.__haystack_input__._sockets_dict.values()): # type: ignore there_are_non_variadics = True break @@ -575,7 +586,7 @@ class Pipeline: waiting_for_input.append((name, comp)) continue - if name in last_inputs and len(comp.__haystack_input__) == len(last_inputs[name]): # type: ignore + if name in last_inputs and len(comp.__haystack_input__._sockets_dict) == len(last_inputs[name]): # type: ignore # This component has all the inputs it needs to run res = comp.run(**last_inputs[name]) @@ -649,7 +660,7 @@ class Pipeline: # This is our last resort, if there's no lazy variadic waiting for input # we're stuck for real and we can't make any progress. for name, comp in waiting_for_input: - is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__.values()) # type: ignore + is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore if is_variadic and not getattr(comp, "is_greedy", False): break else: @@ -680,14 +691,14 @@ class Pipeline: last_inputs[name] = {} # Lazy variadics must be removed only if there's nothing else to run at this stage - is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__.values()) # type: ignore + is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore if is_variadic and not getattr(comp, "is_greedy", False): there_are_only_lazy_variadics = True for other_name, other_comp in waiting_for_input: if name == other_name: continue there_are_only_lazy_variadics &= any( - socket.is_variadic for socket in other_comp.__haystack_input__.values() # type: ignore + socket.is_variadic for socket in other_comp.__haystack_input__._sockets_dict.values() # type: ignore ) and not getattr(other_comp, "is_greedy", False) if not there_are_only_lazy_variadics: @@ -695,7 +706,7 @@ class Pipeline: # Find the first component that has all the inputs it needs to run has_enough_inputs = True - for input_socket in comp.__haystack_input__.values(): # type: ignore + for input_socket in comp.__haystack_input__._sockets_dict.values(): # type: ignore if input_socket.is_mandatory and input_socket.name not in last_inputs[name]: has_enough_inputs = False break diff --git a/haystack/testing/sample_components/repeat.py b/haystack/testing/sample_components/repeat.py index 73e25097a..a1f628798 100644 --- a/haystack/testing/sample_components/repeat.py +++ b/haystack/testing/sample_components/repeat.py @@ -9,11 +9,11 @@ from haystack.core.component import component @component class Repeat: def __init__(self, outputs: List[str]): - self.outputs = outputs + self._outputs = outputs component.set_output_types(self, **{k: int for k in outputs}) def run(self, value: int): """ :param value: the value to repeat. """ - return {val: value for val in self.outputs} + return {val: value for val in self._outputs} diff --git a/test/components/builders/test_dynamic_chat_prompt_builder.py b/test/components/builders/test_dynamic_chat_prompt_builder.py index 67183d028..d75e0051c 100644 --- a/test/components/builders/test_dynamic_chat_prompt_builder.py +++ b/test/components/builders/test_dynamic_chat_prompt_builder.py @@ -15,16 +15,16 @@ class TestDynamicChatPromptBuilder: # we have inputs that contain: prompt_source, template_variables + runtime_variables expected_keys = set(runtime_variables + ["prompt_source", "template_variables"]) - assert set(builder.__haystack_input__.keys()) == expected_keys + assert set(builder.__haystack_input__._sockets_dict.keys()) == expected_keys # response is always prompt regardless of chat mode - assert set(builder.__haystack_output__.keys()) == {"prompt"} + assert set(builder.__haystack_output__._sockets_dict.keys()) == {"prompt"} # prompt_source is a list of ChatMessage - assert builder.__haystack_input__["prompt_source"].type == List[ChatMessage] + assert builder.__haystack_input__._sockets_dict["prompt_source"].type == List[ChatMessage] # output is always prompt, but the type is different depending on the chat mode - assert builder.__haystack_output__["prompt"].type == List[ChatMessage] + assert builder.__haystack_output__._sockets_dict["prompt"].type == List[ChatMessage] def test_non_empty_chat_messages(self): prompt_builder = DynamicChatPromptBuilder(runtime_variables=["documents"]) diff --git a/test/components/builders/test_dynamic_prompt_builder.py b/test/components/builders/test_dynamic_prompt_builder.py index c3508b255..7afacd2ca 100644 --- a/test/components/builders/test_dynamic_prompt_builder.py +++ b/test/components/builders/test_dynamic_prompt_builder.py @@ -16,16 +16,16 @@ class TestDynamicPromptBuilder: # regardless of the chat mode # we have inputs that contain: prompt_source, template_variables + runtime_variables expected_keys = set(runtime_variables + ["prompt_source", "template_variables"]) - assert set(builder.__haystack_input__.keys()) == expected_keys + assert set(builder.__haystack_input__._sockets_dict.keys()) == expected_keys # response is always prompt regardless of chat mode - assert set(builder.__haystack_output__.keys()) == {"prompt"} + assert set(builder.__haystack_output__._sockets_dict.keys()) == {"prompt"} # prompt_source is a list of ChatMessage or a string - assert builder.__haystack_input__["prompt_source"].type == str + assert builder.__haystack_input__._sockets_dict["prompt_source"].type == str # output is always prompt, but the type is different depending on the chat mode - assert builder.__haystack_output__["prompt"].type == str + assert builder.__haystack_output__._sockets_dict["prompt"].type == str def test_processing_a_simple_template_with_provided_variables(self): runtime_variables = ["var1", "var2", "var3"] diff --git a/test/components/routers/test_conditional_router.py b/test/components/routers/test_conditional_router.py index 726d7cd4c..dcb52cff1 100644 --- a/test/components/routers/test_conditional_router.py +++ b/test/components/routers/test_conditional_router.py @@ -86,8 +86,8 @@ class TestRouter: router = ConditionalRouter(routes) assert router.routes == routes - assert set(router.__haystack_input__.keys()) == {"query", "streams"} - assert set(router.__haystack_output__.keys()) == {"query", "streams"} + assert set(router.__haystack_input__._sockets_dict.keys()) == {"query", "streams"} + assert set(router.__haystack_output__._sockets_dict.keys()) == {"query", "streams"} def test_router_evaluate_condition_expressions(self, router): # first route should be selected diff --git a/test/core/component/test_component.py b/test/core/component/test_component.py index a60ea6bae..bbe2605f0 100644 --- a/test/core/component/test_component.py +++ b/test/core/component/test_component.py @@ -1,5 +1,4 @@ -import typing -from typing import Any, Optional +from typing import Any import pytest @@ -89,6 +88,7 @@ def test_missing_run(): def test_set_input_types(): + @component class MockComponent: def __init__(self): component.set_input_types(self, value=Any) @@ -105,7 +105,7 @@ def test_set_input_types(): return {"value": 1} comp = MockComponent() - assert comp.__haystack_input__ == {"value": InputSocket("value", Any)} + assert comp.__haystack_input__._sockets_dict == {"value": InputSocket("value", Any)} assert comp.run() == {"value": 1} @@ -126,7 +126,7 @@ def test_set_output_types(): return {"value": 1} comp = MockComponent() - assert comp.__haystack_output__ == {"value": OutputSocket("value", int)} + assert comp.__haystack_output__._sockets_dict == {"value": OutputSocket("value", int)} def test_output_types_decorator_with_compatible_type(): @@ -144,7 +144,7 @@ def test_output_types_decorator_with_compatible_type(): return cls() comp = MockComponent() - assert comp.__haystack_output__ == {"value": OutputSocket("value", int)} + assert comp.__haystack_output__._sockets_dict == {"value": OutputSocket("value", int)} def test_component_decorator_set_it_as_component(): @@ -173,8 +173,8 @@ def test_input_has_default_value(): return {"value": value} comp = MockComponent() - assert comp.__haystack_input__["value"].default_value == 42 - assert not comp.__haystack_input__["value"].is_mandatory + assert comp.__haystack_input__._sockets_dict["value"].default_value == 42 + assert not comp.__haystack_input__._sockets_dict["value"].is_mandatory def test_keyword_only_args(): @@ -187,5 +187,5 @@ def test_keyword_only_args(): return {"value": arg} comp = MockComponent() - component_inputs = {name: {"type": socket.type} for name, socket in comp.__haystack_input__.items()} + component_inputs = {name: {"type": socket.type} for name, socket in comp.__haystack_input__._sockets_dict.items()} assert component_inputs == {"arg": {"type": int}} diff --git a/test/core/component/test_sockets.py b/test/core/component/test_sockets.py new file mode 100644 index 000000000..ac3b01bda --- /dev/null +++ b/test/core/component/test_sockets.py @@ -0,0 +1,57 @@ +import pytest + +from haystack.core.component.sockets import InputSocket, Sockets +from haystack.core.pipeline import Pipeline +from haystack.testing.factory import component_class + + +class TestSockets: + def test_init(self): + comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})() + sockets = {"input_1": InputSocket("input_1", int), "input_2": InputSocket("input_2", int)} + io = Sockets(component=comp, sockets_dict=sockets, sockets_io_type=InputSocket) + assert io._component == comp + assert "input_1" in io.__dict__ + assert io.__dict__["input_1"] == comp.__haystack_input__._sockets_dict["input_1"] + assert "input_2" in io.__dict__ + assert io.__dict__["input_2"] == comp.__haystack_input__._sockets_dict["input_2"] + + def test_init_with_empty_sockets(self): + comp = component_class("SomeComponent")() + io = Sockets(component=comp, sockets_dict={}, sockets_io_type=InputSocket) + + assert io._component == comp + assert io._sockets_dict == {} + + def test_component_name(self): + comp = component_class("SomeComponent")() + io = Sockets(component=comp, sockets_dict={}, sockets_io_type=InputSocket) + assert io._component_name() == str(comp) + + def test_component_name_added_to_pipeline(self): + comp = component_class("SomeComponent")() + pipeline = Pipeline() + pipeline.add_component("my_component", comp) + + io = Sockets(component=comp, sockets_dict={}, sockets_io_type=InputSocket) + assert io._component_name() == "my_component" + + def test_getattribute(self): + comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})() + io = Sockets(component=comp, sockets_dict=comp.__haystack_input__._sockets_dict, sockets_io_type=InputSocket) + + assert io.input_1 == comp.__haystack_input__._sockets_dict["input_1"] + assert io.input_2 == comp.__haystack_input__._sockets_dict["input_2"] + + def test_getattribute_non_existing_socket(self): + comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})() + io = Sockets(component=comp, sockets_dict=comp.__haystack_input__._sockets_dict, sockets_io_type=InputSocket) + + with pytest.raises(AttributeError): + io.input_3 + + def test_repr(self): + comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})() + io = Sockets(component=comp, sockets_dict=comp.__haystack_input__._sockets_dict, sockets_io_type=InputSocket) + res = repr(io) + assert res == f"{comp} inputs:\n - input_1: int\n - input_2: int" diff --git a/test/core/pipeline/test_pipeline.py b/test/core/pipeline/test_pipeline.py index 5930f5e57..45f1c883d 100644 --- a/test/core/pipeline/test_pipeline.py +++ b/test/core/pipeline/test_pipeline.py @@ -6,7 +6,7 @@ from typing import Optional import pytest -from haystack.core.component.sockets import InputSocket, OutputSocket +from haystack.core.component.types import InputSocket, OutputSocket from haystack.core.errors import PipelineError, PipelineRuntimeError from haystack.core.pipeline import Pipeline from haystack.testing.factory import component_class @@ -28,6 +28,21 @@ def test_add_component_to_different_pipelines(): second_pipe.add_component("some", some_component) +def test_get_component_name(): + pipe = Pipeline() + some_component = component_class("Some")() + pipe.add_component("some", some_component) + + assert pipe.get_component_name(some_component) == "some" + + +def test_get_component_name_not_added_to_pipeline(): + pipe = Pipeline() + some_component = component_class("Some")() + + assert pipe.get_component_name(some_component) == "" + + def test_run_with_component_that_does_not_return_dict(): BrokenComponent = component_class( "BrokenComponent", input_types={"a": int}, output_types={"b": int}, output=1 # type:ignore diff --git a/test/core/pipeline/test_validation_pipeline_io.py b/test/core/pipeline/test_validation_pipeline_io.py index 47fb4c592..f9160799f 100644 --- a/test/core/pipeline/test_validation_pipeline_io.py +++ b/test/core/pipeline/test_validation_pipeline_io.py @@ -5,8 +5,7 @@ from typing import Optional import pytest -from haystack.core.component.sockets import InputSocket, OutputSocket -from haystack.core.component.types import Variadic +from haystack.core.component.types import InputSocket, OutputSocket, Variadic from haystack.core.errors import PipelineValidationError from haystack.core.pipeline import Pipeline from haystack.core.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs