mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-17 01:58:23 +00:00
feat: Add is_greedy argument in @component decorator (#7016)
* Add is_greedy argument in @component decorator * Log warning if Component is greedy and non variadic
This commit is contained in:
parent
5f97e08feb
commit
f1a6b2a78a
@ -15,9 +15,8 @@ else:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@component
|
@component(is_greedy=True)
|
||||||
class Multiplexer:
|
class Multiplexer:
|
||||||
is_greedy = True
|
|
||||||
"""
|
"""
|
||||||
This component is used to distribute a single value to many components that may need it.
|
This component is used to distribute a single value to many components that may need it.
|
||||||
It can take such value from different sources (the user's input, or another component), so
|
It can take such value from different sources (the user's input, or another component), so
|
||||||
|
|||||||
@ -72,7 +72,7 @@ import inspect
|
|||||||
import logging
|
import logging
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from types import new_class
|
from types import new_class
|
||||||
from typing import Any, Protocol, runtime_checkable
|
from typing import Any, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from haystack.core.errors import ComponentError
|
from haystack.core.errors import ComponentError
|
||||||
|
|
||||||
@ -157,6 +157,17 @@ class ComponentMeta(type):
|
|||||||
# We use this flag to check that.
|
# We use this flag to check that.
|
||||||
instance.__haystack_added_to_pipeline__ = None
|
instance.__haystack_added_to_pipeline__ = None
|
||||||
|
|
||||||
|
# Only Components with variadic inputs can be greedy. If the user set the greedy flag
|
||||||
|
# to True, but the component doesn't have a variadic input, we set it to False.
|
||||||
|
# We can have this information only at instance creation time, so we do it here.
|
||||||
|
is_variadic = any(socket.is_variadic for socket in instance.__haystack_input__._sockets_dict.values())
|
||||||
|
if not is_variadic and cls.__haystack_is_greedy__:
|
||||||
|
logging.warning(
|
||||||
|
"Component '%s' has no variadic input, but it's marked as greedy. "
|
||||||
|
"This is not supported and can lead to unexpected behavior.",
|
||||||
|
cls.__name__,
|
||||||
|
)
|
||||||
|
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
|
|
||||||
@ -307,15 +318,15 @@ class _Component:
|
|||||||
|
|
||||||
return output_types_decorator
|
return output_types_decorator
|
||||||
|
|
||||||
def _component(self, class_):
|
def _component(self, cls, is_greedy: bool = False):
|
||||||
"""
|
"""
|
||||||
Decorator validating the structure of the component and registering it in the components registry.
|
Decorator validating the structure of the component and registering it in the components registry.
|
||||||
"""
|
"""
|
||||||
logger.debug("Registering %s as a component", class_)
|
logger.debug("Registering %s as a component", cls)
|
||||||
|
|
||||||
# Check for required methods and fail as soon as possible
|
# Check for required methods and fail as soon as possible
|
||||||
if not hasattr(class_, "run"):
|
if not hasattr(cls, "run"):
|
||||||
raise ComponentError(f"{class_.__name__} must have a 'run()' method. See the docs for more information.")
|
raise ComponentError(f"{cls.__name__} must have a 'run()' method. See the docs for more information.")
|
||||||
|
|
||||||
def copy_class_namespace(namespace):
|
def copy_class_namespace(namespace):
|
||||||
"""
|
"""
|
||||||
@ -323,37 +334,54 @@ class _Component:
|
|||||||
to populate the newly created class. We just copy
|
to populate the newly created class. We just copy
|
||||||
the whole namespace from the decorated class.
|
the whole namespace from the decorated class.
|
||||||
"""
|
"""
|
||||||
for key, val in dict(class_.__dict__).items():
|
for key, val in dict(cls.__dict__).items():
|
||||||
# __dict__ and __weakref__ are class-bound, we should let Python recreate them.
|
# __dict__ and __weakref__ are class-bound, we should let Python recreate them.
|
||||||
if key in ("__dict__", "__weakref__"):
|
if key in ("__dict__", "__weakref__"):
|
||||||
continue
|
continue
|
||||||
namespace[key] = val
|
namespace[key] = val
|
||||||
|
|
||||||
# Recreate the decorated component class so it uses our metaclass
|
# Recreate the decorated component class so it uses our metaclass.
|
||||||
class_: class_.__name__ = new_class(
|
# We must explicitly redefine the type of the class to make sure language servers
|
||||||
class_.__name__, class_.__bases__, {"metaclass": ComponentMeta}, copy_class_namespace
|
# and type checkers understand that the class is of the correct type.
|
||||||
)
|
# mypy doesn't like that we do this though so we explicitly ignore the type check.
|
||||||
|
cls: cls.__name__ = new_class(cls.__name__, cls.__bases__, {"metaclass": ComponentMeta}, copy_class_namespace) # type: ignore[no-redef]
|
||||||
|
|
||||||
# Save the component in the class registry (for deserialization)
|
# Save the component in the class registry (for deserialization)
|
||||||
class_path = f"{class_.__module__}.{class_.__name__}"
|
class_path = f"{cls.__module__}.{cls.__name__}"
|
||||||
if class_path in self.registry:
|
if class_path in self.registry:
|
||||||
# Corner case, but it may occur easily in notebooks when re-running cells.
|
# Corner case, but it may occur easily in notebooks when re-running cells.
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Component %s is already registered. Previous imported from '%s', new imported from '%s'",
|
"Component %s is already registered. Previous imported from '%s', new imported from '%s'",
|
||||||
class_path,
|
class_path,
|
||||||
self.registry[class_path],
|
self.registry[class_path],
|
||||||
class_,
|
cls,
|
||||||
)
|
)
|
||||||
self.registry[class_path] = class_
|
self.registry[class_path] = cls
|
||||||
logger.debug("Registered Component %s", class_)
|
logger.debug("Registered Component %s", cls)
|
||||||
|
|
||||||
# Override the __repr__ method with a default one
|
# Override the __repr__ method with a default one
|
||||||
class_.__repr__ = _component_repr
|
cls.__repr__ = _component_repr
|
||||||
|
|
||||||
return class_
|
# The greedy flag can be True only if the component has a variadic input.
|
||||||
|
# At this point of the lifetime of the component, we can't reliably know if it has a variadic input.
|
||||||
|
# So we set it to whatever the user specified, during the instance creation we'll change it if needed
|
||||||
|
# since we'll have access to the input sockets and check if any of them is variadic.
|
||||||
|
setattr(cls, "__haystack_is_greedy__", is_greedy)
|
||||||
|
|
||||||
def __call__(self, class_):
|
return cls
|
||||||
return self._component(class_)
|
|
||||||
|
def __call__(self, cls: Optional[type] = None, is_greedy: bool = False):
|
||||||
|
# We must wrap the call to the decorator in a function for it to work
|
||||||
|
# correctly with or without parens
|
||||||
|
def wrap(cls):
|
||||||
|
return self._component(cls, is_greedy=is_greedy)
|
||||||
|
|
||||||
|
if cls:
|
||||||
|
# Decorator is called without parens
|
||||||
|
return wrap(cls)
|
||||||
|
|
||||||
|
# Decorator is called with parens
|
||||||
|
return wrap
|
||||||
|
|
||||||
|
|
||||||
component = _Component()
|
component = _Component()
|
||||||
|
|||||||
@ -842,7 +842,7 @@ class Pipeline:
|
|||||||
# we're stuck for real and we can't make any progress.
|
# we're stuck for real and we can't make any progress.
|
||||||
for name, comp in waiting_for_input:
|
for name, comp in waiting_for_input:
|
||||||
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore
|
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore
|
||||||
if is_variadic and not getattr(comp, "is_greedy", False):
|
if is_variadic and not comp.__haystack_is_greedy__: # type: ignore[attr-defined]
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# We're stuck in a loop for real, we can't make any progress.
|
# We're stuck in a loop for real, we can't make any progress.
|
||||||
@ -873,14 +873,17 @@ class Pipeline:
|
|||||||
|
|
||||||
# Lazy variadics must be removed only if there's nothing else to run at this stage
|
# Lazy variadics must be removed only if there's nothing else to run at this stage
|
||||||
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore
|
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore
|
||||||
if is_variadic and not getattr(comp, "is_greedy", False):
|
if is_variadic and not comp.__haystack_is_greedy__: # type: ignore[attr-defined]
|
||||||
there_are_only_lazy_variadics = True
|
there_are_only_lazy_variadics = True
|
||||||
for other_name, other_comp in waiting_for_input:
|
for other_name, other_comp in waiting_for_input:
|
||||||
if name == other_name:
|
if name == other_name:
|
||||||
continue
|
continue
|
||||||
there_are_only_lazy_variadics &= any(
|
there_are_only_lazy_variadics &= (
|
||||||
socket.is_variadic for socket in other_comp.__haystack_input__._sockets_dict.values() # type: ignore
|
any(
|
||||||
) and not getattr(other_comp, "is_greedy", False)
|
socket.is_variadic for socket in other_comp.__haystack_input__._sockets_dict.values() # type: ignore
|
||||||
|
)
|
||||||
|
and not other_comp.__haystack_is_greedy__ # type: ignore[attr-defined]
|
||||||
|
)
|
||||||
|
|
||||||
if not there_are_only_lazy_variadics:
|
if not there_are_only_lazy_variadics:
|
||||||
continue
|
continue
|
||||||
|
|||||||
12
releasenotes/notes/component-greedy-d6630af901e96a4c.yaml
Normal file
12
releasenotes/notes/component-greedy-d6630af901e96a4c.yaml
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
---
|
||||||
|
features:
|
||||||
|
- |
|
||||||
|
Add `is_greedy` argument to `@component` decorator.
|
||||||
|
This flag will change the behaviour of `Component`s with inputs that have a `Variadic` type
|
||||||
|
when running inside a `Pipeline`.
|
||||||
|
|
||||||
|
Variadic `Component`s that are marked as greedy will run as soon as they receive their first input.
|
||||||
|
If not marked as greedy instead they'll wait as long as possible before running to make sure they
|
||||||
|
receive as many inputs as possible from their senders.
|
||||||
|
|
||||||
|
It will be ignored for all other `Component`s even if set explicitly.
|
||||||
@ -1,8 +1,10 @@
|
|||||||
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack.core.component import Component, InputSocket, OutputSocket, component
|
from haystack.core.component import Component, InputSocket, OutputSocket, component
|
||||||
|
from haystack.core.component.types import Variadic
|
||||||
from haystack.core.errors import ComponentError
|
from haystack.core.errors import ComponentError
|
||||||
from haystack.core.pipeline import Pipeline
|
from haystack.core.pipeline import Pipeline
|
||||||
|
|
||||||
@ -218,3 +220,55 @@ def test_repr_added_to_pipeline():
|
|||||||
comp = MockComponent()
|
comp = MockComponent()
|
||||||
pipe.add_component("my_component", comp)
|
pipe.add_component("my_component", comp)
|
||||||
assert repr(comp) == f"{object.__repr__(comp)}\nmy_component\nInputs:\n - value: int\nOutputs:\n - value: int"
|
assert repr(comp) == f"{object.__repr__(comp)}\nmy_component\nInputs:\n - value: int\nOutputs:\n - value: int"
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_greedy_default_with_variadic_input():
|
||||||
|
@component
|
||||||
|
class MockComponent:
|
||||||
|
@component.output_types(value=int)
|
||||||
|
def run(self, value: Variadic[int]):
|
||||||
|
return {"value": value}
|
||||||
|
|
||||||
|
assert not MockComponent.__haystack_is_greedy__
|
||||||
|
assert not MockComponent().__haystack_is_greedy__
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_greedy_default_without_variadic_input():
|
||||||
|
@component
|
||||||
|
class MockComponent:
|
||||||
|
@component.output_types(value=int)
|
||||||
|
def run(self, value: int):
|
||||||
|
return {"value": value}
|
||||||
|
|
||||||
|
assert not MockComponent.__haystack_is_greedy__
|
||||||
|
assert not MockComponent().__haystack_is_greedy__
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_greedy_flag_with_variadic_input():
|
||||||
|
@component(is_greedy=True)
|
||||||
|
class MockComponent:
|
||||||
|
@component.output_types(value=int)
|
||||||
|
def run(self, value: Variadic[int]):
|
||||||
|
return {"value": value}
|
||||||
|
|
||||||
|
assert MockComponent.__haystack_is_greedy__
|
||||||
|
assert MockComponent().__haystack_is_greedy__
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_greedy_flag_without_variadic_input(caplog):
|
||||||
|
caplog.set_level(logging.WARNING)
|
||||||
|
|
||||||
|
@component(is_greedy=True)
|
||||||
|
class MockComponent:
|
||||||
|
@component.output_types(value=int)
|
||||||
|
def run(self, value: int):
|
||||||
|
return {"value": value}
|
||||||
|
|
||||||
|
assert MockComponent.__haystack_is_greedy__
|
||||||
|
assert caplog.text == ""
|
||||||
|
assert MockComponent().__haystack_is_greedy__
|
||||||
|
assert (
|
||||||
|
caplog.text
|
||||||
|
== "WARNING root:component.py:165 Component 'MockComponent' has no variadic input, but it's marked as greedy."
|
||||||
|
" This is not supported and can lead to unexpected behavior.\n"
|
||||||
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user