From f0b45c873fa3da5e45abd8b3db3d78e930c2d53c Mon Sep 17 00:00:00 2001 From: Madeesh Kannan Date: Tue, 27 Aug 2024 14:20:13 +0200 Subject: [PATCH] feat: Extend core component machinery to support an async run method (experimental) (#8279) * feat: Extend core component machinery to support an async run method * Add reno * Fix incorrect docstring * Make `async_run` a coroutine * Make `supports_async` a dunder field --- haystack/core/component/component.py | 111 ++++++++++++----- haystack/core/component/sockets.py | 10 ++ ...nt-support-machinery-6ea4496241aeb3b2.yaml | 6 + test/core/component/test_component.py | 117 ++++++++++++++++++ 4 files changed, 213 insertions(+), 31 deletions(-) create mode 100644 releasenotes/notes/async-component-support-machinery-6ea4496241aeb3b2.yaml diff --git a/haystack/core/component/component.py b/haystack/core/component/component.py index 12e3643c1..e5036554f 100644 --- a/haystack/core/component/component.py +++ b/haystack/core/component/component.py @@ -77,7 +77,7 @@ from contextvars import ContextVar from copy import deepcopy from dataclasses import dataclass from types import new_class -from typing import Any, Dict, Optional, Protocol, runtime_checkable +from typing import Any, Dict, Optional, Protocol, Type, runtime_checkable from haystack import logging from haystack.core.errors import ComponentError @@ -166,7 +166,7 @@ class Component(Protocol): class ComponentMeta(type): @staticmethod - def positional_to_kwargs(cls_type, args) -> Dict[str, Any]: + def _positional_to_kwargs(cls_type, args) -> Dict[str, Any]: """ Convert positional arguments to keyword arguments based on the signature of the `__init__` method. """ @@ -184,6 +184,66 @@ class ComponentMeta(type): out[name] = arg return out + @staticmethod + def _parse_and_set_output_sockets(instance: Any): + has_async_run = hasattr(instance, "async_run") + + # If `component.set_output_types()` was called in the component constructor, + # `__haystack_output__` is already populated, no need to do anything. + if not hasattr(instance, "__haystack_output__"): + # If that's not the case, we need to populate `__haystack_output__` + # + # If either of the run methods were decorated, they'll have a field assigned that + # stores the output specification. If both run methods were decorated, we ensure that + # outputs are the same. We deepcopy the content of the cache to transfer ownership from + # the class method to the actual instance, so that different instances of the same class + # won't share this data. + + run_output_types = getattr(instance.run, "_output_types_cache", {}) + async_run_output_types = getattr(instance.async_run, "_output_types_cache", {}) if has_async_run else {} + + if has_async_run and run_output_types != async_run_output_types: + raise ComponentError("Output type specifications of 'run' and 'async_run' methods must be the same") + output_types_cache = run_output_types + + instance.__haystack_output__ = Sockets(instance, deepcopy(output_types_cache), OutputSocket) + + @staticmethod + def _parse_and_set_input_sockets(component_cls: Type, instance: Any): + def inner(method, sockets): + run_signature = inspect.signature(method) + + # First is 'self' and it doesn't matter. + for param in list(run_signature.parameters)[1:]: + if run_signature.parameters[param].kind not in ( + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + ): # ignore variable args + socket_kwargs = {"name": param, "type": run_signature.parameters[param].annotation} + if run_signature.parameters[param].default != inspect.Parameter.empty: + socket_kwargs["default_value"] = run_signature.parameters[param].default + sockets[param] = InputSocket(**socket_kwargs) + + # Create the sockets if set_input_types() wasn't called in the constructor. + # If it was called and there are some parameters also in the `run()` method, these take precedence. + if not hasattr(instance, "__haystack_input__"): + instance.__haystack_input__ = Sockets(instance, {}, InputSocket) + + inner(getattr(component_cls, "run"), instance.__haystack_input__) + + # Ensure that the sockets are the same for the async method, if it exists. + async_run = getattr(component_cls, "async_run", None) + if async_run is not None: + run_sockets = Sockets(instance, {}, InputSocket) + async_run_sockets = Sockets(instance, {}, InputSocket) + + # Can't use the sockets from above as they might contain + # values set with set_input_types(). + inner(getattr(component_cls, "run"), run_sockets) + inner(async_run, async_run_sockets) + if async_run_sockets != run_sockets: + raise ComponentError("Parameters of 'run' and 'async_run' methods must be the same") + def __call__(cls, *args, **kwargs): """ This method is called when clients instantiate a Component and runs before __new__ and __init__. @@ -195,7 +255,7 @@ class ComponentMeta(type): else: try: pre_init_hook.in_progress = True - named_positional_args = ComponentMeta.positional_to_kwargs(cls, args) + named_positional_args = ComponentMeta._positional_to_kwargs(cls, args) assert ( set(named_positional_args.keys()).intersection(kwargs.keys()) == set() ), "positional and keyword arguments overlap" @@ -207,34 +267,13 @@ class ComponentMeta(type): # Before returning, we have the chance to modify the newly created # Component instance, so we take the chance and set up the I/O sockets + has_async_run = hasattr(instance, "async_run") + if has_async_run and not inspect.iscoroutinefunction(instance.async_run): + raise ComponentError(f"Method 'async_run' of component '{cls.__name__}' must be a coroutine") + instance.__haystack_supports_async__ = has_async_run - # If `component.set_output_types()` was called in the component constructor, - # `__haystack_output__` is already populated, no need to do anything. - if not hasattr(instance, "__haystack_output__"): - # If that's not the case, we need to populate `__haystack_output__` - # - # If the `run` method was decorated, it has a `_output_types_cache` field assigned - # that stores the output specification. - # We deepcopy the content of the cache to transfer ownership from the class method - # to the actual instance, so that different instances of the same class won't share this data. - instance.__haystack_output__ = Sockets( - instance, deepcopy(getattr(instance.run, "_output_types_cache", {})), OutputSocket - ) - - # Create the sockets if set_input_types() wasn't called in the constructor. - # If it was called and there are some parameters also in the `run()` method, these take precedence. - if not hasattr(instance, "__haystack_input__"): - instance.__haystack_input__ = Sockets(instance, {}, InputSocket) - run_signature = inspect.signature(getattr(cls, "run")) - for param in list(run_signature.parameters)[1:]: # First is 'self' and it doesn't matter. - if run_signature.parameters[param].kind not in ( - inspect.Parameter.VAR_POSITIONAL, - inspect.Parameter.VAR_KEYWORD, - ): # ignore variable args - socket_kwargs = {"name": param, "type": run_signature.parameters[param].annotation} - if run_signature.parameters[param].default != inspect.Parameter.empty: - socket_kwargs["default_value"] = run_signature.parameters[param].default - instance.__haystack_input__[param] = InputSocket(**socket_kwargs) + ComponentMeta._parse_and_set_input_sockets(cls, instance) + ComponentMeta._parse_and_set_output_sockets(instance) # Since a Component can't be used in multiple Pipelines at the same time # we need to know if it's already owned by a Pipeline when adding it to one. @@ -290,7 +329,13 @@ class _Component: def __init__(self): self.registry = {} - def set_input_type(self, instance, name: str, type: Any, default: Any = _empty): # noqa: A002 + def set_input_type( + self, + instance, + name: str, + type: Any, # noqa: A002 + default: Any = _empty, + ): """ Add a single input socket to the component instance. @@ -395,6 +440,10 @@ class _Component: the decorated method. The ComponentMeta metaclass will use this data to create sockets at instance creation time. """ + method_name = run_method.__name__ + if method_name not in ("run", "async_run"): + raise ComponentError("'output_types' decorator can only be used on 'run' and `async_run` methods") + setattr( run_method, "_output_types_cache", diff --git a/haystack/core/component/sockets.py b/haystack/core/component/sockets.py index 19a3f798f..22289ae9f 100644 --- a/haystack/core/component/sockets.py +++ b/haystack/core/component/sockets.py @@ -80,6 +80,16 @@ class Sockets: self._sockets_dict = sockets_dict self.__dict__.update(sockets_dict) + def __eq__(self, value: object) -> bool: + if not isinstance(value, Sockets): + return False + + return ( + self._sockets_io_type == value._sockets_io_type + and self._component == value._component + and self._sockets_dict == value._sockets_dict + ) + def __setitem__(self, key: str, socket: Union[InputSocket, OutputSocket]): """ Adds a new socket to this Sockets object. diff --git a/releasenotes/notes/async-component-support-machinery-6ea4496241aeb3b2.yaml b/releasenotes/notes/async-component-support-machinery-6ea4496241aeb3b2.yaml new file mode 100644 index 000000000..bc4b8f55d --- /dev/null +++ b/releasenotes/notes/async-component-support-machinery-6ea4496241aeb3b2.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + Extend core component machinery to support an optional asynchronous `async_run` method in components. + If it's present, it should have the same parameters (and output types) as the run method and must be + implemented as a coroutine. diff --git a/test/core/component/test_component.py b/test/core/component/test_component.py index 42d2cdb1a..c81fbf8ae 100644 --- a/test/core/component/test_component.py +++ b/test/core/component/test_component.py @@ -31,6 +31,33 @@ def test_correct_declaration(): # Verifies also instantiation works with no issues assert MockComponent() assert component.registry["test_component.MockComponent"] == MockComponent + assert isinstance(MockComponent(), Component) + assert MockComponent().__haystack_supports_async__ is False + + +def test_correct_declaration_with_async(): + @component + class MockComponent: + def to_dict(self): + return {} + + @classmethod + def from_dict(cls, data): + return cls() + + @component.output_types(output_value=int) + def run(self, input_value: int): + return {"output_value": input_value} + + @component.output_types(output_value=int) + async def async_run(self, input_value: int): + return {"output_value": input_value} + + # Verifies also instantiation works with no issues + assert MockComponent() + assert component.registry["test_component.MockComponent"] == MockComponent + assert isinstance(MockComponent(), Component) + assert MockComponent().__haystack_supports_async__ is True def test_correct_declaration_with_additional_readonly_property(): @@ -95,6 +122,50 @@ def test_missing_run(): return {"output_value": input_value} +def test_async_run_not_async(): + @component + class MockComponent: + @component.output_types(value=int) + def run(self, value: int): + return {"value": 1} + + @component.output_types(value=int) + def async_run(self, value: int): + return {"value": 1} + + with pytest.raises(ComponentError): + comp = MockComponent() + + +def test_async_run_not_coroutine(): + @component + class MockComponent: + @component.output_types(value=int) + def run(self, value: int): + return {"value": 1} + + @component.output_types(value=int) + async def async_run(self, value: int): + yield {"value": 1} + + with pytest.raises(ComponentError): + comp = MockComponent() + + +def test_parameters_mismatch_run_and_async_run(): + @component + class MockComponent: + @component.output_types(value=int) + def run(self, value: int): + return {"value": 1} + + async def async_run(self, value: str): + yield {"value": "1"} + + with pytest.raises(ComponentError): + comp = MockComponent() + + def test_set_input_types(): @component class MockComponent: @@ -155,6 +226,52 @@ def test_output_types_decorator_with_compatible_type(): assert comp.__haystack_output__._sockets_dict == {"value": OutputSocket("value", int)} +def test_output_types_decorator_wrong_method(): + with pytest.raises(ComponentError): + + @component + class MockComponent: + def run(self, value: int): + return {"value": 1} + + @component.output_types(value=int) + def to_dict(self): + return {} + + @classmethod + def from_dict(cls, data): + return cls() + + +def test_output_types_decorator_mismatch_run_async_run(): + @component + class MockComponent: + @component.output_types(value=int) + def run(self, value: int): + return {"value": 1} + + @component.output_types(value=str) + async def async_run(self, value: int): + return {"value": "1"} + + with pytest.raises(ComponentError): + comp = MockComponent() + + +def test_output_types_decorator_missing_async_run(): + @component + class MockComponent: + @component.output_types(value=int) + def run(self, value: int): + return {"value": 1} + + async def async_run(self, value: int): + return {"value": "1"} + + with pytest.raises(ComponentError): + comp = MockComponent() + + def test_component_decorator_set_it_as_component(): @component class MockComponent: