diff --git a/haystack/components/joiners/branch.py b/haystack/components/joiners/branch.py index 788d78a91..49404ac49 100644 --- a/haystack/components/joiners/branch.py +++ b/haystack/components/joiners/branch.py @@ -5,13 +5,13 @@ from typing import Any, Dict, Type from haystack import component, default_from_dict, default_to_dict, logging -from haystack.core.component.types import Variadic +from haystack.core.component.types import GreedyVariadic from haystack.utils import deserialize_type, serialize_type logger = logging.getLogger(__name__) -@component(is_greedy=True) +@component() class BranchJoiner: """ A component to join different branches of a pipeline into one single output. @@ -100,7 +100,7 @@ class BranchJoiner: """ self.type_ = type_ # type_'s type can't be determined statically - component.set_input_types(self, value=Variadic[type_]) # type: ignore + component.set_input_types(self, value=GreedyVariadic[type_]) # type: ignore component.set_output_types(self, value=type_) def to_dict(self): diff --git a/haystack/core/component/component.py b/haystack/core/component/component.py index 72c444fdd..e1dbf2c5f 100644 --- a/haystack/core/component/component.py +++ b/haystack/core/component/component.py @@ -71,6 +71,7 @@ method decorated with `@component.input`. This dataclass contains: import inspect import sys +import warnings from collections.abc import Callable from contextlib import contextmanager from contextvars import ContextVar @@ -292,17 +293,6 @@ class ComponentMeta(type): # We use this flag to check that. instance.__haystack_added_to_pipeline__ = None - # Only Components with variadic inputs can be greedy. If the user set the greedy flag - # to True, but the component doesn't have a variadic input, we set it to False. - # We can have this information only at instance creation time, so we do it here. - is_variadic = any(socket.is_variadic for socket in instance.__haystack_input__._sockets_dict.values()) - if not is_variadic and cls.__haystack_is_greedy__: - logger.warning( - "Component '{component}' has no variadic input, but it's marked as greedy. " - "This is not supported and can lead to unexpected behavior.", - component=cls.__name__, - ) - return instance @@ -497,12 +487,21 @@ class _Component: return output_types_decorator - def _component(self, cls, is_greedy: bool = False): + def _component(self, cls, is_greedy: Optional[bool] = None): """ Decorator validating the structure of the component and registering it in the components registry. """ logger.debug("Registering {component} as a component", component=cls) + if is_greedy is not None: + msg = ( + "The 'is_greedy' argument is deprecated and will be removed in version '2.7.0'. " + "Change the 'Variadic' input of your Component to 'GreedyVariadic' instead." + ) + warnings.warn(msg, DeprecationWarning) + else: + is_greedy = False + # Check for required methods and fail as soon as possible if not hasattr(cls, "run"): raise ComponentError(f"{cls.__name__} must have a 'run()' method. See the docs for more information.") @@ -542,15 +541,9 @@ class _Component: # Override the __repr__ method with a default one cls.__repr__ = _component_repr - # The greedy flag can be True only if the component has a variadic input. - # At this point of the lifetime of the component, we can't reliably know if it has a variadic input. - # So we set it to whatever the user specified, during the instance creation we'll change it if needed - # since we'll have access to the input sockets and check if any of them is variadic. - setattr(cls, "__haystack_is_greedy__", is_greedy) - return cls - def __call__(self, cls: Optional[type] = None, is_greedy: bool = False): + def __call__(self, cls: Optional[type] = None, is_greedy: Optional[bool] = None): # We must wrap the call to the decorator in a function for it to work # correctly with or without parens def wrap(cls): diff --git a/haystack/core/component/types.py b/haystack/core/component/types.py index f7aa8d087..b08681a75 100644 --- a/haystack/core/component/types.py +++ b/haystack/core/component/types.py @@ -8,6 +8,7 @@ from typing import Any, Iterable, List, Type, TypeVar, get_args from typing_extensions import Annotated, TypeAlias # Python 3.8 compatibility HAYSTACK_VARIADIC_ANNOTATION = "__haystack__variadic_t" +HAYSTACK_GREEDY_VARIADIC_ANNOTATION = "__haystack__greedy_variadic_t" # # Generic type variable used in the Variadic container T = TypeVar("T") @@ -16,9 +17,17 @@ T = TypeVar("T") # Variadic is a custom annotation type we use to mark input types. # This type doesn't do anything else than "marking" the contained # type so it can be used in the `InputSocket` creation where we -# check that its annotation equals to CANALS_VARIADIC_ANNOTATION +# check that its annotation equals to HAYSTACK_VARIADIC_ANNOTATION Variadic: TypeAlias = Annotated[Iterable[T], HAYSTACK_VARIADIC_ANNOTATION] +# GreedyVariadic type is similar to Variadic. +# The only difference is the way it's treated by the Pipeline when input is received +# in a socket with this type. +# Instead of waiting for other inputs to be received, Components that have a GreedyVariadic +# input will be run right after receiving the first input. +# Even if there are multiple connections to that socket. +GreedyVariadic: TypeAlias = Annotated[Iterable[T], HAYSTACK_GREEDY_VARIADIC_ANNOTATION] + class _empty: """Custom object for marking InputSocket.default_value as not set.""" @@ -37,6 +46,8 @@ class InputSocket: The default value of the input. If not set, the input is mandatory. :param is_variadic: Whether the input is variadic or not. + :param is_greedy + Whether the input is a greedy variadic or not. :param senders: The list of components that send data to this input. """ @@ -45,6 +56,7 @@ class InputSocket: type: Type default_value: Any = _empty is_variadic: bool = field(init=False) + is_greedy: bool = field(init=False) senders: List[str] = field(default_factory=list) @property @@ -55,9 +67,14 @@ class InputSocket: def __post_init__(self): try: # __metadata__ is a tuple - self.is_variadic = self.type.__metadata__[0] == HAYSTACK_VARIADIC_ANNOTATION + self.is_variadic = self.type.__metadata__[0] in [ + HAYSTACK_VARIADIC_ANNOTATION, + HAYSTACK_GREEDY_VARIADIC_ANNOTATION, + ] + self.is_greedy = self.type.__metadata__[0] == HAYSTACK_GREEDY_VARIADIC_ANNOTATION except AttributeError: self.is_variadic = False + self.is_greedy = False if self.is_variadic: # We need to "unpack" the type inside the Variadic annotation, # otherwise the pipeline connection api will try to match diff --git a/haystack/core/pipeline/base.py b/haystack/core/pipeline/base.py index 7044c9a36..9478e7dae 100644 --- a/haystack/core/pipeline/base.py +++ b/haystack/core/pipeline/base.py @@ -978,9 +978,8 @@ class PipelineBase: receiver = self.graph.nodes[receiver_name]["instance"] pair = (receiver_name, receiver) - is_greedy = getattr(receiver, "__haystack_is_greedy__", False) if receiver_socket.is_variadic: - if is_greedy: + if receiver_socket.is_greedy: # If the receiver is greedy, we can run it as soon as possible. # First we remove it from the status lists it's in if it's there or # we risk running it multiple times. @@ -1214,7 +1213,7 @@ def _connections_status( def _is_lazy_variadic(c: Component) -> bool: """ - Small utility function to check if a Component has a Variadic input that is not greedy + Small utility function to check if a Component has at least a Variadic input and no GreedyVariadic input. """ is_variadic = any( socket.is_variadic @@ -1222,7 +1221,10 @@ def _is_lazy_variadic(c: Component) -> bool: ) if not is_variadic: return False - return not getattr(c, "__haystack_is_greedy__", False) + return not any( + socket.is_greedy + for socket in c.__haystack_input__._sockets_dict.values() # type: ignore + ) def _has_all_inputs_with_defaults(c: Component) -> bool: diff --git a/releasenotes/notes/deprecate-greedy-argument-4b8c39572f5df25c.yaml b/releasenotes/notes/deprecate-greedy-argument-4b8c39572f5df25c.yaml new file mode 100644 index 000000000..c868a9a19 --- /dev/null +++ b/releasenotes/notes/deprecate-greedy-argument-4b8c39572f5df25c.yaml @@ -0,0 +1,13 @@ +--- +enhancements: + - | + Add new `GreedyVariadic` input type. This has a similar behaviour to `Variadic` input type + as it can be connected to multiple output sockets, though the Pipeline will run it as soon + as it receives an input without waiting for others. + This replaces the `is_greedy` argument in the `@component` decorator. + If you had a Component with a `Variadic` input type and `@component(is_greedy=True)` you need + to change the type to `GreedyVariadic` and remove `is_greedy=true` from `@component`. +deprecations: + - | + `@component` decorator `is_greedy` argument is deprecated and will be removed in version `2.7.0`. + Use `GreedyVariadic` type instead. diff --git a/test/core/component/test_component.py b/test/core/component/test_component.py index 8b4266dbb..49d3a8bd5 100644 --- a/test/core/component/test_component.py +++ b/test/core/component/test_component.py @@ -430,57 +430,6 @@ def test_repr_added_to_pipeline(): assert repr(comp) == f"{object.__repr__(comp)}\nmy_component\nInputs:\n - value: int\nOutputs:\n - value: int" -def test_is_greedy_default_with_variadic_input(): - @component - class MockComponent: - @component.output_types(value=int) - def run(self, value: Variadic[int]): - return {"value": value} - - assert not MockComponent.__haystack_is_greedy__ - assert not MockComponent().__haystack_is_greedy__ - - -def test_is_greedy_default_without_variadic_input(): - @component - class MockComponent: - @component.output_types(value=int) - def run(self, value: int): - return {"value": value} - - assert not MockComponent.__haystack_is_greedy__ - assert not MockComponent().__haystack_is_greedy__ - - -def test_is_greedy_flag_with_variadic_input(): - @component(is_greedy=True) - class MockComponent: - @component.output_types(value=int) - def run(self, value: Variadic[int]): - return {"value": value} - - assert MockComponent.__haystack_is_greedy__ - assert MockComponent().__haystack_is_greedy__ - - -def test_is_greedy_flag_without_variadic_input(caplog): - caplog.set_level(logging.WARNING) - - @component(is_greedy=True) - class MockComponent: - @component.output_types(value=int) - def run(self, value: int): - return {"value": value} - - assert MockComponent.__haystack_is_greedy__ - assert caplog.text == "" - assert MockComponent().__haystack_is_greedy__ - assert ( - "Component 'MockComponent' has no variadic input, but it's marked as greedy." - " This is not supported and can lead to unexpected behavior.\n" in caplog.text - ) - - def test_pre_init_hooking(): @component class MockComponent: diff --git a/test/core/pipeline/test_pipeline.py b/test/core/pipeline/test_pipeline.py index e33d45804..9e8c46efd 100644 --- a/test/core/pipeline/test_pipeline.py +++ b/test/core/pipeline/test_pipeline.py @@ -11,7 +11,7 @@ from haystack import Document from haystack.components.builders import PromptBuilder, AnswerBuilder from haystack.components.joiners import BranchJoiner from haystack.core.component import component -from haystack.core.component.types import InputSocket, OutputSocket, Variadic +from haystack.core.component.types import InputSocket, OutputSocket, Variadic, GreedyVariadic from haystack.core.errors import DeserializationError, PipelineConnectError, PipelineDrawingError, PipelineError from haystack.core.pipeline import Pipeline, PredefinedPipeline from haystack.core.pipeline.base import ( @@ -20,6 +20,7 @@ from haystack.core.pipeline.base import ( _dequeue_component, _enqueue_waiting_component, _dequeue_waiting_component, + _is_lazy_variadic, ) from haystack.core.serialization import DeserializationCallbacks from haystack.testing.factory import component_class @@ -1628,3 +1629,19 @@ class TestPipeline: waiting_queue = [("document_builder", document_builder), ("document_joiner", document_joiner)] _dequeue_waiting_component(("document_builder", document_builder), waiting_queue) assert waiting_queue == [("document_joiner", document_joiner)] + + def test__is_lazy_variadic(self): + VariadicAndGreedyVariadic = component_class( + "VariadicAndGreedyVariadic", input_types={"variadic": Variadic[int], "greedy_variadic": GreedyVariadic[int]} + ) + NonVariadic = component_class("NonVariadic", input_types={"value": int}) + VariadicNonGreedyVariadic = component_class( + "VariadicNonGreedyVariadic", input_types={"variadic": Variadic[int]} + ) + NonVariadicAndGreedyVariadic = component_class( + "NonVariadicAndGreedyVariadic", input_types={"greedy_variadic": GreedyVariadic[int]} + ) + assert not _is_lazy_variadic(VariadicAndGreedyVariadic()) + assert not _is_lazy_variadic(NonVariadic()) + assert _is_lazy_variadic(VariadicNonGreedyVariadic()) + assert not _is_lazy_variadic(NonVariadicAndGreedyVariadic())