mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-16 09:38:07 +00:00
feat: Add ConditionalRouter Haystack 2.x component (#6147)
Co-authored-by: Massimiliano Pippi <mpippi@gmail.com> Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
This commit is contained in:
parent
70e40eae5c
commit
b557f3035e
@ -1,7 +1,7 @@
|
|||||||
from haystack.preview.components.routers.document_joiner import DocumentJoiner
|
from haystack.preview.components.routers.document_joiner import DocumentJoiner
|
||||||
from haystack.preview.components.routers.file_type_router import FileTypeRouter
|
from haystack.preview.components.routers.file_type_router import FileTypeRouter
|
||||||
from haystack.preview.components.routers.metadata_router import MetadataRouter
|
from haystack.preview.components.routers.metadata_router import MetadataRouter
|
||||||
|
from haystack.preview.components.routers.conditional_router import ConditionalRouter
|
||||||
from haystack.preview.components.routers.text_language_router import TextLanguageRouter
|
from haystack.preview.components.routers.text_language_router import TextLanguageRouter
|
||||||
|
|
||||||
|
__all__ = ["DocumentJoiner", "FileTypeRouter", "MetadataRouter", "TextLanguageRouter", "ConditionalRouter"]
|
||||||
__all__ = ["DocumentJoiner", "FileTypeRouter", "MetadataRouter", "TextLanguageRouter"]
|
|
||||||
|
|||||||
347
haystack/preview/components/routers/conditional_router.py
Normal file
347
haystack/preview/components/routers/conditional_router.py
Normal file
@ -0,0 +1,347 @@
|
|||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from typing import List, Dict, Any, Set, get_origin
|
||||||
|
|
||||||
|
from jinja2 import meta, Environment, TemplateSyntaxError
|
||||||
|
from jinja2.nativetypes import NativeEnvironment
|
||||||
|
|
||||||
|
from haystack.preview import component, default_from_dict, default_to_dict, DeserializationError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class NoRouteSelectedException(Exception):
|
||||||
|
"""Exception raised when no route is selected in ConditionalRouter."""
|
||||||
|
|
||||||
|
|
||||||
|
class RouteConditionException(Exception):
|
||||||
|
"""Exception raised when there is an error parsing or evaluating the condition expression in ConditionalRouter."""
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_type(target: Any) -> str:
|
||||||
|
"""
|
||||||
|
Serializes a type or an instance to its string representation, including the module name.
|
||||||
|
|
||||||
|
This function handles types, instances of types, and special typing objects.
|
||||||
|
It assumes that non-typing objects will have a '__name__' attribute and raises
|
||||||
|
an error if a type cannot be serialized.
|
||||||
|
|
||||||
|
:param target: The object to serialize, can be an instance or a type.
|
||||||
|
:type target: Any
|
||||||
|
:return: The string representation of the type.
|
||||||
|
:raises ValueError: If the type cannot be serialized.
|
||||||
|
"""
|
||||||
|
# If the target is a string and contains a dot, treat it as an already serialized type
|
||||||
|
if isinstance(target, str) and "." in target:
|
||||||
|
return target
|
||||||
|
|
||||||
|
# Determine if the target is a type or an instance of a typing object
|
||||||
|
is_type_or_typing = isinstance(target, type) or bool(get_origin(target))
|
||||||
|
type_obj = target if is_type_or_typing else type(target)
|
||||||
|
module = inspect.getmodule(type_obj)
|
||||||
|
type_obj_repr = repr(type_obj)
|
||||||
|
|
||||||
|
if type_obj_repr.startswith("typing."):
|
||||||
|
# e.g., typing.List[int] -> List[int], we'll add the module below
|
||||||
|
type_name = type_obj_repr.split(".", 1)[1]
|
||||||
|
elif hasattr(type_obj, "__name__"):
|
||||||
|
type_name = type_obj.__name__
|
||||||
|
else:
|
||||||
|
# If type cannot be serialized, raise an error
|
||||||
|
raise ValueError(f"Could not serialize type: {type_obj_repr}")
|
||||||
|
|
||||||
|
# Construct the full path with module name if available
|
||||||
|
if module and hasattr(module, "__name__"):
|
||||||
|
if module.__name__ == "builtins":
|
||||||
|
# omit the module name for builtins, it just clutters the output
|
||||||
|
# e.g. instead of 'builtins.str', we'll just return 'str'
|
||||||
|
full_path = type_name
|
||||||
|
else:
|
||||||
|
full_path = f"{module.__name__}.{type_name}"
|
||||||
|
else:
|
||||||
|
full_path = type_name
|
||||||
|
|
||||||
|
return full_path
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_type(type_str: str) -> Any:
|
||||||
|
"""
|
||||||
|
Deserializes a type given its full import path as a string, including nested generic types.
|
||||||
|
|
||||||
|
This function will dynamically import the module if it's not already imported
|
||||||
|
and then retrieve the type object from it. It also handles nested generic types like 'typing.List[typing.Dict[int, str]]'.
|
||||||
|
|
||||||
|
:param type_str: The string representation of the type's full import path.
|
||||||
|
:return: The deserialized type object.
|
||||||
|
:raises DeserializationError: If the type cannot be deserialized due to missing module or type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def parse_generic_args(args_str):
|
||||||
|
args = []
|
||||||
|
bracket_count = 0
|
||||||
|
current_arg = ""
|
||||||
|
|
||||||
|
for char in args_str:
|
||||||
|
if char == "[":
|
||||||
|
bracket_count += 1
|
||||||
|
elif char == "]":
|
||||||
|
bracket_count -= 1
|
||||||
|
|
||||||
|
if char == "," and bracket_count == 0:
|
||||||
|
args.append(current_arg.strip())
|
||||||
|
current_arg = ""
|
||||||
|
else:
|
||||||
|
current_arg += char
|
||||||
|
|
||||||
|
if current_arg:
|
||||||
|
args.append(current_arg.strip())
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
if "[" in type_str and type_str.endswith("]"):
|
||||||
|
# Handle generics
|
||||||
|
main_type_str, generics_str = type_str.split("[", 1)
|
||||||
|
generics_str = generics_str[:-1]
|
||||||
|
|
||||||
|
main_type = deserialize_type(main_type_str)
|
||||||
|
generic_args = tuple(deserialize_type(arg) for arg in parse_generic_args(generics_str))
|
||||||
|
|
||||||
|
# Reconstruct
|
||||||
|
return main_type[generic_args]
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Handle non-generics
|
||||||
|
parts = type_str.split(".")
|
||||||
|
module_name = ".".join(parts[:-1]) or "builtins"
|
||||||
|
type_name = parts[-1]
|
||||||
|
|
||||||
|
module = sys.modules.get(module_name)
|
||||||
|
if not module:
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
except ImportError as e:
|
||||||
|
raise DeserializationError(f"Could not import the module: {module_name}") from e
|
||||||
|
|
||||||
|
deserialized_type = getattr(module, type_name, None)
|
||||||
|
if not deserialized_type:
|
||||||
|
raise DeserializationError(f"Could not locate the type: {type_name} in the module: {module_name}")
|
||||||
|
|
||||||
|
return deserialized_type
|
||||||
|
|
||||||
|
|
||||||
|
@component
|
||||||
|
class ConditionalRouter:
|
||||||
|
"""
|
||||||
|
ConditionalRouter in Haystack 2.x pipelines is designed to manage data routing based on specific conditions.
|
||||||
|
This is achieved by defining a list named 'routes'. Each element in this list is a dictionary representing a
|
||||||
|
single route.
|
||||||
|
|
||||||
|
A route dictionary comprises four key elements:
|
||||||
|
- 'condition': A Jinja2 string expression that determines if the route is selected.
|
||||||
|
- 'output': A Jinja2 expression defining the route's output value.
|
||||||
|
- 'output_type': The type of the output data (e.g., str, List[int]).
|
||||||
|
- 'output_name': The name under which the `output` value of the route is published. This name is used to connect
|
||||||
|
the router to other components in the pipeline.
|
||||||
|
|
||||||
|
Here's an example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from haystack.preview.components.routers import ConditionalRouter
|
||||||
|
|
||||||
|
routes = [
|
||||||
|
{
|
||||||
|
"condition": "{{streams|length > 2}}",
|
||||||
|
"output": "{{streams}}",
|
||||||
|
"output_name": "enough_streams",
|
||||||
|
"output_type": List[int],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"condition": "{{streams|length <= 2}}",
|
||||||
|
"output": "{{streams}}",
|
||||||
|
"output_name": "insufficient_streams",
|
||||||
|
"output_type": List[int],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
router = ConditionalRouter(routes)
|
||||||
|
# When 'streams' has more than 2 items, 'enough_streams' output will activate, emitting the list [1, 2, 3]
|
||||||
|
kwargs = {"streams": [1, 2, 3], "query": "Haystack"}
|
||||||
|
result = router.run(**kwargs)
|
||||||
|
assert result == {"enough_streams": [1, 2, 3]}
|
||||||
|
```
|
||||||
|
|
||||||
|
In this example, we configure two routes. The first route sends the 'streams' value to 'enough_streams' if the
|
||||||
|
stream count exceeds two. Conversely, the second route directs 'streams' to 'insufficient_streams' when there
|
||||||
|
are two or fewer streams.
|
||||||
|
|
||||||
|
In the pipeline setup, the router is connected to other components using the output names. For example, the
|
||||||
|
'enough_streams' output might be connected to another component that processes the streams, while the
|
||||||
|
'insufficient_streams' output might be connected to a component that fetches more streams, and so on.
|
||||||
|
|
||||||
|
Here is a pseudocode example of a pipeline that uses the ConditionalRouter and routes fetched ByteStreams to
|
||||||
|
different components depending on the number of streams fetched:
|
||||||
|
|
||||||
|
```
|
||||||
|
from typing import List
|
||||||
|
from haystack import Pipeline
|
||||||
|
from haystack.preview.dataclasses import ByteStream
|
||||||
|
from haystack.preview.components.routers import ConditionalRouter
|
||||||
|
|
||||||
|
routes = [
|
||||||
|
{
|
||||||
|
"condition": "{{streams|length > 2}}",
|
||||||
|
"output": "{{streams}}",
|
||||||
|
"output_name": "enough_streams",
|
||||||
|
"output_type": List[ByteStream],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"condition": "{{streams|length <= 2}}",
|
||||||
|
"output": "{{streams}}",
|
||||||
|
"output_name": "insufficient_streams",
|
||||||
|
"output_type": List[ByteStream],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
pipe = Pipeline()
|
||||||
|
pipe.add_component("router", router)
|
||||||
|
...
|
||||||
|
pipe.connect("router.enough_streams", "some_component_a.streams")
|
||||||
|
pipe.connect("router.insufficient_streams", "some_component_b.streams_or_some_other_input")
|
||||||
|
...
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, routes: List[Dict]):
|
||||||
|
"""
|
||||||
|
Initializes the ConditionalRouter with a list of routes detailing the conditions for routing.
|
||||||
|
|
||||||
|
:param routes: A list of dictionaries, each defining a route with a boolean condition expression
|
||||||
|
('condition'), an output value ('output'), the output type ('output_type') and
|
||||||
|
('output_name') that defines the output name for the variable defined in 'output'.
|
||||||
|
"""
|
||||||
|
self._validate_routes(routes)
|
||||||
|
self.routes: List[dict] = routes
|
||||||
|
|
||||||
|
# Create a Jinja native environment to inspect variables in the condition templates
|
||||||
|
env = NativeEnvironment()
|
||||||
|
|
||||||
|
# Inspect the routes to determine input and output types.
|
||||||
|
input_types: Set[str] = set() # let's just store the name, type will always be Any
|
||||||
|
output_types: Dict[str, str] = {}
|
||||||
|
|
||||||
|
for route in routes:
|
||||||
|
# extract inputs
|
||||||
|
route_input_names = self._extract_variables(env, [route["output"], route["condition"]])
|
||||||
|
input_types.update(route_input_names)
|
||||||
|
|
||||||
|
# extract outputs
|
||||||
|
output_types.update({route["output_name"]: route["output_type"]})
|
||||||
|
|
||||||
|
component.set_input_types(self, **{var: Any for var in input_types})
|
||||||
|
component.set_output_types(self, **output_types)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
for route in self.routes:
|
||||||
|
# output_type needs to be serialized to a string
|
||||||
|
route["output_type"] = serialize_type(route["output_type"])
|
||||||
|
|
||||||
|
return default_to_dict(self, routes=self.routes)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "ConditionalRouter":
|
||||||
|
init_params = data.get("init_parameters", {})
|
||||||
|
routes = init_params.get("routes")
|
||||||
|
for route in routes:
|
||||||
|
# output_type needs to be deserialized from a string to a type
|
||||||
|
route["output_type"] = deserialize_type(route["output_type"])
|
||||||
|
return default_from_dict(cls, data)
|
||||||
|
|
||||||
|
def run(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Executes the routing logic by evaluating the specified boolean condition expressions
|
||||||
|
for each route in the order they are listed. The method directs the flow
|
||||||
|
of data to the output specified in the first route, whose expression
|
||||||
|
evaluates to True. If no route's expression evaluates to True, an exception
|
||||||
|
is raised.
|
||||||
|
|
||||||
|
:param kwargs: A dictionary containing the pipeline variables, which should
|
||||||
|
include all variables used in the "condition" templates.
|
||||||
|
|
||||||
|
:return: A dictionary containing the output and the corresponding result,
|
||||||
|
based on the first route whose expression evaluates to True.
|
||||||
|
|
||||||
|
:raises NoRouteSelectedException: If no route's expression evaluates to True.
|
||||||
|
"""
|
||||||
|
# Create a Jinja native environment to evaluate the condition templates as Python expressions
|
||||||
|
env = NativeEnvironment()
|
||||||
|
|
||||||
|
for route in self.routes:
|
||||||
|
try:
|
||||||
|
t = env.from_string(route["condition"])
|
||||||
|
if t.render(**kwargs):
|
||||||
|
# We now evaluate the `output` expression to determine the route output
|
||||||
|
t_output = env.from_string(route["output"])
|
||||||
|
output = t_output.render(**kwargs)
|
||||||
|
# and return the output as a dictionary under the output_name key
|
||||||
|
return {route["output_name"]: output}
|
||||||
|
except Exception as e:
|
||||||
|
raise RouteConditionException(f"Error evaluating condition for route '{route}': {e}") from e
|
||||||
|
|
||||||
|
raise NoRouteSelectedException(f"No route fired. Routes: {self.routes}")
|
||||||
|
|
||||||
|
def _validate_routes(self, routes: List[Dict]):
|
||||||
|
"""
|
||||||
|
Validates a list of routes.
|
||||||
|
|
||||||
|
:param routes: A list of routes.
|
||||||
|
:type routes: List[Dict]
|
||||||
|
"""
|
||||||
|
env = NativeEnvironment()
|
||||||
|
for route in routes:
|
||||||
|
try:
|
||||||
|
keys = set(route.keys())
|
||||||
|
except AttributeError:
|
||||||
|
raise ValueError(f"Route must be a dictionary, got: {route}")
|
||||||
|
|
||||||
|
mandatory_fields = {"condition", "output", "output_type", "output_name"}
|
||||||
|
has_all_mandatory_fields = mandatory_fields.issubset(keys)
|
||||||
|
if not has_all_mandatory_fields:
|
||||||
|
raise ValueError(
|
||||||
|
f"Route must contain 'condition', 'output', 'output_type' and 'output_name' fields: {route}"
|
||||||
|
)
|
||||||
|
for field in ["condition", "output"]:
|
||||||
|
if not self._validate_template(env, route[field]):
|
||||||
|
raise ValueError(f"Invalid template for field '{field}': {route[field]}")
|
||||||
|
|
||||||
|
def _extract_variables(self, env: NativeEnvironment, templates: List[str]) -> Set[str]:
|
||||||
|
"""
|
||||||
|
Extracts all variables from a list of Jinja template strings.
|
||||||
|
|
||||||
|
:param env: A Jinja environment.
|
||||||
|
:type env: Environment
|
||||||
|
:param templates: A list of Jinja template strings.
|
||||||
|
:type templates: List[str]
|
||||||
|
:return: A set of variable names.
|
||||||
|
"""
|
||||||
|
variables = set()
|
||||||
|
for template in templates:
|
||||||
|
ast = env.parse(template)
|
||||||
|
variables.update(meta.find_undeclared_variables(ast))
|
||||||
|
return variables
|
||||||
|
|
||||||
|
def _validate_template(self, env: Environment, template_text: str):
|
||||||
|
"""
|
||||||
|
Validates a template string by parsing it with Jinja.
|
||||||
|
|
||||||
|
:param env: A Jinja environment.
|
||||||
|
:type env: Environment
|
||||||
|
:param template_text: A Jinja template string.
|
||||||
|
:type template_text: str
|
||||||
|
:return: True if the template is valid, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
env.parse(template_text)
|
||||||
|
return True
|
||||||
|
except TemplateSyntaxError:
|
||||||
|
return False
|
||||||
6
releasenotes/notes/add-router-f1f0cec79b1efe9a.yaml
Normal file
6
releasenotes/notes/add-router-f1f0cec79b1efe9a.yaml
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
---
|
||||||
|
preview:
|
||||||
|
- |
|
||||||
|
Add `ConditionalRouter` component to enhance the conditional pipeline routing capabilities.
|
||||||
|
The `ConditionalRouter` component orchestrates the flow of data by evaluating specified route conditions
|
||||||
|
to determine the appropriate route among a set of provided route alternatives.
|
||||||
324
test/preview/components/routers/test_conditional_router.py
Normal file
324
test/preview/components/routers/test_conditional_router.py
Normal file
@ -0,0 +1,324 @@
|
|||||||
|
import copy
|
||||||
|
import typing
|
||||||
|
from typing import List, Dict
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from haystack.preview.components.routers import ConditionalRouter
|
||||||
|
from haystack.preview.components.routers.conditional_router import (
|
||||||
|
NoRouteSelectedException,
|
||||||
|
serialize_type,
|
||||||
|
deserialize_type,
|
||||||
|
)
|
||||||
|
from haystack.preview.dataclasses import ChatMessage
|
||||||
|
|
||||||
|
|
||||||
|
class TestRouter:
|
||||||
|
@pytest.fixture
|
||||||
|
def routes(self):
|
||||||
|
return [
|
||||||
|
{"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str, "output_name": "query"},
|
||||||
|
{
|
||||||
|
"condition": "{{streams|length >= 2}}",
|
||||||
|
"output": "{{streams}}",
|
||||||
|
"output_type": List[int],
|
||||||
|
"output_name": "streams",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def router(self, routes):
|
||||||
|
return ConditionalRouter(routes)
|
||||||
|
|
||||||
|
def test_missing_mandatory_fields(self):
|
||||||
|
"""
|
||||||
|
Router raises a ValueError if each route does not contain 'condition', 'output', and 'output_type' keys
|
||||||
|
"""
|
||||||
|
routes = [
|
||||||
|
{"condition": "{{streams|length < 2}}", "output": "{{query}}"},
|
||||||
|
{"condition": "{{streams|length < 2}}", "output_type": str},
|
||||||
|
]
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ConditionalRouter(routes)
|
||||||
|
|
||||||
|
def test_invalid_condition_field(self):
|
||||||
|
"""
|
||||||
|
ConditionalRouter init raises a ValueError if one of the routes contains invalid condition
|
||||||
|
"""
|
||||||
|
# invalid condition field
|
||||||
|
routes = [{"condition": "{{streams|length < 2", "output": "query", "output_type": str, "output_name": "test"}]
|
||||||
|
with pytest.raises(ValueError, match="Invalid template"):
|
||||||
|
ConditionalRouter(routes)
|
||||||
|
|
||||||
|
def test_no_vars_in_output_route_but_with_output_name(self):
|
||||||
|
"""
|
||||||
|
Router can't accept a route with no variables used in the output field
|
||||||
|
"""
|
||||||
|
routes = [
|
||||||
|
{
|
||||||
|
"condition": "{{streams|length > 2}}",
|
||||||
|
"output": "This is a constant",
|
||||||
|
"output_name": "enough_streams",
|
||||||
|
"output_type": str,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
router = ConditionalRouter(routes)
|
||||||
|
kwargs = {"streams": [1, 2, 3], "query": "Haystack"}
|
||||||
|
result = router.run(**kwargs)
|
||||||
|
assert result == {"enough_streams": "This is a constant"}
|
||||||
|
|
||||||
|
def test_mandatory_and_optional_fields_with_extra_fields(self):
|
||||||
|
"""
|
||||||
|
Router accepts a list of routes with mandatory and optional fields but not if some new field is added
|
||||||
|
"""
|
||||||
|
|
||||||
|
routes = [
|
||||||
|
{
|
||||||
|
"condition": "{{streams|length < 2}}",
|
||||||
|
"output": "{{query}}",
|
||||||
|
"output_type": str,
|
||||||
|
"output_name": "test",
|
||||||
|
"bla": "bla",
|
||||||
|
},
|
||||||
|
{"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str},
|
||||||
|
]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ConditionalRouter(routes)
|
||||||
|
|
||||||
|
def test_router_initialized(self, routes):
|
||||||
|
router = ConditionalRouter(routes)
|
||||||
|
|
||||||
|
assert router.routes == routes
|
||||||
|
assert set(router.__canals_input__.keys()) == {"query", "streams"}
|
||||||
|
assert set(router.__canals_output__.keys()) == {"query", "streams"}
|
||||||
|
|
||||||
|
def test_router_evaluate_condition_expressions(self, router):
|
||||||
|
# first route should be selected
|
||||||
|
kwargs = {"streams": [1, 2, 3], "query": "test"}
|
||||||
|
result = router.run(**kwargs)
|
||||||
|
assert result == {"streams": [1, 2, 3]}
|
||||||
|
|
||||||
|
# second route should be selected
|
||||||
|
kwargs = {"streams": [1], "query": "test"}
|
||||||
|
result = router.run(**kwargs)
|
||||||
|
assert result == {"query": "test"}
|
||||||
|
|
||||||
|
def test_router_evaluate_condition_expressions_using_output_slot(self):
|
||||||
|
routes = [
|
||||||
|
{
|
||||||
|
"condition": "{{streams|length > 2}}",
|
||||||
|
"output": "{{streams}}",
|
||||||
|
"output_name": "enough_streams",
|
||||||
|
"output_type": List[int],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"condition": "{{streams|length <= 2}}",
|
||||||
|
"output": "{{streams}}",
|
||||||
|
"output_name": "insufficient_streams",
|
||||||
|
"output_type": List[int],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
router = ConditionalRouter(routes)
|
||||||
|
# enough_streams output slot will be selected with [1, 2, 3] list being outputted
|
||||||
|
kwargs = {"streams": [1, 2, 3], "query": "Haystack"}
|
||||||
|
result = router.run(**kwargs)
|
||||||
|
assert result == {"enough_streams": [1, 2, 3]}
|
||||||
|
|
||||||
|
def test_complex_condition(self):
|
||||||
|
routes = [
|
||||||
|
{
|
||||||
|
"condition": "{{messages[-1].metadata.finish_reason == 'function_call'}}",
|
||||||
|
"output": "{{streams}}",
|
||||||
|
"output_type": List[int],
|
||||||
|
"output_name": "streams",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"condition": "{{True}}",
|
||||||
|
"output": "{{query}}",
|
||||||
|
"output_type": str,
|
||||||
|
"output_name": "query",
|
||||||
|
}, # catch-all condition
|
||||||
|
]
|
||||||
|
router = ConditionalRouter(routes)
|
||||||
|
message = mock.MagicMock()
|
||||||
|
message.metadata.finish_reason = "function_call"
|
||||||
|
result = router.run(messages=[message], streams=[1, 2, 3], query="my query")
|
||||||
|
assert result == {"streams": [1, 2, 3]}
|
||||||
|
|
||||||
|
def test_router_no_route(self, router):
|
||||||
|
# should raise an exception
|
||||||
|
router = ConditionalRouter(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"condition": "{{streams|length < 2}}",
|
||||||
|
"output": "{{query}}",
|
||||||
|
"output_type": str,
|
||||||
|
"output_name": "query",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"condition": "{{streams|length >= 5}}",
|
||||||
|
"output": "{{streams}}",
|
||||||
|
"output_type": List[int],
|
||||||
|
"output_name": "streams",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs = {"streams": [1, 2, 3], "query": "test"}
|
||||||
|
with pytest.raises(NoRouteSelectedException):
|
||||||
|
router.run(**kwargs)
|
||||||
|
|
||||||
|
def test_router_raises_value_error_if_route_not_dictionary(self):
|
||||||
|
"""
|
||||||
|
Router raises a ValueError if each route is not a dictionary
|
||||||
|
"""
|
||||||
|
routes = [
|
||||||
|
{"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str, "output_name": "query"},
|
||||||
|
["{{streams|length >= 2}}", "streams", List[int]],
|
||||||
|
]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ConditionalRouter(routes)
|
||||||
|
|
||||||
|
def test_router_raises_value_error_if_route_missing_keys(self):
|
||||||
|
"""
|
||||||
|
Router raises a ValueError if each route does not contain 'condition', 'output', and 'output_type' keys
|
||||||
|
"""
|
||||||
|
routes = [
|
||||||
|
{"condition": "{{streams|length < 2}}", "output": "{{query}}"},
|
||||||
|
{"condition": "{{streams|length < 2}}", "output_type": str},
|
||||||
|
]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ConditionalRouter(routes)
|
||||||
|
|
||||||
|
def test_output_type_serialization(self):
|
||||||
|
assert serialize_type(str) == "str"
|
||||||
|
assert serialize_type(List[int]) == "typing.List[int]"
|
||||||
|
assert serialize_type(List[Dict[str, int]]) == "typing.List[typing.Dict[str, int]]"
|
||||||
|
assert serialize_type(ChatMessage) == "haystack.preview.dataclasses.chat_message.ChatMessage"
|
||||||
|
assert serialize_type(typing.List[Dict[str, int]]) == "typing.List[typing.Dict[str, int]]"
|
||||||
|
assert serialize_type(List[ChatMessage]) == "typing.List[haystack.preview.dataclasses.chat_message.ChatMessage]"
|
||||||
|
assert (
|
||||||
|
serialize_type(typing.Dict[int, ChatMessage])
|
||||||
|
== "typing.Dict[int, haystack.preview.dataclasses.chat_message.ChatMessage]"
|
||||||
|
)
|
||||||
|
assert serialize_type(int) == "int"
|
||||||
|
assert serialize_type(ChatMessage.from_user("ciao")) == "haystack.preview.dataclasses.chat_message.ChatMessage"
|
||||||
|
|
||||||
|
def test_output_type_deserialization(self):
|
||||||
|
assert deserialize_type("str") == str
|
||||||
|
assert deserialize_type("typing.List[int]") == typing.List[int]
|
||||||
|
assert deserialize_type("typing.List[typing.Dict[str, int]]") == typing.List[Dict[str, int]]
|
||||||
|
assert deserialize_type("typing.Dict[str, int]") == Dict[str, int]
|
||||||
|
assert deserialize_type("typing.Dict[str, typing.List[int]]") == Dict[str, List[int]]
|
||||||
|
assert deserialize_type("typing.List[typing.Dict[str, typing.List[int]]]") == List[Dict[str, List[int]]]
|
||||||
|
assert (
|
||||||
|
deserialize_type("typing.List[haystack.preview.dataclasses.chat_message.ChatMessage]")
|
||||||
|
== typing.List[ChatMessage]
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
deserialize_type("typing.Dict[int, haystack.preview.dataclasses.chat_message.ChatMessage]")
|
||||||
|
== typing.Dict[int, ChatMessage]
|
||||||
|
)
|
||||||
|
assert deserialize_type("haystack.preview.dataclasses.chat_message.ChatMessage") == ChatMessage
|
||||||
|
assert deserialize_type("int") == int
|
||||||
|
|
||||||
|
def test_router_de_serialization(self):
|
||||||
|
routes = [
|
||||||
|
{"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str, "output_name": "query"},
|
||||||
|
{
|
||||||
|
"condition": "{{streams|length >= 2}}",
|
||||||
|
"output": "{{streams}}",
|
||||||
|
"output_type": List[int],
|
||||||
|
"output_name": "streams",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
router = ConditionalRouter(routes)
|
||||||
|
router_dict = router.to_dict()
|
||||||
|
|
||||||
|
# assert that the router dict is correct, with all keys and values being strings
|
||||||
|
for route in router_dict["init_parameters"]["routes"]:
|
||||||
|
for key in route.keys():
|
||||||
|
assert isinstance(key, str)
|
||||||
|
assert isinstance(route[key], str)
|
||||||
|
|
||||||
|
new_router = ConditionalRouter.from_dict(router_dict)
|
||||||
|
assert router.routes == new_router.routes
|
||||||
|
|
||||||
|
# now use both routers with the same input
|
||||||
|
kwargs = {"streams": [1, 2, 3], "query": "Haystack"}
|
||||||
|
result1 = router.run(**kwargs)
|
||||||
|
result2 = new_router.run(**kwargs)
|
||||||
|
|
||||||
|
# check that the result is the same and correct
|
||||||
|
assert result1 == result2 and result1 == {"streams": [1, 2, 3]}
|
||||||
|
|
||||||
|
def test_router_de_serialization_user_type(self):
|
||||||
|
routes = [
|
||||||
|
{
|
||||||
|
"condition": "{{streams|length < 2}}",
|
||||||
|
"output": "{{message}}",
|
||||||
|
"output_type": ChatMessage,
|
||||||
|
"output_name": "message",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"condition": "{{streams|length >= 2}}",
|
||||||
|
"output": "{{streams}}",
|
||||||
|
"output_type": List[int],
|
||||||
|
"output_name": "streams",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
router = ConditionalRouter(routes)
|
||||||
|
router_dict = router.to_dict()
|
||||||
|
|
||||||
|
# assert that the router dict is correct, with all keys and values being strings
|
||||||
|
for route in router_dict["init_parameters"]["routes"]:
|
||||||
|
for key in route.keys():
|
||||||
|
assert isinstance(key, str)
|
||||||
|
assert isinstance(route[key], str)
|
||||||
|
|
||||||
|
# check that the output_type is a string and a proper class name
|
||||||
|
assert (
|
||||||
|
router_dict["init_parameters"]["routes"][0]["output_type"]
|
||||||
|
== "haystack.preview.dataclasses.chat_message.ChatMessage"
|
||||||
|
)
|
||||||
|
|
||||||
|
# deserialize the router
|
||||||
|
new_router = ConditionalRouter.from_dict(router_dict)
|
||||||
|
|
||||||
|
# check that the output_type is the right class
|
||||||
|
assert new_router.routes[0]["output_type"] == ChatMessage
|
||||||
|
assert router.routes == new_router.routes
|
||||||
|
|
||||||
|
# now use both routers to run the same message
|
||||||
|
message = ChatMessage.from_user("ciao")
|
||||||
|
kwargs = {"streams": [1], "message": message}
|
||||||
|
result1 = router.run(**kwargs)
|
||||||
|
result2 = new_router.run(**kwargs)
|
||||||
|
|
||||||
|
# check that the result is the same and correct
|
||||||
|
assert result1 == result2 and result1["message"].content == message.content
|
||||||
|
|
||||||
|
def test_router_serialization_idempotence(self):
|
||||||
|
routes = [
|
||||||
|
{
|
||||||
|
"condition": "{{streams|length < 2}}",
|
||||||
|
"output": "{{message}}",
|
||||||
|
"output_type": ChatMessage,
|
||||||
|
"output_name": "message",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"condition": "{{streams|length >= 2}}",
|
||||||
|
"output": "{{streams}}",
|
||||||
|
"output_type": List[int],
|
||||||
|
"output_name": "streams",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
router = ConditionalRouter(routes)
|
||||||
|
# invoke to_dict twice and check that the result is the same
|
||||||
|
router_dict_first_invocation = copy.deepcopy(router.to_dict())
|
||||||
|
router_dict_second_invocation = router.to_dict()
|
||||||
|
assert router_dict_first_invocation == router_dict_second_invocation
|
||||||
Loading…
x
Reference in New Issue
Block a user