Update preview Pipelines following Canals changes (#4821)

This commit is contained in:
Silvano Cerza 2023-05-05 19:47:32 +02:00 committed by GitHub
parent 43509c88bf
commit 705a2c025f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 24 deletions

View File

@ -1,4 +1,5 @@
from typing import List, Dict, Union, Any, Tuple, Optional, Callable
import inspect
from pathlib import Path
@ -53,25 +54,23 @@ class Pipeline(CanalsPipeline):
except KeyError as e:
raise NoSuchStoreError(f"No store named '{name}' is connected to this pipeline.") from e
def run(
self,
data: Union[Dict[str, Any], List[Tuple[str, Any]]],
parameters: Optional[Dict[str, Dict[str, Any]]] = None,
debug: bool = False,
):
def run(self, data: Union[Dict[str, Any], List[Tuple[str, Any]]], debug: bool = False):
"""
Wrapper on top of Canals Pipeline.run(). Adds the `stores` parameter to all nodes.
"""
if not parameters:
parameters = {}
for node in self.graph.nodes:
if not node in parameters.keys():
parameters[node] = {"stores": self.stores}
else:
parameters[node] = {"stores": self.stores, **parameters[node]}
# Get all nodes in this pipelines instance
for node_name in self.graph.nodes:
node = self.graph.nodes[node_name]["instance"]
# Get node inputs
input_params = inspect.signature(node.run).parameters
# If the node needs a store adds the list of stores to its default inputs
if "stores" in input_params:
if not hasattr(node, "defaults"):
setattr(node, "defaults", {})
node.defaults["stores"] = self.stores
super().run(data=data, parameters=parameters, debug=debug)
super().run(data=data, debug=debug)
def load_pipelines(path: Path, _reader: Optional[Callable[..., Any]] = None):

View File

@ -1,4 +1,5 @@
from typing import Dict, Any
from dataclasses import dataclass
import pytest
@ -33,16 +34,13 @@ def test_pipeline_stores_in_params():
@component
class MockComponent:
def __init__(self):
self.inputs = ["value"]
self.outputs = ["value"]
self.init_parameters = {}
@dataclass
class Output:
value: int
def run(self, name: str, data: Dict[str, Any], parameters: Dict[str, Dict[str, Any]]):
assert name in parameters.keys()
assert "stores" in parameters[name].keys()
assert parameters[name]["stores"] == {"first_store": store_1, "second_store": store_2}
return ({"value": None}, parameters or {})
def run(self, value: int, stores: Dict[str, Any]) -> Output:
assert stores == {"first_store": store_1, "second_store": store_2}
return MockComponent.Output(value=value)
pipe = Pipeline()
pipe.add_component("component", MockComponent())
@ -50,4 +48,4 @@ def test_pipeline_stores_in_params():
pipe.add_store(name="first_store", store=store_1)
pipe.add_store(name="second_store", store=store_2)
pipe.run(data={"value": None})
pipe.run(data={"component": {"value": None}})