feat: Extend core component machinery to support an async run method (experimental) (#8279)

* feat: Extend core component machinery to support an async run method

* Add reno

* Fix incorrect docstring

* Make `async_run` a coroutine

* Make `supports_async` a dunder field
This commit is contained in:
Madeesh Kannan 2024-08-27 14:20:13 +02:00 committed by GitHub
parent 1fa30d4aaa
commit f0b45c873f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 213 additions and 31 deletions

View File

@ -77,7 +77,7 @@ from contextvars import ContextVar
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from types import new_class from types import new_class
from typing import Any, Dict, Optional, Protocol, runtime_checkable from typing import Any, Dict, Optional, Protocol, Type, runtime_checkable
from haystack import logging from haystack import logging
from haystack.core.errors import ComponentError from haystack.core.errors import ComponentError
@ -166,7 +166,7 @@ class Component(Protocol):
class ComponentMeta(type): class ComponentMeta(type):
@staticmethod @staticmethod
def positional_to_kwargs(cls_type, args) -> Dict[str, Any]: def _positional_to_kwargs(cls_type, args) -> Dict[str, Any]:
""" """
Convert positional arguments to keyword arguments based on the signature of the `__init__` method. Convert positional arguments to keyword arguments based on the signature of the `__init__` method.
""" """
@ -184,6 +184,66 @@ class ComponentMeta(type):
out[name] = arg out[name] = arg
return out return out
@staticmethod
def _parse_and_set_output_sockets(instance: Any):
has_async_run = hasattr(instance, "async_run")
# If `component.set_output_types()` was called in the component constructor,
# `__haystack_output__` is already populated, no need to do anything.
if not hasattr(instance, "__haystack_output__"):
# If that's not the case, we need to populate `__haystack_output__`
#
# If either of the run methods were decorated, they'll have a field assigned that
# stores the output specification. If both run methods were decorated, we ensure that
# outputs are the same. We deepcopy the content of the cache to transfer ownership from
# the class method to the actual instance, so that different instances of the same class
# won't share this data.
run_output_types = getattr(instance.run, "_output_types_cache", {})
async_run_output_types = getattr(instance.async_run, "_output_types_cache", {}) if has_async_run else {}
if has_async_run and run_output_types != async_run_output_types:
raise ComponentError("Output type specifications of 'run' and 'async_run' methods must be the same")
output_types_cache = run_output_types
instance.__haystack_output__ = Sockets(instance, deepcopy(output_types_cache), OutputSocket)
@staticmethod
def _parse_and_set_input_sockets(component_cls: Type, instance: Any):
def inner(method, sockets):
run_signature = inspect.signature(method)
# First is 'self' and it doesn't matter.
for param in list(run_signature.parameters)[1:]:
if run_signature.parameters[param].kind not in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
): # ignore variable args
socket_kwargs = {"name": param, "type": run_signature.parameters[param].annotation}
if run_signature.parameters[param].default != inspect.Parameter.empty:
socket_kwargs["default_value"] = run_signature.parameters[param].default
sockets[param] = InputSocket(**socket_kwargs)
# Create the sockets if set_input_types() wasn't called in the constructor.
# If it was called and there are some parameters also in the `run()` method, these take precedence.
if not hasattr(instance, "__haystack_input__"):
instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
inner(getattr(component_cls, "run"), instance.__haystack_input__)
# Ensure that the sockets are the same for the async method, if it exists.
async_run = getattr(component_cls, "async_run", None)
if async_run is not None:
run_sockets = Sockets(instance, {}, InputSocket)
async_run_sockets = Sockets(instance, {}, InputSocket)
# Can't use the sockets from above as they might contain
# values set with set_input_types().
inner(getattr(component_cls, "run"), run_sockets)
inner(async_run, async_run_sockets)
if async_run_sockets != run_sockets:
raise ComponentError("Parameters of 'run' and 'async_run' methods must be the same")
def __call__(cls, *args, **kwargs): def __call__(cls, *args, **kwargs):
""" """
This method is called when clients instantiate a Component and runs before __new__ and __init__. This method is called when clients instantiate a Component and runs before __new__ and __init__.
@ -195,7 +255,7 @@ class ComponentMeta(type):
else: else:
try: try:
pre_init_hook.in_progress = True pre_init_hook.in_progress = True
named_positional_args = ComponentMeta.positional_to_kwargs(cls, args) named_positional_args = ComponentMeta._positional_to_kwargs(cls, args)
assert ( assert (
set(named_positional_args.keys()).intersection(kwargs.keys()) == set() set(named_positional_args.keys()).intersection(kwargs.keys()) == set()
), "positional and keyword arguments overlap" ), "positional and keyword arguments overlap"
@ -207,34 +267,13 @@ class ComponentMeta(type):
# Before returning, we have the chance to modify the newly created # Before returning, we have the chance to modify the newly created
# Component instance, so we take the chance and set up the I/O sockets # Component instance, so we take the chance and set up the I/O sockets
has_async_run = hasattr(instance, "async_run")
if has_async_run and not inspect.iscoroutinefunction(instance.async_run):
raise ComponentError(f"Method 'async_run' of component '{cls.__name__}' must be a coroutine")
instance.__haystack_supports_async__ = has_async_run
# If `component.set_output_types()` was called in the component constructor, ComponentMeta._parse_and_set_input_sockets(cls, instance)
# `__haystack_output__` is already populated, no need to do anything. ComponentMeta._parse_and_set_output_sockets(instance)
if not hasattr(instance, "__haystack_output__"):
# If that's not the case, we need to populate `__haystack_output__`
#
# If the `run` method was decorated, it has a `_output_types_cache` field assigned
# that stores the output specification.
# We deepcopy the content of the cache to transfer ownership from the class method
# to the actual instance, so that different instances of the same class won't share this data.
instance.__haystack_output__ = Sockets(
instance, deepcopy(getattr(instance.run, "_output_types_cache", {})), OutputSocket
)
# Create the sockets if set_input_types() wasn't called in the constructor.
# If it was called and there are some parameters also in the `run()` method, these take precedence.
if not hasattr(instance, "__haystack_input__"):
instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
run_signature = inspect.signature(getattr(cls, "run"))
for param in list(run_signature.parameters)[1:]: # First is 'self' and it doesn't matter.
if run_signature.parameters[param].kind not in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
): # ignore variable args
socket_kwargs = {"name": param, "type": run_signature.parameters[param].annotation}
if run_signature.parameters[param].default != inspect.Parameter.empty:
socket_kwargs["default_value"] = run_signature.parameters[param].default
instance.__haystack_input__[param] = InputSocket(**socket_kwargs)
# Since a Component can't be used in multiple Pipelines at the same time # Since a Component can't be used in multiple Pipelines at the same time
# we need to know if it's already owned by a Pipeline when adding it to one. # we need to know if it's already owned by a Pipeline when adding it to one.
@ -290,7 +329,13 @@ class _Component:
def __init__(self): def __init__(self):
self.registry = {} self.registry = {}
def set_input_type(self, instance, name: str, type: Any, default: Any = _empty): # noqa: A002 def set_input_type(
self,
instance,
name: str,
type: Any, # noqa: A002
default: Any = _empty,
):
""" """
Add a single input socket to the component instance. Add a single input socket to the component instance.
@ -395,6 +440,10 @@ class _Component:
the decorated method. The ComponentMeta metaclass will use this data to create the decorated method. The ComponentMeta metaclass will use this data to create
sockets at instance creation time. sockets at instance creation time.
""" """
method_name = run_method.__name__
if method_name not in ("run", "async_run"):
raise ComponentError("'output_types' decorator can only be used on 'run' and `async_run` methods")
setattr( setattr(
run_method, run_method,
"_output_types_cache", "_output_types_cache",

View File

@ -80,6 +80,16 @@ class Sockets:
self._sockets_dict = sockets_dict self._sockets_dict = sockets_dict
self.__dict__.update(sockets_dict) self.__dict__.update(sockets_dict)
def __eq__(self, value: object) -> bool:
if not isinstance(value, Sockets):
return False
return (
self._sockets_io_type == value._sockets_io_type
and self._component == value._component
and self._sockets_dict == value._sockets_dict
)
def __setitem__(self, key: str, socket: Union[InputSocket, OutputSocket]): def __setitem__(self, key: str, socket: Union[InputSocket, OutputSocket]):
""" """
Adds a new socket to this Sockets object. Adds a new socket to this Sockets object.

View File

@ -0,0 +1,6 @@
---
features:
- |
Extend core component machinery to support an optional asynchronous `async_run` method in components.
If it's present, it should have the same parameters (and output types) as the run method and must be
implemented as a coroutine.

View File

@ -31,6 +31,33 @@ def test_correct_declaration():
# Verifies also instantiation works with no issues # Verifies also instantiation works with no issues
assert MockComponent() assert MockComponent()
assert component.registry["test_component.MockComponent"] == MockComponent assert component.registry["test_component.MockComponent"] == MockComponent
assert isinstance(MockComponent(), Component)
assert MockComponent().__haystack_supports_async__ is False
def test_correct_declaration_with_async():
@component
class MockComponent:
def to_dict(self):
return {}
@classmethod
def from_dict(cls, data):
return cls()
@component.output_types(output_value=int)
def run(self, input_value: int):
return {"output_value": input_value}
@component.output_types(output_value=int)
async def async_run(self, input_value: int):
return {"output_value": input_value}
# Verifies also instantiation works with no issues
assert MockComponent()
assert component.registry["test_component.MockComponent"] == MockComponent
assert isinstance(MockComponent(), Component)
assert MockComponent().__haystack_supports_async__ is True
def test_correct_declaration_with_additional_readonly_property(): def test_correct_declaration_with_additional_readonly_property():
@ -95,6 +122,50 @@ def test_missing_run():
return {"output_value": input_value} return {"output_value": input_value}
def test_async_run_not_async():
@component
class MockComponent:
@component.output_types(value=int)
def run(self, value: int):
return {"value": 1}
@component.output_types(value=int)
def async_run(self, value: int):
return {"value": 1}
with pytest.raises(ComponentError):
comp = MockComponent()
def test_async_run_not_coroutine():
@component
class MockComponent:
@component.output_types(value=int)
def run(self, value: int):
return {"value": 1}
@component.output_types(value=int)
async def async_run(self, value: int):
yield {"value": 1}
with pytest.raises(ComponentError):
comp = MockComponent()
def test_parameters_mismatch_run_and_async_run():
@component
class MockComponent:
@component.output_types(value=int)
def run(self, value: int):
return {"value": 1}
async def async_run(self, value: str):
yield {"value": "1"}
with pytest.raises(ComponentError):
comp = MockComponent()
def test_set_input_types(): def test_set_input_types():
@component @component
class MockComponent: class MockComponent:
@ -155,6 +226,52 @@ def test_output_types_decorator_with_compatible_type():
assert comp.__haystack_output__._sockets_dict == {"value": OutputSocket("value", int)} assert comp.__haystack_output__._sockets_dict == {"value": OutputSocket("value", int)}
def test_output_types_decorator_wrong_method():
with pytest.raises(ComponentError):
@component
class MockComponent:
def run(self, value: int):
return {"value": 1}
@component.output_types(value=int)
def to_dict(self):
return {}
@classmethod
def from_dict(cls, data):
return cls()
def test_output_types_decorator_mismatch_run_async_run():
@component
class MockComponent:
@component.output_types(value=int)
def run(self, value: int):
return {"value": 1}
@component.output_types(value=str)
async def async_run(self, value: int):
return {"value": "1"}
with pytest.raises(ComponentError):
comp = MockComponent()
def test_output_types_decorator_missing_async_run():
@component
class MockComponent:
@component.output_types(value=int)
def run(self, value: int):
return {"value": 1}
async def async_run(self, value: int):
return {"value": "1"}
with pytest.raises(ComponentError):
comp = MockComponent()
def test_component_decorator_set_it_as_component(): def test_component_decorator_set_it_as_component():
@component @component
class MockComponent: class MockComponent: