mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 13:06:29 +00:00
feat: Deprecate @component decorator is_greedy argument (#8400)
* Deprecate @component decorator is_greedy argument * Fix some typos and docstrings * Add _is_lazy_variadic test
This commit is contained in:
parent
2cc76beacd
commit
0df379e6a2
@ -5,13 +5,13 @@
|
||||
from typing import Any, Dict, Type
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.core.component.types import Variadic
|
||||
from haystack.core.component.types import GreedyVariadic
|
||||
from haystack.utils import deserialize_type, serialize_type
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@component(is_greedy=True)
|
||||
@component()
|
||||
class BranchJoiner:
|
||||
"""
|
||||
A component to join different branches of a pipeline into one single output.
|
||||
@ -100,7 +100,7 @@ class BranchJoiner:
|
||||
"""
|
||||
self.type_ = type_
|
||||
# type_'s type can't be determined statically
|
||||
component.set_input_types(self, value=Variadic[type_]) # type: ignore
|
||||
component.set_input_types(self, value=GreedyVariadic[type_]) # type: ignore
|
||||
component.set_output_types(self, value=type_)
|
||||
|
||||
def to_dict(self):
|
||||
|
||||
@ -71,6 +71,7 @@ method decorated with `@component.input`. This dataclass contains:
|
||||
|
||||
import inspect
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
@ -292,17 +293,6 @@ class ComponentMeta(type):
|
||||
# We use this flag to check that.
|
||||
instance.__haystack_added_to_pipeline__ = None
|
||||
|
||||
# Only Components with variadic inputs can be greedy. If the user set the greedy flag
|
||||
# to True, but the component doesn't have a variadic input, we set it to False.
|
||||
# We can have this information only at instance creation time, so we do it here.
|
||||
is_variadic = any(socket.is_variadic for socket in instance.__haystack_input__._sockets_dict.values())
|
||||
if not is_variadic and cls.__haystack_is_greedy__:
|
||||
logger.warning(
|
||||
"Component '{component}' has no variadic input, but it's marked as greedy. "
|
||||
"This is not supported and can lead to unexpected behavior.",
|
||||
component=cls.__name__,
|
||||
)
|
||||
|
||||
return instance
|
||||
|
||||
|
||||
@ -497,12 +487,21 @@ class _Component:
|
||||
|
||||
return output_types_decorator
|
||||
|
||||
def _component(self, cls, is_greedy: bool = False):
|
||||
def _component(self, cls, is_greedy: Optional[bool] = None):
|
||||
"""
|
||||
Decorator validating the structure of the component and registering it in the components registry.
|
||||
"""
|
||||
logger.debug("Registering {component} as a component", component=cls)
|
||||
|
||||
if is_greedy is not None:
|
||||
msg = (
|
||||
"The 'is_greedy' argument is deprecated and will be removed in version '2.7.0'. "
|
||||
"Change the 'Variadic' input of your Component to 'GreedyVariadic' instead."
|
||||
)
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
else:
|
||||
is_greedy = False
|
||||
|
||||
# Check for required methods and fail as soon as possible
|
||||
if not hasattr(cls, "run"):
|
||||
raise ComponentError(f"{cls.__name__} must have a 'run()' method. See the docs for more information.")
|
||||
@ -542,15 +541,9 @@ class _Component:
|
||||
# Override the __repr__ method with a default one
|
||||
cls.__repr__ = _component_repr
|
||||
|
||||
# The greedy flag can be True only if the component has a variadic input.
|
||||
# At this point of the lifetime of the component, we can't reliably know if it has a variadic input.
|
||||
# So we set it to whatever the user specified, during the instance creation we'll change it if needed
|
||||
# since we'll have access to the input sockets and check if any of them is variadic.
|
||||
setattr(cls, "__haystack_is_greedy__", is_greedy)
|
||||
|
||||
return cls
|
||||
|
||||
def __call__(self, cls: Optional[type] = None, is_greedy: bool = False):
|
||||
def __call__(self, cls: Optional[type] = None, is_greedy: Optional[bool] = None):
|
||||
# We must wrap the call to the decorator in a function for it to work
|
||||
# correctly with or without parens
|
||||
def wrap(cls):
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Any, Iterable, List, Type, TypeVar, get_args
|
||||
from typing_extensions import Annotated, TypeAlias # Python 3.8 compatibility
|
||||
|
||||
HAYSTACK_VARIADIC_ANNOTATION = "__haystack__variadic_t"
|
||||
HAYSTACK_GREEDY_VARIADIC_ANNOTATION = "__haystack__greedy_variadic_t"
|
||||
|
||||
# # Generic type variable used in the Variadic container
|
||||
T = TypeVar("T")
|
||||
@ -16,9 +17,17 @@ T = TypeVar("T")
|
||||
# Variadic is a custom annotation type we use to mark input types.
|
||||
# This type doesn't do anything else than "marking" the contained
|
||||
# type so it can be used in the `InputSocket` creation where we
|
||||
# check that its annotation equals to CANALS_VARIADIC_ANNOTATION
|
||||
# check that its annotation equals to HAYSTACK_VARIADIC_ANNOTATION
|
||||
Variadic: TypeAlias = Annotated[Iterable[T], HAYSTACK_VARIADIC_ANNOTATION]
|
||||
|
||||
# GreedyVariadic type is similar to Variadic.
|
||||
# The only difference is the way it's treated by the Pipeline when input is received
|
||||
# in a socket with this type.
|
||||
# Instead of waiting for other inputs to be received, Components that have a GreedyVariadic
|
||||
# input will be run right after receiving the first input.
|
||||
# Even if there are multiple connections to that socket.
|
||||
GreedyVariadic: TypeAlias = Annotated[Iterable[T], HAYSTACK_GREEDY_VARIADIC_ANNOTATION]
|
||||
|
||||
|
||||
class _empty:
|
||||
"""Custom object for marking InputSocket.default_value as not set."""
|
||||
@ -37,6 +46,8 @@ class InputSocket:
|
||||
The default value of the input. If not set, the input is mandatory.
|
||||
:param is_variadic:
|
||||
Whether the input is variadic or not.
|
||||
:param is_greedy
|
||||
Whether the input is a greedy variadic or not.
|
||||
:param senders:
|
||||
The list of components that send data to this input.
|
||||
"""
|
||||
@ -45,6 +56,7 @@ class InputSocket:
|
||||
type: Type
|
||||
default_value: Any = _empty
|
||||
is_variadic: bool = field(init=False)
|
||||
is_greedy: bool = field(init=False)
|
||||
senders: List[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
@ -55,9 +67,14 @@ class InputSocket:
|
||||
def __post_init__(self):
|
||||
try:
|
||||
# __metadata__ is a tuple
|
||||
self.is_variadic = self.type.__metadata__[0] == HAYSTACK_VARIADIC_ANNOTATION
|
||||
self.is_variadic = self.type.__metadata__[0] in [
|
||||
HAYSTACK_VARIADIC_ANNOTATION,
|
||||
HAYSTACK_GREEDY_VARIADIC_ANNOTATION,
|
||||
]
|
||||
self.is_greedy = self.type.__metadata__[0] == HAYSTACK_GREEDY_VARIADIC_ANNOTATION
|
||||
except AttributeError:
|
||||
self.is_variadic = False
|
||||
self.is_greedy = False
|
||||
if self.is_variadic:
|
||||
# We need to "unpack" the type inside the Variadic annotation,
|
||||
# otherwise the pipeline connection api will try to match
|
||||
|
||||
@ -978,9 +978,8 @@ class PipelineBase:
|
||||
receiver = self.graph.nodes[receiver_name]["instance"]
|
||||
pair = (receiver_name, receiver)
|
||||
|
||||
is_greedy = getattr(receiver, "__haystack_is_greedy__", False)
|
||||
if receiver_socket.is_variadic:
|
||||
if is_greedy:
|
||||
if receiver_socket.is_greedy:
|
||||
# If the receiver is greedy, we can run it as soon as possible.
|
||||
# First we remove it from the status lists it's in if it's there or
|
||||
# we risk running it multiple times.
|
||||
@ -1214,7 +1213,7 @@ def _connections_status(
|
||||
|
||||
def _is_lazy_variadic(c: Component) -> bool:
|
||||
"""
|
||||
Small utility function to check if a Component has a Variadic input that is not greedy
|
||||
Small utility function to check if a Component has at least a Variadic input and no GreedyVariadic input.
|
||||
"""
|
||||
is_variadic = any(
|
||||
socket.is_variadic
|
||||
@ -1222,7 +1221,10 @@ def _is_lazy_variadic(c: Component) -> bool:
|
||||
)
|
||||
if not is_variadic:
|
||||
return False
|
||||
return not getattr(c, "__haystack_is_greedy__", False)
|
||||
return not any(
|
||||
socket.is_greedy
|
||||
for socket in c.__haystack_input__._sockets_dict.values() # type: ignore
|
||||
)
|
||||
|
||||
|
||||
def _has_all_inputs_with_defaults(c: Component) -> bool:
|
||||
|
||||
@ -0,0 +1,13 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Add new `GreedyVariadic` input type. This has a similar behaviour to `Variadic` input type
|
||||
as it can be connected to multiple output sockets, though the Pipeline will run it as soon
|
||||
as it receives an input without waiting for others.
|
||||
This replaces the `is_greedy` argument in the `@component` decorator.
|
||||
If you had a Component with a `Variadic` input type and `@component(is_greedy=True)` you need
|
||||
to change the type to `GreedyVariadic` and remove `is_greedy=true` from `@component`.
|
||||
deprecations:
|
||||
- |
|
||||
`@component` decorator `is_greedy` argument is deprecated and will be removed in version `2.7.0`.
|
||||
Use `GreedyVariadic` type instead.
|
||||
@ -430,57 +430,6 @@ def test_repr_added_to_pipeline():
|
||||
assert repr(comp) == f"{object.__repr__(comp)}\nmy_component\nInputs:\n - value: int\nOutputs:\n - value: int"
|
||||
|
||||
|
||||
def test_is_greedy_default_with_variadic_input():
|
||||
@component
|
||||
class MockComponent:
|
||||
@component.output_types(value=int)
|
||||
def run(self, value: Variadic[int]):
|
||||
return {"value": value}
|
||||
|
||||
assert not MockComponent.__haystack_is_greedy__
|
||||
assert not MockComponent().__haystack_is_greedy__
|
||||
|
||||
|
||||
def test_is_greedy_default_without_variadic_input():
|
||||
@component
|
||||
class MockComponent:
|
||||
@component.output_types(value=int)
|
||||
def run(self, value: int):
|
||||
return {"value": value}
|
||||
|
||||
assert not MockComponent.__haystack_is_greedy__
|
||||
assert not MockComponent().__haystack_is_greedy__
|
||||
|
||||
|
||||
def test_is_greedy_flag_with_variadic_input():
|
||||
@component(is_greedy=True)
|
||||
class MockComponent:
|
||||
@component.output_types(value=int)
|
||||
def run(self, value: Variadic[int]):
|
||||
return {"value": value}
|
||||
|
||||
assert MockComponent.__haystack_is_greedy__
|
||||
assert MockComponent().__haystack_is_greedy__
|
||||
|
||||
|
||||
def test_is_greedy_flag_without_variadic_input(caplog):
|
||||
caplog.set_level(logging.WARNING)
|
||||
|
||||
@component(is_greedy=True)
|
||||
class MockComponent:
|
||||
@component.output_types(value=int)
|
||||
def run(self, value: int):
|
||||
return {"value": value}
|
||||
|
||||
assert MockComponent.__haystack_is_greedy__
|
||||
assert caplog.text == ""
|
||||
assert MockComponent().__haystack_is_greedy__
|
||||
assert (
|
||||
"Component 'MockComponent' has no variadic input, but it's marked as greedy."
|
||||
" This is not supported and can lead to unexpected behavior.\n" in caplog.text
|
||||
)
|
||||
|
||||
|
||||
def test_pre_init_hooking():
|
||||
@component
|
||||
class MockComponent:
|
||||
|
||||
@ -11,7 +11,7 @@ from haystack import Document
|
||||
from haystack.components.builders import PromptBuilder, AnswerBuilder
|
||||
from haystack.components.joiners import BranchJoiner
|
||||
from haystack.core.component import component
|
||||
from haystack.core.component.types import InputSocket, OutputSocket, Variadic
|
||||
from haystack.core.component.types import InputSocket, OutputSocket, Variadic, GreedyVariadic
|
||||
from haystack.core.errors import DeserializationError, PipelineConnectError, PipelineDrawingError, PipelineError
|
||||
from haystack.core.pipeline import Pipeline, PredefinedPipeline
|
||||
from haystack.core.pipeline.base import (
|
||||
@ -20,6 +20,7 @@ from haystack.core.pipeline.base import (
|
||||
_dequeue_component,
|
||||
_enqueue_waiting_component,
|
||||
_dequeue_waiting_component,
|
||||
_is_lazy_variadic,
|
||||
)
|
||||
from haystack.core.serialization import DeserializationCallbacks
|
||||
from haystack.testing.factory import component_class
|
||||
@ -1628,3 +1629,19 @@ class TestPipeline:
|
||||
waiting_queue = [("document_builder", document_builder), ("document_joiner", document_joiner)]
|
||||
_dequeue_waiting_component(("document_builder", document_builder), waiting_queue)
|
||||
assert waiting_queue == [("document_joiner", document_joiner)]
|
||||
|
||||
def test__is_lazy_variadic(self):
|
||||
VariadicAndGreedyVariadic = component_class(
|
||||
"VariadicAndGreedyVariadic", input_types={"variadic": Variadic[int], "greedy_variadic": GreedyVariadic[int]}
|
||||
)
|
||||
NonVariadic = component_class("NonVariadic", input_types={"value": int})
|
||||
VariadicNonGreedyVariadic = component_class(
|
||||
"VariadicNonGreedyVariadic", input_types={"variadic": Variadic[int]}
|
||||
)
|
||||
NonVariadicAndGreedyVariadic = component_class(
|
||||
"NonVariadicAndGreedyVariadic", input_types={"greedy_variadic": GreedyVariadic[int]}
|
||||
)
|
||||
assert not _is_lazy_variadic(VariadicAndGreedyVariadic())
|
||||
assert not _is_lazy_variadic(NonVariadic())
|
||||
assert _is_lazy_variadic(VariadicNonGreedyVariadic())
|
||||
assert not _is_lazy_variadic(NonVariadicAndGreedyVariadic())
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user