mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-13 07:47:26 +00:00
feat: Add is_greedy argument in @component decorator (#7016)
* Add is_greedy argument in @component decorator * Log warning if Component is greedy and non variadic
This commit is contained in:
parent
5f97e08feb
commit
f1a6b2a78a
@ -15,9 +15,8 @@ else:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@component
|
||||
@component(is_greedy=True)
|
||||
class Multiplexer:
|
||||
is_greedy = True
|
||||
"""
|
||||
This component is used to distribute a single value to many components that may need it.
|
||||
It can take such value from different sources (the user's input, or another component), so
|
||||
|
||||
@ -72,7 +72,7 @@ import inspect
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from types import new_class
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
from typing import Any, Optional, Protocol, runtime_checkable
|
||||
|
||||
from haystack.core.errors import ComponentError
|
||||
|
||||
@ -157,6 +157,17 @@ class ComponentMeta(type):
|
||||
# We use this flag to check that.
|
||||
instance.__haystack_added_to_pipeline__ = None
|
||||
|
||||
# Only Components with variadic inputs can be greedy. If the user set the greedy flag
|
||||
# to True, but the component doesn't have a variadic input, we set it to False.
|
||||
# We can have this information only at instance creation time, so we do it here.
|
||||
is_variadic = any(socket.is_variadic for socket in instance.__haystack_input__._sockets_dict.values())
|
||||
if not is_variadic and cls.__haystack_is_greedy__:
|
||||
logging.warning(
|
||||
"Component '%s' has no variadic input, but it's marked as greedy. "
|
||||
"This is not supported and can lead to unexpected behavior.",
|
||||
cls.__name__,
|
||||
)
|
||||
|
||||
return instance
|
||||
|
||||
|
||||
@ -307,15 +318,15 @@ class _Component:
|
||||
|
||||
return output_types_decorator
|
||||
|
||||
def _component(self, class_):
|
||||
def _component(self, cls, is_greedy: bool = False):
|
||||
"""
|
||||
Decorator validating the structure of the component and registering it in the components registry.
|
||||
"""
|
||||
logger.debug("Registering %s as a component", class_)
|
||||
logger.debug("Registering %s as a component", cls)
|
||||
|
||||
# Check for required methods and fail as soon as possible
|
||||
if not hasattr(class_, "run"):
|
||||
raise ComponentError(f"{class_.__name__} must have a 'run()' method. See the docs for more information.")
|
||||
if not hasattr(cls, "run"):
|
||||
raise ComponentError(f"{cls.__name__} must have a 'run()' method. See the docs for more information.")
|
||||
|
||||
def copy_class_namespace(namespace):
|
||||
"""
|
||||
@ -323,37 +334,54 @@ class _Component:
|
||||
to populate the newly created class. We just copy
|
||||
the whole namespace from the decorated class.
|
||||
"""
|
||||
for key, val in dict(class_.__dict__).items():
|
||||
for key, val in dict(cls.__dict__).items():
|
||||
# __dict__ and __weakref__ are class-bound, we should let Python recreate them.
|
||||
if key in ("__dict__", "__weakref__"):
|
||||
continue
|
||||
namespace[key] = val
|
||||
|
||||
# Recreate the decorated component class so it uses our metaclass
|
||||
class_: class_.__name__ = new_class(
|
||||
class_.__name__, class_.__bases__, {"metaclass": ComponentMeta}, copy_class_namespace
|
||||
)
|
||||
# Recreate the decorated component class so it uses our metaclass.
|
||||
# We must explicitly redefine the type of the class to make sure language servers
|
||||
# and type checkers understand that the class is of the correct type.
|
||||
# mypy doesn't like that we do this though so we explicitly ignore the type check.
|
||||
cls: cls.__name__ = new_class(cls.__name__, cls.__bases__, {"metaclass": ComponentMeta}, copy_class_namespace) # type: ignore[no-redef]
|
||||
|
||||
# Save the component in the class registry (for deserialization)
|
||||
class_path = f"{class_.__module__}.{class_.__name__}"
|
||||
class_path = f"{cls.__module__}.{cls.__name__}"
|
||||
if class_path in self.registry:
|
||||
# Corner case, but it may occur easily in notebooks when re-running cells.
|
||||
logger.debug(
|
||||
"Component %s is already registered. Previous imported from '%s', new imported from '%s'",
|
||||
class_path,
|
||||
self.registry[class_path],
|
||||
class_,
|
||||
cls,
|
||||
)
|
||||
self.registry[class_path] = class_
|
||||
logger.debug("Registered Component %s", class_)
|
||||
self.registry[class_path] = cls
|
||||
logger.debug("Registered Component %s", cls)
|
||||
|
||||
# Override the __repr__ method with a default one
|
||||
class_.__repr__ = _component_repr
|
||||
cls.__repr__ = _component_repr
|
||||
|
||||
return class_
|
||||
# The greedy flag can be True only if the component has a variadic input.
|
||||
# At this point of the lifetime of the component, we can't reliably know if it has a variadic input.
|
||||
# So we set it to whatever the user specified, during the instance creation we'll change it if needed
|
||||
# since we'll have access to the input sockets and check if any of them is variadic.
|
||||
setattr(cls, "__haystack_is_greedy__", is_greedy)
|
||||
|
||||
def __call__(self, class_):
|
||||
return self._component(class_)
|
||||
return cls
|
||||
|
||||
def __call__(self, cls: Optional[type] = None, is_greedy: bool = False):
|
||||
# We must wrap the call to the decorator in a function for it to work
|
||||
# correctly with or without parens
|
||||
def wrap(cls):
|
||||
return self._component(cls, is_greedy=is_greedy)
|
||||
|
||||
if cls:
|
||||
# Decorator is called without parens
|
||||
return wrap(cls)
|
||||
|
||||
# Decorator is called with parens
|
||||
return wrap
|
||||
|
||||
|
||||
component = _Component()
|
||||
|
||||
@ -842,7 +842,7 @@ class Pipeline:
|
||||
# we're stuck for real and we can't make any progress.
|
||||
for name, comp in waiting_for_input:
|
||||
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore
|
||||
if is_variadic and not getattr(comp, "is_greedy", False):
|
||||
if is_variadic and not comp.__haystack_is_greedy__: # type: ignore[attr-defined]
|
||||
break
|
||||
else:
|
||||
# We're stuck in a loop for real, we can't make any progress.
|
||||
@ -873,14 +873,17 @@ class Pipeline:
|
||||
|
||||
# Lazy variadics must be removed only if there's nothing else to run at this stage
|
||||
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore
|
||||
if is_variadic and not getattr(comp, "is_greedy", False):
|
||||
if is_variadic and not comp.__haystack_is_greedy__: # type: ignore[attr-defined]
|
||||
there_are_only_lazy_variadics = True
|
||||
for other_name, other_comp in waiting_for_input:
|
||||
if name == other_name:
|
||||
continue
|
||||
there_are_only_lazy_variadics &= any(
|
||||
socket.is_variadic for socket in other_comp.__haystack_input__._sockets_dict.values() # type: ignore
|
||||
) and not getattr(other_comp, "is_greedy", False)
|
||||
there_are_only_lazy_variadics &= (
|
||||
any(
|
||||
socket.is_variadic for socket in other_comp.__haystack_input__._sockets_dict.values() # type: ignore
|
||||
)
|
||||
and not other_comp.__haystack_is_greedy__ # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
if not there_are_only_lazy_variadics:
|
||||
continue
|
||||
|
||||
12
releasenotes/notes/component-greedy-d6630af901e96a4c.yaml
Normal file
12
releasenotes/notes/component-greedy-d6630af901e96a4c.yaml
Normal file
@ -0,0 +1,12 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Add `is_greedy` argument to `@component` decorator.
|
||||
This flag will change the behaviour of `Component`s with inputs that have a `Variadic` type
|
||||
when running inside a `Pipeline`.
|
||||
|
||||
Variadic `Component`s that are marked as greedy will run as soon as they receive their first input.
|
||||
If not marked as greedy instead they'll wait as long as possible before running to make sure they
|
||||
receive as many inputs as possible from their senders.
|
||||
|
||||
It will be ignored for all other `Component`s even if set explicitly.
|
||||
@ -1,8 +1,10 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.core.component import Component, InputSocket, OutputSocket, component
|
||||
from haystack.core.component.types import Variadic
|
||||
from haystack.core.errors import ComponentError
|
||||
from haystack.core.pipeline import Pipeline
|
||||
|
||||
@ -218,3 +220,55 @@ def test_repr_added_to_pipeline():
|
||||
comp = MockComponent()
|
||||
pipe.add_component("my_component", comp)
|
||||
assert repr(comp) == f"{object.__repr__(comp)}\nmy_component\nInputs:\n - value: int\nOutputs:\n - value: int"
|
||||
|
||||
|
||||
def test_is_greedy_default_with_variadic_input():
|
||||
@component
|
||||
class MockComponent:
|
||||
@component.output_types(value=int)
|
||||
def run(self, value: Variadic[int]):
|
||||
return {"value": value}
|
||||
|
||||
assert not MockComponent.__haystack_is_greedy__
|
||||
assert not MockComponent().__haystack_is_greedy__
|
||||
|
||||
|
||||
def test_is_greedy_default_without_variadic_input():
|
||||
@component
|
||||
class MockComponent:
|
||||
@component.output_types(value=int)
|
||||
def run(self, value: int):
|
||||
return {"value": value}
|
||||
|
||||
assert not MockComponent.__haystack_is_greedy__
|
||||
assert not MockComponent().__haystack_is_greedy__
|
||||
|
||||
|
||||
def test_is_greedy_flag_with_variadic_input():
|
||||
@component(is_greedy=True)
|
||||
class MockComponent:
|
||||
@component.output_types(value=int)
|
||||
def run(self, value: Variadic[int]):
|
||||
return {"value": value}
|
||||
|
||||
assert MockComponent.__haystack_is_greedy__
|
||||
assert MockComponent().__haystack_is_greedy__
|
||||
|
||||
|
||||
def test_is_greedy_flag_without_variadic_input(caplog):
|
||||
caplog.set_level(logging.WARNING)
|
||||
|
||||
@component(is_greedy=True)
|
||||
class MockComponent:
|
||||
@component.output_types(value=int)
|
||||
def run(self, value: int):
|
||||
return {"value": value}
|
||||
|
||||
assert MockComponent.__haystack_is_greedy__
|
||||
assert caplog.text == ""
|
||||
assert MockComponent().__haystack_is_greedy__
|
||||
assert (
|
||||
caplog.text
|
||||
== "WARNING root:component.py:165 Component 'MockComponent' has no variadic input, but it's marked as greedy."
|
||||
" This is not supported and can lead to unexpected behavior.\n"
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user