mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-26 06:28:33 +00:00
Improve error message for nodes failing validation (#2313)
* Similar test case seems to pass * Update Documentation & Code Style * Improve error message * Slightly clarify info message * Fix mismatch between node and node_class in the schema generation * Remove condition that node class names cannot begin with Base and update tests * Indentation * Update Documentation & Code Style * feedback * Update Documentation & Code Style Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
5454d57bfa
commit
7261377643
@ -440,7 +440,7 @@ then be found in the dict returned by this method under the key "_debug"
|
||||
|
||||
```python
|
||||
@send_event
|
||||
def eval(labels: List[MultiLabel], documents: Optional[List[List[Document]]] = None, params: Optional[dict] = None, sas_model_name_or_path: str = None, add_isolated_node_eval: bool = False) -> EvaluationResult
|
||||
def eval(labels: List[MultiLabel], documents: Optional[List[List[Document]]] = None, params: Optional[dict] = None, sas_model_name_or_path: str = None, sas_batch_size: int = 32, sas_use_gpu: bool = True, add_isolated_node_eval: bool = False) -> EvaluationResult
|
||||
```
|
||||
|
||||
Evaluates the pipeline by running the pipeline once per query in debug mode
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
|
||||
|
||||
import logging
|
||||
|
||||
@ -113,7 +113,11 @@ class Config(BaseConfig):
|
||||
extra = "forbid" # type: ignore
|
||||
|
||||
|
||||
def find_subclasses_in_modules(importable_modules: List[str], include_base_classes: bool = False):
|
||||
def is_valid_component_class(class_):
|
||||
return inspect.isclass(class_) and not inspect.isabstract(class_) and issubclass(class_, BaseComponent)
|
||||
|
||||
|
||||
def find_subclasses_in_modules(importable_modules: List[str]):
|
||||
"""
|
||||
This function returns a list `(module, class)` of all the classes that can be imported
|
||||
dynamically, for example from a pipeline YAML definition or to generate documentation.
|
||||
@ -121,19 +125,14 @@ def find_subclasses_in_modules(importable_modules: List[str], include_base_class
|
||||
By default it won't include Base classes, which should be abstract.
|
||||
"""
|
||||
return [
|
||||
(module, clazz)
|
||||
(module, class_)
|
||||
for module in importable_modules
|
||||
for _, clazz in inspect.getmembers(sys.modules[module])
|
||||
if (
|
||||
inspect.isclass(clazz)
|
||||
and not inspect.isabstract(clazz)
|
||||
and issubclass(clazz, BaseComponent)
|
||||
and (include_base_classes or not clazz.__name__.startswith("Base"))
|
||||
)
|
||||
for _, class_ in inspect.getmembers(sys.modules[module])
|
||||
if is_valid_component_class(class_)
|
||||
]
|
||||
|
||||
|
||||
def create_schema_for_node(node: BaseComponent) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
def create_schema_for_node_class(node_class: Type[BaseComponent]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""
|
||||
Create the JSON schema for a single BaseComponent subclass,
|
||||
including all accessory classes.
|
||||
@ -141,15 +140,17 @@ def create_schema_for_node(node: BaseComponent) -> Tuple[Dict[str, Any], Dict[st
|
||||
:returns: the schema for the node and all accessory classes,
|
||||
and a dict with the reference to the node only.
|
||||
"""
|
||||
if not hasattr(node, "__name__"):
|
||||
raise PipelineSchemaError(f"Node {node} has no __name__ attribute, cannot create a schema for it.")
|
||||
if not hasattr(node_class, "__name__"):
|
||||
raise PipelineSchemaError(
|
||||
f"Node class '{node_class}' has no '__name__' attribute, cannot create a schema for it."
|
||||
)
|
||||
|
||||
node_name = getattr(node, "__name__")
|
||||
node_name = getattr(node_class, "__name__")
|
||||
|
||||
logger.info(f"Processing node: {node_name}")
|
||||
logger.info(f"Creating schema for '{node_name}'")
|
||||
|
||||
# Read the relevant init parameters from __init__'s signature
|
||||
init_method = getattr(node, "__init__", None)
|
||||
init_method = getattr(node_class, "__init__", None)
|
||||
if not init_method:
|
||||
raise PipelineSchemaError(f"Could not read the __init__ method of {node_name} to create its schema.")
|
||||
|
||||
@ -228,11 +229,11 @@ def get_json_schema(
|
||||
node_refs = [] # References to the nodes only (accessory classes cannot be listed among the nodes in a config)
|
||||
|
||||
# List all known nodes in the given modules
|
||||
possible_nodes = find_subclasses_in_modules(importable_modules=modules)
|
||||
possible_node_classes = find_subclasses_in_modules(importable_modules=modules)
|
||||
|
||||
# Build the definitions and refs for the nodes
|
||||
for _, node in possible_nodes:
|
||||
node_definition, node_ref = create_schema_for_node(node)
|
||||
for _, node_class in possible_node_classes:
|
||||
node_definition, node_ref = create_schema_for_node_class(node_class)
|
||||
schema_definitions.update(node_definition)
|
||||
node_refs.append(node_ref)
|
||||
|
||||
@ -303,17 +304,24 @@ def get_json_schema(
|
||||
return pipeline_schema
|
||||
|
||||
|
||||
def inject_definition_in_schema(node: BaseComponent, schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def inject_definition_in_schema(node_class: Type[BaseComponent], schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Given a node and a schema in dict form, injects the JSON schema for the new component
|
||||
so that pipelines containing such note can be validated against it.
|
||||
|
||||
:returns: the updated schema
|
||||
"""
|
||||
schema_definition, node_ref = create_schema_for_node(node)
|
||||
if not is_valid_component_class(node_class):
|
||||
raise PipelineSchemaError(
|
||||
f"Can't generate a valid schema for node of type '{node_class.__name__}'. "
|
||||
"Possible causes: \n"
|
||||
" - it has abstract methods\n"
|
||||
" - its __init__() take something else than Python primitive types or other nodes as parameter.\n"
|
||||
)
|
||||
schema_definition, node_ref = create_schema_for_node_class(node_class)
|
||||
schema["definitions"].update(schema_definition)
|
||||
schema["properties"]["components"]["items"]["anyOf"].append(node_ref)
|
||||
logger.info(f"Added definition for {getattr(node, '__name__')}")
|
||||
logger.info(f"Added definition for {getattr(node_class, '__name__')}")
|
||||
return schema
|
||||
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ import inspect
|
||||
import logging
|
||||
|
||||
from haystack.schema import Document, MultiLabel
|
||||
from haystack.errors import PipelineSchemaError
|
||||
from haystack.telemetry import send_custom_event
|
||||
from haystack.errors import HaystackError
|
||||
|
||||
@ -73,6 +74,7 @@ class BaseComponent(ABC):
|
||||
|
||||
# Keeps track of all available subclasses by name.
|
||||
# Enables generic load() for all specific component implementations.
|
||||
# Registers abstract classes and base classes too.
|
||||
cls._subclasses[cls.__name__] = cls
|
||||
|
||||
@property
|
||||
@ -108,7 +110,7 @@ class BaseComponent(ABC):
|
||||
@classmethod
|
||||
def get_subclass(cls, component_type: str):
|
||||
if component_type not in cls._subclasses.keys():
|
||||
raise HaystackError(f"Haystack component with the name '{component_type}' does not exist.")
|
||||
raise PipelineSchemaError(f"Haystack component with the name '{component_type}' not found.")
|
||||
subclass = cls._subclasses[component_type]
|
||||
return subclass
|
||||
|
||||
|
||||
@ -206,16 +206,18 @@ def validate_config(pipeline_config: Dict) -> None:
|
||||
logger.info(
|
||||
f"Missing definition for node of type {validation.instance['type']}. Looking into local classes..."
|
||||
)
|
||||
missing_component = BaseComponent.get_subclass(validation.instance["type"])
|
||||
schema = inject_definition_in_schema(node=missing_component, schema=schema)
|
||||
missing_component_class = BaseComponent.get_subclass(validation.instance["type"])
|
||||
schema = inject_definition_in_schema(node_class=missing_component_class, schema=schema)
|
||||
loaded_custom_nodes.append(validation.instance["type"])
|
||||
continue
|
||||
|
||||
# A node with the given name was imported, but something else is wrong with it.
|
||||
# A node with the given name was in the schema, but something else is wrong with it.
|
||||
# Probably it references unknown classes in its init parameters.
|
||||
raise PipelineSchemaError(
|
||||
f"Cannot process node of type {validation.instance['type']}. Make sure its __init__ function "
|
||||
"does not reference external classes, but uses only Python primitive types."
|
||||
f"Node of type {validation.instance['type']} found, but it failed validation. Possible causes:\n"
|
||||
" - The node is missing some mandatory parameter\n"
|
||||
" - Wrong indentation of some parameter in YAML\n"
|
||||
"See the stacktrace for more information."
|
||||
) from validation
|
||||
|
||||
# Format the error to make it as clear as possible
|
||||
|
||||
@ -1,14 +1,13 @@
|
||||
from abc import abstractmethod
|
||||
import pytest
|
||||
import json
|
||||
import numpy as np
|
||||
import inspect
|
||||
import networkx as nx
|
||||
from enum import Enum
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
import haystack
|
||||
from haystack import Pipeline
|
||||
from haystack import document_stores
|
||||
from haystack.document_stores.base import BaseDocumentStore
|
||||
from haystack.nodes import _json_schema
|
||||
from haystack.nodes import FileTypeClassifier
|
||||
from haystack.errors import HaystackError, PipelineConfigError, PipelineSchemaError
|
||||
@ -251,6 +250,7 @@ def test_load_yaml_wrong_component(tmp_path):
|
||||
def test_load_yaml_custom_component(tmp_path):
|
||||
class CustomNode(MockNode):
|
||||
def __init__(self, param: int):
|
||||
super().__init__()
|
||||
self.param = param
|
||||
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
@ -270,7 +270,124 @@ def test_load_yaml_custom_component(tmp_path):
|
||||
- Query
|
||||
"""
|
||||
)
|
||||
Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
assert pipeline.get_node("custom_node").param == 1
|
||||
|
||||
|
||||
def test_load_yaml_custom_component_cant_be_abstract(tmp_path):
|
||||
class CustomNode(MockNode):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
def abstract_method(self):
|
||||
pass
|
||||
|
||||
assert inspect.isabstract(CustomNode)
|
||||
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: unstable
|
||||
components:
|
||||
- name: custom_node
|
||||
type: CustomNode
|
||||
pipelines:
|
||||
- name: my_pipeline
|
||||
nodes:
|
||||
- name: custom_node
|
||||
inputs:
|
||||
- Query
|
||||
"""
|
||||
)
|
||||
with pytest.raises(PipelineSchemaError, match="abstract"):
|
||||
Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
|
||||
|
||||
def test_load_yaml_custom_component_name_can_include_base(tmp_path):
|
||||
class BaseCustomNode(MockNode):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: unstable
|
||||
components:
|
||||
- name: custom_node
|
||||
type: BaseCustomNode
|
||||
pipelines:
|
||||
- name: my_pipeline
|
||||
nodes:
|
||||
- name: custom_node
|
||||
inputs:
|
||||
- Query
|
||||
"""
|
||||
)
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
assert isinstance(pipeline.get_node("custom_node"), BaseCustomNode)
|
||||
|
||||
|
||||
def test_load_yaml_custom_component_must_subclass_basecomponent(tmp_path):
|
||||
class SomeCustomNode:
|
||||
def run(self, *a, **k):
|
||||
pass
|
||||
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: unstable
|
||||
components:
|
||||
- name: custom_node
|
||||
type: SomeCustomNode
|
||||
params:
|
||||
param: 1
|
||||
pipelines:
|
||||
- name: my_pipeline
|
||||
nodes:
|
||||
- name: custom_node
|
||||
inputs:
|
||||
- Query
|
||||
"""
|
||||
)
|
||||
with pytest.raises(PipelineSchemaError, match="'SomeCustomNode' not found"):
|
||||
Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
|
||||
|
||||
def test_load_yaml_custom_component_referencing_other_node_in_init(tmp_path):
|
||||
class OtherNode(MockNode):
|
||||
def __init__(self, another_param: str):
|
||||
super().__init__()
|
||||
self.param = another_param
|
||||
|
||||
class CustomNode(MockNode):
|
||||
def __init__(self, other_node: OtherNode):
|
||||
super().__init__()
|
||||
self.other_node = other_node
|
||||
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
version: unstable
|
||||
components:
|
||||
- name: other_node
|
||||
type: OtherNode
|
||||
params:
|
||||
another_param: value
|
||||
- name: custom_node
|
||||
type: CustomNode
|
||||
params:
|
||||
other_node: other_node
|
||||
pipelines:
|
||||
- name: my_pipeline
|
||||
nodes:
|
||||
- name: custom_node
|
||||
inputs:
|
||||
- Query
|
||||
"""
|
||||
)
|
||||
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
|
||||
assert isinstance(pipeline.get_node("custom_node"), CustomNode)
|
||||
|
||||
|
||||
def test_load_yaml_custom_component_with_helper_class_in_init(tmp_path):
|
||||
@ -289,6 +406,7 @@ def test_load_yaml_custom_component_with_helper_class_in_init(tmp_path):
|
||||
|
||||
class CustomNode(MockNode):
|
||||
def __init__(self, some_exotic_parameter: HelperClass = HelperClass(1)):
|
||||
super().__init__()
|
||||
self.some_exotic_parameter = some_exotic_parameter
|
||||
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
@ -325,6 +443,7 @@ def test_load_yaml_custom_component_with_helper_class_in_yaml(tmp_path):
|
||||
|
||||
class CustomNode(MockNode):
|
||||
def __init__(self, some_exotic_parameter: HelperClass):
|
||||
super().__init__()
|
||||
self.some_exotic_parameter = some_exotic_parameter
|
||||
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
@ -363,6 +482,7 @@ def test_load_yaml_custom_component_with_enum_in_init(tmp_path):
|
||||
|
||||
class CustomNode(MockNode):
|
||||
def __init__(self, some_exotic_parameter: Flags = None):
|
||||
super().__init__()
|
||||
self.some_exotic_parameter = some_exotic_parameter
|
||||
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
@ -399,6 +519,7 @@ def test_load_yaml_custom_component_with_enum_in_yaml(tmp_path):
|
||||
|
||||
class CustomNode(MockNode):
|
||||
def __init__(self, some_exotic_parameter: Flags):
|
||||
super().__init__()
|
||||
self.some_exotic_parameter = some_exotic_parameter
|
||||
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
@ -432,6 +553,7 @@ def test_load_yaml_custom_component_with_external_constant(tmp_path):
|
||||
|
||||
class CustomNode(MockNode):
|
||||
def __init__(self, some_exotic_parameter: str):
|
||||
super().__init__()
|
||||
self.some_exotic_parameter = some_exotic_parameter
|
||||
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
@ -458,10 +580,12 @@ def test_load_yaml_custom_component_with_external_constant(tmp_path):
|
||||
|
||||
def test_load_yaml_custom_component_with_superclass(tmp_path):
|
||||
class BaseCustomNode(MockNode):
|
||||
pass
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
class CustomNode(BaseCustomNode):
|
||||
def __init__(self, some_exotic_parameter: str):
|
||||
super().__init__()
|
||||
self.some_exotic_parameter = some_exotic_parameter
|
||||
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user