Add support for positional args in pipeline.get_config() (#2478)

* add support for positional args in pipeline.get_config()

* Update Documentation & Code Style

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
tstadel 2022-05-02 14:41:07 +02:00 committed by GitHub
parent 7d6b3fe954
commit 509944f47d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 13 deletions

View File

@ -27,11 +27,6 @@ def exportable_to_yaml(init_func):
# Call the actuall __init__ function with all the arguments # Call the actuall __init__ function with all the arguments
init_func(self, *args, **kwargs) init_func(self, *args, **kwargs)
# Warn for unnamed input params - should be rare
if args:
logger.warning(
"Unnamed __init__ parameters will not be saved to YAML if Pipeline.save_to_yaml() is called!"
)
# Create the configuration dictionary if it doesn't exist yet # Create the configuration dictionary if it doesn't exist yet
if not self._component_config: if not self._component_config:
self._component_config = {"params": {}, "type": type(self).__name__} self._component_config = {"params": {}, "type": type(self).__name__}
@ -46,6 +41,14 @@ def exportable_to_yaml(init_func):
for k, v in kwargs.items(): for k, v in kwargs.items():
self._component_config["params"][k] = v self._component_config["params"][k] = v
# Store unnamed input parameters in self._component_config too by inferring their names
sig = inspect.signature(init_func)
parameter_names = list(sig.parameters.keys())
# we can be sure that the first one is always "self"
arg_names = parameter_names[1 : 1 + len(args)]
for arg, arg_name in zip(args, arg_names):
self._component_config["params"][arg_name] = arg
return wrapper_exportable_to_yaml return wrapper_exportable_to_yaml

View File

@ -386,21 +386,17 @@ def test_get_config_custom_node_with_params():
assert pipeline.get_config()["components"][0]["params"] == {"param": 10} assert pipeline.get_config()["components"][0]["params"] == {"param": 10}
def test_get_config_custom_node_with_positional_params(caplog): def test_get_config_custom_node_with_positional_params():
class CustomNode(MockNode): class CustomNode(MockNode):
def __init__(self, param: int = 1): def __init__(self, param: int = 1):
super().__init__() super().__init__()
self.param = param self.param = param
pipeline = Pipeline() pipeline = Pipeline()
with caplog.at_level(logging.WARNING):
pipeline.add_node(CustomNode(10), name="custom_node", inputs=["Query"]) pipeline.add_node(CustomNode(10), name="custom_node", inputs=["Query"])
assert (
"Unnamed __init__ parameters will not be saved to YAML "
"if Pipeline.save_to_yaml() is called" in caplog.text
)
assert len(pipeline.get_config()["components"]) == 1 assert len(pipeline.get_config()["components"]) == 1
assert pipeline.get_config()["components"][0]["params"] == {} assert pipeline.get_config()["components"][0]["params"] == {"param": 10}
def test_generate_code_simple_pipeline(): def test_generate_code_simple_pipeline():