feat: Deprecate @component decorator is_greedy argument (#8400)

* Deprecate @component decorator is_greedy argument

* Fix some typos and docstrings

* Add _is_lazy_variadic test
This commit is contained in:
Silvano Cerza 2024-09-25 11:28:30 +02:00 committed by GitHub
parent 2cc76beacd
commit 0df379e6a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 71 additions and 80 deletions

View File

@ -5,13 +5,13 @@
from typing import Any, Dict, Type
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.core.component.types import Variadic
from haystack.core.component.types import GreedyVariadic
from haystack.utils import deserialize_type, serialize_type
logger = logging.getLogger(__name__)
@component(is_greedy=True)
@component()
class BranchJoiner:
"""
A component to join different branches of a pipeline into one single output.
@ -100,7 +100,7 @@ class BranchJoiner:
"""
self.type_ = type_
# type_'s type can't be determined statically
component.set_input_types(self, value=Variadic[type_]) # type: ignore
component.set_input_types(self, value=GreedyVariadic[type_]) # type: ignore
component.set_output_types(self, value=type_)
def to_dict(self):

View File

@ -71,6 +71,7 @@ method decorated with `@component.input`. This dataclass contains:
import inspect
import sys
import warnings
from collections.abc import Callable
from contextlib import contextmanager
from contextvars import ContextVar
@ -292,17 +293,6 @@ class ComponentMeta(type):
# We use this flag to check that.
instance.__haystack_added_to_pipeline__ = None
# Only Components with variadic inputs can be greedy. If the user set the greedy flag
# to True, but the component doesn't have a variadic input, we set it to False.
# We can have this information only at instance creation time, so we do it here.
is_variadic = any(socket.is_variadic for socket in instance.__haystack_input__._sockets_dict.values())
if not is_variadic and cls.__haystack_is_greedy__:
logger.warning(
"Component '{component}' has no variadic input, but it's marked as greedy. "
"This is not supported and can lead to unexpected behavior.",
component=cls.__name__,
)
return instance
@ -497,12 +487,21 @@ class _Component:
return output_types_decorator
def _component(self, cls, is_greedy: bool = False):
def _component(self, cls, is_greedy: Optional[bool] = None):
"""
Decorator validating the structure of the component and registering it in the components registry.
"""
logger.debug("Registering {component} as a component", component=cls)
if is_greedy is not None:
msg = (
"The 'is_greedy' argument is deprecated and will be removed in version '2.7.0'. "
"Change the 'Variadic' input of your Component to 'GreedyVariadic' instead."
)
warnings.warn(msg, DeprecationWarning)
else:
is_greedy = False
# Check for required methods and fail as soon as possible
if not hasattr(cls, "run"):
raise ComponentError(f"{cls.__name__} must have a 'run()' method. See the docs for more information.")
@ -542,15 +541,9 @@ class _Component:
# Override the __repr__ method with a default one
cls.__repr__ = _component_repr
# The greedy flag can be True only if the component has a variadic input.
# At this point of the lifetime of the component, we can't reliably know if it has a variadic input.
# So we set it to whatever the user specified, during the instance creation we'll change it if needed
# since we'll have access to the input sockets and check if any of them is variadic.
setattr(cls, "__haystack_is_greedy__", is_greedy)
return cls
def __call__(self, cls: Optional[type] = None, is_greedy: bool = False):
def __call__(self, cls: Optional[type] = None, is_greedy: Optional[bool] = None):
# We must wrap the call to the decorator in a function for it to work
# correctly with or without parens
def wrap(cls):

View File

@ -8,6 +8,7 @@ from typing import Any, Iterable, List, Type, TypeVar, get_args
from typing_extensions import Annotated, TypeAlias # Python 3.8 compatibility
HAYSTACK_VARIADIC_ANNOTATION = "__haystack__variadic_t"
HAYSTACK_GREEDY_VARIADIC_ANNOTATION = "__haystack__greedy_variadic_t"
# # Generic type variable used in the Variadic container
T = TypeVar("T")
@ -16,9 +17,17 @@ T = TypeVar("T")
# Variadic is a custom annotation type we use to mark input types.
# This type doesn't do anything else than "marking" the contained
# type so it can be used in the `InputSocket` creation where we
# check that its annotation equals to CANALS_VARIADIC_ANNOTATION
# check that its annotation equals to HAYSTACK_VARIADIC_ANNOTATION
Variadic: TypeAlias = Annotated[Iterable[T], HAYSTACK_VARIADIC_ANNOTATION]
# GreedyVariadic type is similar to Variadic.
# The only difference is the way it's treated by the Pipeline when input is received
# in a socket with this type.
# Instead of waiting for other inputs to be received, Components that have a GreedyVariadic
# input will be run right after receiving the first input.
# Even if there are multiple connections to that socket.
GreedyVariadic: TypeAlias = Annotated[Iterable[T], HAYSTACK_GREEDY_VARIADIC_ANNOTATION]
class _empty:
"""Custom object for marking InputSocket.default_value as not set."""
@ -37,6 +46,8 @@ class InputSocket:
The default value of the input. If not set, the input is mandatory.
:param is_variadic:
Whether the input is variadic or not.
:param is_greedy
Whether the input is a greedy variadic or not.
:param senders:
The list of components that send data to this input.
"""
@ -45,6 +56,7 @@ class InputSocket:
type: Type
default_value: Any = _empty
is_variadic: bool = field(init=False)
is_greedy: bool = field(init=False)
senders: List[str] = field(default_factory=list)
@property
@ -55,9 +67,14 @@ class InputSocket:
def __post_init__(self):
try:
# __metadata__ is a tuple
self.is_variadic = self.type.__metadata__[0] == HAYSTACK_VARIADIC_ANNOTATION
self.is_variadic = self.type.__metadata__[0] in [
HAYSTACK_VARIADIC_ANNOTATION,
HAYSTACK_GREEDY_VARIADIC_ANNOTATION,
]
self.is_greedy = self.type.__metadata__[0] == HAYSTACK_GREEDY_VARIADIC_ANNOTATION
except AttributeError:
self.is_variadic = False
self.is_greedy = False
if self.is_variadic:
# We need to "unpack" the type inside the Variadic annotation,
# otherwise the pipeline connection api will try to match

View File

@ -978,9 +978,8 @@ class PipelineBase:
receiver = self.graph.nodes[receiver_name]["instance"]
pair = (receiver_name, receiver)
is_greedy = getattr(receiver, "__haystack_is_greedy__", False)
if receiver_socket.is_variadic:
if is_greedy:
if receiver_socket.is_greedy:
# If the receiver is greedy, we can run it as soon as possible.
# First we remove it from the status lists it's in if it's there or
# we risk running it multiple times.
@ -1214,7 +1213,7 @@ def _connections_status(
def _is_lazy_variadic(c: Component) -> bool:
"""
Small utility function to check if a Component has a Variadic input that is not greedy
Small utility function to check if a Component has at least a Variadic input and no GreedyVariadic input.
"""
is_variadic = any(
socket.is_variadic
@ -1222,7 +1221,10 @@ def _is_lazy_variadic(c: Component) -> bool:
)
if not is_variadic:
return False
return not getattr(c, "__haystack_is_greedy__", False)
return not any(
socket.is_greedy
for socket in c.__haystack_input__._sockets_dict.values() # type: ignore
)
def _has_all_inputs_with_defaults(c: Component) -> bool:

View File

@ -0,0 +1,13 @@
---
enhancements:
- |
Add new `GreedyVariadic` input type. This has a similar behaviour to `Variadic` input type
as it can be connected to multiple output sockets, though the Pipeline will run it as soon
as it receives an input without waiting for others.
This replaces the `is_greedy` argument in the `@component` decorator.
If you had a Component with a `Variadic` input type and `@component(is_greedy=True)` you need
to change the type to `GreedyVariadic` and remove `is_greedy=true` from `@component`.
deprecations:
- |
`@component` decorator `is_greedy` argument is deprecated and will be removed in version `2.7.0`.
Use `GreedyVariadic` type instead.

View File

@ -430,57 +430,6 @@ def test_repr_added_to_pipeline():
assert repr(comp) == f"{object.__repr__(comp)}\nmy_component\nInputs:\n - value: int\nOutputs:\n - value: int"
def test_is_greedy_default_with_variadic_input():
@component
class MockComponent:
@component.output_types(value=int)
def run(self, value: Variadic[int]):
return {"value": value}
assert not MockComponent.__haystack_is_greedy__
assert not MockComponent().__haystack_is_greedy__
def test_is_greedy_default_without_variadic_input():
@component
class MockComponent:
@component.output_types(value=int)
def run(self, value: int):
return {"value": value}
assert not MockComponent.__haystack_is_greedy__
assert not MockComponent().__haystack_is_greedy__
def test_is_greedy_flag_with_variadic_input():
@component(is_greedy=True)
class MockComponent:
@component.output_types(value=int)
def run(self, value: Variadic[int]):
return {"value": value}
assert MockComponent.__haystack_is_greedy__
assert MockComponent().__haystack_is_greedy__
def test_is_greedy_flag_without_variadic_input(caplog):
caplog.set_level(logging.WARNING)
@component(is_greedy=True)
class MockComponent:
@component.output_types(value=int)
def run(self, value: int):
return {"value": value}
assert MockComponent.__haystack_is_greedy__
assert caplog.text == ""
assert MockComponent().__haystack_is_greedy__
assert (
"Component 'MockComponent' has no variadic input, but it's marked as greedy."
" This is not supported and can lead to unexpected behavior.\n" in caplog.text
)
def test_pre_init_hooking():
@component
class MockComponent:

View File

@ -11,7 +11,7 @@ from haystack import Document
from haystack.components.builders import PromptBuilder, AnswerBuilder
from haystack.components.joiners import BranchJoiner
from haystack.core.component import component
from haystack.core.component.types import InputSocket, OutputSocket, Variadic
from haystack.core.component.types import InputSocket, OutputSocket, Variadic, GreedyVariadic
from haystack.core.errors import DeserializationError, PipelineConnectError, PipelineDrawingError, PipelineError
from haystack.core.pipeline import Pipeline, PredefinedPipeline
from haystack.core.pipeline.base import (
@ -20,6 +20,7 @@ from haystack.core.pipeline.base import (
_dequeue_component,
_enqueue_waiting_component,
_dequeue_waiting_component,
_is_lazy_variadic,
)
from haystack.core.serialization import DeserializationCallbacks
from haystack.testing.factory import component_class
@ -1628,3 +1629,19 @@ class TestPipeline:
waiting_queue = [("document_builder", document_builder), ("document_joiner", document_joiner)]
_dequeue_waiting_component(("document_builder", document_builder), waiting_queue)
assert waiting_queue == [("document_joiner", document_joiner)]
def test__is_lazy_variadic(self):
VariadicAndGreedyVariadic = component_class(
"VariadicAndGreedyVariadic", input_types={"variadic": Variadic[int], "greedy_variadic": GreedyVariadic[int]}
)
NonVariadic = component_class("NonVariadic", input_types={"value": int})
VariadicNonGreedyVariadic = component_class(
"VariadicNonGreedyVariadic", input_types={"variadic": Variadic[int]}
)
NonVariadicAndGreedyVariadic = component_class(
"NonVariadicAndGreedyVariadic", input_types={"greedy_variadic": GreedyVariadic[int]}
)
assert not _is_lazy_variadic(VariadicAndGreedyVariadic())
assert not _is_lazy_variadic(NonVariadic())
assert _is_lazy_variadic(VariadicNonGreedyVariadic())
assert not _is_lazy_variadic(NonVariadicAndGreedyVariadic())