mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-22 12:44:01 +00:00
551 lines
21 KiB
Python
551 lines
21 KiB
Python
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
"""
|
|
Attributes:
|
|
|
|
component: Marks a class as a component. Any class decorated with `@component` can be used by a Pipeline.
|
|
|
|
All components must follow the contract below. This docstring is the source of truth for components contract.
|
|
|
|
<hr>
|
|
|
|
`@component` decorator
|
|
|
|
All component classes must be decorated with the `@component` decorator. This allows Haystack to discover them.
|
|
|
|
<hr>
|
|
|
|
`__init__(self, **kwargs)`
|
|
|
|
Optional method.
|
|
|
|
Components may have an `__init__` method where they define:
|
|
|
|
- `self.init_parameters = {same parameters that the __init__ method received}`:
|
|
In this dictionary you can store any state the components wish to be persisted when they are saved.
|
|
These values will be given to the `__init__` method of a new instance when the pipeline is loaded.
|
|
Note that by default the `@component` decorator saves the arguments automatically.
|
|
However, if a component sets their own `init_parameters` manually in `__init__()`, that will be used instead.
|
|
Note: all of the values contained here **must be JSON serializable**. Serialize them manually if needed.
|
|
|
|
Components should take only "basic" Python types as parameters of their `__init__` function, or iterables and
|
|
dictionaries containing only such values. Anything else (objects, functions, etc) will raise an exception at init
|
|
time. If there's the need for such values, consider serializing them to a string.
|
|
|
|
_(TODO explain how to use classes and functions in init. In the meantime see `test/components/test_accumulate.py`)_
|
|
|
|
The `__init__` must be extremely lightweight, because it's a frequent operation during the construction and
|
|
validation of the pipeline. If a component has some heavy state to initialize (models, backends, etc...) refer to
|
|
the `warm_up()` method.
|
|
|
|
<hr>
|
|
|
|
`warm_up(self)`
|
|
|
|
Optional method.
|
|
|
|
This method is called by Pipeline before the graph execution. Make sure to avoid double-initializations,
|
|
because Pipeline will not keep track of which components it called `warm_up()` on.
|
|
|
|
<hr>
|
|
|
|
`run(self, data)`
|
|
|
|
Mandatory method.
|
|
|
|
This is the method where the main functionality of the component should be carried out. It's called by
|
|
`Pipeline.run()`.
|
|
|
|
When the component should run, Pipeline will call this method with an instance of the dataclass returned by the
|
|
method decorated with `@component.input`. This dataclass contains:
|
|
|
|
- all the input values coming from other components connected to it,
|
|
- if any is missing, the corresponding value defined in `self.defaults`, if it exists.
|
|
|
|
`run()` must return a single instance of the dataclass declared through the method decorated with
|
|
`@component.output`.
|
|
|
|
"""
|
|
|
|
import inspect
|
|
from collections.abc import Callable
|
|
from contextlib import contextmanager
|
|
from contextvars import ContextVar
|
|
from copy import deepcopy
|
|
from dataclasses import dataclass
|
|
from types import new_class
|
|
from typing import Any, Dict, Optional, Protocol, Type, TypeVar, runtime_checkable
|
|
|
|
from typing_extensions import ParamSpec
|
|
|
|
from haystack import logging
|
|
from haystack.core.errors import ComponentError
|
|
|
|
from .sockets import Sockets
|
|
from .types import InputSocket, OutputSocket, _empty
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
P = ParamSpec("P")
|
|
R = TypeVar("R", bound=Dict[str, Any])
|
|
|
|
|
|
@dataclass
|
|
class PreInitHookPayload:
|
|
"""
|
|
Payload for the hook called before a component instance is initialized.
|
|
|
|
:param callback:
|
|
Receives the following inputs: component class and init parameter keyword args.
|
|
:param in_progress:
|
|
Flag to indicate if the hook is currently being executed.
|
|
Used to prevent it from being called recursively (if the component's constructor
|
|
instantiates another component).
|
|
"""
|
|
|
|
callback: Callable
|
|
in_progress: bool = False
|
|
|
|
|
|
_COMPONENT_PRE_INIT_HOOK: ContextVar[Optional[PreInitHookPayload]] = ContextVar("component_pre_init_hook", default=None)
|
|
|
|
|
|
@contextmanager
|
|
def _hook_component_init(callback: Callable):
|
|
"""
|
|
Context manager to set a callback that will be invoked before a component's constructor is called.
|
|
|
|
The callback receives the component class and the init parameters (as keyword arguments) and can modify the init
|
|
parameters in place.
|
|
|
|
:param callback:
|
|
Callback function to invoke.
|
|
"""
|
|
token = _COMPONENT_PRE_INIT_HOOK.set(PreInitHookPayload(callback))
|
|
try:
|
|
yield
|
|
finally:
|
|
_COMPONENT_PRE_INIT_HOOK.reset(token)
|
|
|
|
|
|
@runtime_checkable
|
|
class Component(Protocol):
|
|
"""
|
|
Note this is only used by type checking tools.
|
|
|
|
In order to implement the `Component` protocol, custom components need to
|
|
have a `run` method. The signature of the method and its return value
|
|
won't be checked, i.e. classes with the following methods:
|
|
|
|
def run(self, param: str) -> Dict[str, Any]:
|
|
...
|
|
|
|
and
|
|
|
|
def run(self, **kwargs):
|
|
...
|
|
|
|
will be both considered as respecting the protocol. This makes the type
|
|
checking much weaker, but we have other places where we ensure code is
|
|
dealing with actual Components.
|
|
|
|
The protocol is runtime checkable so it'll be possible to assert:
|
|
|
|
isinstance(MyComponent, Component)
|
|
"""
|
|
|
|
# This is the most reliable way to define the protocol for the `run` method.
|
|
# Defining a method doesn't work as different Components will have different
|
|
# arguments. Even defining here a method with `**kwargs` doesn't work as the
|
|
# expected signature must be identical.
|
|
# This makes most Language Servers and type checkers happy and shows less errors.
|
|
run: Callable[..., Dict[str, Any]]
|
|
|
|
|
|
class ComponentMeta(type):
|
|
@staticmethod
|
|
def _positional_to_kwargs(cls_type, args) -> Dict[str, Any]:
|
|
"""
|
|
Convert positional arguments to keyword arguments based on the signature of the `__init__` method.
|
|
"""
|
|
init_signature = inspect.signature(cls_type.__init__)
|
|
init_params = {name: info for name, info in init_signature.parameters.items() if name != "self"}
|
|
|
|
out = {}
|
|
for arg, (name, info) in zip(args, init_params.items()):
|
|
if info.kind == inspect.Parameter.VAR_POSITIONAL:
|
|
raise ComponentError(
|
|
"Pre-init hooks do not support components with variadic positional args in their init method"
|
|
)
|
|
|
|
assert info.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.POSITIONAL_ONLY)
|
|
out[name] = arg
|
|
return out
|
|
|
|
@staticmethod
|
|
def _parse_and_set_output_sockets(instance: Any):
|
|
has_async_run = hasattr(instance, "run_async")
|
|
|
|
# If `component.set_output_types()` was called in the component constructor,
|
|
# `__haystack_output__` is already populated, no need to do anything.
|
|
if not hasattr(instance, "__haystack_output__"):
|
|
# If that's not the case, we need to populate `__haystack_output__`
|
|
#
|
|
# If either of the run methods were decorated, they'll have a field assigned that
|
|
# stores the output specification. If both run methods were decorated, we ensure that
|
|
# outputs are the same. We deepcopy the content of the cache to transfer ownership from
|
|
# the class method to the actual instance, so that different instances of the same class
|
|
# won't share this data.
|
|
|
|
run_output_types = getattr(instance.run, "_output_types_cache", {})
|
|
async_run_output_types = getattr(instance.run_async, "_output_types_cache", {}) if has_async_run else {}
|
|
|
|
if has_async_run and run_output_types != async_run_output_types:
|
|
raise ComponentError("Output type specifications of 'run' and 'run_async' methods must be the same")
|
|
output_types_cache = run_output_types
|
|
|
|
instance.__haystack_output__ = Sockets(instance, deepcopy(output_types_cache), OutputSocket)
|
|
|
|
@staticmethod
|
|
def _parse_and_set_input_sockets(component_cls: Type, instance: Any):
|
|
def inner(method, sockets):
|
|
from inspect import Parameter
|
|
|
|
run_signature = inspect.signature(method)
|
|
|
|
for param_name, param_info in run_signature.parameters.items():
|
|
if param_name == "self" or param_info.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
|
|
continue
|
|
|
|
socket_kwargs = {"name": param_name, "type": param_info.annotation}
|
|
if param_info.default != Parameter.empty:
|
|
socket_kwargs["default_value"] = param_info.default
|
|
|
|
new_socket = InputSocket(**socket_kwargs)
|
|
|
|
# Also ensure that new sockets don't override existing ones.
|
|
existing_socket = sockets.get(param_name)
|
|
if existing_socket is not None and existing_socket != new_socket:
|
|
raise ComponentError(
|
|
"set_input_types()/set_input_type() cannot override the parameters of the 'run' method"
|
|
)
|
|
|
|
sockets[param_name] = new_socket
|
|
|
|
return run_signature
|
|
|
|
# Create the sockets if set_input_types() wasn't called in the constructor.
|
|
if not hasattr(instance, "__haystack_input__"):
|
|
instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
|
|
|
|
inner(getattr(component_cls, "run"), instance.__haystack_input__)
|
|
|
|
# Ensure that the sockets are the same for the async method, if it exists.
|
|
async_run = getattr(component_cls, "run_async", None)
|
|
if async_run is not None:
|
|
run_sockets = Sockets(instance, {}, InputSocket)
|
|
async_run_sockets = Sockets(instance, {}, InputSocket)
|
|
|
|
# Can't use the sockets from above as they might contain
|
|
# values set with set_input_types().
|
|
run_sig = inner(getattr(component_cls, "run"), run_sockets)
|
|
async_run_sig = inner(async_run, async_run_sockets)
|
|
|
|
if async_run_sockets != run_sockets or run_sig != async_run_sig:
|
|
raise ComponentError("Parameters of 'run' and 'run_async' methods must be the same")
|
|
|
|
def __call__(cls, *args, **kwargs):
|
|
"""
|
|
This method is called when clients instantiate a Component and runs before __new__ and __init__.
|
|
"""
|
|
# This will call __new__ then __init__, giving us back the Component instance
|
|
pre_init_hook = _COMPONENT_PRE_INIT_HOOK.get()
|
|
if pre_init_hook is None or pre_init_hook.in_progress:
|
|
instance = super().__call__(*args, **kwargs)
|
|
else:
|
|
try:
|
|
pre_init_hook.in_progress = True
|
|
named_positional_args = ComponentMeta._positional_to_kwargs(cls, args)
|
|
assert set(named_positional_args.keys()).intersection(kwargs.keys()) == set(), (
|
|
"positional and keyword arguments overlap"
|
|
)
|
|
kwargs.update(named_positional_args)
|
|
pre_init_hook.callback(cls, kwargs)
|
|
instance = super().__call__(**kwargs)
|
|
finally:
|
|
pre_init_hook.in_progress = False
|
|
|
|
# Before returning, we have the chance to modify the newly created
|
|
# Component instance, so we take the chance and set up the I/O sockets
|
|
has_async_run = hasattr(instance, "run_async")
|
|
if has_async_run and not inspect.iscoroutinefunction(instance.run_async):
|
|
raise ComponentError(f"Method 'run_async' of component '{cls.__name__}' must be a coroutine")
|
|
instance.__haystack_supports_async__ = has_async_run
|
|
|
|
ComponentMeta._parse_and_set_input_sockets(cls, instance)
|
|
ComponentMeta._parse_and_set_output_sockets(instance)
|
|
|
|
# Since a Component can't be used in multiple Pipelines at the same time
|
|
# we need to know if it's already owned by a Pipeline when adding it to one.
|
|
# We use this flag to check that.
|
|
instance.__haystack_added_to_pipeline__ = None
|
|
|
|
return instance
|
|
|
|
|
|
def _component_repr(component: Component) -> str:
|
|
"""
|
|
All Components override their __repr__ method with this one.
|
|
|
|
It prints the component name and the input/output sockets.
|
|
"""
|
|
result = object.__repr__(component)
|
|
if pipeline := getattr(component, "__haystack_added_to_pipeline__", None):
|
|
# This Component has been added in a Pipeline, let's get the name from there.
|
|
result += f"\n{pipeline.get_component_name(component)}"
|
|
|
|
# We're explicitly ignoring the type here because we're sure that the component
|
|
# has the __haystack_input__ and __haystack_output__ attributes at this point
|
|
return (
|
|
f"{result}\n{getattr(component, '__haystack_input__', '<invalid_input_sockets>')}"
|
|
f"\n{getattr(component, '__haystack_output__', '<invalid_output_sockets>')}"
|
|
)
|
|
|
|
|
|
def _component_run_has_kwargs(component_cls: Type) -> bool:
|
|
run_method = getattr(component_cls, "run", None)
|
|
if run_method is None:
|
|
return False
|
|
else:
|
|
return any(
|
|
param.kind == inspect.Parameter.VAR_KEYWORD for param in inspect.signature(run_method).parameters.values()
|
|
)
|
|
|
|
|
|
class _Component:
|
|
"""
|
|
See module's docstring.
|
|
|
|
Args:
|
|
cls: the class that should be used as a component.
|
|
|
|
Returns:
|
|
A class that can be recognized as a component.
|
|
|
|
Raises:
|
|
ComponentError: if the class provided has no `run()` method or otherwise doesn't respect the component contract.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.registry = {}
|
|
|
|
def set_input_type(
|
|
self,
|
|
instance,
|
|
name: str,
|
|
type: Any, # noqa: A002
|
|
default: Any = _empty,
|
|
):
|
|
"""
|
|
Add a single input socket to the component instance.
|
|
|
|
Replaces any existing input socket with the same name.
|
|
|
|
:param instance: Component instance where the input type will be added.
|
|
:param name: name of the input socket.
|
|
:param type: type of the input socket.
|
|
:param default: default value of the input socket, defaults to _empty
|
|
"""
|
|
if not _component_run_has_kwargs(instance.__class__):
|
|
raise ComponentError(
|
|
"Cannot set input types on a component that doesn't have a kwargs parameter in the 'run' method"
|
|
)
|
|
|
|
if not hasattr(instance, "__haystack_input__"):
|
|
instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
|
|
instance.__haystack_input__[name] = InputSocket(name=name, type=type, default_value=default)
|
|
|
|
def set_input_types(self, instance, **types):
|
|
"""
|
|
Method that specifies the input types when 'kwargs' is passed to the run method.
|
|
|
|
Use as:
|
|
|
|
```python
|
|
@component
|
|
class MyComponent:
|
|
|
|
def __init__(self, value: int):
|
|
component.set_input_types(self, value_1=str, value_2=str)
|
|
...
|
|
|
|
@component.output_types(output_1=int, output_2=str)
|
|
def run(self, **kwargs):
|
|
return {"output_1": kwargs["value_1"], "output_2": ""}
|
|
```
|
|
|
|
Note that if the `run()` method also specifies some parameters, those will take precedence.
|
|
|
|
For example:
|
|
|
|
```python
|
|
@component
|
|
class MyComponent:
|
|
|
|
def __init__(self, value: int):
|
|
component.set_input_types(self, value_1=str, value_2=str)
|
|
...
|
|
|
|
@component.output_types(output_1=int, output_2=str)
|
|
def run(self, value_0: str, value_1: Optional[str] = None, **kwargs):
|
|
return {"output_1": kwargs["value_1"], "output_2": ""}
|
|
```
|
|
|
|
would add a mandatory `value_0` parameters, make the `value_1`
|
|
parameter optional with a default None, and keep the `value_2`
|
|
parameter mandatory as specified in `set_input_types`.
|
|
|
|
"""
|
|
if not _component_run_has_kwargs(instance.__class__):
|
|
raise ComponentError(
|
|
"Cannot set input types on a component that doesn't have a kwargs parameter in the 'run' method"
|
|
)
|
|
|
|
instance.__haystack_input__ = Sockets(
|
|
instance, {name: InputSocket(name=name, type=type_) for name, type_ in types.items()}, InputSocket
|
|
)
|
|
|
|
def set_output_types(self, instance, **types):
|
|
"""
|
|
Method that specifies the output types when the 'run' method is not decorated with 'component.output_types'.
|
|
|
|
Use as:
|
|
|
|
```python
|
|
@component
|
|
class MyComponent:
|
|
|
|
def __init__(self, value: int):
|
|
component.set_output_types(self, output_1=int, output_2=str)
|
|
...
|
|
|
|
# no decorators here
|
|
def run(self, value: int):
|
|
return {"output_1": 1, "output_2": "2"}
|
|
```
|
|
"""
|
|
has_decorator = hasattr(instance.run, "_output_types_cache")
|
|
if has_decorator:
|
|
raise ComponentError(
|
|
"Cannot call `set_output_types` on a component that already has "
|
|
"the 'output_types' decorator on its `run` method"
|
|
)
|
|
|
|
instance.__haystack_output__ = Sockets(
|
|
instance, {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()}, OutputSocket
|
|
)
|
|
|
|
def output_types(self, **types: Any) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
|
"""
|
|
Decorator factory that specifies the output types of a component.
|
|
|
|
Use as:
|
|
|
|
```python
|
|
@component
|
|
class MyComponent:
|
|
@component.output_types(output_1=int, output_2=str)
|
|
def run(self, value: int):
|
|
return {"output_1": 1, "output_2": "2"}
|
|
```
|
|
"""
|
|
|
|
def output_types_decorator(run_method: Callable[P, R]) -> Callable[P, R]:
|
|
"""
|
|
Decorator that sets the output types of the decorated method.
|
|
|
|
This happens at class creation time, and since we don't have the decorated
|
|
class available here, we temporarily store the output types as an attribute of
|
|
the decorated method. The ComponentMeta metaclass will use this data to create
|
|
sockets at instance creation time.
|
|
"""
|
|
method_name = run_method.__name__
|
|
if method_name not in ("run", "run_async"):
|
|
raise ComponentError("'output_types' decorator can only be used on 'run' and 'run_async' methods")
|
|
|
|
setattr(
|
|
run_method,
|
|
"_output_types_cache",
|
|
{name: OutputSocket(name=name, type=type_) for name, type_ in types.items()},
|
|
)
|
|
return run_method
|
|
|
|
return output_types_decorator
|
|
|
|
def _component(self, cls: Any):
|
|
"""
|
|
Decorator validating the structure of the component and registering it in the components registry.
|
|
"""
|
|
logger.debug("Registering {component} as a component", component=cls)
|
|
|
|
# Check for required methods and fail as soon as possible
|
|
if not hasattr(cls, "run"):
|
|
raise ComponentError(f"{cls.__name__} must have a 'run()' method. See the docs for more information.")
|
|
|
|
def copy_class_namespace(namespace):
|
|
"""
|
|
This is the callback that `typing.new_class` will use to populate the newly created class.
|
|
|
|
Simply copy the whole namespace from the decorated class.
|
|
"""
|
|
for key, val in dict(cls.__dict__).items():
|
|
# __dict__ and __weakref__ are class-bound, we should let Python recreate them.
|
|
if key in ("__dict__", "__weakref__"):
|
|
continue
|
|
namespace[key] = val
|
|
|
|
# Recreate the decorated component class so it uses our metaclass.
|
|
# We must explicitly redefine the type of the class to make sure language servers
|
|
# and type checkers understand that the class is of the correct type.
|
|
# mypy doesn't like that we do this though so we explicitly ignore the type check.
|
|
new_cls: cls.__name__ = new_class(
|
|
cls.__name__, cls.__bases__, {"metaclass": ComponentMeta}, copy_class_namespace
|
|
) # type: ignore[no-redef]
|
|
|
|
# Save the component in the class registry (for deserialization)
|
|
class_path = f"{new_cls.__module__}.{new_cls.__name__}"
|
|
if class_path in self.registry:
|
|
# Corner case, but it may occur easily in notebooks when re-running cells.
|
|
logger.debug(
|
|
"Component {component} is already registered. Previous imported from '{module_name}', \
|
|
new imported from '{new_module_name}'",
|
|
component=class_path,
|
|
module_name=self.registry[class_path],
|
|
new_module_name=new_cls,
|
|
)
|
|
self.registry[class_path] = new_cls
|
|
logger.debug("Registered Component {component}", component=new_cls)
|
|
|
|
# Override the __repr__ method with a default one
|
|
new_cls.__repr__ = _component_repr
|
|
|
|
return new_cls
|
|
|
|
def __call__(self, cls: Optional[type] = None):
|
|
# We must wrap the call to the decorator in a function for it to work
|
|
# correctly with or without parens
|
|
def wrap(cls):
|
|
return self._component(cls)
|
|
|
|
if cls:
|
|
# Decorator is called without parens
|
|
return wrap(cls)
|
|
|
|
# Decorator is called with parens
|
|
return wrap
|
|
|
|
|
|
component = _Component()
|