Improve error message for nodes failing validation (#2313)

* Similar test case seems to pass

* Update Documentation & Code Style

* Improve error message

* Slightly clarify info message

* Fix mismatch between node and node_class in the schema generation

* Remove condition that node class names cannot begin with Base and update tests

* Indentation

* Update Documentation & Code Style

* feedback

* Update Documentation & Code Style

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Sara Zan 2022-03-21 14:47:24 +01:00 committed by GitHub
parent 5454d57bfa
commit 7261377643
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 170 additions and 34 deletions

View File

@ -440,7 +440,7 @@ then be found in the dict returned by this method under the key "_debug"
```python
@send_event
def eval(labels: List[MultiLabel], documents: Optional[List[List[Document]]] = None, params: Optional[dict] = None, sas_model_name_or_path: str = None, add_isolated_node_eval: bool = False) -> EvaluationResult
def eval(labels: List[MultiLabel], documents: Optional[List[List[Document]]] = None, params: Optional[dict] = None, sas_model_name_or_path: str = None, sas_batch_size: int = 32, sas_use_gpu: bool = True, add_isolated_node_eval: bool = False) -> EvaluationResult
```
Evaluates the pipeline by running the pipeline once per query in debug mode

View File

@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
import logging
@ -113,7 +113,11 @@ class Config(BaseConfig):
extra = "forbid" # type: ignore
def find_subclasses_in_modules(importable_modules: List[str], include_base_classes: bool = False):
def is_valid_component_class(class_):
return inspect.isclass(class_) and not inspect.isabstract(class_) and issubclass(class_, BaseComponent)
def find_subclasses_in_modules(importable_modules: List[str]):
"""
This function returns a list `(module, class)` of all the classes that can be imported
dynamically, for example from a pipeline YAML definition or to generate documentation.
@ -121,19 +125,14 @@ def find_subclasses_in_modules(importable_modules: List[str], include_base_class
By default it won't include Base classes, which should be abstract.
"""
return [
(module, clazz)
(module, class_)
for module in importable_modules
for _, clazz in inspect.getmembers(sys.modules[module])
if (
inspect.isclass(clazz)
and not inspect.isabstract(clazz)
and issubclass(clazz, BaseComponent)
and (include_base_classes or not clazz.__name__.startswith("Base"))
)
for _, class_ in inspect.getmembers(sys.modules[module])
if is_valid_component_class(class_)
]
def create_schema_for_node(node: BaseComponent) -> Tuple[Dict[str, Any], Dict[str, Any]]:
def create_schema_for_node_class(node_class: Type[BaseComponent]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
Create the JSON schema for a single BaseComponent subclass,
including all accessory classes.
@ -141,15 +140,17 @@ def create_schema_for_node(node: BaseComponent) -> Tuple[Dict[str, Any], Dict[st
:returns: the schema for the node and all accessory classes,
and a dict with the reference to the node only.
"""
if not hasattr(node, "__name__"):
raise PipelineSchemaError(f"Node {node} has no __name__ attribute, cannot create a schema for it.")
if not hasattr(node_class, "__name__"):
raise PipelineSchemaError(
f"Node class '{node_class}' has no '__name__' attribute, cannot create a schema for it."
)
node_name = getattr(node, "__name__")
node_name = getattr(node_class, "__name__")
logger.info(f"Processing node: {node_name}")
logger.info(f"Creating schema for '{node_name}'")
# Read the relevant init parameters from __init__'s signature
init_method = getattr(node, "__init__", None)
init_method = getattr(node_class, "__init__", None)
if not init_method:
raise PipelineSchemaError(f"Could not read the __init__ method of {node_name} to create its schema.")
@ -228,11 +229,11 @@ def get_json_schema(
node_refs = [] # References to the nodes only (accessory classes cannot be listed among the nodes in a config)
# List all known nodes in the given modules
possible_nodes = find_subclasses_in_modules(importable_modules=modules)
possible_node_classes = find_subclasses_in_modules(importable_modules=modules)
# Build the definitions and refs for the nodes
for _, node in possible_nodes:
node_definition, node_ref = create_schema_for_node(node)
for _, node_class in possible_node_classes:
node_definition, node_ref = create_schema_for_node_class(node_class)
schema_definitions.update(node_definition)
node_refs.append(node_ref)
@ -303,17 +304,24 @@ def get_json_schema(
return pipeline_schema
def inject_definition_in_schema(node: BaseComponent, schema: Dict[str, Any]) -> Dict[str, Any]:
def inject_definition_in_schema(node_class: Type[BaseComponent], schema: Dict[str, Any]) -> Dict[str, Any]:
"""
Given a node and a schema in dict form, injects the JSON schema for the new component
so that pipelines containing such note can be validated against it.
:returns: the updated schema
"""
schema_definition, node_ref = create_schema_for_node(node)
if not is_valid_component_class(node_class):
raise PipelineSchemaError(
f"Can't generate a valid schema for node of type '{node_class.__name__}'. "
"Possible causes: \n"
" - it has abstract methods\n"
" - its __init__() take something else than Python primitive types or other nodes as parameter.\n"
)
schema_definition, node_ref = create_schema_for_node_class(node_class)
schema["definitions"].update(schema_definition)
schema["properties"]["components"]["items"]["anyOf"].append(node_ref)
logger.info(f"Added definition for {getattr(node, '__name__')}")
logger.info(f"Added definition for {getattr(node_class, '__name__')}")
return schema

View File

@ -9,6 +9,7 @@ import inspect
import logging
from haystack.schema import Document, MultiLabel
from haystack.errors import PipelineSchemaError
from haystack.telemetry import send_custom_event
from haystack.errors import HaystackError
@ -73,6 +74,7 @@ class BaseComponent(ABC):
# Keeps track of all available subclasses by name.
# Enables generic load() for all specific component implementations.
# Registers abstract classes and base classes too.
cls._subclasses[cls.__name__] = cls
@property
@ -108,7 +110,7 @@ class BaseComponent(ABC):
@classmethod
def get_subclass(cls, component_type: str):
if component_type not in cls._subclasses.keys():
raise HaystackError(f"Haystack component with the name '{component_type}' does not exist.")
raise PipelineSchemaError(f"Haystack component with the name '{component_type}' not found.")
subclass = cls._subclasses[component_type]
return subclass

View File

@ -206,16 +206,18 @@ def validate_config(pipeline_config: Dict) -> None:
logger.info(
f"Missing definition for node of type {validation.instance['type']}. Looking into local classes..."
)
missing_component = BaseComponent.get_subclass(validation.instance["type"])
schema = inject_definition_in_schema(node=missing_component, schema=schema)
missing_component_class = BaseComponent.get_subclass(validation.instance["type"])
schema = inject_definition_in_schema(node_class=missing_component_class, schema=schema)
loaded_custom_nodes.append(validation.instance["type"])
continue
# A node with the given name was imported, but something else is wrong with it.
# A node with the given name was in the schema, but something else is wrong with it.
# Probably it references unknown classes in its init parameters.
raise PipelineSchemaError(
f"Cannot process node of type {validation.instance['type']}. Make sure its __init__ function "
"does not reference external classes, but uses only Python primitive types."
f"Node of type {validation.instance['type']} found, but it failed validation. Possible causes:\n"
" - The node is missing some mandatory parameter\n"
" - Wrong indentation of some parameter in YAML\n"
"See the stacktrace for more information."
) from validation
# Format the error to make it as clear as possible

View File

@ -1,14 +1,13 @@
from abc import abstractmethod
import pytest
import json
import numpy as np
import inspect
import networkx as nx
from enum import Enum
from pydantic.dataclasses import dataclass
import haystack
from haystack import Pipeline
from haystack import document_stores
from haystack.document_stores.base import BaseDocumentStore
from haystack.nodes import _json_schema
from haystack.nodes import FileTypeClassifier
from haystack.errors import HaystackError, PipelineConfigError, PipelineSchemaError
@ -251,6 +250,7 @@ def test_load_yaml_wrong_component(tmp_path):
def test_load_yaml_custom_component(tmp_path):
class CustomNode(MockNode):
def __init__(self, param: int):
super().__init__()
self.param = param
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
@ -270,7 +270,124 @@ def test_load_yaml_custom_component(tmp_path):
- Query
"""
)
Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
assert pipeline.get_node("custom_node").param == 1
def test_load_yaml_custom_component_cant_be_abstract(tmp_path):
class CustomNode(MockNode):
def __init__(self):
super().__init__()
@abstractmethod
def abstract_method(self):
pass
assert inspect.isabstract(CustomNode)
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
f"""
version: unstable
components:
- name: custom_node
type: CustomNode
pipelines:
- name: my_pipeline
nodes:
- name: custom_node
inputs:
- Query
"""
)
with pytest.raises(PipelineSchemaError, match="abstract"):
Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
def test_load_yaml_custom_component_name_can_include_base(tmp_path):
class BaseCustomNode(MockNode):
def __init__(self):
super().__init__()
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
f"""
version: unstable
components:
- name: custom_node
type: BaseCustomNode
pipelines:
- name: my_pipeline
nodes:
- name: custom_node
inputs:
- Query
"""
)
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
assert isinstance(pipeline.get_node("custom_node"), BaseCustomNode)
def test_load_yaml_custom_component_must_subclass_basecomponent(tmp_path):
class SomeCustomNode:
def run(self, *a, **k):
pass
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
f"""
version: unstable
components:
- name: custom_node
type: SomeCustomNode
params:
param: 1
pipelines:
- name: my_pipeline
nodes:
- name: custom_node
inputs:
- Query
"""
)
with pytest.raises(PipelineSchemaError, match="'SomeCustomNode' not found"):
Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
def test_load_yaml_custom_component_referencing_other_node_in_init(tmp_path):
class OtherNode(MockNode):
def __init__(self, another_param: str):
super().__init__()
self.param = another_param
class CustomNode(MockNode):
def __init__(self, other_node: OtherNode):
super().__init__()
self.other_node = other_node
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
f"""
version: unstable
components:
- name: other_node
type: OtherNode
params:
another_param: value
- name: custom_node
type: CustomNode
params:
other_node: other_node
pipelines:
- name: my_pipeline
nodes:
- name: custom_node
inputs:
- Query
"""
)
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
assert isinstance(pipeline.get_node("custom_node"), CustomNode)
def test_load_yaml_custom_component_with_helper_class_in_init(tmp_path):
@ -289,6 +406,7 @@ def test_load_yaml_custom_component_with_helper_class_in_init(tmp_path):
class CustomNode(MockNode):
def __init__(self, some_exotic_parameter: HelperClass = HelperClass(1)):
super().__init__()
self.some_exotic_parameter = some_exotic_parameter
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
@ -325,6 +443,7 @@ def test_load_yaml_custom_component_with_helper_class_in_yaml(tmp_path):
class CustomNode(MockNode):
def __init__(self, some_exotic_parameter: HelperClass):
super().__init__()
self.some_exotic_parameter = some_exotic_parameter
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
@ -363,6 +482,7 @@ def test_load_yaml_custom_component_with_enum_in_init(tmp_path):
class CustomNode(MockNode):
def __init__(self, some_exotic_parameter: Flags = None):
super().__init__()
self.some_exotic_parameter = some_exotic_parameter
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
@ -399,6 +519,7 @@ def test_load_yaml_custom_component_with_enum_in_yaml(tmp_path):
class CustomNode(MockNode):
def __init__(self, some_exotic_parameter: Flags):
super().__init__()
self.some_exotic_parameter = some_exotic_parameter
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
@ -432,6 +553,7 @@ def test_load_yaml_custom_component_with_external_constant(tmp_path):
class CustomNode(MockNode):
def __init__(self, some_exotic_parameter: str):
super().__init__()
self.some_exotic_parameter = some_exotic_parameter
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
@ -458,10 +580,12 @@ def test_load_yaml_custom_component_with_external_constant(tmp_path):
def test_load_yaml_custom_component_with_superclass(tmp_path):
class BaseCustomNode(MockNode):
pass
def __init__(self):
super().__init__()
class CustomNode(BaseCustomNode):
def __init__(self, some_exotic_parameter: str):
super().__init__()
self.some_exotic_parameter = some_exotic_parameter
with open(tmp_path / "tmp_config.yml", "w") as tmp_file: