feat: Add __repr__ method to all Components (#6927)

* Add __repr__ to show Component I/O

* Add release notes

* Change Component repr to show full module path and name in Pipeline

* Fix linting
This commit is contained in:
Silvano Cerza 2024-02-08 11:46:10 +01:00 committed by GitHub
parent 74683fe74d
commit 2f965fb176
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 61 additions and 22 deletions

View File

@ -160,6 +160,21 @@ class ComponentMeta(type):
return instance
def _component_repr(component: Component) -> str:
"""
All Components override their __repr__ method with this one.
It prints the component name and the input/output sockets.
"""
result = object.__repr__(component)
if pipeline := getattr(component, "__haystack_added_to_pipeline__"):
# This Component has been added in a Pipeline, let's get the name from there.
result += f"\n{pipeline.get_component_name(component)}"
# We're explicitly ignoring the type here because we're sure that the component
# has the __haystack_input__ and __haystack_output__ attributes at this point
return f"{result}\n{component.__haystack_input__}\n{component.__haystack_output__}" # type: ignore[attr-defined]
class _Component:
"""
See module's docstring.
@ -332,6 +347,9 @@ class _Component:
self.registry[class_path] = class_
logger.debug("Registered Component %s", class_)
# Override the __repr__ method with a default one
class_.__repr__ = _component_repr
return class_
def __call__(self, class_):

View File

@ -82,8 +82,9 @@ class Sockets:
return pipeline.get_component_name(self._component)
# This Component has not been added to a Pipeline yet, so we can't know its name.
# Let's use the class name instead.
return str(self._component)
# Let's use default __repr__. We don't call repr() directly as Components have a custom
# __repr__ method and that would lead to infinite recursion since we call Sockets.__repr__ in it.
return object.__repr__(self._component)
def __getattribute__(self, name):
try:
@ -96,12 +97,10 @@ class Sockets:
return object.__getattribute__(self, name)
def __repr__(self) -> str:
result = self._component_name()
result = ""
if self._sockets_io_type == InputSocket:
result += " inputs:\n"
result = "Inputs:\n"
elif self._sockets_io_type == OutputSocket:
result += " outputs:\n"
result = "Outputs:\n"
result += "\n".join([f" - {n}: {_type_name(s.type)}" for n, s in self._sockets_dict.items()])
return result
return result + "\n".join([f" - {n}: {_type_name(s.type)}" for n, s in self._sockets_dict.items()])

View File

@ -0,0 +1,6 @@
---
enhancements:
- |
Add `__repr__` to all Components to print their I/O.
This can also be useful in Jupyter notebooks as this will be shown as a cell output
if the it's the last expression in a cell.

View File

@ -4,6 +4,7 @@ import pytest
from haystack.core.component import Component, InputSocket, OutputSocket, component
from haystack.core.errors import ComponentError
from haystack.core.pipeline import Pipeline
def test_correct_declaration():
@ -189,3 +190,31 @@ def test_keyword_only_args():
comp = MockComponent()
component_inputs = {name: {"type": socket.type} for name, socket in comp.__haystack_input__._sockets_dict.items()}
assert component_inputs == {"arg": {"type": int}}
def test_repr():
@component
class MockComponent:
def __init__(self):
component.set_output_types(self, value=int)
def run(self, value: int):
return {"value": value}
comp = MockComponent()
assert repr(comp) == f"{object.__repr__(comp)}\nInputs:\n - value: int\nOutputs:\n - value: int"
def test_repr_added_to_pipeline():
@component
class MockComponent:
def __init__(self):
component.set_output_types(self, value=int)
def run(self, value: int):
return {"value": value}
pipe = Pipeline()
comp = MockComponent()
pipe.add_component("my_component", comp)
assert repr(comp) == f"{object.__repr__(comp)}\nmy_component\nInputs:\n - value: int\nOutputs:\n - value: int"

View File

@ -23,19 +23,6 @@ class TestSockets:
assert io._component == comp
assert io._sockets_dict == {}
def test_component_name(self):
comp = component_class("SomeComponent")()
io = Sockets(component=comp, sockets_dict={}, sockets_io_type=InputSocket)
assert io._component_name() == str(comp)
def test_component_name_added_to_pipeline(self):
comp = component_class("SomeComponent")()
pipeline = Pipeline()
pipeline.add_component("my_component", comp)
io = Sockets(component=comp, sockets_dict={}, sockets_io_type=InputSocket)
assert io._component_name() == "my_component"
def test_getattribute(self):
comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})()
io = Sockets(component=comp, sockets_dict=comp.__haystack_input__._sockets_dict, sockets_io_type=InputSocket)
@ -54,4 +41,4 @@ class TestSockets:
comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})()
io = Sockets(component=comp, sockets_dict=comp.__haystack_input__._sockets_dict, sockets_io_type=InputSocket)
res = repr(io)
assert res == f"{comp} inputs:\n - input_1: int\n - input_2: int"
assert res == "Inputs:\n - input_1: int\n - input_2: int"