mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-30 01:09:43 +00:00
Add tests for missing __init__ and super().__init__() in custom nodes (#2350)
* Add tests for missing init and super * Update Documentation & Code Style * change in with endswith * Move test in pipeline.py and change test in pipeline_yaml.py * Update Documentation & Code Style * Use caplog to test the warning * Update Documentation & Code Style * move tests into test_pipeline and use get_config * Update Documentation & Code Style * Unmock version name * Update Documentation & Code Style Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
73f9ab0f57
commit
d98883b79d
@ -37,7 +37,10 @@ def exportable_to_yaml(init_func):
|
||||
self._component_config = {"params": {}, "type": type(self).__name__}
|
||||
|
||||
# Make sure it runs only on the __init__of the implementations, not in superclasses
|
||||
if init_func.__qualname__ == f"{self.__class__.__name__}.{init_func.__name__}":
|
||||
# NOTE: we use '.endswith' because inner classes's __qualname__ will include the parent class'
|
||||
# name, like: ParentClass.InnerClass.__init__.
|
||||
# Inner classes are heavily used in tests.
|
||||
if init_func.__qualname__.endswith(f"{self.__class__.__name__}.{init_func.__name__}"):
|
||||
|
||||
# Store all the named input parameters in self._component_config
|
||||
for k, v in kwargs.items():
|
||||
|
||||
@ -54,7 +54,8 @@ def _optional_component_not_installed(component: str, dep_group: str, source_err
|
||||
f"Failed to import '{component}', "
|
||||
"which is an optional component in Haystack.\n"
|
||||
f"Run 'pip install 'farm-haystack[{dep_group}]'' "
|
||||
"to install the required dependencies and make this component available."
|
||||
"to install the required dependencies and make this component available.\n"
|
||||
f"(Original error: {str(source_error)})"
|
||||
) from source_error
|
||||
|
||||
|
||||
|
||||
@ -35,6 +35,7 @@ from .conftest import (
|
||||
SAMPLES_PATH,
|
||||
MockDocumentStore,
|
||||
MockRetriever,
|
||||
MockNode,
|
||||
deepset_cloud_fixture,
|
||||
)
|
||||
|
||||
@ -350,6 +351,36 @@ def test_get_config_component_with_superclass_arguments():
|
||||
assert pipeline.get_document_store().base_parameter == "something"
|
||||
|
||||
|
||||
def test_get_config_custom_node_with_params():
|
||||
class CustomNode(MockNode):
|
||||
def __init__(self, param: int):
|
||||
super().__init__()
|
||||
self.param = param
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(CustomNode(param=10), name="custom_node", inputs=["Query"])
|
||||
|
||||
assert len(pipeline.get_config()["components"]) == 1
|
||||
assert pipeline.get_config()["components"][0]["params"] == {"param": 10}
|
||||
|
||||
|
||||
def test_get_config_custom_node_with_positional_params(caplog):
|
||||
class CustomNode(MockNode):
|
||||
def __init__(self, param: int = 1):
|
||||
super().__init__()
|
||||
self.param = param
|
||||
|
||||
pipeline = Pipeline()
|
||||
with caplog.at_level(logging.WARNING):
|
||||
pipeline.add_node(CustomNode(10), name="custom_node", inputs=["Query"])
|
||||
assert (
|
||||
"Unnamed __init__ parameters will not be saved to YAML "
|
||||
"if Pipeline.save_to_yaml() is called" in caplog.text
|
||||
)
|
||||
assert len(pipeline.get_config()["components"]) == 1
|
||||
assert pipeline.get_config()["components"][0]["params"] == {}
|
||||
|
||||
|
||||
def test_generate_code_simple_pipeline():
|
||||
config = {
|
||||
"version": "unstable",
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from abc import abstractmethod
|
||||
from numpy import mat
|
||||
import pytest
|
||||
import json
|
||||
import logging
|
||||
import inspect
|
||||
import networkx as nx
|
||||
from enum import Enum
|
||||
@ -11,6 +13,7 @@ from haystack import Pipeline
|
||||
from haystack.nodes import _json_schema
|
||||
from haystack.nodes import FileTypeClassifier
|
||||
from haystack.errors import HaystackError, PipelineConfigError, PipelineSchemaError
|
||||
from haystack.nodes.base import BaseComponent
|
||||
|
||||
from .conftest import SAMPLES_PATH, MockNode, MockDocumentStore, MockReader, MockRetriever
|
||||
from . import conftest
|
||||
@ -41,7 +44,9 @@ def mock_json_schema(request, monkeypatch, tmp_path):
|
||||
|
||||
# Generate mock schema in tmp_path
|
||||
filename = f"haystack-pipeline-unstable.schema.json"
|
||||
test_schema = _json_schema.get_json_schema(filename=filename, compatible_versions=["unstable"])
|
||||
test_schema = _json_schema.get_json_schema(
|
||||
filename=filename, compatible_versions=["unstable", haystack.__version__]
|
||||
)
|
||||
|
||||
with open(tmp_path / filename, "w") as schema_file:
|
||||
json.dump(test_schema, schema_file, indent=4)
|
||||
@ -274,11 +279,65 @@ def test_load_yaml_custom_component(tmp_path):
|
||||
assert pipeline.get_node("custom_node").param == 1
|
||||
|
||||
|
||||
def test_load_yaml_custom_component_with_no_init(tmp_path):
|
||||
class CustomNode(MockNode):
|
||||
pass
|
||||
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: unstable
|
||||
components:
|
||||
- name: custom_node
|
||||
type: CustomNode
|
||||
pipelines:
|
||||
- name: my_pipeline
|
||||
nodes:
|
||||
- name: custom_node
|
||||
inputs:
|
||||
- Query
|
||||
"""
|
||||
)
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
assert isinstance(pipeline.get_node("custom_node"), CustomNode)
|
||||
|
||||
|
||||
def test_load_yaml_custom_component_neednt_call_super(tmp_path):
|
||||
"""This is a side-effect. Here for behavior documentation only"""
|
||||
|
||||
class CustomNode(BaseComponent):
|
||||
outgoing_edges = 1
|
||||
|
||||
def __init__(self, param: int):
|
||||
self.param = param
|
||||
|
||||
def run(self, *a, **k):
|
||||
pass
|
||||
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: unstable
|
||||
components:
|
||||
- name: custom_node
|
||||
type: CustomNode
|
||||
params:
|
||||
param: 1
|
||||
pipelines:
|
||||
- name: my_pipeline
|
||||
nodes:
|
||||
- name: custom_node
|
||||
inputs:
|
||||
- Query
|
||||
"""
|
||||
)
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
assert isinstance(pipeline.get_node("custom_node"), CustomNode)
|
||||
assert pipeline.get_node("custom_node").param == 1
|
||||
|
||||
|
||||
def test_load_yaml_custom_component_cant_be_abstract(tmp_path):
|
||||
class CustomNode(MockNode):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
def abstract_method(self):
|
||||
pass
|
||||
@ -575,7 +634,7 @@ def test_load_yaml_custom_component_with_external_constant(tmp_path):
|
||||
)
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
node = pipeline.get_node("custom_node")
|
||||
node.some_exotic_parameter == "AnotherClass.CLASS_CONSTANT"
|
||||
assert node.some_exotic_parameter == "AnotherClass.CLASS_CONSTANT"
|
||||
|
||||
|
||||
def test_load_yaml_custom_component_with_superclass(tmp_path):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user