mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-29 08:26:19 +00:00
feat: Add support for multiple outputs in ConditionalRouter (#9271)
* feat: Add support for multiple outputs in ConditionalRouter * Update haystack/components/routers/conditional_router.py Co-authored-by: Sebastian Husch Lee <sjrl@users.noreply.github.com> * add additional route --------- Co-authored-by: Sebastian Husch Lee <sjrl@users.noreply.github.com>
This commit is contained in:
parent
4a908d075e
commit
f97472329f
@ -203,11 +203,19 @@ class ConditionalRouter:
|
||||
|
||||
for route in routes:
|
||||
# extract inputs
|
||||
route_input_names = self._extract_variables(self._env, [route["output"], route["condition"]])
|
||||
route_input_names = self._extract_variables(
|
||||
self._env,
|
||||
[route["condition"]] + (route["output"] if isinstance(route["output"], list) else [route["output"]]),
|
||||
)
|
||||
input_types.update(route_input_names)
|
||||
|
||||
# extract outputs
|
||||
output_types.update({route["output_name"]: route["output_type"]})
|
||||
output_names = route["output_name"] if isinstance(route["output_name"], list) else [route["output_name"]]
|
||||
output_types_list = (
|
||||
route["output_type"] if isinstance(route["output_type"], list) else [route["output_type"]]
|
||||
)
|
||||
|
||||
output_types.update(dict(zip(output_names, output_types_list)))
|
||||
|
||||
# remove optional variables from mandatory input types
|
||||
mandatory_input_types = input_types - set(self.optional_variables)
|
||||
@ -306,27 +314,45 @@ class ConditionalRouter:
|
||||
rendered = ast.literal_eval(rendered)
|
||||
if not rendered:
|
||||
continue
|
||||
# We now evaluate the `output` expression to determine the route output
|
||||
t_output = self._env.from_string(route["output"])
|
||||
output = t_output.render(**kwargs)
|
||||
|
||||
# Handle multiple outputs
|
||||
outputs = route["output"] if isinstance(route["output"], list) else [route["output"]]
|
||||
output_types = (
|
||||
route["output_type"] if isinstance(route["output_type"], list) else [route["output_type"]]
|
||||
)
|
||||
output_names = (
|
||||
route["output_name"] if isinstance(route["output_name"], list) else [route["output_name"]]
|
||||
)
|
||||
|
||||
result = {}
|
||||
for output, output_type, output_name in zip(outputs, output_types, output_names):
|
||||
# Evaluate output template
|
||||
t_output = self._env.from_string(output)
|
||||
output_value = t_output.render(**kwargs)
|
||||
|
||||
# We suppress the exception in case the output is already a string, otherwise
|
||||
# we try to evaluate it and would fail.
|
||||
# This must be done cause the output could be different literal structures.
|
||||
# This doesn't support any user types.
|
||||
with contextlib.suppress(Exception):
|
||||
if not self._unsafe:
|
||||
output = ast.literal_eval(output)
|
||||
output_value = ast.literal_eval(output_value)
|
||||
|
||||
# Validate output type if needed
|
||||
if self._validate_output_type and not self._output_matches_type(output_value, output_type):
|
||||
raise ValueError(f"Route '{output_name}' type doesn't match expected type")
|
||||
|
||||
result[output_name] = output_value
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# If this was a type‐validation failure, let it propagate as a ValueError
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
msg = f"Error evaluating condition for route '{route}': {e}"
|
||||
raise RouteConditionException(msg) from e
|
||||
|
||||
if self._validate_output_type and not self._output_matches_type(output, route["output_type"]):
|
||||
msg = f"""Route '{route["output_name"]}' type doesn't match expected type"""
|
||||
raise ValueError(msg)
|
||||
|
||||
# and return the output as a dictionary under the output_name key
|
||||
return {route["output_name"]: output}
|
||||
|
||||
raise NoRouteSelectedException(f"No route fired. Routes: {self.routes}")
|
||||
|
||||
def _validate_routes(self, routes: List[Dict]):
|
||||
@ -347,9 +373,23 @@ class ConditionalRouter:
|
||||
raise ValueError(
|
||||
f"Route must contain 'condition', 'output', 'output_type' and 'output_name' fields: {route}"
|
||||
)
|
||||
for field in ["condition", "output"]:
|
||||
if not self._validate_template(self._env, route[field]):
|
||||
raise ValueError(f"Invalid template for field '{field}': {route[field]}")
|
||||
|
||||
# Validate outputs are consistent
|
||||
outputs = route["output"] if isinstance(route["output"], list) else [route["output"]]
|
||||
output_types = route["output_type"] if isinstance(route["output_type"], list) else [route["output_type"]]
|
||||
output_names = route["output_name"] if isinstance(route["output_name"], list) else [route["output_name"]]
|
||||
|
||||
# Check lengths match
|
||||
if not len(outputs) == len(output_types) == len(output_names):
|
||||
raise ValueError(f"Route output, output_type and output_name must have same length: {route}")
|
||||
|
||||
# Validate templates
|
||||
if not self._validate_template(self._env, route["condition"]):
|
||||
raise ValueError(f"Invalid template for condition: {route['condition']}")
|
||||
|
||||
for output in outputs:
|
||||
if not self._validate_template(self._env, output):
|
||||
raise ValueError(f"Invalid template for output: {output}")
|
||||
|
||||
def _extract_variables(self, env: Environment, templates: List[str]) -> Set[str]:
|
||||
"""
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Add support for multiple outputs in ConditionalRouter
|
||||
@ -574,3 +574,44 @@ class TestRouter:
|
||||
assert new_router.routes == router.routes
|
||||
assert new_router.routes[0]["output_type"] is str
|
||||
assert new_router.routes[0]["output_type"] is original_output_type
|
||||
|
||||
def test_multiple_outputs_per_route(self):
|
||||
"""Test that router handles multiple outputs per route correctly"""
|
||||
routes = [
|
||||
{
|
||||
"condition": "{{streams|length >= 2}}",
|
||||
"output": ["{{streams}}", "{{query}}"],
|
||||
"output_type": [List[int], str],
|
||||
"output_name": ["streams", "query"],
|
||||
},
|
||||
{
|
||||
"condition": "{{streams|length < 2}}",
|
||||
"output": ["{{streams}}", "{{custom_error_message}}"],
|
||||
"output_type": [List[int], str],
|
||||
"output_name": ["streams", "custom_error_message"],
|
||||
},
|
||||
]
|
||||
router = ConditionalRouter(routes)
|
||||
|
||||
# Test with sufficient input streams
|
||||
result = router.run(streams=[1, 2, 3], query="test_1", custom_error_message="Not enough streams")
|
||||
assert result == {"streams": [1, 2, 3], "query": "test_1"}
|
||||
|
||||
# Test with insufficient input streams
|
||||
result = router.run(streams=[1], query="test_2", custom_error_message="Not enough streams")
|
||||
assert result == {"streams": [1], "custom_error_message": "Not enough streams"}
|
||||
|
||||
def test_multiple_outputs_validation(self):
|
||||
"""Test validation of routes with multiple outputs"""
|
||||
# Test mismatched lengths
|
||||
with pytest.raises(ValueError, match="must have same length"):
|
||||
ConditionalRouter(
|
||||
[
|
||||
{
|
||||
"condition": "{{streams|length >= 2}}",
|
||||
"output": ["{{streams}}", "{{query}}"],
|
||||
"output_type": [List[int]],
|
||||
"output_name": ["streams"],
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user