diff --git a/haystack/core/component/component.py b/haystack/core/component/component.py index 95ae87d73..c231e41df 100644 --- a/haystack/core/component/component.py +++ b/haystack/core/component/component.py @@ -160,6 +160,21 @@ class ComponentMeta(type): return instance +def _component_repr(component: Component) -> str: + """ + All Components override their __repr__ method with this one. + It prints the component name and the input/output sockets. + """ + result = object.__repr__(component) + if pipeline := getattr(component, "__haystack_added_to_pipeline__"): + # This Component has been added in a Pipeline, let's get the name from there. + result += f"\n{pipeline.get_component_name(component)}" + + # We're explicitly ignoring the type here because we're sure that the component + # has the __haystack_input__ and __haystack_output__ attributes at this point + return f"{result}\n{component.__haystack_input__}\n{component.__haystack_output__}" # type: ignore[attr-defined] + + class _Component: """ See module's docstring. @@ -332,6 +347,9 @@ class _Component: self.registry[class_path] = class_ logger.debug("Registered Component %s", class_) + # Override the __repr__ method with a default one + class_.__repr__ = _component_repr + return class_ def __call__(self, class_): diff --git a/haystack/core/component/sockets.py b/haystack/core/component/sockets.py index 25bf4fdc8..374ae6303 100644 --- a/haystack/core/component/sockets.py +++ b/haystack/core/component/sockets.py @@ -82,8 +82,9 @@ class Sockets: return pipeline.get_component_name(self._component) # This Component has not been added to a Pipeline yet, so we can't know its name. - # Let's use the class name instead. - return str(self._component) + # Let's use default __repr__. We don't call repr() directly as Components have a custom + # __repr__ method and that would lead to infinite recursion since we call Sockets.__repr__ in it. + return object.__repr__(self._component) def __getattribute__(self, name): try: @@ -96,12 +97,10 @@ class Sockets: return object.__getattribute__(self, name) def __repr__(self) -> str: - result = self._component_name() + result = "" if self._sockets_io_type == InputSocket: - result += " inputs:\n" + result = "Inputs:\n" elif self._sockets_io_type == OutputSocket: - result += " outputs:\n" + result = "Outputs:\n" - result += "\n".join([f" - {n}: {_type_name(s.type)}" for n, s in self._sockets_dict.items()]) - - return result + return result + "\n".join([f" - {n}: {_type_name(s.type)}" for n, s in self._sockets_dict.items()]) diff --git a/releasenotes/notes/component-repr-a6486af81530bc3b.yaml b/releasenotes/notes/component-repr-a6486af81530bc3b.yaml new file mode 100644 index 000000000..3a7439e92 --- /dev/null +++ b/releasenotes/notes/component-repr-a6486af81530bc3b.yaml @@ -0,0 +1,6 @@ +--- +enhancements: + - | + Add `__repr__` to all Components to print their I/O. + This can also be useful in Jupyter notebooks as this will be shown as a cell output + if the it's the last expression in a cell. diff --git a/test/core/component/test_component.py b/test/core/component/test_component.py index bbe2605f0..b093c32b8 100644 --- a/test/core/component/test_component.py +++ b/test/core/component/test_component.py @@ -4,6 +4,7 @@ import pytest from haystack.core.component import Component, InputSocket, OutputSocket, component from haystack.core.errors import ComponentError +from haystack.core.pipeline import Pipeline def test_correct_declaration(): @@ -189,3 +190,31 @@ def test_keyword_only_args(): comp = MockComponent() component_inputs = {name: {"type": socket.type} for name, socket in comp.__haystack_input__._sockets_dict.items()} assert component_inputs == {"arg": {"type": int}} + + +def test_repr(): + @component + class MockComponent: + def __init__(self): + component.set_output_types(self, value=int) + + def run(self, value: int): + return {"value": value} + + comp = MockComponent() + assert repr(comp) == f"{object.__repr__(comp)}\nInputs:\n - value: int\nOutputs:\n - value: int" + + +def test_repr_added_to_pipeline(): + @component + class MockComponent: + def __init__(self): + component.set_output_types(self, value=int) + + def run(self, value: int): + return {"value": value} + + pipe = Pipeline() + comp = MockComponent() + pipe.add_component("my_component", comp) + assert repr(comp) == f"{object.__repr__(comp)}\nmy_component\nInputs:\n - value: int\nOutputs:\n - value: int" diff --git a/test/core/component/test_sockets.py b/test/core/component/test_sockets.py index ac3b01bda..6e942b84f 100644 --- a/test/core/component/test_sockets.py +++ b/test/core/component/test_sockets.py @@ -23,19 +23,6 @@ class TestSockets: assert io._component == comp assert io._sockets_dict == {} - def test_component_name(self): - comp = component_class("SomeComponent")() - io = Sockets(component=comp, sockets_dict={}, sockets_io_type=InputSocket) - assert io._component_name() == str(comp) - - def test_component_name_added_to_pipeline(self): - comp = component_class("SomeComponent")() - pipeline = Pipeline() - pipeline.add_component("my_component", comp) - - io = Sockets(component=comp, sockets_dict={}, sockets_io_type=InputSocket) - assert io._component_name() == "my_component" - def test_getattribute(self): comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})() io = Sockets(component=comp, sockets_dict=comp.__haystack_input__._sockets_dict, sockets_io_type=InputSocket) @@ -54,4 +41,4 @@ class TestSockets: comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})() io = Sockets(component=comp, sockets_dict=comp.__haystack_input__._sockets_dict, sockets_io_type=InputSocket) res = repr(io) - assert res == f"{comp} inputs:\n - input_1: int\n - input_2: int" + assert res == "Inputs:\n - input_1: int\n - input_2: int"