mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-06 21:05:33 +00:00
Add tests for missing __init__ and super().__init__() in custom nodes (#2350)
* Add tests for missing init and super * Update Documentation & Code Style * change in with endswith * Move test in pipeline.py and change test in pipeline_yaml.py * Update Documentation & Code Style * Use caplog to test the warning * Update Documentation & Code Style * move tests into test_pipeline and use get_config * Update Documentation & Code Style * Unmock version name * Update Documentation & Code Style Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
73f9ab0f57
commit
d98883b79d
@ -37,7 +37,10 @@ def exportable_to_yaml(init_func):
|
|||||||
self._component_config = {"params": {}, "type": type(self).__name__}
|
self._component_config = {"params": {}, "type": type(self).__name__}
|
||||||
|
|
||||||
# Make sure it runs only on the __init__of the implementations, not in superclasses
|
# Make sure it runs only on the __init__of the implementations, not in superclasses
|
||||||
if init_func.__qualname__ == f"{self.__class__.__name__}.{init_func.__name__}":
|
# NOTE: we use '.endswith' because inner classes's __qualname__ will include the parent class'
|
||||||
|
# name, like: ParentClass.InnerClass.__init__.
|
||||||
|
# Inner classes are heavily used in tests.
|
||||||
|
if init_func.__qualname__.endswith(f"{self.__class__.__name__}.{init_func.__name__}"):
|
||||||
|
|
||||||
# Store all the named input parameters in self._component_config
|
# Store all the named input parameters in self._component_config
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
|
|||||||
@ -54,7 +54,8 @@ def _optional_component_not_installed(component: str, dep_group: str, source_err
|
|||||||
f"Failed to import '{component}', "
|
f"Failed to import '{component}', "
|
||||||
"which is an optional component in Haystack.\n"
|
"which is an optional component in Haystack.\n"
|
||||||
f"Run 'pip install 'farm-haystack[{dep_group}]'' "
|
f"Run 'pip install 'farm-haystack[{dep_group}]'' "
|
||||||
"to install the required dependencies and make this component available."
|
"to install the required dependencies and make this component available.\n"
|
||||||
|
f"(Original error: {str(source_error)})"
|
||||||
) from source_error
|
) from source_error
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -35,6 +35,7 @@ from .conftest import (
|
|||||||
SAMPLES_PATH,
|
SAMPLES_PATH,
|
||||||
MockDocumentStore,
|
MockDocumentStore,
|
||||||
MockRetriever,
|
MockRetriever,
|
||||||
|
MockNode,
|
||||||
deepset_cloud_fixture,
|
deepset_cloud_fixture,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -350,6 +351,36 @@ def test_get_config_component_with_superclass_arguments():
|
|||||||
assert pipeline.get_document_store().base_parameter == "something"
|
assert pipeline.get_document_store().base_parameter == "something"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_config_custom_node_with_params():
|
||||||
|
class CustomNode(MockNode):
|
||||||
|
def __init__(self, param: int):
|
||||||
|
super().__init__()
|
||||||
|
self.param = param
|
||||||
|
|
||||||
|
pipeline = Pipeline()
|
||||||
|
pipeline.add_node(CustomNode(param=10), name="custom_node", inputs=["Query"])
|
||||||
|
|
||||||
|
assert len(pipeline.get_config()["components"]) == 1
|
||||||
|
assert pipeline.get_config()["components"][0]["params"] == {"param": 10}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_config_custom_node_with_positional_params(caplog):
|
||||||
|
class CustomNode(MockNode):
|
||||||
|
def __init__(self, param: int = 1):
|
||||||
|
super().__init__()
|
||||||
|
self.param = param
|
||||||
|
|
||||||
|
pipeline = Pipeline()
|
||||||
|
with caplog.at_level(logging.WARNING):
|
||||||
|
pipeline.add_node(CustomNode(10), name="custom_node", inputs=["Query"])
|
||||||
|
assert (
|
||||||
|
"Unnamed __init__ parameters will not be saved to YAML "
|
||||||
|
"if Pipeline.save_to_yaml() is called" in caplog.text
|
||||||
|
)
|
||||||
|
assert len(pipeline.get_config()["components"]) == 1
|
||||||
|
assert pipeline.get_config()["components"][0]["params"] == {}
|
||||||
|
|
||||||
|
|
||||||
def test_generate_code_simple_pipeline():
|
def test_generate_code_simple_pipeline():
|
||||||
config = {
|
config = {
|
||||||
"version": "unstable",
|
"version": "unstable",
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
from numpy import mat
|
||||||
import pytest
|
import pytest
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import inspect
|
import inspect
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@ -11,6 +13,7 @@ from haystack import Pipeline
|
|||||||
from haystack.nodes import _json_schema
|
from haystack.nodes import _json_schema
|
||||||
from haystack.nodes import FileTypeClassifier
|
from haystack.nodes import FileTypeClassifier
|
||||||
from haystack.errors import HaystackError, PipelineConfigError, PipelineSchemaError
|
from haystack.errors import HaystackError, PipelineConfigError, PipelineSchemaError
|
||||||
|
from haystack.nodes.base import BaseComponent
|
||||||
|
|
||||||
from .conftest import SAMPLES_PATH, MockNode, MockDocumentStore, MockReader, MockRetriever
|
from .conftest import SAMPLES_PATH, MockNode, MockDocumentStore, MockReader, MockRetriever
|
||||||
from . import conftest
|
from . import conftest
|
||||||
@ -41,7 +44,9 @@ def mock_json_schema(request, monkeypatch, tmp_path):
|
|||||||
|
|
||||||
# Generate mock schema in tmp_path
|
# Generate mock schema in tmp_path
|
||||||
filename = f"haystack-pipeline-unstable.schema.json"
|
filename = f"haystack-pipeline-unstable.schema.json"
|
||||||
test_schema = _json_schema.get_json_schema(filename=filename, compatible_versions=["unstable"])
|
test_schema = _json_schema.get_json_schema(
|
||||||
|
filename=filename, compatible_versions=["unstable", haystack.__version__]
|
||||||
|
)
|
||||||
|
|
||||||
with open(tmp_path / filename, "w") as schema_file:
|
with open(tmp_path / filename, "w") as schema_file:
|
||||||
json.dump(test_schema, schema_file, indent=4)
|
json.dump(test_schema, schema_file, indent=4)
|
||||||
@ -274,11 +279,65 @@ def test_load_yaml_custom_component(tmp_path):
|
|||||||
assert pipeline.get_node("custom_node").param == 1
|
assert pipeline.get_node("custom_node").param == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_yaml_custom_component_with_no_init(tmp_path):
|
||||||
|
class CustomNode(MockNode):
|
||||||
|
pass
|
||||||
|
|
||||||
|
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||||
|
tmp_file.write(
|
||||||
|
f"""
|
||||||
|
version: unstable
|
||||||
|
components:
|
||||||
|
- name: custom_node
|
||||||
|
type: CustomNode
|
||||||
|
pipelines:
|
||||||
|
- name: my_pipeline
|
||||||
|
nodes:
|
||||||
|
- name: custom_node
|
||||||
|
inputs:
|
||||||
|
- Query
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||||
|
assert isinstance(pipeline.get_node("custom_node"), CustomNode)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_yaml_custom_component_neednt_call_super(tmp_path):
|
||||||
|
"""This is a side-effect. Here for behavior documentation only"""
|
||||||
|
|
||||||
|
class CustomNode(BaseComponent):
|
||||||
|
outgoing_edges = 1
|
||||||
|
|
||||||
|
def __init__(self, param: int):
|
||||||
|
self.param = param
|
||||||
|
|
||||||
|
def run(self, *a, **k):
|
||||||
|
pass
|
||||||
|
|
||||||
|
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||||
|
tmp_file.write(
|
||||||
|
f"""
|
||||||
|
version: unstable
|
||||||
|
components:
|
||||||
|
- name: custom_node
|
||||||
|
type: CustomNode
|
||||||
|
params:
|
||||||
|
param: 1
|
||||||
|
pipelines:
|
||||||
|
- name: my_pipeline
|
||||||
|
nodes:
|
||||||
|
- name: custom_node
|
||||||
|
inputs:
|
||||||
|
- Query
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||||
|
assert isinstance(pipeline.get_node("custom_node"), CustomNode)
|
||||||
|
assert pipeline.get_node("custom_node").param == 1
|
||||||
|
|
||||||
|
|
||||||
def test_load_yaml_custom_component_cant_be_abstract(tmp_path):
|
def test_load_yaml_custom_component_cant_be_abstract(tmp_path):
|
||||||
class CustomNode(MockNode):
|
class CustomNode(MockNode):
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def abstract_method(self):
|
def abstract_method(self):
|
||||||
pass
|
pass
|
||||||
@ -575,7 +634,7 @@ def test_load_yaml_custom_component_with_external_constant(tmp_path):
|
|||||||
)
|
)
|
||||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||||
node = pipeline.get_node("custom_node")
|
node = pipeline.get_node("custom_node")
|
||||||
node.some_exotic_parameter == "AnotherClass.CLASS_CONSTANT"
|
assert node.some_exotic_parameter == "AnotherClass.CLASS_CONSTANT"
|
||||||
|
|
||||||
|
|
||||||
def test_load_yaml_custom_component_with_superclass(tmp_path):
|
def test_load_yaml_custom_component_with_superclass(tmp_path):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user