From 509944f47da103ac2cb09cf4f3bcf6552ba35261 Mon Sep 17 00:00:00 2001 From: tstadel <60758086+tstadel@users.noreply.github.com> Date: Mon, 2 May 2022 14:41:07 +0200 Subject: [PATCH] Add support for positional args in pipeline.get_config() (#2478) * add support for positional args in pipeline.get_config() * Update Documentation & Code Style Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- haystack/nodes/base.py | 13 ++++++++----- test/test_pipeline.py | 12 ++++-------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/haystack/nodes/base.py b/haystack/nodes/base.py index 636c3e5d5..d2418afe4 100644 --- a/haystack/nodes/base.py +++ b/haystack/nodes/base.py @@ -27,11 +27,6 @@ def exportable_to_yaml(init_func): # Call the actuall __init__ function with all the arguments init_func(self, *args, **kwargs) - # Warn for unnamed input params - should be rare - if args: - logger.warning( - "Unnamed __init__ parameters will not be saved to YAML if Pipeline.save_to_yaml() is called!" - ) # Create the configuration dictionary if it doesn't exist yet if not self._component_config: self._component_config = {"params": {}, "type": type(self).__name__} @@ -46,6 +41,14 @@ def exportable_to_yaml(init_func): for k, v in kwargs.items(): self._component_config["params"][k] = v + # Store unnamed input parameters in self._component_config too by inferring their names + sig = inspect.signature(init_func) + parameter_names = list(sig.parameters.keys()) + # we can be sure that the first one is always "self" + arg_names = parameter_names[1 : 1 + len(args)] + for arg, arg_name in zip(args, arg_names): + self._component_config["params"][arg_name] = arg + return wrapper_exportable_to_yaml diff --git a/test/test_pipeline.py b/test/test_pipeline.py index 6292b4e90..483eb4977 100644 --- a/test/test_pipeline.py +++ b/test/test_pipeline.py @@ -386,21 +386,17 @@ def test_get_config_custom_node_with_params(): assert pipeline.get_config()["components"][0]["params"] == {"param": 10} -def test_get_config_custom_node_with_positional_params(caplog): +def test_get_config_custom_node_with_positional_params(): class CustomNode(MockNode): def __init__(self, param: int = 1): super().__init__() self.param = param pipeline = Pipeline() - with caplog.at_level(logging.WARNING): - pipeline.add_node(CustomNode(10), name="custom_node", inputs=["Query"]) - assert ( - "Unnamed __init__ parameters will not be saved to YAML " - "if Pipeline.save_to_yaml() is called" in caplog.text - ) + pipeline.add_node(CustomNode(10), name="custom_node", inputs=["Query"]) + assert len(pipeline.get_config()["components"]) == 1 - assert pipeline.get_config()["components"][0]["params"] == {} + assert pipeline.get_config()["components"][0]["params"] == {"param": 10} def test_generate_code_simple_pipeline():