feat: Add is_greedy argument in @component decorator (#7016)

* Add is_greedy argument in @component decorator

* Log warning if Component is greedy and non variadic
This commit is contained in:
Silvano Cerza 2024-02-19 12:43:40 +01:00 committed by GitHub
parent 5f97e08feb
commit f1a6b2a78a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 121 additions and 25 deletions

View File

@ -15,9 +15,8 @@ else:
logger = logging.getLogger(__name__)
@component
@component(is_greedy=True)
class Multiplexer:
is_greedy = True
"""
This component is used to distribute a single value to many components that may need it.
It can take such value from different sources (the user's input, or another component), so

View File

@ -72,7 +72,7 @@ import inspect
import logging
from copy import deepcopy
from types import new_class
from typing import Any, Protocol, runtime_checkable
from typing import Any, Optional, Protocol, runtime_checkable
from haystack.core.errors import ComponentError
@ -157,6 +157,17 @@ class ComponentMeta(type):
# We use this flag to check that.
instance.__haystack_added_to_pipeline__ = None
# Only Components with variadic inputs can be greedy. If the user set the greedy flag
# to True, but the component doesn't have a variadic input, we set it to False.
# We can have this information only at instance creation time, so we do it here.
is_variadic = any(socket.is_variadic for socket in instance.__haystack_input__._sockets_dict.values())
if not is_variadic and cls.__haystack_is_greedy__:
logging.warning(
"Component '%s' has no variadic input, but it's marked as greedy. "
"This is not supported and can lead to unexpected behavior.",
cls.__name__,
)
return instance
@ -307,15 +318,15 @@ class _Component:
return output_types_decorator
def _component(self, class_):
def _component(self, cls, is_greedy: bool = False):
"""
Decorator validating the structure of the component and registering it in the components registry.
"""
logger.debug("Registering %s as a component", class_)
logger.debug("Registering %s as a component", cls)
# Check for required methods and fail as soon as possible
if not hasattr(class_, "run"):
raise ComponentError(f"{class_.__name__} must have a 'run()' method. See the docs for more information.")
if not hasattr(cls, "run"):
raise ComponentError(f"{cls.__name__} must have a 'run()' method. See the docs for more information.")
def copy_class_namespace(namespace):
"""
@ -323,37 +334,54 @@ class _Component:
to populate the newly created class. We just copy
the whole namespace from the decorated class.
"""
for key, val in dict(class_.__dict__).items():
for key, val in dict(cls.__dict__).items():
# __dict__ and __weakref__ are class-bound, we should let Python recreate them.
if key in ("__dict__", "__weakref__"):
continue
namespace[key] = val
# Recreate the decorated component class so it uses our metaclass
class_: class_.__name__ = new_class(
class_.__name__, class_.__bases__, {"metaclass": ComponentMeta}, copy_class_namespace
)
# Recreate the decorated component class so it uses our metaclass.
# We must explicitly redefine the type of the class to make sure language servers
# and type checkers understand that the class is of the correct type.
# mypy doesn't like that we do this though so we explicitly ignore the type check.
cls: cls.__name__ = new_class(cls.__name__, cls.__bases__, {"metaclass": ComponentMeta}, copy_class_namespace) # type: ignore[no-redef]
# Save the component in the class registry (for deserialization)
class_path = f"{class_.__module__}.{class_.__name__}"
class_path = f"{cls.__module__}.{cls.__name__}"
if class_path in self.registry:
# Corner case, but it may occur easily in notebooks when re-running cells.
logger.debug(
"Component %s is already registered. Previous imported from '%s', new imported from '%s'",
class_path,
self.registry[class_path],
class_,
cls,
)
self.registry[class_path] = class_
logger.debug("Registered Component %s", class_)
self.registry[class_path] = cls
logger.debug("Registered Component %s", cls)
# Override the __repr__ method with a default one
class_.__repr__ = _component_repr
cls.__repr__ = _component_repr
return class_
# The greedy flag can be True only if the component has a variadic input.
# At this point of the lifetime of the component, we can't reliably know if it has a variadic input.
# So we set it to whatever the user specified, during the instance creation we'll change it if needed
# since we'll have access to the input sockets and check if any of them is variadic.
setattr(cls, "__haystack_is_greedy__", is_greedy)
def __call__(self, class_):
return self._component(class_)
return cls
def __call__(self, cls: Optional[type] = None, is_greedy: bool = False):
# We must wrap the call to the decorator in a function for it to work
# correctly with or without parens
def wrap(cls):
return self._component(cls, is_greedy=is_greedy)
if cls:
# Decorator is called without parens
return wrap(cls)
# Decorator is called with parens
return wrap
component = _Component()

View File

@ -842,7 +842,7 @@ class Pipeline:
# we're stuck for real and we can't make any progress.
for name, comp in waiting_for_input:
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore
if is_variadic and not getattr(comp, "is_greedy", False):
if is_variadic and not comp.__haystack_is_greedy__: # type: ignore[attr-defined]
break
else:
# We're stuck in a loop for real, we can't make any progress.
@ -873,14 +873,17 @@ class Pipeline:
# Lazy variadics must be removed only if there's nothing else to run at this stage
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore
if is_variadic and not getattr(comp, "is_greedy", False):
if is_variadic and not comp.__haystack_is_greedy__: # type: ignore[attr-defined]
there_are_only_lazy_variadics = True
for other_name, other_comp in waiting_for_input:
if name == other_name:
continue
there_are_only_lazy_variadics &= any(
socket.is_variadic for socket in other_comp.__haystack_input__._sockets_dict.values() # type: ignore
) and not getattr(other_comp, "is_greedy", False)
there_are_only_lazy_variadics &= (
any(
socket.is_variadic for socket in other_comp.__haystack_input__._sockets_dict.values() # type: ignore
)
and not other_comp.__haystack_is_greedy__ # type: ignore[attr-defined]
)
if not there_are_only_lazy_variadics:
continue

View File

@ -0,0 +1,12 @@
---
features:
- |
Add `is_greedy` argument to `@component` decorator.
This flag will change the behaviour of `Component`s with inputs that have a `Variadic` type
when running inside a `Pipeline`.
Variadic `Component`s that are marked as greedy will run as soon as they receive their first input.
If not marked as greedy instead they'll wait as long as possible before running to make sure they
receive as many inputs as possible from their senders.
It will be ignored for all other `Component`s even if set explicitly.

View File

@ -1,8 +1,10 @@
import logging
from typing import Any
import pytest
from haystack.core.component import Component, InputSocket, OutputSocket, component
from haystack.core.component.types import Variadic
from haystack.core.errors import ComponentError
from haystack.core.pipeline import Pipeline
@ -218,3 +220,55 @@ def test_repr_added_to_pipeline():
comp = MockComponent()
pipe.add_component("my_component", comp)
assert repr(comp) == f"{object.__repr__(comp)}\nmy_component\nInputs:\n - value: int\nOutputs:\n - value: int"
def test_is_greedy_default_with_variadic_input():
@component
class MockComponent:
@component.output_types(value=int)
def run(self, value: Variadic[int]):
return {"value": value}
assert not MockComponent.__haystack_is_greedy__
assert not MockComponent().__haystack_is_greedy__
def test_is_greedy_default_without_variadic_input():
@component
class MockComponent:
@component.output_types(value=int)
def run(self, value: int):
return {"value": value}
assert not MockComponent.__haystack_is_greedy__
assert not MockComponent().__haystack_is_greedy__
def test_is_greedy_flag_with_variadic_input():
@component(is_greedy=True)
class MockComponent:
@component.output_types(value=int)
def run(self, value: Variadic[int]):
return {"value": value}
assert MockComponent.__haystack_is_greedy__
assert MockComponent().__haystack_is_greedy__
def test_is_greedy_flag_without_variadic_input(caplog):
caplog.set_level(logging.WARNING)
@component(is_greedy=True)
class MockComponent:
@component.output_types(value=int)
def run(self, value: int):
return {"value": value}
assert MockComponent.__haystack_is_greedy__
assert caplog.text == ""
assert MockComponent().__haystack_is_greedy__
assert (
caplog.text
== "WARNING root:component.py:165 Component 'MockComponent' has no variadic input, but it's marked as greedy."
" This is not supported and can lead to unexpected behavior.\n"
)