mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-17 10:34:10 +00:00
feat: Extend core component machinery to support an async run method (experimental) (#8279)
* feat: Extend core component machinery to support an async run method * Add reno * Fix incorrect docstring * Make `async_run` a coroutine * Make `supports_async` a dunder field
This commit is contained in:
parent
1fa30d4aaa
commit
f0b45c873f
@ -77,7 +77,7 @@ from contextvars import ContextVar
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from types import new_class
|
from types import new_class
|
||||||
from typing import Any, Dict, Optional, Protocol, runtime_checkable
|
from typing import Any, Dict, Optional, Protocol, Type, runtime_checkable
|
||||||
|
|
||||||
from haystack import logging
|
from haystack import logging
|
||||||
from haystack.core.errors import ComponentError
|
from haystack.core.errors import ComponentError
|
||||||
@ -166,7 +166,7 @@ class Component(Protocol):
|
|||||||
|
|
||||||
class ComponentMeta(type):
|
class ComponentMeta(type):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def positional_to_kwargs(cls_type, args) -> Dict[str, Any]:
|
def _positional_to_kwargs(cls_type, args) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Convert positional arguments to keyword arguments based on the signature of the `__init__` method.
|
Convert positional arguments to keyword arguments based on the signature of the `__init__` method.
|
||||||
"""
|
"""
|
||||||
@ -184,6 +184,66 @@ class ComponentMeta(type):
|
|||||||
out[name] = arg
|
out[name] = arg
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_and_set_output_sockets(instance: Any):
|
||||||
|
has_async_run = hasattr(instance, "async_run")
|
||||||
|
|
||||||
|
# If `component.set_output_types()` was called in the component constructor,
|
||||||
|
# `__haystack_output__` is already populated, no need to do anything.
|
||||||
|
if not hasattr(instance, "__haystack_output__"):
|
||||||
|
# If that's not the case, we need to populate `__haystack_output__`
|
||||||
|
#
|
||||||
|
# If either of the run methods were decorated, they'll have a field assigned that
|
||||||
|
# stores the output specification. If both run methods were decorated, we ensure that
|
||||||
|
# outputs are the same. We deepcopy the content of the cache to transfer ownership from
|
||||||
|
# the class method to the actual instance, so that different instances of the same class
|
||||||
|
# won't share this data.
|
||||||
|
|
||||||
|
run_output_types = getattr(instance.run, "_output_types_cache", {})
|
||||||
|
async_run_output_types = getattr(instance.async_run, "_output_types_cache", {}) if has_async_run else {}
|
||||||
|
|
||||||
|
if has_async_run and run_output_types != async_run_output_types:
|
||||||
|
raise ComponentError("Output type specifications of 'run' and 'async_run' methods must be the same")
|
||||||
|
output_types_cache = run_output_types
|
||||||
|
|
||||||
|
instance.__haystack_output__ = Sockets(instance, deepcopy(output_types_cache), OutputSocket)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_and_set_input_sockets(component_cls: Type, instance: Any):
|
||||||
|
def inner(method, sockets):
|
||||||
|
run_signature = inspect.signature(method)
|
||||||
|
|
||||||
|
# First is 'self' and it doesn't matter.
|
||||||
|
for param in list(run_signature.parameters)[1:]:
|
||||||
|
if run_signature.parameters[param].kind not in (
|
||||||
|
inspect.Parameter.VAR_POSITIONAL,
|
||||||
|
inspect.Parameter.VAR_KEYWORD,
|
||||||
|
): # ignore variable args
|
||||||
|
socket_kwargs = {"name": param, "type": run_signature.parameters[param].annotation}
|
||||||
|
if run_signature.parameters[param].default != inspect.Parameter.empty:
|
||||||
|
socket_kwargs["default_value"] = run_signature.parameters[param].default
|
||||||
|
sockets[param] = InputSocket(**socket_kwargs)
|
||||||
|
|
||||||
|
# Create the sockets if set_input_types() wasn't called in the constructor.
|
||||||
|
# If it was called and there are some parameters also in the `run()` method, these take precedence.
|
||||||
|
if not hasattr(instance, "__haystack_input__"):
|
||||||
|
instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
|
||||||
|
|
||||||
|
inner(getattr(component_cls, "run"), instance.__haystack_input__)
|
||||||
|
|
||||||
|
# Ensure that the sockets are the same for the async method, if it exists.
|
||||||
|
async_run = getattr(component_cls, "async_run", None)
|
||||||
|
if async_run is not None:
|
||||||
|
run_sockets = Sockets(instance, {}, InputSocket)
|
||||||
|
async_run_sockets = Sockets(instance, {}, InputSocket)
|
||||||
|
|
||||||
|
# Can't use the sockets from above as they might contain
|
||||||
|
# values set with set_input_types().
|
||||||
|
inner(getattr(component_cls, "run"), run_sockets)
|
||||||
|
inner(async_run, async_run_sockets)
|
||||||
|
if async_run_sockets != run_sockets:
|
||||||
|
raise ComponentError("Parameters of 'run' and 'async_run' methods must be the same")
|
||||||
|
|
||||||
def __call__(cls, *args, **kwargs):
|
def __call__(cls, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
This method is called when clients instantiate a Component and runs before __new__ and __init__.
|
This method is called when clients instantiate a Component and runs before __new__ and __init__.
|
||||||
@ -195,7 +255,7 @@ class ComponentMeta(type):
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
pre_init_hook.in_progress = True
|
pre_init_hook.in_progress = True
|
||||||
named_positional_args = ComponentMeta.positional_to_kwargs(cls, args)
|
named_positional_args = ComponentMeta._positional_to_kwargs(cls, args)
|
||||||
assert (
|
assert (
|
||||||
set(named_positional_args.keys()).intersection(kwargs.keys()) == set()
|
set(named_positional_args.keys()).intersection(kwargs.keys()) == set()
|
||||||
), "positional and keyword arguments overlap"
|
), "positional and keyword arguments overlap"
|
||||||
@ -207,34 +267,13 @@ class ComponentMeta(type):
|
|||||||
|
|
||||||
# Before returning, we have the chance to modify the newly created
|
# Before returning, we have the chance to modify the newly created
|
||||||
# Component instance, so we take the chance and set up the I/O sockets
|
# Component instance, so we take the chance and set up the I/O sockets
|
||||||
|
has_async_run = hasattr(instance, "async_run")
|
||||||
|
if has_async_run and not inspect.iscoroutinefunction(instance.async_run):
|
||||||
|
raise ComponentError(f"Method 'async_run' of component '{cls.__name__}' must be a coroutine")
|
||||||
|
instance.__haystack_supports_async__ = has_async_run
|
||||||
|
|
||||||
# If `component.set_output_types()` was called in the component constructor,
|
ComponentMeta._parse_and_set_input_sockets(cls, instance)
|
||||||
# `__haystack_output__` is already populated, no need to do anything.
|
ComponentMeta._parse_and_set_output_sockets(instance)
|
||||||
if not hasattr(instance, "__haystack_output__"):
|
|
||||||
# If that's not the case, we need to populate `__haystack_output__`
|
|
||||||
#
|
|
||||||
# If the `run` method was decorated, it has a `_output_types_cache` field assigned
|
|
||||||
# that stores the output specification.
|
|
||||||
# We deepcopy the content of the cache to transfer ownership from the class method
|
|
||||||
# to the actual instance, so that different instances of the same class won't share this data.
|
|
||||||
instance.__haystack_output__ = Sockets(
|
|
||||||
instance, deepcopy(getattr(instance.run, "_output_types_cache", {})), OutputSocket
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create the sockets if set_input_types() wasn't called in the constructor.
|
|
||||||
# If it was called and there are some parameters also in the `run()` method, these take precedence.
|
|
||||||
if not hasattr(instance, "__haystack_input__"):
|
|
||||||
instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
|
|
||||||
run_signature = inspect.signature(getattr(cls, "run"))
|
|
||||||
for param in list(run_signature.parameters)[1:]: # First is 'self' and it doesn't matter.
|
|
||||||
if run_signature.parameters[param].kind not in (
|
|
||||||
inspect.Parameter.VAR_POSITIONAL,
|
|
||||||
inspect.Parameter.VAR_KEYWORD,
|
|
||||||
): # ignore variable args
|
|
||||||
socket_kwargs = {"name": param, "type": run_signature.parameters[param].annotation}
|
|
||||||
if run_signature.parameters[param].default != inspect.Parameter.empty:
|
|
||||||
socket_kwargs["default_value"] = run_signature.parameters[param].default
|
|
||||||
instance.__haystack_input__[param] = InputSocket(**socket_kwargs)
|
|
||||||
|
|
||||||
# Since a Component can't be used in multiple Pipelines at the same time
|
# Since a Component can't be used in multiple Pipelines at the same time
|
||||||
# we need to know if it's already owned by a Pipeline when adding it to one.
|
# we need to know if it's already owned by a Pipeline when adding it to one.
|
||||||
@ -290,7 +329,13 @@ class _Component:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.registry = {}
|
self.registry = {}
|
||||||
|
|
||||||
def set_input_type(self, instance, name: str, type: Any, default: Any = _empty): # noqa: A002
|
def set_input_type(
|
||||||
|
self,
|
||||||
|
instance,
|
||||||
|
name: str,
|
||||||
|
type: Any, # noqa: A002
|
||||||
|
default: Any = _empty,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Add a single input socket to the component instance.
|
Add a single input socket to the component instance.
|
||||||
|
|
||||||
@ -395,6 +440,10 @@ class _Component:
|
|||||||
the decorated method. The ComponentMeta metaclass will use this data to create
|
the decorated method. The ComponentMeta metaclass will use this data to create
|
||||||
sockets at instance creation time.
|
sockets at instance creation time.
|
||||||
"""
|
"""
|
||||||
|
method_name = run_method.__name__
|
||||||
|
if method_name not in ("run", "async_run"):
|
||||||
|
raise ComponentError("'output_types' decorator can only be used on 'run' and `async_run` methods")
|
||||||
|
|
||||||
setattr(
|
setattr(
|
||||||
run_method,
|
run_method,
|
||||||
"_output_types_cache",
|
"_output_types_cache",
|
||||||
|
|||||||
@ -80,6 +80,16 @@ class Sockets:
|
|||||||
self._sockets_dict = sockets_dict
|
self._sockets_dict = sockets_dict
|
||||||
self.__dict__.update(sockets_dict)
|
self.__dict__.update(sockets_dict)
|
||||||
|
|
||||||
|
def __eq__(self, value: object) -> bool:
|
||||||
|
if not isinstance(value, Sockets):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return (
|
||||||
|
self._sockets_io_type == value._sockets_io_type
|
||||||
|
and self._component == value._component
|
||||||
|
and self._sockets_dict == value._sockets_dict
|
||||||
|
)
|
||||||
|
|
||||||
def __setitem__(self, key: str, socket: Union[InputSocket, OutputSocket]):
|
def __setitem__(self, key: str, socket: Union[InputSocket, OutputSocket]):
|
||||||
"""
|
"""
|
||||||
Adds a new socket to this Sockets object.
|
Adds a new socket to this Sockets object.
|
||||||
|
|||||||
@ -0,0 +1,6 @@
|
|||||||
|
---
|
||||||
|
features:
|
||||||
|
- |
|
||||||
|
Extend core component machinery to support an optional asynchronous `async_run` method in components.
|
||||||
|
If it's present, it should have the same parameters (and output types) as the run method and must be
|
||||||
|
implemented as a coroutine.
|
||||||
@ -31,6 +31,33 @@ def test_correct_declaration():
|
|||||||
# Verifies also instantiation works with no issues
|
# Verifies also instantiation works with no issues
|
||||||
assert MockComponent()
|
assert MockComponent()
|
||||||
assert component.registry["test_component.MockComponent"] == MockComponent
|
assert component.registry["test_component.MockComponent"] == MockComponent
|
||||||
|
assert isinstance(MockComponent(), Component)
|
||||||
|
assert MockComponent().__haystack_supports_async__ is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_correct_declaration_with_async():
|
||||||
|
@component
|
||||||
|
class MockComponent:
|
||||||
|
def to_dict(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data):
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
@component.output_types(output_value=int)
|
||||||
|
def run(self, input_value: int):
|
||||||
|
return {"output_value": input_value}
|
||||||
|
|
||||||
|
@component.output_types(output_value=int)
|
||||||
|
async def async_run(self, input_value: int):
|
||||||
|
return {"output_value": input_value}
|
||||||
|
|
||||||
|
# Verifies also instantiation works with no issues
|
||||||
|
assert MockComponent()
|
||||||
|
assert component.registry["test_component.MockComponent"] == MockComponent
|
||||||
|
assert isinstance(MockComponent(), Component)
|
||||||
|
assert MockComponent().__haystack_supports_async__ is True
|
||||||
|
|
||||||
|
|
||||||
def test_correct_declaration_with_additional_readonly_property():
|
def test_correct_declaration_with_additional_readonly_property():
|
||||||
@ -95,6 +122,50 @@ def test_missing_run():
|
|||||||
return {"output_value": input_value}
|
return {"output_value": input_value}
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_run_not_async():
|
||||||
|
@component
|
||||||
|
class MockComponent:
|
||||||
|
@component.output_types(value=int)
|
||||||
|
def run(self, value: int):
|
||||||
|
return {"value": 1}
|
||||||
|
|
||||||
|
@component.output_types(value=int)
|
||||||
|
def async_run(self, value: int):
|
||||||
|
return {"value": 1}
|
||||||
|
|
||||||
|
with pytest.raises(ComponentError):
|
||||||
|
comp = MockComponent()
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_run_not_coroutine():
|
||||||
|
@component
|
||||||
|
class MockComponent:
|
||||||
|
@component.output_types(value=int)
|
||||||
|
def run(self, value: int):
|
||||||
|
return {"value": 1}
|
||||||
|
|
||||||
|
@component.output_types(value=int)
|
||||||
|
async def async_run(self, value: int):
|
||||||
|
yield {"value": 1}
|
||||||
|
|
||||||
|
with pytest.raises(ComponentError):
|
||||||
|
comp = MockComponent()
|
||||||
|
|
||||||
|
|
||||||
|
def test_parameters_mismatch_run_and_async_run():
|
||||||
|
@component
|
||||||
|
class MockComponent:
|
||||||
|
@component.output_types(value=int)
|
||||||
|
def run(self, value: int):
|
||||||
|
return {"value": 1}
|
||||||
|
|
||||||
|
async def async_run(self, value: str):
|
||||||
|
yield {"value": "1"}
|
||||||
|
|
||||||
|
with pytest.raises(ComponentError):
|
||||||
|
comp = MockComponent()
|
||||||
|
|
||||||
|
|
||||||
def test_set_input_types():
|
def test_set_input_types():
|
||||||
@component
|
@component
|
||||||
class MockComponent:
|
class MockComponent:
|
||||||
@ -155,6 +226,52 @@ def test_output_types_decorator_with_compatible_type():
|
|||||||
assert comp.__haystack_output__._sockets_dict == {"value": OutputSocket("value", int)}
|
assert comp.__haystack_output__._sockets_dict == {"value": OutputSocket("value", int)}
|
||||||
|
|
||||||
|
|
||||||
|
def test_output_types_decorator_wrong_method():
|
||||||
|
with pytest.raises(ComponentError):
|
||||||
|
|
||||||
|
@component
|
||||||
|
class MockComponent:
|
||||||
|
def run(self, value: int):
|
||||||
|
return {"value": 1}
|
||||||
|
|
||||||
|
@component.output_types(value=int)
|
||||||
|
def to_dict(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data):
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
|
||||||
|
def test_output_types_decorator_mismatch_run_async_run():
|
||||||
|
@component
|
||||||
|
class MockComponent:
|
||||||
|
@component.output_types(value=int)
|
||||||
|
def run(self, value: int):
|
||||||
|
return {"value": 1}
|
||||||
|
|
||||||
|
@component.output_types(value=str)
|
||||||
|
async def async_run(self, value: int):
|
||||||
|
return {"value": "1"}
|
||||||
|
|
||||||
|
with pytest.raises(ComponentError):
|
||||||
|
comp = MockComponent()
|
||||||
|
|
||||||
|
|
||||||
|
def test_output_types_decorator_missing_async_run():
|
||||||
|
@component
|
||||||
|
class MockComponent:
|
||||||
|
@component.output_types(value=int)
|
||||||
|
def run(self, value: int):
|
||||||
|
return {"value": 1}
|
||||||
|
|
||||||
|
async def async_run(self, value: int):
|
||||||
|
return {"value": "1"}
|
||||||
|
|
||||||
|
with pytest.raises(ComponentError):
|
||||||
|
comp = MockComponent()
|
||||||
|
|
||||||
|
|
||||||
def test_component_decorator_set_it_as_component():
|
def test_component_decorator_set_it_as_component():
|
||||||
@component
|
@component
|
||||||
class MockComponent:
|
class MockComponent:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user