feat: Add ConditionalRouter Haystack 2.x component (#6147)

Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
This commit is contained in:
Vladimir Blagojevic 2023-11-23 10:28:08 +01:00 committed by GitHub
parent 70e40eae5c
commit b557f3035e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 679 additions and 2 deletions

View File

@ -1,7 +1,7 @@
from haystack.preview.components.routers.document_joiner import DocumentJoiner
from haystack.preview.components.routers.file_type_router import FileTypeRouter
from haystack.preview.components.routers.metadata_router import MetadataRouter
from haystack.preview.components.routers.conditional_router import ConditionalRouter
from haystack.preview.components.routers.text_language_router import TextLanguageRouter
__all__ = ["DocumentJoiner", "FileTypeRouter", "MetadataRouter", "TextLanguageRouter"]
__all__ = ["DocumentJoiner", "FileTypeRouter", "MetadataRouter", "TextLanguageRouter", "ConditionalRouter"]

View File

@ -0,0 +1,347 @@
import importlib
import inspect
import logging
import sys
from typing import List, Dict, Any, Set, get_origin
from jinja2 import meta, Environment, TemplateSyntaxError
from jinja2.nativetypes import NativeEnvironment
from haystack.preview import component, default_from_dict, default_to_dict, DeserializationError
logger = logging.getLogger(__name__)
class NoRouteSelectedException(Exception):
"""Exception raised when no route is selected in ConditionalRouter."""
class RouteConditionException(Exception):
"""Exception raised when there is an error parsing or evaluating the condition expression in ConditionalRouter."""
def serialize_type(target: Any) -> str:
"""
Serializes a type or an instance to its string representation, including the module name.
This function handles types, instances of types, and special typing objects.
It assumes that non-typing objects will have a '__name__' attribute and raises
an error if a type cannot be serialized.
:param target: The object to serialize, can be an instance or a type.
:type target: Any
:return: The string representation of the type.
:raises ValueError: If the type cannot be serialized.
"""
# If the target is a string and contains a dot, treat it as an already serialized type
if isinstance(target, str) and "." in target:
return target
# Determine if the target is a type or an instance of a typing object
is_type_or_typing = isinstance(target, type) or bool(get_origin(target))
type_obj = target if is_type_or_typing else type(target)
module = inspect.getmodule(type_obj)
type_obj_repr = repr(type_obj)
if type_obj_repr.startswith("typing."):
# e.g., typing.List[int] -> List[int], we'll add the module below
type_name = type_obj_repr.split(".", 1)[1]
elif hasattr(type_obj, "__name__"):
type_name = type_obj.__name__
else:
# If type cannot be serialized, raise an error
raise ValueError(f"Could not serialize type: {type_obj_repr}")
# Construct the full path with module name if available
if module and hasattr(module, "__name__"):
if module.__name__ == "builtins":
# omit the module name for builtins, it just clutters the output
# e.g. instead of 'builtins.str', we'll just return 'str'
full_path = type_name
else:
full_path = f"{module.__name__}.{type_name}"
else:
full_path = type_name
return full_path
def deserialize_type(type_str: str) -> Any:
"""
Deserializes a type given its full import path as a string, including nested generic types.
This function will dynamically import the module if it's not already imported
and then retrieve the type object from it. It also handles nested generic types like 'typing.List[typing.Dict[int, str]]'.
:param type_str: The string representation of the type's full import path.
:return: The deserialized type object.
:raises DeserializationError: If the type cannot be deserialized due to missing module or type.
"""
def parse_generic_args(args_str):
args = []
bracket_count = 0
current_arg = ""
for char in args_str:
if char == "[":
bracket_count += 1
elif char == "]":
bracket_count -= 1
if char == "," and bracket_count == 0:
args.append(current_arg.strip())
current_arg = ""
else:
current_arg += char
if current_arg:
args.append(current_arg.strip())
return args
if "[" in type_str and type_str.endswith("]"):
# Handle generics
main_type_str, generics_str = type_str.split("[", 1)
generics_str = generics_str[:-1]
main_type = deserialize_type(main_type_str)
generic_args = tuple(deserialize_type(arg) for arg in parse_generic_args(generics_str))
# Reconstruct
return main_type[generic_args]
else:
# Handle non-generics
parts = type_str.split(".")
module_name = ".".join(parts[:-1]) or "builtins"
type_name = parts[-1]
module = sys.modules.get(module_name)
if not module:
try:
module = importlib.import_module(module_name)
except ImportError as e:
raise DeserializationError(f"Could not import the module: {module_name}") from e
deserialized_type = getattr(module, type_name, None)
if not deserialized_type:
raise DeserializationError(f"Could not locate the type: {type_name} in the module: {module_name}")
return deserialized_type
@component
class ConditionalRouter:
"""
ConditionalRouter in Haystack 2.x pipelines is designed to manage data routing based on specific conditions.
This is achieved by defining a list named 'routes'. Each element in this list is a dictionary representing a
single route.
A route dictionary comprises four key elements:
- 'condition': A Jinja2 string expression that determines if the route is selected.
- 'output': A Jinja2 expression defining the route's output value.
- 'output_type': The type of the output data (e.g., str, List[int]).
- 'output_name': The name under which the `output` value of the route is published. This name is used to connect
the router to other components in the pipeline.
Here's an example:
```python
from haystack.preview.components.routers import ConditionalRouter
routes = [
{
"condition": "{{streams|length > 2}}",
"output": "{{streams}}",
"output_name": "enough_streams",
"output_type": List[int],
},
{
"condition": "{{streams|length <= 2}}",
"output": "{{streams}}",
"output_name": "insufficient_streams",
"output_type": List[int],
},
]
router = ConditionalRouter(routes)
# When 'streams' has more than 2 items, 'enough_streams' output will activate, emitting the list [1, 2, 3]
kwargs = {"streams": [1, 2, 3], "query": "Haystack"}
result = router.run(**kwargs)
assert result == {"enough_streams": [1, 2, 3]}
```
In this example, we configure two routes. The first route sends the 'streams' value to 'enough_streams' if the
stream count exceeds two. Conversely, the second route directs 'streams' to 'insufficient_streams' when there
are two or fewer streams.
In the pipeline setup, the router is connected to other components using the output names. For example, the
'enough_streams' output might be connected to another component that processes the streams, while the
'insufficient_streams' output might be connected to a component that fetches more streams, and so on.
Here is a pseudocode example of a pipeline that uses the ConditionalRouter and routes fetched ByteStreams to
different components depending on the number of streams fetched:
```
from typing import List
from haystack import Pipeline
from haystack.preview.dataclasses import ByteStream
from haystack.preview.components.routers import ConditionalRouter
routes = [
{
"condition": "{{streams|length > 2}}",
"output": "{{streams}}",
"output_name": "enough_streams",
"output_type": List[ByteStream],
},
{
"condition": "{{streams|length <= 2}}",
"output": "{{streams}}",
"output_name": "insufficient_streams",
"output_type": List[ByteStream],
},
]
pipe = Pipeline()
pipe.add_component("router", router)
...
pipe.connect("router.enough_streams", "some_component_a.streams")
pipe.connect("router.insufficient_streams", "some_component_b.streams_or_some_other_input")
...
```
"""
def __init__(self, routes: List[Dict]):
"""
Initializes the ConditionalRouter with a list of routes detailing the conditions for routing.
:param routes: A list of dictionaries, each defining a route with a boolean condition expression
('condition'), an output value ('output'), the output type ('output_type') and
('output_name') that defines the output name for the variable defined in 'output'.
"""
self._validate_routes(routes)
self.routes: List[dict] = routes
# Create a Jinja native environment to inspect variables in the condition templates
env = NativeEnvironment()
# Inspect the routes to determine input and output types.
input_types: Set[str] = set() # let's just store the name, type will always be Any
output_types: Dict[str, str] = {}
for route in routes:
# extract inputs
route_input_names = self._extract_variables(env, [route["output"], route["condition"]])
input_types.update(route_input_names)
# extract outputs
output_types.update({route["output_name"]: route["output_type"]})
component.set_input_types(self, **{var: Any for var in input_types})
component.set_output_types(self, **output_types)
def to_dict(self) -> Dict[str, Any]:
for route in self.routes:
# output_type needs to be serialized to a string
route["output_type"] = serialize_type(route["output_type"])
return default_to_dict(self, routes=self.routes)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ConditionalRouter":
init_params = data.get("init_parameters", {})
routes = init_params.get("routes")
for route in routes:
# output_type needs to be deserialized from a string to a type
route["output_type"] = deserialize_type(route["output_type"])
return default_from_dict(cls, data)
def run(self, **kwargs):
"""
Executes the routing logic by evaluating the specified boolean condition expressions
for each route in the order they are listed. The method directs the flow
of data to the output specified in the first route, whose expression
evaluates to True. If no route's expression evaluates to True, an exception
is raised.
:param kwargs: A dictionary containing the pipeline variables, which should
include all variables used in the "condition" templates.
:return: A dictionary containing the output and the corresponding result,
based on the first route whose expression evaluates to True.
:raises NoRouteSelectedException: If no route's expression evaluates to True.
"""
# Create a Jinja native environment to evaluate the condition templates as Python expressions
env = NativeEnvironment()
for route in self.routes:
try:
t = env.from_string(route["condition"])
if t.render(**kwargs):
# We now evaluate the `output` expression to determine the route output
t_output = env.from_string(route["output"])
output = t_output.render(**kwargs)
# and return the output as a dictionary under the output_name key
return {route["output_name"]: output}
except Exception as e:
raise RouteConditionException(f"Error evaluating condition for route '{route}': {e}") from e
raise NoRouteSelectedException(f"No route fired. Routes: {self.routes}")
def _validate_routes(self, routes: List[Dict]):
"""
Validates a list of routes.
:param routes: A list of routes.
:type routes: List[Dict]
"""
env = NativeEnvironment()
for route in routes:
try:
keys = set(route.keys())
except AttributeError:
raise ValueError(f"Route must be a dictionary, got: {route}")
mandatory_fields = {"condition", "output", "output_type", "output_name"}
has_all_mandatory_fields = mandatory_fields.issubset(keys)
if not has_all_mandatory_fields:
raise ValueError(
f"Route must contain 'condition', 'output', 'output_type' and 'output_name' fields: {route}"
)
for field in ["condition", "output"]:
if not self._validate_template(env, route[field]):
raise ValueError(f"Invalid template for field '{field}': {route[field]}")
def _extract_variables(self, env: NativeEnvironment, templates: List[str]) -> Set[str]:
"""
Extracts all variables from a list of Jinja template strings.
:param env: A Jinja environment.
:type env: Environment
:param templates: A list of Jinja template strings.
:type templates: List[str]
:return: A set of variable names.
"""
variables = set()
for template in templates:
ast = env.parse(template)
variables.update(meta.find_undeclared_variables(ast))
return variables
def _validate_template(self, env: Environment, template_text: str):
"""
Validates a template string by parsing it with Jinja.
:param env: A Jinja environment.
:type env: Environment
:param template_text: A Jinja template string.
:type template_text: str
:return: True if the template is valid, False otherwise.
"""
try:
env.parse(template_text)
return True
except TemplateSyntaxError:
return False

View File

@ -0,0 +1,6 @@
---
preview:
- |
Add `ConditionalRouter` component to enhance the conditional pipeline routing capabilities.
The `ConditionalRouter` component orchestrates the flow of data by evaluating specified route conditions
to determine the appropriate route among a set of provided route alternatives.

View File

@ -0,0 +1,324 @@
import copy
import typing
from typing import List, Dict
from unittest import mock
import pytest
from haystack.preview.components.routers import ConditionalRouter
from haystack.preview.components.routers.conditional_router import (
NoRouteSelectedException,
serialize_type,
deserialize_type,
)
from haystack.preview.dataclasses import ChatMessage
class TestRouter:
@pytest.fixture
def routes(self):
return [
{"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str, "output_name": "query"},
{
"condition": "{{streams|length >= 2}}",
"output": "{{streams}}",
"output_type": List[int],
"output_name": "streams",
},
]
@pytest.fixture
def router(self, routes):
return ConditionalRouter(routes)
def test_missing_mandatory_fields(self):
"""
Router raises a ValueError if each route does not contain 'condition', 'output', and 'output_type' keys
"""
routes = [
{"condition": "{{streams|length < 2}}", "output": "{{query}}"},
{"condition": "{{streams|length < 2}}", "output_type": str},
]
with pytest.raises(ValueError):
ConditionalRouter(routes)
def test_invalid_condition_field(self):
"""
ConditionalRouter init raises a ValueError if one of the routes contains invalid condition
"""
# invalid condition field
routes = [{"condition": "{{streams|length < 2", "output": "query", "output_type": str, "output_name": "test"}]
with pytest.raises(ValueError, match="Invalid template"):
ConditionalRouter(routes)
def test_no_vars_in_output_route_but_with_output_name(self):
"""
Router can't accept a route with no variables used in the output field
"""
routes = [
{
"condition": "{{streams|length > 2}}",
"output": "This is a constant",
"output_name": "enough_streams",
"output_type": str,
}
]
router = ConditionalRouter(routes)
kwargs = {"streams": [1, 2, 3], "query": "Haystack"}
result = router.run(**kwargs)
assert result == {"enough_streams": "This is a constant"}
def test_mandatory_and_optional_fields_with_extra_fields(self):
"""
Router accepts a list of routes with mandatory and optional fields but not if some new field is added
"""
routes = [
{
"condition": "{{streams|length < 2}}",
"output": "{{query}}",
"output_type": str,
"output_name": "test",
"bla": "bla",
},
{"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str},
]
with pytest.raises(ValueError):
ConditionalRouter(routes)
def test_router_initialized(self, routes):
router = ConditionalRouter(routes)
assert router.routes == routes
assert set(router.__canals_input__.keys()) == {"query", "streams"}
assert set(router.__canals_output__.keys()) == {"query", "streams"}
def test_router_evaluate_condition_expressions(self, router):
# first route should be selected
kwargs = {"streams": [1, 2, 3], "query": "test"}
result = router.run(**kwargs)
assert result == {"streams": [1, 2, 3]}
# second route should be selected
kwargs = {"streams": [1], "query": "test"}
result = router.run(**kwargs)
assert result == {"query": "test"}
def test_router_evaluate_condition_expressions_using_output_slot(self):
routes = [
{
"condition": "{{streams|length > 2}}",
"output": "{{streams}}",
"output_name": "enough_streams",
"output_type": List[int],
},
{
"condition": "{{streams|length <= 2}}",
"output": "{{streams}}",
"output_name": "insufficient_streams",
"output_type": List[int],
},
]
router = ConditionalRouter(routes)
# enough_streams output slot will be selected with [1, 2, 3] list being outputted
kwargs = {"streams": [1, 2, 3], "query": "Haystack"}
result = router.run(**kwargs)
assert result == {"enough_streams": [1, 2, 3]}
def test_complex_condition(self):
routes = [
{
"condition": "{{messages[-1].metadata.finish_reason == 'function_call'}}",
"output": "{{streams}}",
"output_type": List[int],
"output_name": "streams",
},
{
"condition": "{{True}}",
"output": "{{query}}",
"output_type": str,
"output_name": "query",
}, # catch-all condition
]
router = ConditionalRouter(routes)
message = mock.MagicMock()
message.metadata.finish_reason = "function_call"
result = router.run(messages=[message], streams=[1, 2, 3], query="my query")
assert result == {"streams": [1, 2, 3]}
def test_router_no_route(self, router):
# should raise an exception
router = ConditionalRouter(
[
{
"condition": "{{streams|length < 2}}",
"output": "{{query}}",
"output_type": str,
"output_name": "query",
},
{
"condition": "{{streams|length >= 5}}",
"output": "{{streams}}",
"output_type": List[int],
"output_name": "streams",
},
]
)
kwargs = {"streams": [1, 2, 3], "query": "test"}
with pytest.raises(NoRouteSelectedException):
router.run(**kwargs)
def test_router_raises_value_error_if_route_not_dictionary(self):
"""
Router raises a ValueError if each route is not a dictionary
"""
routes = [
{"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str, "output_name": "query"},
["{{streams|length >= 2}}", "streams", List[int]],
]
with pytest.raises(ValueError):
ConditionalRouter(routes)
def test_router_raises_value_error_if_route_missing_keys(self):
"""
Router raises a ValueError if each route does not contain 'condition', 'output', and 'output_type' keys
"""
routes = [
{"condition": "{{streams|length < 2}}", "output": "{{query}}"},
{"condition": "{{streams|length < 2}}", "output_type": str},
]
with pytest.raises(ValueError):
ConditionalRouter(routes)
def test_output_type_serialization(self):
assert serialize_type(str) == "str"
assert serialize_type(List[int]) == "typing.List[int]"
assert serialize_type(List[Dict[str, int]]) == "typing.List[typing.Dict[str, int]]"
assert serialize_type(ChatMessage) == "haystack.preview.dataclasses.chat_message.ChatMessage"
assert serialize_type(typing.List[Dict[str, int]]) == "typing.List[typing.Dict[str, int]]"
assert serialize_type(List[ChatMessage]) == "typing.List[haystack.preview.dataclasses.chat_message.ChatMessage]"
assert (
serialize_type(typing.Dict[int, ChatMessage])
== "typing.Dict[int, haystack.preview.dataclasses.chat_message.ChatMessage]"
)
assert serialize_type(int) == "int"
assert serialize_type(ChatMessage.from_user("ciao")) == "haystack.preview.dataclasses.chat_message.ChatMessage"
def test_output_type_deserialization(self):
assert deserialize_type("str") == str
assert deserialize_type("typing.List[int]") == typing.List[int]
assert deserialize_type("typing.List[typing.Dict[str, int]]") == typing.List[Dict[str, int]]
assert deserialize_type("typing.Dict[str, int]") == Dict[str, int]
assert deserialize_type("typing.Dict[str, typing.List[int]]") == Dict[str, List[int]]
assert deserialize_type("typing.List[typing.Dict[str, typing.List[int]]]") == List[Dict[str, List[int]]]
assert (
deserialize_type("typing.List[haystack.preview.dataclasses.chat_message.ChatMessage]")
== typing.List[ChatMessage]
)
assert (
deserialize_type("typing.Dict[int, haystack.preview.dataclasses.chat_message.ChatMessage]")
== typing.Dict[int, ChatMessage]
)
assert deserialize_type("haystack.preview.dataclasses.chat_message.ChatMessage") == ChatMessage
assert deserialize_type("int") == int
def test_router_de_serialization(self):
routes = [
{"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str, "output_name": "query"},
{
"condition": "{{streams|length >= 2}}",
"output": "{{streams}}",
"output_type": List[int],
"output_name": "streams",
},
]
router = ConditionalRouter(routes)
router_dict = router.to_dict()
# assert that the router dict is correct, with all keys and values being strings
for route in router_dict["init_parameters"]["routes"]:
for key in route.keys():
assert isinstance(key, str)
assert isinstance(route[key], str)
new_router = ConditionalRouter.from_dict(router_dict)
assert router.routes == new_router.routes
# now use both routers with the same input
kwargs = {"streams": [1, 2, 3], "query": "Haystack"}
result1 = router.run(**kwargs)
result2 = new_router.run(**kwargs)
# check that the result is the same and correct
assert result1 == result2 and result1 == {"streams": [1, 2, 3]}
def test_router_de_serialization_user_type(self):
routes = [
{
"condition": "{{streams|length < 2}}",
"output": "{{message}}",
"output_type": ChatMessage,
"output_name": "message",
},
{
"condition": "{{streams|length >= 2}}",
"output": "{{streams}}",
"output_type": List[int],
"output_name": "streams",
},
]
router = ConditionalRouter(routes)
router_dict = router.to_dict()
# assert that the router dict is correct, with all keys and values being strings
for route in router_dict["init_parameters"]["routes"]:
for key in route.keys():
assert isinstance(key, str)
assert isinstance(route[key], str)
# check that the output_type is a string and a proper class name
assert (
router_dict["init_parameters"]["routes"][0]["output_type"]
== "haystack.preview.dataclasses.chat_message.ChatMessage"
)
# deserialize the router
new_router = ConditionalRouter.from_dict(router_dict)
# check that the output_type is the right class
assert new_router.routes[0]["output_type"] == ChatMessage
assert router.routes == new_router.routes
# now use both routers to run the same message
message = ChatMessage.from_user("ciao")
kwargs = {"streams": [1], "message": message}
result1 = router.run(**kwargs)
result2 = new_router.run(**kwargs)
# check that the result is the same and correct
assert result1 == result2 and result1["message"].content == message.content
def test_router_serialization_idempotence(self):
routes = [
{
"condition": "{{streams|length < 2}}",
"output": "{{message}}",
"output_type": ChatMessage,
"output_name": "message",
},
{
"condition": "{{streams|length >= 2}}",
"output": "{{streams}}",
"output_type": List[int],
"output_name": "streams",
},
]
router = ConditionalRouter(routes)
# invoke to_dict twice and check that the result is the same
router_dict_first_invocation = copy.deepcopy(router.to_dict())
router_dict_second_invocation = router.to_dict()
assert router_dict_first_invocation == router_dict_second_invocation