Add tests for missing __init__ and super().__init__() in custom nodes (#2350)

* Add tests for missing init and super

* Update Documentation & Code Style

* change in with endswith

* Move test in pipeline.py and change test in pipeline_yaml.py

* Update Documentation & Code Style

* Use caplog to test the warning

* Update Documentation & Code Style

* move tests into test_pipeline and use get_config

* Update Documentation & Code Style

* Unmock version name

* Update Documentation & Code Style

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Sara Zan 2022-04-13 14:29:05 +02:00 committed by GitHub
parent 73f9ab0f57
commit d98883b79d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 101 additions and 7 deletions

View File

@ -37,7 +37,10 @@ def exportable_to_yaml(init_func):
self._component_config = {"params": {}, "type": type(self).__name__}
# Make sure it runs only on the __init__of the implementations, not in superclasses
if init_func.__qualname__ == f"{self.__class__.__name__}.{init_func.__name__}":
# NOTE: we use '.endswith' because inner classes's __qualname__ will include the parent class'
# name, like: ParentClass.InnerClass.__init__.
# Inner classes are heavily used in tests.
if init_func.__qualname__.endswith(f"{self.__class__.__name__}.{init_func.__name__}"):
# Store all the named input parameters in self._component_config
for k, v in kwargs.items():

View File

@ -54,7 +54,8 @@ def _optional_component_not_installed(component: str, dep_group: str, source_err
f"Failed to import '{component}', "
"which is an optional component in Haystack.\n"
f"Run 'pip install 'farm-haystack[{dep_group}]'' "
"to install the required dependencies and make this component available."
"to install the required dependencies and make this component available.\n"
f"(Original error: {str(source_error)})"
) from source_error

View File

@ -35,6 +35,7 @@ from .conftest import (
SAMPLES_PATH,
MockDocumentStore,
MockRetriever,
MockNode,
deepset_cloud_fixture,
)
@ -350,6 +351,36 @@ def test_get_config_component_with_superclass_arguments():
assert pipeline.get_document_store().base_parameter == "something"
def test_get_config_custom_node_with_params():
class CustomNode(MockNode):
def __init__(self, param: int):
super().__init__()
self.param = param
pipeline = Pipeline()
pipeline.add_node(CustomNode(param=10), name="custom_node", inputs=["Query"])
assert len(pipeline.get_config()["components"]) == 1
assert pipeline.get_config()["components"][0]["params"] == {"param": 10}
def test_get_config_custom_node_with_positional_params(caplog):
class CustomNode(MockNode):
def __init__(self, param: int = 1):
super().__init__()
self.param = param
pipeline = Pipeline()
with caplog.at_level(logging.WARNING):
pipeline.add_node(CustomNode(10), name="custom_node", inputs=["Query"])
assert (
"Unnamed __init__ parameters will not be saved to YAML "
"if Pipeline.save_to_yaml() is called" in caplog.text
)
assert len(pipeline.get_config()["components"]) == 1
assert pipeline.get_config()["components"][0]["params"] == {}
def test_generate_code_simple_pipeline():
config = {
"version": "unstable",

View File

@ -1,6 +1,8 @@
from abc import abstractmethod
from numpy import mat
import pytest
import json
import logging
import inspect
import networkx as nx
from enum import Enum
@ -11,6 +13,7 @@ from haystack import Pipeline
from haystack.nodes import _json_schema
from haystack.nodes import FileTypeClassifier
from haystack.errors import HaystackError, PipelineConfigError, PipelineSchemaError
from haystack.nodes.base import BaseComponent
from .conftest import SAMPLES_PATH, MockNode, MockDocumentStore, MockReader, MockRetriever
from . import conftest
@ -41,7 +44,9 @@ def mock_json_schema(request, monkeypatch, tmp_path):
# Generate mock schema in tmp_path
filename = f"haystack-pipeline-unstable.schema.json"
test_schema = _json_schema.get_json_schema(filename=filename, compatible_versions=["unstable"])
test_schema = _json_schema.get_json_schema(
filename=filename, compatible_versions=["unstable", haystack.__version__]
)
with open(tmp_path / filename, "w") as schema_file:
json.dump(test_schema, schema_file, indent=4)
@ -274,11 +279,65 @@ def test_load_yaml_custom_component(tmp_path):
assert pipeline.get_node("custom_node").param == 1
def test_load_yaml_custom_component_with_no_init(tmp_path):
class CustomNode(MockNode):
pass
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
f"""
version: unstable
components:
- name: custom_node
type: CustomNode
pipelines:
- name: my_pipeline
nodes:
- name: custom_node
inputs:
- Query
"""
)
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
assert isinstance(pipeline.get_node("custom_node"), CustomNode)
def test_load_yaml_custom_component_neednt_call_super(tmp_path):
"""This is a side-effect. Here for behavior documentation only"""
class CustomNode(BaseComponent):
outgoing_edges = 1
def __init__(self, param: int):
self.param = param
def run(self, *a, **k):
pass
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
f"""
version: unstable
components:
- name: custom_node
type: CustomNode
params:
param: 1
pipelines:
- name: my_pipeline
nodes:
- name: custom_node
inputs:
- Query
"""
)
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
assert isinstance(pipeline.get_node("custom_node"), CustomNode)
assert pipeline.get_node("custom_node").param == 1
def test_load_yaml_custom_component_cant_be_abstract(tmp_path):
class CustomNode(MockNode):
def __init__(self):
super().__init__()
@abstractmethod
def abstract_method(self):
pass
@ -575,7 +634,7 @@ def test_load_yaml_custom_component_with_external_constant(tmp_path):
)
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
node = pipeline.get_node("custom_node")
node.some_exotic_parameter == "AnotherClass.CLASS_CONSTANT"
assert node.some_exotic_parameter == "AnotherClass.CLASS_CONSTANT"
def test_load_yaml_custom_component_with_superclass(tmp_path):