From 7fe0244258fefbbe9f4c8c2f1bdb4c9419b9f3d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Fern=C3=A1ndez?= <67836662+CarlosFerLo@users.noreply.github.com> Date: Mon, 10 Jun 2024 14:54:07 +0200 Subject: [PATCH] feat: add methods to remove and replace components in a pipeline (#7820) * add remove_component method plus unit tests * add docstrings * add reno * add type annotation to remove_component method Co-authored-by: Madeesh Kannan * solve bug not allowing a component to be reatached to a pipeline after being removed * Properly remove Component from Pipeline * Ignore mypy --------- Co-authored-by: Madeesh Kannan Co-authored-by: Silvano Cerza --- haystack/core/pipeline/base.py | 42 ++++++++++++ ...hod-to-pipeline-base-c7ca1aa68b0f396b.yaml | 4 ++ test/core/pipeline/test_pipeline.py | 65 +++++++++++++++++++ 3 files changed, 111 insertions(+) create mode 100644 releasenotes/notes/add-remove-method-to-pipeline-base-c7ca1aa68b0f396b.yaml diff --git a/haystack/core/pipeline/base.py b/haystack/core/pipeline/base.py index e3c139b6f..d782e204d 100644 --- a/haystack/core/pipeline/base.py +++ b/haystack/core/pipeline/base.py @@ -308,6 +308,48 @@ class PipelineBase: visits=0, ) + def remove_component(self, name: str) -> Component: + """ + Remove and returns component from the pipeline. + + Remove an existing component from the pipeline by providing its name. + All edges that connect to the component will also be deleted. + + :param name: + The name of the component to remove. + :returns: + The removed Component instance. + + :raises ValueError: + If there is no component with that name already in the Pipeline. + """ + + # Check that a component with that name is in the Pipeline + try: + instance = self.get_component(name) + except ValueError as exc: + raise ValueError( + f"There is no component named '{name}' in the pipeline. The valid component names are: ", + ", ".join(n for n in self.graph.nodes), + ) from exc + + # Delete component from the graph, deleting all its connections + self.graph.remove_node(name) + + # Reset the Component sockets' senders and receivers + input_sockets = instance.__haystack_input__._sockets_dict # type: ignore[attr-defined] + for socket in input_sockets.values(): + socket.senders = [] + + output_sockets = instance.__haystack_output__._sockets_dict # type: ignore[attr-defined] + for socket in output_sockets.values(): + socket.receivers = [] + + # Reset the Component's pipeline reference + setattr(instance, "__haystack_added_to_pipeline__", None) + + return instance + def connect(self, sender: str, receiver: str) -> "PipelineBase": """ Connects two components together. diff --git a/releasenotes/notes/add-remove-method-to-pipeline-base-c7ca1aa68b0f396b.yaml b/releasenotes/notes/add-remove-method-to-pipeline-base-c7ca1aa68b0f396b.yaml new file mode 100644 index 000000000..bfc1dc442 --- /dev/null +++ b/releasenotes/notes/add-remove-method-to-pipeline-base-c7ca1aa68b0f396b.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Added the 'remove_component' method in 'PipelineBase' to delete components and its connections. diff --git a/test/core/pipeline/test_pipeline.py b/test/core/pipeline/test_pipeline.py index b4c550a71..5433e4e4d 100644 --- a/test/core/pipeline/test_pipeline.py +++ b/test/core/pipeline/test_pipeline.py @@ -119,6 +119,71 @@ class TestPipeline: with pytest.raises(PipelineError): second_pipe.add_component("some", some_component) + def test_remove_component_raises_if_invalid_component_name(self): + pipe = Pipeline() + component = component_class("Some")() + + pipe.add_component("1", component) + + with pytest.raises(ValueError): + pipe.remove_component("2") + + def test_remove_component_removes_component_and_its_edges(self): + pipe = Pipeline() + component_1 = component_class("Type1")() + component_2 = component_class("Type2")() + component_3 = component_class("Type3")() + component_4 = component_class("Type4")() + + pipe.add_component("1", component_1) + pipe.add_component("2", component_2) + pipe.add_component("3", component_3) + pipe.add_component("4", component_4) + + pipe.connect("1", "2") + pipe.connect("2", "3") + pipe.connect("3", "4") + + pipe.remove_component("2") + + assert ["1", "3", "4"] == sorted(pipe.graph.nodes) + assert [("3", "4")] == sorted([(u, v) for (u, v) in pipe.graph.edges()]) + + def test_remove_component_allows_you_to_reuse_the_component(self): + pipe = Pipeline() + Some = component_class("Some", input_types={"in": int}, output_types={"out": int}) + + pipe.add_component("component_1", Some()) + pipe.add_component("component_2", Some()) + pipe.add_component("component_3", Some()) + pipe.connect("component_1", "component_2") + pipe.connect("component_2", "component_3") + component_2 = pipe.remove_component("component_2") + + assert component_2.__haystack_added_to_pipeline__ is None + assert component_2.__haystack_input__._sockets_dict == {"in": InputSocket(name="in", type=int, senders=[])} + assert component_2.__haystack_output__._sockets_dict == { + "out": OutputSocket(name="out", type=int, receivers=[]) + } + + pipe2 = Pipeline() + pipe2.add_component("component_4", Some()) + pipe2.add_component("component_2", component_2) + pipe2.add_component("component_5", Some()) + + pipe2.connect("component_4", "component_2") + pipe2.connect("component_2", "component_5") + assert component_2.__haystack_added_to_pipeline__ is pipe2 + assert component_2.__haystack_input__._sockets_dict == { + "in": InputSocket(name="in", type=int, senders=["component_4"]) + } + assert component_2.__haystack_output__._sockets_dict == { + "out": OutputSocket(name="out", type=int, receivers=["component_5"]) + } + + # instance = pipe2.get_component("some") + # assert instance == component + # UNIT def test_get_component_name(self): pipe = Pipeline()