mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-27 01:46:33 +00:00
Add support for positional args in pipeline.get_config() (#2478)
* add support for positional args in pipeline.get_config() * Update Documentation & Code Style Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
7d6b3fe954
commit
509944f47d
@ -27,11 +27,6 @@ def exportable_to_yaml(init_func):
|
||||
# Call the actuall __init__ function with all the arguments
|
||||
init_func(self, *args, **kwargs)
|
||||
|
||||
# Warn for unnamed input params - should be rare
|
||||
if args:
|
||||
logger.warning(
|
||||
"Unnamed __init__ parameters will not be saved to YAML if Pipeline.save_to_yaml() is called!"
|
||||
)
|
||||
# Create the configuration dictionary if it doesn't exist yet
|
||||
if not self._component_config:
|
||||
self._component_config = {"params": {}, "type": type(self).__name__}
|
||||
@ -46,6 +41,14 @@ def exportable_to_yaml(init_func):
|
||||
for k, v in kwargs.items():
|
||||
self._component_config["params"][k] = v
|
||||
|
||||
# Store unnamed input parameters in self._component_config too by inferring their names
|
||||
sig = inspect.signature(init_func)
|
||||
parameter_names = list(sig.parameters.keys())
|
||||
# we can be sure that the first one is always "self"
|
||||
arg_names = parameter_names[1 : 1 + len(args)]
|
||||
for arg, arg_name in zip(args, arg_names):
|
||||
self._component_config["params"][arg_name] = arg
|
||||
|
||||
return wrapper_exportable_to_yaml
|
||||
|
||||
|
||||
|
@ -386,21 +386,17 @@ def test_get_config_custom_node_with_params():
|
||||
assert pipeline.get_config()["components"][0]["params"] == {"param": 10}
|
||||
|
||||
|
||||
def test_get_config_custom_node_with_positional_params(caplog):
|
||||
def test_get_config_custom_node_with_positional_params():
|
||||
class CustomNode(MockNode):
|
||||
def __init__(self, param: int = 1):
|
||||
super().__init__()
|
||||
self.param = param
|
||||
|
||||
pipeline = Pipeline()
|
||||
with caplog.at_level(logging.WARNING):
|
||||
pipeline.add_node(CustomNode(10), name="custom_node", inputs=["Query"])
|
||||
assert (
|
||||
"Unnamed __init__ parameters will not be saved to YAML "
|
||||
"if Pipeline.save_to_yaml() is called" in caplog.text
|
||||
)
|
||||
|
||||
assert len(pipeline.get_config()["components"]) == 1
|
||||
assert pipeline.get_config()["components"][0]["params"] == {}
|
||||
assert pipeline.get_config()["components"][0]["params"] == {"param": 10}
|
||||
|
||||
|
||||
def test_generate_code_simple_pipeline():
|
||||
|
Loading…
x
Reference in New Issue
Block a user