diff --git a/docs/_src/api/api/pipelines.md b/docs/_src/api/api/pipelines.md index 59a61c701..700eca06f 100644 --- a/docs/_src/api/api/pipelines.md +++ b/docs/_src/api/api/pipelines.md @@ -440,7 +440,7 @@ then be found in the dict returned by this method under the key "_debug" ```python @send_event -def eval(labels: List[MultiLabel], documents: Optional[List[List[Document]]] = None, params: Optional[dict] = None, sas_model_name_or_path: str = None, add_isolated_node_eval: bool = False) -> EvaluationResult +def eval(labels: List[MultiLabel], documents: Optional[List[List[Document]]] = None, params: Optional[dict] = None, sas_model_name_or_path: str = None, sas_batch_size: int = 32, sas_use_gpu: bool = True, add_isolated_node_eval: bool = False) -> EvaluationResult ``` Evaluates the pipeline by running the pipeline once per query in debug mode diff --git a/haystack/nodes/_json_schema.py b/haystack/nodes/_json_schema.py index e10e31111..e9a880476 100644 --- a/haystack/nodes/_json_schema.py +++ b/haystack/nodes/_json_schema.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Set, Tuple +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type import logging @@ -113,7 +113,11 @@ class Config(BaseConfig): extra = "forbid" # type: ignore -def find_subclasses_in_modules(importable_modules: List[str], include_base_classes: bool = False): +def is_valid_component_class(class_): + return inspect.isclass(class_) and not inspect.isabstract(class_) and issubclass(class_, BaseComponent) + + +def find_subclasses_in_modules(importable_modules: List[str]): """ This function returns a list `(module, class)` of all the classes that can be imported dynamically, for example from a pipeline YAML definition or to generate documentation. @@ -121,19 +125,14 @@ def find_subclasses_in_modules(importable_modules: List[str], include_base_class By default it won't include Base classes, which should be abstract. """ return [ - (module, clazz) + (module, class_) for module in importable_modules - for _, clazz in inspect.getmembers(sys.modules[module]) - if ( - inspect.isclass(clazz) - and not inspect.isabstract(clazz) - and issubclass(clazz, BaseComponent) - and (include_base_classes or not clazz.__name__.startswith("Base")) - ) + for _, class_ in inspect.getmembers(sys.modules[module]) + if is_valid_component_class(class_) ] -def create_schema_for_node(node: BaseComponent) -> Tuple[Dict[str, Any], Dict[str, Any]]: +def create_schema_for_node_class(node_class: Type[BaseComponent]) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ Create the JSON schema for a single BaseComponent subclass, including all accessory classes. @@ -141,15 +140,17 @@ def create_schema_for_node(node: BaseComponent) -> Tuple[Dict[str, Any], Dict[st :returns: the schema for the node and all accessory classes, and a dict with the reference to the node only. """ - if not hasattr(node, "__name__"): - raise PipelineSchemaError(f"Node {node} has no __name__ attribute, cannot create a schema for it.") + if not hasattr(node_class, "__name__"): + raise PipelineSchemaError( + f"Node class '{node_class}' has no '__name__' attribute, cannot create a schema for it." + ) - node_name = getattr(node, "__name__") + node_name = getattr(node_class, "__name__") - logger.info(f"Processing node: {node_name}") + logger.info(f"Creating schema for '{node_name}'") # Read the relevant init parameters from __init__'s signature - init_method = getattr(node, "__init__", None) + init_method = getattr(node_class, "__init__", None) if not init_method: raise PipelineSchemaError(f"Could not read the __init__ method of {node_name} to create its schema.") @@ -228,11 +229,11 @@ def get_json_schema( node_refs = [] # References to the nodes only (accessory classes cannot be listed among the nodes in a config) # List all known nodes in the given modules - possible_nodes = find_subclasses_in_modules(importable_modules=modules) + possible_node_classes = find_subclasses_in_modules(importable_modules=modules) # Build the definitions and refs for the nodes - for _, node in possible_nodes: - node_definition, node_ref = create_schema_for_node(node) + for _, node_class in possible_node_classes: + node_definition, node_ref = create_schema_for_node_class(node_class) schema_definitions.update(node_definition) node_refs.append(node_ref) @@ -303,17 +304,24 @@ def get_json_schema( return pipeline_schema -def inject_definition_in_schema(node: BaseComponent, schema: Dict[str, Any]) -> Dict[str, Any]: +def inject_definition_in_schema(node_class: Type[BaseComponent], schema: Dict[str, Any]) -> Dict[str, Any]: """ Given a node and a schema in dict form, injects the JSON schema for the new component so that pipelines containing such note can be validated against it. :returns: the updated schema """ - schema_definition, node_ref = create_schema_for_node(node) + if not is_valid_component_class(node_class): + raise PipelineSchemaError( + f"Can't generate a valid schema for node of type '{node_class.__name__}'. " + "Possible causes: \n" + " - it has abstract methods\n" + " - its __init__() take something else than Python primitive types or other nodes as parameter.\n" + ) + schema_definition, node_ref = create_schema_for_node_class(node_class) schema["definitions"].update(schema_definition) schema["properties"]["components"]["items"]["anyOf"].append(node_ref) - logger.info(f"Added definition for {getattr(node, '__name__')}") + logger.info(f"Added definition for {getattr(node_class, '__name__')}") return schema diff --git a/haystack/nodes/base.py b/haystack/nodes/base.py index c02d58e51..9c8d89499 100644 --- a/haystack/nodes/base.py +++ b/haystack/nodes/base.py @@ -9,6 +9,7 @@ import inspect import logging from haystack.schema import Document, MultiLabel +from haystack.errors import PipelineSchemaError from haystack.telemetry import send_custom_event from haystack.errors import HaystackError @@ -73,6 +74,7 @@ class BaseComponent(ABC): # Keeps track of all available subclasses by name. # Enables generic load() for all specific component implementations. + # Registers abstract classes and base classes too. cls._subclasses[cls.__name__] = cls @property @@ -108,7 +110,7 @@ class BaseComponent(ABC): @classmethod def get_subclass(cls, component_type: str): if component_type not in cls._subclasses.keys(): - raise HaystackError(f"Haystack component with the name '{component_type}' does not exist.") + raise PipelineSchemaError(f"Haystack component with the name '{component_type}' not found.") subclass = cls._subclasses[component_type] return subclass diff --git a/haystack/pipelines/config.py b/haystack/pipelines/config.py index f4df76aa2..c0f2b5f3a 100644 --- a/haystack/pipelines/config.py +++ b/haystack/pipelines/config.py @@ -206,16 +206,18 @@ def validate_config(pipeline_config: Dict) -> None: logger.info( f"Missing definition for node of type {validation.instance['type']}. Looking into local classes..." ) - missing_component = BaseComponent.get_subclass(validation.instance["type"]) - schema = inject_definition_in_schema(node=missing_component, schema=schema) + missing_component_class = BaseComponent.get_subclass(validation.instance["type"]) + schema = inject_definition_in_schema(node_class=missing_component_class, schema=schema) loaded_custom_nodes.append(validation.instance["type"]) continue - # A node with the given name was imported, but something else is wrong with it. + # A node with the given name was in the schema, but something else is wrong with it. # Probably it references unknown classes in its init parameters. raise PipelineSchemaError( - f"Cannot process node of type {validation.instance['type']}. Make sure its __init__ function " - "does not reference external classes, but uses only Python primitive types." + f"Node of type {validation.instance['type']} found, but it failed validation. Possible causes:\n" + " - The node is missing some mandatory parameter\n" + " - Wrong indentation of some parameter in YAML\n" + "See the stacktrace for more information." ) from validation # Format the error to make it as clear as possible diff --git a/test/test_pipeline_yaml.py b/test/test_pipeline_yaml.py index cf32bdbcc..534262663 100644 --- a/test/test_pipeline_yaml.py +++ b/test/test_pipeline_yaml.py @@ -1,14 +1,13 @@ +from abc import abstractmethod import pytest import json -import numpy as np +import inspect import networkx as nx from enum import Enum from pydantic.dataclasses import dataclass import haystack from haystack import Pipeline -from haystack import document_stores -from haystack.document_stores.base import BaseDocumentStore from haystack.nodes import _json_schema from haystack.nodes import FileTypeClassifier from haystack.errors import HaystackError, PipelineConfigError, PipelineSchemaError @@ -251,6 +250,7 @@ def test_load_yaml_wrong_component(tmp_path): def test_load_yaml_custom_component(tmp_path): class CustomNode(MockNode): def __init__(self, param: int): + super().__init__() self.param = param with open(tmp_path / "tmp_config.yml", "w") as tmp_file: @@ -270,7 +270,124 @@ def test_load_yaml_custom_component(tmp_path): - Query """ ) - Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml") + pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml") + assert pipeline.get_node("custom_node").param == 1 + + +def test_load_yaml_custom_component_cant_be_abstract(tmp_path): + class CustomNode(MockNode): + def __init__(self): + super().__init__() + + @abstractmethod + def abstract_method(self): + pass + + assert inspect.isabstract(CustomNode) + + with open(tmp_path / "tmp_config.yml", "w") as tmp_file: + tmp_file.write( + f""" + version: unstable + components: + - name: custom_node + type: CustomNode + pipelines: + - name: my_pipeline + nodes: + - name: custom_node + inputs: + - Query + """ + ) + with pytest.raises(PipelineSchemaError, match="abstract"): + Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml") + + +def test_load_yaml_custom_component_name_can_include_base(tmp_path): + class BaseCustomNode(MockNode): + def __init__(self): + super().__init__() + + with open(tmp_path / "tmp_config.yml", "w") as tmp_file: + tmp_file.write( + f""" + version: unstable + components: + - name: custom_node + type: BaseCustomNode + pipelines: + - name: my_pipeline + nodes: + - name: custom_node + inputs: + - Query + """ + ) + pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml") + assert isinstance(pipeline.get_node("custom_node"), BaseCustomNode) + + +def test_load_yaml_custom_component_must_subclass_basecomponent(tmp_path): + class SomeCustomNode: + def run(self, *a, **k): + pass + + with open(tmp_path / "tmp_config.yml", "w") as tmp_file: + tmp_file.write( + f""" + version: unstable + components: + - name: custom_node + type: SomeCustomNode + params: + param: 1 + pipelines: + - name: my_pipeline + nodes: + - name: custom_node + inputs: + - Query + """ + ) + with pytest.raises(PipelineSchemaError, match="'SomeCustomNode' not found"): + Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml") + + +def test_load_yaml_custom_component_referencing_other_node_in_init(tmp_path): + class OtherNode(MockNode): + def __init__(self, another_param: str): + super().__init__() + self.param = another_param + + class CustomNode(MockNode): + def __init__(self, other_node: OtherNode): + super().__init__() + self.other_node = other_node + + with open(tmp_path / "tmp_config.yml", "w") as tmp_file: + tmp_file.write( + f""" + version: unstable + components: + - name: other_node + type: OtherNode + params: + another_param: value + - name: custom_node + type: CustomNode + params: + other_node: other_node + pipelines: + - name: my_pipeline + nodes: + - name: custom_node + inputs: + - Query + """ + ) + pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml") + assert isinstance(pipeline.get_node("custom_node"), CustomNode) def test_load_yaml_custom_component_with_helper_class_in_init(tmp_path): @@ -289,6 +406,7 @@ def test_load_yaml_custom_component_with_helper_class_in_init(tmp_path): class CustomNode(MockNode): def __init__(self, some_exotic_parameter: HelperClass = HelperClass(1)): + super().__init__() self.some_exotic_parameter = some_exotic_parameter with open(tmp_path / "tmp_config.yml", "w") as tmp_file: @@ -325,6 +443,7 @@ def test_load_yaml_custom_component_with_helper_class_in_yaml(tmp_path): class CustomNode(MockNode): def __init__(self, some_exotic_parameter: HelperClass): + super().__init__() self.some_exotic_parameter = some_exotic_parameter with open(tmp_path / "tmp_config.yml", "w") as tmp_file: @@ -363,6 +482,7 @@ def test_load_yaml_custom_component_with_enum_in_init(tmp_path): class CustomNode(MockNode): def __init__(self, some_exotic_parameter: Flags = None): + super().__init__() self.some_exotic_parameter = some_exotic_parameter with open(tmp_path / "tmp_config.yml", "w") as tmp_file: @@ -399,6 +519,7 @@ def test_load_yaml_custom_component_with_enum_in_yaml(tmp_path): class CustomNode(MockNode): def __init__(self, some_exotic_parameter: Flags): + super().__init__() self.some_exotic_parameter = some_exotic_parameter with open(tmp_path / "tmp_config.yml", "w") as tmp_file: @@ -432,6 +553,7 @@ def test_load_yaml_custom_component_with_external_constant(tmp_path): class CustomNode(MockNode): def __init__(self, some_exotic_parameter: str): + super().__init__() self.some_exotic_parameter = some_exotic_parameter with open(tmp_path / "tmp_config.yml", "w") as tmp_file: @@ -458,10 +580,12 @@ def test_load_yaml_custom_component_with_external_constant(tmp_path): def test_load_yaml_custom_component_with_superclass(tmp_path): class BaseCustomNode(MockNode): - pass + def __init__(self): + super().__init__() class CustomNode(BaseCustomNode): def __init__(self, some_exotic_parameter: str): + super().__init__() self.some_exotic_parameter = some_exotic_parameter with open(tmp_path / "tmp_config.yml", "w") as tmp_file: