diff --git a/haystack/components/others/multiplexer.py b/haystack/components/others/multiplexer.py index ddacdce85..e3f5d5abb 100644 --- a/haystack/components/others/multiplexer.py +++ b/haystack/components/others/multiplexer.py @@ -15,9 +15,8 @@ else: logger = logging.getLogger(__name__) -@component +@component(is_greedy=True) class Multiplexer: - is_greedy = True """ This component is used to distribute a single value to many components that may need it. It can take such value from different sources (the user's input, or another component), so diff --git a/haystack/core/component/component.py b/haystack/core/component/component.py index c231e41df..2046aea2d 100644 --- a/haystack/core/component/component.py +++ b/haystack/core/component/component.py @@ -72,7 +72,7 @@ import inspect import logging from copy import deepcopy from types import new_class -from typing import Any, Protocol, runtime_checkable +from typing import Any, Optional, Protocol, runtime_checkable from haystack.core.errors import ComponentError @@ -157,6 +157,17 @@ class ComponentMeta(type): # We use this flag to check that. instance.__haystack_added_to_pipeline__ = None + # Only Components with variadic inputs can be greedy. If the user set the greedy flag + # to True, but the component doesn't have a variadic input, we set it to False. + # We can have this information only at instance creation time, so we do it here. + is_variadic = any(socket.is_variadic for socket in instance.__haystack_input__._sockets_dict.values()) + if not is_variadic and cls.__haystack_is_greedy__: + logging.warning( + "Component '%s' has no variadic input, but it's marked as greedy. " + "This is not supported and can lead to unexpected behavior.", + cls.__name__, + ) + return instance @@ -307,15 +318,15 @@ class _Component: return output_types_decorator - def _component(self, class_): + def _component(self, cls, is_greedy: bool = False): """ Decorator validating the structure of the component and registering it in the components registry. """ - logger.debug("Registering %s as a component", class_) + logger.debug("Registering %s as a component", cls) # Check for required methods and fail as soon as possible - if not hasattr(class_, "run"): - raise ComponentError(f"{class_.__name__} must have a 'run()' method. See the docs for more information.") + if not hasattr(cls, "run"): + raise ComponentError(f"{cls.__name__} must have a 'run()' method. See the docs for more information.") def copy_class_namespace(namespace): """ @@ -323,37 +334,54 @@ class _Component: to populate the newly created class. We just copy the whole namespace from the decorated class. """ - for key, val in dict(class_.__dict__).items(): + for key, val in dict(cls.__dict__).items(): # __dict__ and __weakref__ are class-bound, we should let Python recreate them. if key in ("__dict__", "__weakref__"): continue namespace[key] = val - # Recreate the decorated component class so it uses our metaclass - class_: class_.__name__ = new_class( - class_.__name__, class_.__bases__, {"metaclass": ComponentMeta}, copy_class_namespace - ) + # Recreate the decorated component class so it uses our metaclass. + # We must explicitly redefine the type of the class to make sure language servers + # and type checkers understand that the class is of the correct type. + # mypy doesn't like that we do this though so we explicitly ignore the type check. + cls: cls.__name__ = new_class(cls.__name__, cls.__bases__, {"metaclass": ComponentMeta}, copy_class_namespace) # type: ignore[no-redef] # Save the component in the class registry (for deserialization) - class_path = f"{class_.__module__}.{class_.__name__}" + class_path = f"{cls.__module__}.{cls.__name__}" if class_path in self.registry: # Corner case, but it may occur easily in notebooks when re-running cells. logger.debug( "Component %s is already registered. Previous imported from '%s', new imported from '%s'", class_path, self.registry[class_path], - class_, + cls, ) - self.registry[class_path] = class_ - logger.debug("Registered Component %s", class_) + self.registry[class_path] = cls + logger.debug("Registered Component %s", cls) # Override the __repr__ method with a default one - class_.__repr__ = _component_repr + cls.__repr__ = _component_repr - return class_ + # The greedy flag can be True only if the component has a variadic input. + # At this point of the lifetime of the component, we can't reliably know if it has a variadic input. + # So we set it to whatever the user specified, during the instance creation we'll change it if needed + # since we'll have access to the input sockets and check if any of them is variadic. + setattr(cls, "__haystack_is_greedy__", is_greedy) - def __call__(self, class_): - return self._component(class_) + return cls + + def __call__(self, cls: Optional[type] = None, is_greedy: bool = False): + # We must wrap the call to the decorator in a function for it to work + # correctly with or without parens + def wrap(cls): + return self._component(cls, is_greedy=is_greedy) + + if cls: + # Decorator is called without parens + return wrap(cls) + + # Decorator is called with parens + return wrap component = _Component() diff --git a/haystack/core/pipeline/pipeline.py b/haystack/core/pipeline/pipeline.py index 2b8708b49..b4dfebe94 100644 --- a/haystack/core/pipeline/pipeline.py +++ b/haystack/core/pipeline/pipeline.py @@ -842,7 +842,7 @@ class Pipeline: # we're stuck for real and we can't make any progress. for name, comp in waiting_for_input: is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore - if is_variadic and not getattr(comp, "is_greedy", False): + if is_variadic and not comp.__haystack_is_greedy__: # type: ignore[attr-defined] break else: # We're stuck in a loop for real, we can't make any progress. @@ -873,14 +873,17 @@ class Pipeline: # Lazy variadics must be removed only if there's nothing else to run at this stage is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore - if is_variadic and not getattr(comp, "is_greedy", False): + if is_variadic and not comp.__haystack_is_greedy__: # type: ignore[attr-defined] there_are_only_lazy_variadics = True for other_name, other_comp in waiting_for_input: if name == other_name: continue - there_are_only_lazy_variadics &= any( - socket.is_variadic for socket in other_comp.__haystack_input__._sockets_dict.values() # type: ignore - ) and not getattr(other_comp, "is_greedy", False) + there_are_only_lazy_variadics &= ( + any( + socket.is_variadic for socket in other_comp.__haystack_input__._sockets_dict.values() # type: ignore + ) + and not other_comp.__haystack_is_greedy__ # type: ignore[attr-defined] + ) if not there_are_only_lazy_variadics: continue diff --git a/releasenotes/notes/component-greedy-d6630af901e96a4c.yaml b/releasenotes/notes/component-greedy-d6630af901e96a4c.yaml new file mode 100644 index 000000000..c973a150d --- /dev/null +++ b/releasenotes/notes/component-greedy-d6630af901e96a4c.yaml @@ -0,0 +1,12 @@ +--- +features: + - | + Add `is_greedy` argument to `@component` decorator. + This flag will change the behaviour of `Component`s with inputs that have a `Variadic` type + when running inside a `Pipeline`. + + Variadic `Component`s that are marked as greedy will run as soon as they receive their first input. + If not marked as greedy instead they'll wait as long as possible before running to make sure they + receive as many inputs as possible from their senders. + + It will be ignored for all other `Component`s even if set explicitly. diff --git a/test/core/component/test_component.py b/test/core/component/test_component.py index b093c32b8..1bd971114 100644 --- a/test/core/component/test_component.py +++ b/test/core/component/test_component.py @@ -1,8 +1,10 @@ +import logging from typing import Any import pytest from haystack.core.component import Component, InputSocket, OutputSocket, component +from haystack.core.component.types import Variadic from haystack.core.errors import ComponentError from haystack.core.pipeline import Pipeline @@ -218,3 +220,55 @@ def test_repr_added_to_pipeline(): comp = MockComponent() pipe.add_component("my_component", comp) assert repr(comp) == f"{object.__repr__(comp)}\nmy_component\nInputs:\n - value: int\nOutputs:\n - value: int" + + +def test_is_greedy_default_with_variadic_input(): + @component + class MockComponent: + @component.output_types(value=int) + def run(self, value: Variadic[int]): + return {"value": value} + + assert not MockComponent.__haystack_is_greedy__ + assert not MockComponent().__haystack_is_greedy__ + + +def test_is_greedy_default_without_variadic_input(): + @component + class MockComponent: + @component.output_types(value=int) + def run(self, value: int): + return {"value": value} + + assert not MockComponent.__haystack_is_greedy__ + assert not MockComponent().__haystack_is_greedy__ + + +def test_is_greedy_flag_with_variadic_input(): + @component(is_greedy=True) + class MockComponent: + @component.output_types(value=int) + def run(self, value: Variadic[int]): + return {"value": value} + + assert MockComponent.__haystack_is_greedy__ + assert MockComponent().__haystack_is_greedy__ + + +def test_is_greedy_flag_without_variadic_input(caplog): + caplog.set_level(logging.WARNING) + + @component(is_greedy=True) + class MockComponent: + @component.output_types(value=int) + def run(self, value: int): + return {"value": value} + + assert MockComponent.__haystack_is_greedy__ + assert caplog.text == "" + assert MockComponent().__haystack_is_greedy__ + assert ( + caplog.text + == "WARNING root:component.py:165 Component 'MockComponent' has no variadic input, but it's marked as greedy." + " This is not supported and can lead to unexpected behavior.\n" + )