mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-27 18:06:17 +00:00
Add support for positional args in pipeline.get_config() (#2478)
* add support for positional args in pipeline.get_config() * Update Documentation & Code Style Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
7d6b3fe954
commit
509944f47d
@ -27,11 +27,6 @@ def exportable_to_yaml(init_func):
|
|||||||
# Call the actuall __init__ function with all the arguments
|
# Call the actuall __init__ function with all the arguments
|
||||||
init_func(self, *args, **kwargs)
|
init_func(self, *args, **kwargs)
|
||||||
|
|
||||||
# Warn for unnamed input params - should be rare
|
|
||||||
if args:
|
|
||||||
logger.warning(
|
|
||||||
"Unnamed __init__ parameters will not be saved to YAML if Pipeline.save_to_yaml() is called!"
|
|
||||||
)
|
|
||||||
# Create the configuration dictionary if it doesn't exist yet
|
# Create the configuration dictionary if it doesn't exist yet
|
||||||
if not self._component_config:
|
if not self._component_config:
|
||||||
self._component_config = {"params": {}, "type": type(self).__name__}
|
self._component_config = {"params": {}, "type": type(self).__name__}
|
||||||
@ -46,6 +41,14 @@ def exportable_to_yaml(init_func):
|
|||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
self._component_config["params"][k] = v
|
self._component_config["params"][k] = v
|
||||||
|
|
||||||
|
# Store unnamed input parameters in self._component_config too by inferring their names
|
||||||
|
sig = inspect.signature(init_func)
|
||||||
|
parameter_names = list(sig.parameters.keys())
|
||||||
|
# we can be sure that the first one is always "self"
|
||||||
|
arg_names = parameter_names[1 : 1 + len(args)]
|
||||||
|
for arg, arg_name in zip(args, arg_names):
|
||||||
|
self._component_config["params"][arg_name] = arg
|
||||||
|
|
||||||
return wrapper_exportable_to_yaml
|
return wrapper_exportable_to_yaml
|
||||||
|
|
||||||
|
|
||||||
|
@ -386,21 +386,17 @@ def test_get_config_custom_node_with_params():
|
|||||||
assert pipeline.get_config()["components"][0]["params"] == {"param": 10}
|
assert pipeline.get_config()["components"][0]["params"] == {"param": 10}
|
||||||
|
|
||||||
|
|
||||||
def test_get_config_custom_node_with_positional_params(caplog):
|
def test_get_config_custom_node_with_positional_params():
|
||||||
class CustomNode(MockNode):
|
class CustomNode(MockNode):
|
||||||
def __init__(self, param: int = 1):
|
def __init__(self, param: int = 1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.param = param
|
self.param = param
|
||||||
|
|
||||||
pipeline = Pipeline()
|
pipeline = Pipeline()
|
||||||
with caplog.at_level(logging.WARNING):
|
pipeline.add_node(CustomNode(10), name="custom_node", inputs=["Query"])
|
||||||
pipeline.add_node(CustomNode(10), name="custom_node", inputs=["Query"])
|
|
||||||
assert (
|
|
||||||
"Unnamed __init__ parameters will not be saved to YAML "
|
|
||||||
"if Pipeline.save_to_yaml() is called" in caplog.text
|
|
||||||
)
|
|
||||||
assert len(pipeline.get_config()["components"]) == 1
|
assert len(pipeline.get_config()["components"]) == 1
|
||||||
assert pipeline.get_config()["components"][0]["params"] == {}
|
assert pipeline.get_config()["components"][0]["params"] == {"param": 10}
|
||||||
|
|
||||||
|
|
||||||
def test_generate_code_simple_pipeline():
|
def test_generate_code_simple_pipeline():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user