feat: add methods to remove and replace components in a pipeline (#7820)

* add remove_component method plus unit tests

* add docstrings

* add reno

* add type annotation to remove_component method

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>

* solve bug not allowing a component to be reatached to a pipeline after being removed

* Properly remove Component from Pipeline

* Ignore mypy

---------

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
Co-authored-by: Silvano Cerza <silvanocerza@gmail.com>
This commit is contained in:
Carlos Fernández 2024-06-10 14:54:07 +02:00 committed by GitHub
parent 639ee598fd
commit 7fe0244258
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 111 additions and 0 deletions

View File

@ -308,6 +308,48 @@ class PipelineBase:
visits=0,
)
def remove_component(self, name: str) -> Component:
"""
Remove and returns component from the pipeline.
Remove an existing component from the pipeline by providing its name.
All edges that connect to the component will also be deleted.
:param name:
The name of the component to remove.
:returns:
The removed Component instance.
:raises ValueError:
If there is no component with that name already in the Pipeline.
"""
# Check that a component with that name is in the Pipeline
try:
instance = self.get_component(name)
except ValueError as exc:
raise ValueError(
f"There is no component named '{name}' in the pipeline. The valid component names are: ",
", ".join(n for n in self.graph.nodes),
) from exc
# Delete component from the graph, deleting all its connections
self.graph.remove_node(name)
# Reset the Component sockets' senders and receivers
input_sockets = instance.__haystack_input__._sockets_dict # type: ignore[attr-defined]
for socket in input_sockets.values():
socket.senders = []
output_sockets = instance.__haystack_output__._sockets_dict # type: ignore[attr-defined]
for socket in output_sockets.values():
socket.receivers = []
# Reset the Component's pipeline reference
setattr(instance, "__haystack_added_to_pipeline__", None)
return instance
def connect(self, sender: str, receiver: str) -> "PipelineBase":
"""
Connects two components together.

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
Added the 'remove_component' method in 'PipelineBase' to delete components and its connections.

View File

@ -119,6 +119,71 @@ class TestPipeline:
with pytest.raises(PipelineError):
second_pipe.add_component("some", some_component)
def test_remove_component_raises_if_invalid_component_name(self):
pipe = Pipeline()
component = component_class("Some")()
pipe.add_component("1", component)
with pytest.raises(ValueError):
pipe.remove_component("2")
def test_remove_component_removes_component_and_its_edges(self):
pipe = Pipeline()
component_1 = component_class("Type1")()
component_2 = component_class("Type2")()
component_3 = component_class("Type3")()
component_4 = component_class("Type4")()
pipe.add_component("1", component_1)
pipe.add_component("2", component_2)
pipe.add_component("3", component_3)
pipe.add_component("4", component_4)
pipe.connect("1", "2")
pipe.connect("2", "3")
pipe.connect("3", "4")
pipe.remove_component("2")
assert ["1", "3", "4"] == sorted(pipe.graph.nodes)
assert [("3", "4")] == sorted([(u, v) for (u, v) in pipe.graph.edges()])
def test_remove_component_allows_you_to_reuse_the_component(self):
pipe = Pipeline()
Some = component_class("Some", input_types={"in": int}, output_types={"out": int})
pipe.add_component("component_1", Some())
pipe.add_component("component_2", Some())
pipe.add_component("component_3", Some())
pipe.connect("component_1", "component_2")
pipe.connect("component_2", "component_3")
component_2 = pipe.remove_component("component_2")
assert component_2.__haystack_added_to_pipeline__ is None
assert component_2.__haystack_input__._sockets_dict == {"in": InputSocket(name="in", type=int, senders=[])}
assert component_2.__haystack_output__._sockets_dict == {
"out": OutputSocket(name="out", type=int, receivers=[])
}
pipe2 = Pipeline()
pipe2.add_component("component_4", Some())
pipe2.add_component("component_2", component_2)
pipe2.add_component("component_5", Some())
pipe2.connect("component_4", "component_2")
pipe2.connect("component_2", "component_5")
assert component_2.__haystack_added_to_pipeline__ is pipe2
assert component_2.__haystack_input__._sockets_dict == {
"in": InputSocket(name="in", type=int, senders=["component_4"])
}
assert component_2.__haystack_output__._sockets_dict == {
"out": OutputSocket(name="out", type=int, receivers=["component_5"])
}
# instance = pipe2.get_component("some")
# assert instance == component
# UNIT
def test_get_component_name(self):
pipe = Pipeline()