feat: Extend core component machinery to support an async run method (experimental) (#8279)

* feat: Extend core component machinery to support an async run method

* Add reno

* Fix incorrect docstring

* Make `async_run` a coroutine

* Make `supports_async` a dunder field
This commit is contained in:
Madeesh Kannan 2024-08-27 14:20:13 +02:00 committed by GitHub
parent 1fa30d4aaa
commit f0b45c873f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 213 additions and 31 deletions

View File

@ -77,7 +77,7 @@ from contextvars import ContextVar
from copy import deepcopy
from dataclasses import dataclass
from types import new_class
from typing import Any, Dict, Optional, Protocol, runtime_checkable
from typing import Any, Dict, Optional, Protocol, Type, runtime_checkable
from haystack import logging
from haystack.core.errors import ComponentError
@ -166,7 +166,7 @@ class Component(Protocol):
class ComponentMeta(type):
@staticmethod
def positional_to_kwargs(cls_type, args) -> Dict[str, Any]:
def _positional_to_kwargs(cls_type, args) -> Dict[str, Any]:
"""
Convert positional arguments to keyword arguments based on the signature of the `__init__` method.
"""
@ -184,6 +184,66 @@ class ComponentMeta(type):
out[name] = arg
return out
@staticmethod
def _parse_and_set_output_sockets(instance: Any):
has_async_run = hasattr(instance, "async_run")
# If `component.set_output_types()` was called in the component constructor,
# `__haystack_output__` is already populated, no need to do anything.
if not hasattr(instance, "__haystack_output__"):
# If that's not the case, we need to populate `__haystack_output__`
#
# If either of the run methods were decorated, they'll have a field assigned that
# stores the output specification. If both run methods were decorated, we ensure that
# outputs are the same. We deepcopy the content of the cache to transfer ownership from
# the class method to the actual instance, so that different instances of the same class
# won't share this data.
run_output_types = getattr(instance.run, "_output_types_cache", {})
async_run_output_types = getattr(instance.async_run, "_output_types_cache", {}) if has_async_run else {}
if has_async_run and run_output_types != async_run_output_types:
raise ComponentError("Output type specifications of 'run' and 'async_run' methods must be the same")
output_types_cache = run_output_types
instance.__haystack_output__ = Sockets(instance, deepcopy(output_types_cache), OutputSocket)
@staticmethod
def _parse_and_set_input_sockets(component_cls: Type, instance: Any):
def inner(method, sockets):
run_signature = inspect.signature(method)
# First is 'self' and it doesn't matter.
for param in list(run_signature.parameters)[1:]:
if run_signature.parameters[param].kind not in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
): # ignore variable args
socket_kwargs = {"name": param, "type": run_signature.parameters[param].annotation}
if run_signature.parameters[param].default != inspect.Parameter.empty:
socket_kwargs["default_value"] = run_signature.parameters[param].default
sockets[param] = InputSocket(**socket_kwargs)
# Create the sockets if set_input_types() wasn't called in the constructor.
# If it was called and there are some parameters also in the `run()` method, these take precedence.
if not hasattr(instance, "__haystack_input__"):
instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
inner(getattr(component_cls, "run"), instance.__haystack_input__)
# Ensure that the sockets are the same for the async method, if it exists.
async_run = getattr(component_cls, "async_run", None)
if async_run is not None:
run_sockets = Sockets(instance, {}, InputSocket)
async_run_sockets = Sockets(instance, {}, InputSocket)
# Can't use the sockets from above as they might contain
# values set with set_input_types().
inner(getattr(component_cls, "run"), run_sockets)
inner(async_run, async_run_sockets)
if async_run_sockets != run_sockets:
raise ComponentError("Parameters of 'run' and 'async_run' methods must be the same")
def __call__(cls, *args, **kwargs):
"""
This method is called when clients instantiate a Component and runs before __new__ and __init__.
@ -195,7 +255,7 @@ class ComponentMeta(type):
else:
try:
pre_init_hook.in_progress = True
named_positional_args = ComponentMeta.positional_to_kwargs(cls, args)
named_positional_args = ComponentMeta._positional_to_kwargs(cls, args)
assert (
set(named_positional_args.keys()).intersection(kwargs.keys()) == set()
), "positional and keyword arguments overlap"
@ -207,34 +267,13 @@ class ComponentMeta(type):
# Before returning, we have the chance to modify the newly created
# Component instance, so we take the chance and set up the I/O sockets
has_async_run = hasattr(instance, "async_run")
if has_async_run and not inspect.iscoroutinefunction(instance.async_run):
raise ComponentError(f"Method 'async_run' of component '{cls.__name__}' must be a coroutine")
instance.__haystack_supports_async__ = has_async_run
# If `component.set_output_types()` was called in the component constructor,
# `__haystack_output__` is already populated, no need to do anything.
if not hasattr(instance, "__haystack_output__"):
# If that's not the case, we need to populate `__haystack_output__`
#
# If the `run` method was decorated, it has a `_output_types_cache` field assigned
# that stores the output specification.
# We deepcopy the content of the cache to transfer ownership from the class method
# to the actual instance, so that different instances of the same class won't share this data.
instance.__haystack_output__ = Sockets(
instance, deepcopy(getattr(instance.run, "_output_types_cache", {})), OutputSocket
)
# Create the sockets if set_input_types() wasn't called in the constructor.
# If it was called and there are some parameters also in the `run()` method, these take precedence.
if not hasattr(instance, "__haystack_input__"):
instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
run_signature = inspect.signature(getattr(cls, "run"))
for param in list(run_signature.parameters)[1:]: # First is 'self' and it doesn't matter.
if run_signature.parameters[param].kind not in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
): # ignore variable args
socket_kwargs = {"name": param, "type": run_signature.parameters[param].annotation}
if run_signature.parameters[param].default != inspect.Parameter.empty:
socket_kwargs["default_value"] = run_signature.parameters[param].default
instance.__haystack_input__[param] = InputSocket(**socket_kwargs)
ComponentMeta._parse_and_set_input_sockets(cls, instance)
ComponentMeta._parse_and_set_output_sockets(instance)
# Since a Component can't be used in multiple Pipelines at the same time
# we need to know if it's already owned by a Pipeline when adding it to one.
@ -290,7 +329,13 @@ class _Component:
def __init__(self):
self.registry = {}
def set_input_type(self, instance, name: str, type: Any, default: Any = _empty): # noqa: A002
def set_input_type(
self,
instance,
name: str,
type: Any, # noqa: A002
default: Any = _empty,
):
"""
Add a single input socket to the component instance.
@ -395,6 +440,10 @@ class _Component:
the decorated method. The ComponentMeta metaclass will use this data to create
sockets at instance creation time.
"""
method_name = run_method.__name__
if method_name not in ("run", "async_run"):
raise ComponentError("'output_types' decorator can only be used on 'run' and `async_run` methods")
setattr(
run_method,
"_output_types_cache",

View File

@ -80,6 +80,16 @@ class Sockets:
self._sockets_dict = sockets_dict
self.__dict__.update(sockets_dict)
def __eq__(self, value: object) -> bool:
if not isinstance(value, Sockets):
return False
return (
self._sockets_io_type == value._sockets_io_type
and self._component == value._component
and self._sockets_dict == value._sockets_dict
)
def __setitem__(self, key: str, socket: Union[InputSocket, OutputSocket]):
"""
Adds a new socket to this Sockets object.

View File

@ -0,0 +1,6 @@
---
features:
- |
Extend core component machinery to support an optional asynchronous `async_run` method in components.
If it's present, it should have the same parameters (and output types) as the run method and must be
implemented as a coroutine.

View File

@ -31,6 +31,33 @@ def test_correct_declaration():
# Verifies also instantiation works with no issues
assert MockComponent()
assert component.registry["test_component.MockComponent"] == MockComponent
assert isinstance(MockComponent(), Component)
assert MockComponent().__haystack_supports_async__ is False
def test_correct_declaration_with_async():
@component
class MockComponent:
def to_dict(self):
return {}
@classmethod
def from_dict(cls, data):
return cls()
@component.output_types(output_value=int)
def run(self, input_value: int):
return {"output_value": input_value}
@component.output_types(output_value=int)
async def async_run(self, input_value: int):
return {"output_value": input_value}
# Verifies also instantiation works with no issues
assert MockComponent()
assert component.registry["test_component.MockComponent"] == MockComponent
assert isinstance(MockComponent(), Component)
assert MockComponent().__haystack_supports_async__ is True
def test_correct_declaration_with_additional_readonly_property():
@ -95,6 +122,50 @@ def test_missing_run():
return {"output_value": input_value}
def test_async_run_not_async():
@component
class MockComponent:
@component.output_types(value=int)
def run(self, value: int):
return {"value": 1}
@component.output_types(value=int)
def async_run(self, value: int):
return {"value": 1}
with pytest.raises(ComponentError):
comp = MockComponent()
def test_async_run_not_coroutine():
@component
class MockComponent:
@component.output_types(value=int)
def run(self, value: int):
return {"value": 1}
@component.output_types(value=int)
async def async_run(self, value: int):
yield {"value": 1}
with pytest.raises(ComponentError):
comp = MockComponent()
def test_parameters_mismatch_run_and_async_run():
@component
class MockComponent:
@component.output_types(value=int)
def run(self, value: int):
return {"value": 1}
async def async_run(self, value: str):
yield {"value": "1"}
with pytest.raises(ComponentError):
comp = MockComponent()
def test_set_input_types():
@component
class MockComponent:
@ -155,6 +226,52 @@ def test_output_types_decorator_with_compatible_type():
assert comp.__haystack_output__._sockets_dict == {"value": OutputSocket("value", int)}
def test_output_types_decorator_wrong_method():
with pytest.raises(ComponentError):
@component
class MockComponent:
def run(self, value: int):
return {"value": 1}
@component.output_types(value=int)
def to_dict(self):
return {}
@classmethod
def from_dict(cls, data):
return cls()
def test_output_types_decorator_mismatch_run_async_run():
@component
class MockComponent:
@component.output_types(value=int)
def run(self, value: int):
return {"value": 1}
@component.output_types(value=str)
async def async_run(self, value: int):
return {"value": "1"}
with pytest.raises(ComponentError):
comp = MockComponent()
def test_output_types_decorator_missing_async_run():
@component
class MockComponent:
@component.output_types(value=int)
def run(self, value: int):
return {"value": 1}
async def async_run(self, value: int):
return {"value": "1"}
with pytest.raises(ComponentError):
comp = MockComponent()
def test_component_decorator_set_it_as_component():
@component
class MockComponent: