feat: Add support for multiple outputs in ConditionalRouter (#9271)

* feat: Add support for multiple outputs in ConditionalRouter

* Update haystack/components/routers/conditional_router.py

Co-authored-by: Sebastian Husch Lee <sjrl@users.noreply.github.com>

* add additional route

---------

Co-authored-by: Sebastian Husch Lee <sjrl@users.noreply.github.com>
This commit is contained in:
Mohammed Abdul Razak Wahab 2025-04-24 19:47:06 +05:30 committed by GitHub
parent 4a908d075e
commit f97472329f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 107 additions and 22 deletions

View File

@ -203,11 +203,19 @@ class ConditionalRouter:
for route in routes:
# extract inputs
route_input_names = self._extract_variables(self._env, [route["output"], route["condition"]])
route_input_names = self._extract_variables(
self._env,
[route["condition"]] + (route["output"] if isinstance(route["output"], list) else [route["output"]]),
)
input_types.update(route_input_names)
# extract outputs
output_types.update({route["output_name"]: route["output_type"]})
output_names = route["output_name"] if isinstance(route["output_name"], list) else [route["output_name"]]
output_types_list = (
route["output_type"] if isinstance(route["output_type"], list) else [route["output_type"]]
)
output_types.update(dict(zip(output_names, output_types_list)))
# remove optional variables from mandatory input types
mandatory_input_types = input_types - set(self.optional_variables)
@ -306,27 +314,45 @@ class ConditionalRouter:
rendered = ast.literal_eval(rendered)
if not rendered:
continue
# We now evaluate the `output` expression to determine the route output
t_output = self._env.from_string(route["output"])
output = t_output.render(**kwargs)
# We suppress the exception in case the output is already a string, otherwise
# we try to evaluate it and would fail.
# This must be done cause the output could be different literal structures.
# This doesn't support any user types.
with contextlib.suppress(Exception):
if not self._unsafe:
output = ast.literal_eval(output)
# Handle multiple outputs
outputs = route["output"] if isinstance(route["output"], list) else [route["output"]]
output_types = (
route["output_type"] if isinstance(route["output_type"], list) else [route["output_type"]]
)
output_names = (
route["output_name"] if isinstance(route["output_name"], list) else [route["output_name"]]
)
result = {}
for output, output_type, output_name in zip(outputs, output_types, output_names):
# Evaluate output template
t_output = self._env.from_string(output)
output_value = t_output.render(**kwargs)
# We suppress the exception in case the output is already a string, otherwise
# we try to evaluate it and would fail.
# This must be done cause the output could be different literal structures.
# This doesn't support any user types.
with contextlib.suppress(Exception):
if not self._unsafe:
output_value = ast.literal_eval(output_value)
# Validate output type if needed
if self._validate_output_type and not self._output_matches_type(output_value, output_type):
raise ValueError(f"Route '{output_name}' type doesn't match expected type")
result[output_name] = output_value
return result
except Exception as e:
# If this was a typevalidation failure, let it propagate as a ValueError
if isinstance(e, ValueError):
raise
msg = f"Error evaluating condition for route '{route}': {e}"
raise RouteConditionException(msg) from e
if self._validate_output_type and not self._output_matches_type(output, route["output_type"]):
msg = f"""Route '{route["output_name"]}' type doesn't match expected type"""
raise ValueError(msg)
# and return the output as a dictionary under the output_name key
return {route["output_name"]: output}
raise NoRouteSelectedException(f"No route fired. Routes: {self.routes}")
def _validate_routes(self, routes: List[Dict]):
@ -347,9 +373,23 @@ class ConditionalRouter:
raise ValueError(
f"Route must contain 'condition', 'output', 'output_type' and 'output_name' fields: {route}"
)
for field in ["condition", "output"]:
if not self._validate_template(self._env, route[field]):
raise ValueError(f"Invalid template for field '{field}': {route[field]}")
# Validate outputs are consistent
outputs = route["output"] if isinstance(route["output"], list) else [route["output"]]
output_types = route["output_type"] if isinstance(route["output_type"], list) else [route["output_type"]]
output_names = route["output_name"] if isinstance(route["output_name"], list) else [route["output_name"]]
# Check lengths match
if not len(outputs) == len(output_types) == len(output_names):
raise ValueError(f"Route output, output_type and output_name must have same length: {route}")
# Validate templates
if not self._validate_template(self._env, route["condition"]):
raise ValueError(f"Invalid template for condition: {route['condition']}")
for output in outputs:
if not self._validate_template(self._env, output):
raise ValueError(f"Invalid template for output: {output}")
def _extract_variables(self, env: Environment, templates: List[str]) -> Set[str]:
"""

View File

@ -0,0 +1,4 @@
---
features:
- |
Add support for multiple outputs in ConditionalRouter

View File

@ -574,3 +574,44 @@ class TestRouter:
assert new_router.routes == router.routes
assert new_router.routes[0]["output_type"] is str
assert new_router.routes[0]["output_type"] is original_output_type
def test_multiple_outputs_per_route(self):
"""Test that router handles multiple outputs per route correctly"""
routes = [
{
"condition": "{{streams|length >= 2}}",
"output": ["{{streams}}", "{{query}}"],
"output_type": [List[int], str],
"output_name": ["streams", "query"],
},
{
"condition": "{{streams|length < 2}}",
"output": ["{{streams}}", "{{custom_error_message}}"],
"output_type": [List[int], str],
"output_name": ["streams", "custom_error_message"],
},
]
router = ConditionalRouter(routes)
# Test with sufficient input streams
result = router.run(streams=[1, 2, 3], query="test_1", custom_error_message="Not enough streams")
assert result == {"streams": [1, 2, 3], "query": "test_1"}
# Test with insufficient input streams
result = router.run(streams=[1], query="test_2", custom_error_message="Not enough streams")
assert result == {"streams": [1], "custom_error_message": "Not enough streams"}
def test_multiple_outputs_validation(self):
"""Test validation of routes with multiple outputs"""
# Test mismatched lengths
with pytest.raises(ValueError, match="must have same length"):
ConditionalRouter(
[
{
"condition": "{{streams|length >= 2}}",
"output": ["{{streams}}", "{{query}}"],
"output_type": [List[int]],
"output_name": ["streams"],
}
]
)