mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-11-04 03:39:31 +00:00 
			
		
		
		
	Add tests for missing __init__ and super().__init__() in custom nodes (#2350)
				
					
				
			* Add tests for missing init and super * Update Documentation & Code Style * change in with endswith * Move test in pipeline.py and change test in pipeline_yaml.py * Update Documentation & Code Style * Use caplog to test the warning * Update Documentation & Code Style * move tests into test_pipeline and use get_config * Update Documentation & Code Style * Unmock version name * Update Documentation & Code Style Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									73f9ab0f57
								
							
						
					
					
						commit
						d98883b79d
					
				@ -37,7 +37,10 @@ def exportable_to_yaml(init_func):
 | 
			
		||||
            self._component_config = {"params": {}, "type": type(self).__name__}
 | 
			
		||||
 | 
			
		||||
        # Make sure it runs only on the __init__of the implementations, not in superclasses
 | 
			
		||||
        if init_func.__qualname__ == f"{self.__class__.__name__}.{init_func.__name__}":
 | 
			
		||||
        # NOTE: we use '.endswith' because inner classes's __qualname__ will include the parent class'
 | 
			
		||||
        #   name, like: ParentClass.InnerClass.__init__.
 | 
			
		||||
        #   Inner classes are heavily used in tests.
 | 
			
		||||
        if init_func.__qualname__.endswith(f"{self.__class__.__name__}.{init_func.__name__}"):
 | 
			
		||||
 | 
			
		||||
            # Store all the named input parameters in self._component_config
 | 
			
		||||
            for k, v in kwargs.items():
 | 
			
		||||
 | 
			
		||||
@ -54,7 +54,8 @@ def _optional_component_not_installed(component: str, dep_group: str, source_err
 | 
			
		||||
        f"Failed to import '{component}', "
 | 
			
		||||
        "which is an optional component in Haystack.\n"
 | 
			
		||||
        f"Run 'pip install 'farm-haystack[{dep_group}]'' "
 | 
			
		||||
        "to install the required dependencies and make this component available."
 | 
			
		||||
        "to install the required dependencies and make this component available.\n"
 | 
			
		||||
        f"(Original error: {str(source_error)})"
 | 
			
		||||
    ) from source_error
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -35,6 +35,7 @@ from .conftest import (
 | 
			
		||||
    SAMPLES_PATH,
 | 
			
		||||
    MockDocumentStore,
 | 
			
		||||
    MockRetriever,
 | 
			
		||||
    MockNode,
 | 
			
		||||
    deepset_cloud_fixture,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -350,6 +351,36 @@ def test_get_config_component_with_superclass_arguments():
 | 
			
		||||
    assert pipeline.get_document_store().base_parameter == "something"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_get_config_custom_node_with_params():
 | 
			
		||||
    class CustomNode(MockNode):
 | 
			
		||||
        def __init__(self, param: int):
 | 
			
		||||
            super().__init__()
 | 
			
		||||
            self.param = param
 | 
			
		||||
 | 
			
		||||
    pipeline = Pipeline()
 | 
			
		||||
    pipeline.add_node(CustomNode(param=10), name="custom_node", inputs=["Query"])
 | 
			
		||||
 | 
			
		||||
    assert len(pipeline.get_config()["components"]) == 1
 | 
			
		||||
    assert pipeline.get_config()["components"][0]["params"] == {"param": 10}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_get_config_custom_node_with_positional_params(caplog):
 | 
			
		||||
    class CustomNode(MockNode):
 | 
			
		||||
        def __init__(self, param: int = 1):
 | 
			
		||||
            super().__init__()
 | 
			
		||||
            self.param = param
 | 
			
		||||
 | 
			
		||||
    pipeline = Pipeline()
 | 
			
		||||
    with caplog.at_level(logging.WARNING):
 | 
			
		||||
        pipeline.add_node(CustomNode(10), name="custom_node", inputs=["Query"])
 | 
			
		||||
        assert (
 | 
			
		||||
            "Unnamed __init__ parameters will not be saved to YAML "
 | 
			
		||||
            "if Pipeline.save_to_yaml() is called" in caplog.text
 | 
			
		||||
        )
 | 
			
		||||
    assert len(pipeline.get_config()["components"]) == 1
 | 
			
		||||
    assert pipeline.get_config()["components"][0]["params"] == {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_generate_code_simple_pipeline():
 | 
			
		||||
    config = {
 | 
			
		||||
        "version": "unstable",
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,8 @@
 | 
			
		||||
from abc import abstractmethod
 | 
			
		||||
from numpy import mat
 | 
			
		||||
import pytest
 | 
			
		||||
import json
 | 
			
		||||
import logging
 | 
			
		||||
import inspect
 | 
			
		||||
import networkx as nx
 | 
			
		||||
from enum import Enum
 | 
			
		||||
@ -11,6 +13,7 @@ from haystack import Pipeline
 | 
			
		||||
from haystack.nodes import _json_schema
 | 
			
		||||
from haystack.nodes import FileTypeClassifier
 | 
			
		||||
from haystack.errors import HaystackError, PipelineConfigError, PipelineSchemaError
 | 
			
		||||
from haystack.nodes.base import BaseComponent
 | 
			
		||||
 | 
			
		||||
from .conftest import SAMPLES_PATH, MockNode, MockDocumentStore, MockReader, MockRetriever
 | 
			
		||||
from . import conftest
 | 
			
		||||
@ -41,7 +44,9 @@ def mock_json_schema(request, monkeypatch, tmp_path):
 | 
			
		||||
 | 
			
		||||
    # Generate mock schema in tmp_path
 | 
			
		||||
    filename = f"haystack-pipeline-unstable.schema.json"
 | 
			
		||||
    test_schema = _json_schema.get_json_schema(filename=filename, compatible_versions=["unstable"])
 | 
			
		||||
    test_schema = _json_schema.get_json_schema(
 | 
			
		||||
        filename=filename, compatible_versions=["unstable", haystack.__version__]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    with open(tmp_path / filename, "w") as schema_file:
 | 
			
		||||
        json.dump(test_schema, schema_file, indent=4)
 | 
			
		||||
@ -274,11 +279,65 @@ def test_load_yaml_custom_component(tmp_path):
 | 
			
		||||
    assert pipeline.get_node("custom_node").param == 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_load_yaml_custom_component_with_no_init(tmp_path):
 | 
			
		||||
    class CustomNode(MockNode):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
 | 
			
		||||
        tmp_file.write(
 | 
			
		||||
            f"""
 | 
			
		||||
            version: unstable
 | 
			
		||||
            components:
 | 
			
		||||
            - name: custom_node
 | 
			
		||||
              type: CustomNode
 | 
			
		||||
            pipelines:
 | 
			
		||||
            - name: my_pipeline
 | 
			
		||||
              nodes:
 | 
			
		||||
              - name: custom_node
 | 
			
		||||
                inputs:
 | 
			
		||||
                - Query
 | 
			
		||||
        """
 | 
			
		||||
        )
 | 
			
		||||
    pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
 | 
			
		||||
    assert isinstance(pipeline.get_node("custom_node"), CustomNode)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_load_yaml_custom_component_neednt_call_super(tmp_path):
 | 
			
		||||
    """This is a side-effect. Here for behavior documentation only"""
 | 
			
		||||
 | 
			
		||||
    class CustomNode(BaseComponent):
 | 
			
		||||
        outgoing_edges = 1
 | 
			
		||||
 | 
			
		||||
        def __init__(self, param: int):
 | 
			
		||||
            self.param = param
 | 
			
		||||
 | 
			
		||||
        def run(self, *a, **k):
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
    with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
 | 
			
		||||
        tmp_file.write(
 | 
			
		||||
            f"""
 | 
			
		||||
            version: unstable
 | 
			
		||||
            components:
 | 
			
		||||
            - name: custom_node
 | 
			
		||||
              type: CustomNode
 | 
			
		||||
              params:
 | 
			
		||||
                param: 1
 | 
			
		||||
            pipelines:
 | 
			
		||||
            - name: my_pipeline
 | 
			
		||||
              nodes:
 | 
			
		||||
              - name: custom_node
 | 
			
		||||
                inputs:
 | 
			
		||||
                - Query
 | 
			
		||||
        """
 | 
			
		||||
        )
 | 
			
		||||
    pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
 | 
			
		||||
    assert isinstance(pipeline.get_node("custom_node"), CustomNode)
 | 
			
		||||
    assert pipeline.get_node("custom_node").param == 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_load_yaml_custom_component_cant_be_abstract(tmp_path):
 | 
			
		||||
    class CustomNode(MockNode):
 | 
			
		||||
        def __init__(self):
 | 
			
		||||
            super().__init__()
 | 
			
		||||
 | 
			
		||||
        @abstractmethod
 | 
			
		||||
        def abstract_method(self):
 | 
			
		||||
            pass
 | 
			
		||||
@ -575,7 +634,7 @@ def test_load_yaml_custom_component_with_external_constant(tmp_path):
 | 
			
		||||
        )
 | 
			
		||||
    pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
 | 
			
		||||
    node = pipeline.get_node("custom_node")
 | 
			
		||||
    node.some_exotic_parameter == "AnotherClass.CLASS_CONSTANT"
 | 
			
		||||
    assert node.some_exotic_parameter == "AnotherClass.CLASS_CONSTANT"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_load_yaml_custom_component_with_superclass(tmp_path):
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user