diff --git a/haystack/nodes/base.py b/haystack/nodes/base.py index fc3a0aea2..636c3e5d5 100644 --- a/haystack/nodes/base.py +++ b/haystack/nodes/base.py @@ -37,7 +37,10 @@ def exportable_to_yaml(init_func): self._component_config = {"params": {}, "type": type(self).__name__} # Make sure it runs only on the __init__of the implementations, not in superclasses - if init_func.__qualname__ == f"{self.__class__.__name__}.{init_func.__name__}": + # NOTE: we use '.endswith' because inner classes's __qualname__ will include the parent class' + # name, like: ParentClass.InnerClass.__init__. + # Inner classes are heavily used in tests. + if init_func.__qualname__.endswith(f"{self.__class__.__name__}.{init_func.__name__}"): # Store all the named input parameters in self._component_config for k, v in kwargs.items(): diff --git a/haystack/utils/import_utils.py b/haystack/utils/import_utils.py index 0152365d2..9c9d4bbac 100644 --- a/haystack/utils/import_utils.py +++ b/haystack/utils/import_utils.py @@ -54,7 +54,8 @@ def _optional_component_not_installed(component: str, dep_group: str, source_err f"Failed to import '{component}', " "which is an optional component in Haystack.\n" f"Run 'pip install 'farm-haystack[{dep_group}]'' " - "to install the required dependencies and make this component available." + "to install the required dependencies and make this component available.\n" + f"(Original error: {str(source_error)})" ) from source_error diff --git a/test/test_pipeline.py b/test/test_pipeline.py index 30b36cafe..28f7d2dd3 100644 --- a/test/test_pipeline.py +++ b/test/test_pipeline.py @@ -35,6 +35,7 @@ from .conftest import ( SAMPLES_PATH, MockDocumentStore, MockRetriever, + MockNode, deepset_cloud_fixture, ) @@ -350,6 +351,36 @@ def test_get_config_component_with_superclass_arguments(): assert pipeline.get_document_store().base_parameter == "something" +def test_get_config_custom_node_with_params(): + class CustomNode(MockNode): + def __init__(self, param: int): + super().__init__() + self.param = param + + pipeline = Pipeline() + pipeline.add_node(CustomNode(param=10), name="custom_node", inputs=["Query"]) + + assert len(pipeline.get_config()["components"]) == 1 + assert pipeline.get_config()["components"][0]["params"] == {"param": 10} + + +def test_get_config_custom_node_with_positional_params(caplog): + class CustomNode(MockNode): + def __init__(self, param: int = 1): + super().__init__() + self.param = param + + pipeline = Pipeline() + with caplog.at_level(logging.WARNING): + pipeline.add_node(CustomNode(10), name="custom_node", inputs=["Query"]) + assert ( + "Unnamed __init__ parameters will not be saved to YAML " + "if Pipeline.save_to_yaml() is called" in caplog.text + ) + assert len(pipeline.get_config()["components"]) == 1 + assert pipeline.get_config()["components"][0]["params"] == {} + + def test_generate_code_simple_pipeline(): config = { "version": "unstable", diff --git a/test/test_pipeline_yaml.py b/test/test_pipeline_yaml.py index 534262663..105fdc1e4 100644 --- a/test/test_pipeline_yaml.py +++ b/test/test_pipeline_yaml.py @@ -1,6 +1,8 @@ from abc import abstractmethod +from numpy import mat import pytest import json +import logging import inspect import networkx as nx from enum import Enum @@ -11,6 +13,7 @@ from haystack import Pipeline from haystack.nodes import _json_schema from haystack.nodes import FileTypeClassifier from haystack.errors import HaystackError, PipelineConfigError, PipelineSchemaError +from haystack.nodes.base import BaseComponent from .conftest import SAMPLES_PATH, MockNode, MockDocumentStore, MockReader, MockRetriever from . import conftest @@ -41,7 +44,9 @@ def mock_json_schema(request, monkeypatch, tmp_path): # Generate mock schema in tmp_path filename = f"haystack-pipeline-unstable.schema.json" - test_schema = _json_schema.get_json_schema(filename=filename, compatible_versions=["unstable"]) + test_schema = _json_schema.get_json_schema( + filename=filename, compatible_versions=["unstable", haystack.__version__] + ) with open(tmp_path / filename, "w") as schema_file: json.dump(test_schema, schema_file, indent=4) @@ -274,11 +279,65 @@ def test_load_yaml_custom_component(tmp_path): assert pipeline.get_node("custom_node").param == 1 +def test_load_yaml_custom_component_with_no_init(tmp_path): + class CustomNode(MockNode): + pass + + with open(tmp_path / "tmp_config.yml", "w") as tmp_file: + tmp_file.write( + f""" + version: unstable + components: + - name: custom_node + type: CustomNode + pipelines: + - name: my_pipeline + nodes: + - name: custom_node + inputs: + - Query + """ + ) + pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml") + assert isinstance(pipeline.get_node("custom_node"), CustomNode) + + +def test_load_yaml_custom_component_neednt_call_super(tmp_path): + """This is a side-effect. Here for behavior documentation only""" + + class CustomNode(BaseComponent): + outgoing_edges = 1 + + def __init__(self, param: int): + self.param = param + + def run(self, *a, **k): + pass + + with open(tmp_path / "tmp_config.yml", "w") as tmp_file: + tmp_file.write( + f""" + version: unstable + components: + - name: custom_node + type: CustomNode + params: + param: 1 + pipelines: + - name: my_pipeline + nodes: + - name: custom_node + inputs: + - Query + """ + ) + pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml") + assert isinstance(pipeline.get_node("custom_node"), CustomNode) + assert pipeline.get_node("custom_node").param == 1 + + def test_load_yaml_custom_component_cant_be_abstract(tmp_path): class CustomNode(MockNode): - def __init__(self): - super().__init__() - @abstractmethod def abstract_method(self): pass @@ -575,7 +634,7 @@ def test_load_yaml_custom_component_with_external_constant(tmp_path): ) pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml") node = pipeline.get_node("custom_node") - node.some_exotic_parameter == "AnotherClass.CLASS_CONSTANT" + assert node.some_exotic_parameter == "AnotherClass.CLASS_CONSTANT" def test_load_yaml_custom_component_with_superclass(tmp_path):