fix: small improvement to pipeline v2 tests (#5153)

* add missing return

* improve test

* docstring
This commit is contained in:
ZanSara 2023-06-16 12:07:00 +02:00 committed by GitHub
parent 23a22be03c
commit f52477d31b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 4 deletions

View File

@ -2,7 +2,7 @@ from typing import List, Dict, Any, Optional, Callable
from pathlib import Path from pathlib import Path
from canals.component import ComponentInput from canals.component import ComponentInput, ComponentOutput
from canals.pipeline import ( from canals.pipeline import (
Pipeline as CanalsPipeline, Pipeline as CanalsPipeline,
PipelineError, PipelineError,
@ -55,9 +55,14 @@ class Pipeline(CanalsPipeline):
except KeyError as e: except KeyError as e:
raise NoSuchStoreError(f"No store named '{name}' is connected to this pipeline.") from e raise NoSuchStoreError(f"No store named '{name}' is connected to this pipeline.") from e
def run(self, data: Dict[str, ComponentInput], debug: bool = False): def run(self, data: Dict[str, ComponentInput], debug: bool = False) -> Dict[str, ComponentOutput]:
""" """
Wrapper on top of Canals Pipeline.run(). Adds the `stores` parameter to all nodes. Wrapper on top of Canals Pipeline.run(). Adds the `stores` parameter to all nodes.
:params data: the inputs to give to the input components of the Pipeline.
:params parameters: a dictionary with all the parameters of all the components, namespaced by component.
:params debug: whether to collect and return debug information.
:returns A dictionary with the outputs of the output components of the Pipeline.
""" """
# Get all nodes in this pipelines instance # Get all nodes in this pipelines instance
for node_name in self.graph.nodes: for node_name in self.graph.nodes:
@ -72,7 +77,7 @@ class Pipeline(CanalsPipeline):
node.defaults["stores"] = self.stores node.defaults["stores"] = self.stores
# Run the pipeline # Run the pipeline
super().run(data=data, debug=debug) return super().run(data=data, debug=debug)
def load_pipelines(path: Path, _reader: Optional[Callable[..., Any]] = None): def load_pipelines(path: Path, _reader: Optional[Callable[..., Any]] = None):

View File

@ -53,4 +53,4 @@ def test_pipeline_stores_in_params():
pipe.add_store(name="first_store", store=store_1) pipe.add_store(name="first_store", store=store_1)
pipe.add_store(name="second_store", store=store_2) pipe.add_store(name="second_store", store=store_2)
pipe.run(data={"component": MockComponent.Input(value=1)}) assert pipe.run(data={"component": MockComponent.Input(value=1)}) == {"component": MockComponent.Output(value=1)}