feat: Add is_greedy argument in @component decorator (#7016)

* Add is_greedy argument in @component decorator

* Log warning if Component is greedy and non variadic
This commit is contained in:
Silvano Cerza 2024-02-19 12:43:40 +01:00 committed by GitHub
parent 5f97e08feb
commit f1a6b2a78a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 121 additions and 25 deletions

View File

@ -15,9 +15,8 @@ else:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@component @component(is_greedy=True)
class Multiplexer: class Multiplexer:
is_greedy = True
""" """
This component is used to distribute a single value to many components that may need it. This component is used to distribute a single value to many components that may need it.
It can take such value from different sources (the user's input, or another component), so It can take such value from different sources (the user's input, or another component), so

View File

@ -72,7 +72,7 @@ import inspect
import logging import logging
from copy import deepcopy from copy import deepcopy
from types import new_class from types import new_class
from typing import Any, Protocol, runtime_checkable from typing import Any, Optional, Protocol, runtime_checkable
from haystack.core.errors import ComponentError from haystack.core.errors import ComponentError
@ -157,6 +157,17 @@ class ComponentMeta(type):
# We use this flag to check that. # We use this flag to check that.
instance.__haystack_added_to_pipeline__ = None instance.__haystack_added_to_pipeline__ = None
# Only Components with variadic inputs can be greedy. If the user set the greedy flag
# to True, but the component doesn't have a variadic input, we set it to False.
# We can have this information only at instance creation time, so we do it here.
is_variadic = any(socket.is_variadic for socket in instance.__haystack_input__._sockets_dict.values())
if not is_variadic and cls.__haystack_is_greedy__:
logging.warning(
"Component '%s' has no variadic input, but it's marked as greedy. "
"This is not supported and can lead to unexpected behavior.",
cls.__name__,
)
return instance return instance
@ -307,15 +318,15 @@ class _Component:
return output_types_decorator return output_types_decorator
def _component(self, class_): def _component(self, cls, is_greedy: bool = False):
""" """
Decorator validating the structure of the component and registering it in the components registry. Decorator validating the structure of the component and registering it in the components registry.
""" """
logger.debug("Registering %s as a component", class_) logger.debug("Registering %s as a component", cls)
# Check for required methods and fail as soon as possible # Check for required methods and fail as soon as possible
if not hasattr(class_, "run"): if not hasattr(cls, "run"):
raise ComponentError(f"{class_.__name__} must have a 'run()' method. See the docs for more information.") raise ComponentError(f"{cls.__name__} must have a 'run()' method. See the docs for more information.")
def copy_class_namespace(namespace): def copy_class_namespace(namespace):
""" """
@ -323,37 +334,54 @@ class _Component:
to populate the newly created class. We just copy to populate the newly created class. We just copy
the whole namespace from the decorated class. the whole namespace from the decorated class.
""" """
for key, val in dict(class_.__dict__).items(): for key, val in dict(cls.__dict__).items():
# __dict__ and __weakref__ are class-bound, we should let Python recreate them. # __dict__ and __weakref__ are class-bound, we should let Python recreate them.
if key in ("__dict__", "__weakref__"): if key in ("__dict__", "__weakref__"):
continue continue
namespace[key] = val namespace[key] = val
# Recreate the decorated component class so it uses our metaclass # Recreate the decorated component class so it uses our metaclass.
class_: class_.__name__ = new_class( # We must explicitly redefine the type of the class to make sure language servers
class_.__name__, class_.__bases__, {"metaclass": ComponentMeta}, copy_class_namespace # and type checkers understand that the class is of the correct type.
) # mypy doesn't like that we do this though so we explicitly ignore the type check.
cls: cls.__name__ = new_class(cls.__name__, cls.__bases__, {"metaclass": ComponentMeta}, copy_class_namespace) # type: ignore[no-redef]
# Save the component in the class registry (for deserialization) # Save the component in the class registry (for deserialization)
class_path = f"{class_.__module__}.{class_.__name__}" class_path = f"{cls.__module__}.{cls.__name__}"
if class_path in self.registry: if class_path in self.registry:
# Corner case, but it may occur easily in notebooks when re-running cells. # Corner case, but it may occur easily in notebooks when re-running cells.
logger.debug( logger.debug(
"Component %s is already registered. Previous imported from '%s', new imported from '%s'", "Component %s is already registered. Previous imported from '%s', new imported from '%s'",
class_path, class_path,
self.registry[class_path], self.registry[class_path],
class_, cls,
) )
self.registry[class_path] = class_ self.registry[class_path] = cls
logger.debug("Registered Component %s", class_) logger.debug("Registered Component %s", cls)
# Override the __repr__ method with a default one # Override the __repr__ method with a default one
class_.__repr__ = _component_repr cls.__repr__ = _component_repr
return class_ # The greedy flag can be True only if the component has a variadic input.
# At this point of the lifetime of the component, we can't reliably know if it has a variadic input.
# So we set it to whatever the user specified, during the instance creation we'll change it if needed
# since we'll have access to the input sockets and check if any of them is variadic.
setattr(cls, "__haystack_is_greedy__", is_greedy)
def __call__(self, class_): return cls
return self._component(class_)
def __call__(self, cls: Optional[type] = None, is_greedy: bool = False):
# We must wrap the call to the decorator in a function for it to work
# correctly with or without parens
def wrap(cls):
return self._component(cls, is_greedy=is_greedy)
if cls:
# Decorator is called without parens
return wrap(cls)
# Decorator is called with parens
return wrap
component = _Component() component = _Component()

View File

@ -842,7 +842,7 @@ class Pipeline:
# we're stuck for real and we can't make any progress. # we're stuck for real and we can't make any progress.
for name, comp in waiting_for_input: for name, comp in waiting_for_input:
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore
if is_variadic and not getattr(comp, "is_greedy", False): if is_variadic and not comp.__haystack_is_greedy__: # type: ignore[attr-defined]
break break
else: else:
# We're stuck in a loop for real, we can't make any progress. # We're stuck in a loop for real, we can't make any progress.
@ -873,14 +873,17 @@ class Pipeline:
# Lazy variadics must be removed only if there's nothing else to run at this stage # Lazy variadics must be removed only if there's nothing else to run at this stage
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore
if is_variadic and not getattr(comp, "is_greedy", False): if is_variadic and not comp.__haystack_is_greedy__: # type: ignore[attr-defined]
there_are_only_lazy_variadics = True there_are_only_lazy_variadics = True
for other_name, other_comp in waiting_for_input: for other_name, other_comp in waiting_for_input:
if name == other_name: if name == other_name:
continue continue
there_are_only_lazy_variadics &= any( there_are_only_lazy_variadics &= (
socket.is_variadic for socket in other_comp.__haystack_input__._sockets_dict.values() # type: ignore any(
) and not getattr(other_comp, "is_greedy", False) socket.is_variadic for socket in other_comp.__haystack_input__._sockets_dict.values() # type: ignore
)
and not other_comp.__haystack_is_greedy__ # type: ignore[attr-defined]
)
if not there_are_only_lazy_variadics: if not there_are_only_lazy_variadics:
continue continue

View File

@ -0,0 +1,12 @@
---
features:
- |
Add `is_greedy` argument to `@component` decorator.
This flag will change the behaviour of `Component`s with inputs that have a `Variadic` type
when running inside a `Pipeline`.
Variadic `Component`s that are marked as greedy will run as soon as they receive their first input.
If not marked as greedy instead they'll wait as long as possible before running to make sure they
receive as many inputs as possible from their senders.
It will be ignored for all other `Component`s even if set explicitly.

View File

@ -1,8 +1,10 @@
import logging
from typing import Any from typing import Any
import pytest import pytest
from haystack.core.component import Component, InputSocket, OutputSocket, component from haystack.core.component import Component, InputSocket, OutputSocket, component
from haystack.core.component.types import Variadic
from haystack.core.errors import ComponentError from haystack.core.errors import ComponentError
from haystack.core.pipeline import Pipeline from haystack.core.pipeline import Pipeline
@ -218,3 +220,55 @@ def test_repr_added_to_pipeline():
comp = MockComponent() comp = MockComponent()
pipe.add_component("my_component", comp) pipe.add_component("my_component", comp)
assert repr(comp) == f"{object.__repr__(comp)}\nmy_component\nInputs:\n - value: int\nOutputs:\n - value: int" assert repr(comp) == f"{object.__repr__(comp)}\nmy_component\nInputs:\n - value: int\nOutputs:\n - value: int"
def test_is_greedy_default_with_variadic_input():
@component
class MockComponent:
@component.output_types(value=int)
def run(self, value: Variadic[int]):
return {"value": value}
assert not MockComponent.__haystack_is_greedy__
assert not MockComponent().__haystack_is_greedy__
def test_is_greedy_default_without_variadic_input():
@component
class MockComponent:
@component.output_types(value=int)
def run(self, value: int):
return {"value": value}
assert not MockComponent.__haystack_is_greedy__
assert not MockComponent().__haystack_is_greedy__
def test_is_greedy_flag_with_variadic_input():
@component(is_greedy=True)
class MockComponent:
@component.output_types(value=int)
def run(self, value: Variadic[int]):
return {"value": value}
assert MockComponent.__haystack_is_greedy__
assert MockComponent().__haystack_is_greedy__
def test_is_greedy_flag_without_variadic_input(caplog):
caplog.set_level(logging.WARNING)
@component(is_greedy=True)
class MockComponent:
@component.output_types(value=int)
def run(self, value: int):
return {"value": value}
assert MockComponent.__haystack_is_greedy__
assert caplog.text == ""
assert MockComponent().__haystack_is_greedy__
assert (
caplog.text
== "WARNING root:component.py:165 Component 'MockComponent' has no variadic input, but it's marked as greedy."
" This is not supported and can lead to unexpected behavior.\n"
)