Add tests for missing __init__ and super().__init__() in custom nodes (#2350)

* Add tests for missing init and super

* Update Documentation & Code Style

* change in with endswith

* Move test in pipeline.py and change test in pipeline_yaml.py

* Update Documentation & Code Style

* Use caplog to test the warning

* Update Documentation & Code Style

* move tests into test_pipeline and use get_config

* Update Documentation & Code Style

* Unmock version name

* Update Documentation & Code Style

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Sara Zan 2022-04-13 14:29:05 +02:00 committed by GitHub
parent 73f9ab0f57
commit d98883b79d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 101 additions and 7 deletions

View File

@ -37,7 +37,10 @@ def exportable_to_yaml(init_func):
self._component_config = {"params": {}, "type": type(self).__name__} self._component_config = {"params": {}, "type": type(self).__name__}
# Make sure it runs only on the __init__of the implementations, not in superclasses # Make sure it runs only on the __init__of the implementations, not in superclasses
if init_func.__qualname__ == f"{self.__class__.__name__}.{init_func.__name__}": # NOTE: we use '.endswith' because inner classes's __qualname__ will include the parent class'
# name, like: ParentClass.InnerClass.__init__.
# Inner classes are heavily used in tests.
if init_func.__qualname__.endswith(f"{self.__class__.__name__}.{init_func.__name__}"):
# Store all the named input parameters in self._component_config # Store all the named input parameters in self._component_config
for k, v in kwargs.items(): for k, v in kwargs.items():

View File

@ -54,7 +54,8 @@ def _optional_component_not_installed(component: str, dep_group: str, source_err
f"Failed to import '{component}', " f"Failed to import '{component}', "
"which is an optional component in Haystack.\n" "which is an optional component in Haystack.\n"
f"Run 'pip install 'farm-haystack[{dep_group}]'' " f"Run 'pip install 'farm-haystack[{dep_group}]'' "
"to install the required dependencies and make this component available." "to install the required dependencies and make this component available.\n"
f"(Original error: {str(source_error)})"
) from source_error ) from source_error

View File

@ -35,6 +35,7 @@ from .conftest import (
SAMPLES_PATH, SAMPLES_PATH,
MockDocumentStore, MockDocumentStore,
MockRetriever, MockRetriever,
MockNode,
deepset_cloud_fixture, deepset_cloud_fixture,
) )
@ -350,6 +351,36 @@ def test_get_config_component_with_superclass_arguments():
assert pipeline.get_document_store().base_parameter == "something" assert pipeline.get_document_store().base_parameter == "something"
def test_get_config_custom_node_with_params():
class CustomNode(MockNode):
def __init__(self, param: int):
super().__init__()
self.param = param
pipeline = Pipeline()
pipeline.add_node(CustomNode(param=10), name="custom_node", inputs=["Query"])
assert len(pipeline.get_config()["components"]) == 1
assert pipeline.get_config()["components"][0]["params"] == {"param": 10}
def test_get_config_custom_node_with_positional_params(caplog):
class CustomNode(MockNode):
def __init__(self, param: int = 1):
super().__init__()
self.param = param
pipeline = Pipeline()
with caplog.at_level(logging.WARNING):
pipeline.add_node(CustomNode(10), name="custom_node", inputs=["Query"])
assert (
"Unnamed __init__ parameters will not be saved to YAML "
"if Pipeline.save_to_yaml() is called" in caplog.text
)
assert len(pipeline.get_config()["components"]) == 1
assert pipeline.get_config()["components"][0]["params"] == {}
def test_generate_code_simple_pipeline(): def test_generate_code_simple_pipeline():
config = { config = {
"version": "unstable", "version": "unstable",

View File

@ -1,6 +1,8 @@
from abc import abstractmethod from abc import abstractmethod
from numpy import mat
import pytest import pytest
import json import json
import logging
import inspect import inspect
import networkx as nx import networkx as nx
from enum import Enum from enum import Enum
@ -11,6 +13,7 @@ from haystack import Pipeline
from haystack.nodes import _json_schema from haystack.nodes import _json_schema
from haystack.nodes import FileTypeClassifier from haystack.nodes import FileTypeClassifier
from haystack.errors import HaystackError, PipelineConfigError, PipelineSchemaError from haystack.errors import HaystackError, PipelineConfigError, PipelineSchemaError
from haystack.nodes.base import BaseComponent
from .conftest import SAMPLES_PATH, MockNode, MockDocumentStore, MockReader, MockRetriever from .conftest import SAMPLES_PATH, MockNode, MockDocumentStore, MockReader, MockRetriever
from . import conftest from . import conftest
@ -41,7 +44,9 @@ def mock_json_schema(request, monkeypatch, tmp_path):
# Generate mock schema in tmp_path # Generate mock schema in tmp_path
filename = f"haystack-pipeline-unstable.schema.json" filename = f"haystack-pipeline-unstable.schema.json"
test_schema = _json_schema.get_json_schema(filename=filename, compatible_versions=["unstable"]) test_schema = _json_schema.get_json_schema(
filename=filename, compatible_versions=["unstable", haystack.__version__]
)
with open(tmp_path / filename, "w") as schema_file: with open(tmp_path / filename, "w") as schema_file:
json.dump(test_schema, schema_file, indent=4) json.dump(test_schema, schema_file, indent=4)
@ -274,11 +279,65 @@ def test_load_yaml_custom_component(tmp_path):
assert pipeline.get_node("custom_node").param == 1 assert pipeline.get_node("custom_node").param == 1
def test_load_yaml_custom_component_with_no_init(tmp_path):
class CustomNode(MockNode):
pass
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
f"""
version: unstable
components:
- name: custom_node
type: CustomNode
pipelines:
- name: my_pipeline
nodes:
- name: custom_node
inputs:
- Query
"""
)
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
assert isinstance(pipeline.get_node("custom_node"), CustomNode)
def test_load_yaml_custom_component_neednt_call_super(tmp_path):
"""This is a side-effect. Here for behavior documentation only"""
class CustomNode(BaseComponent):
outgoing_edges = 1
def __init__(self, param: int):
self.param = param
def run(self, *a, **k):
pass
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
f"""
version: unstable
components:
- name: custom_node
type: CustomNode
params:
param: 1
pipelines:
- name: my_pipeline
nodes:
- name: custom_node
inputs:
- Query
"""
)
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
assert isinstance(pipeline.get_node("custom_node"), CustomNode)
assert pipeline.get_node("custom_node").param == 1
def test_load_yaml_custom_component_cant_be_abstract(tmp_path): def test_load_yaml_custom_component_cant_be_abstract(tmp_path):
class CustomNode(MockNode): class CustomNode(MockNode):
def __init__(self):
super().__init__()
@abstractmethod @abstractmethod
def abstract_method(self): def abstract_method(self):
pass pass
@ -575,7 +634,7 @@ def test_load_yaml_custom_component_with_external_constant(tmp_path):
) )
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml") pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
node = pipeline.get_node("custom_node") node = pipeline.get_node("custom_node")
node.some_exotic_parameter == "AnotherClass.CLASS_CONSTANT" assert node.some_exotic_parameter == "AnotherClass.CLASS_CONSTANT"
def test_load_yaml_custom_component_with_superclass(tmp_path): def test_load_yaml_custom_component_with_superclass(tmp_path):