haystack/test/components/routers/test_llm_messages_router.py
Stefano Fiorucci 0d0a66b4f5
feat: add LLMMessagesRouter, a component to route Chat Messages using LLMs (#9540)
* llmmessagesrouter - draft

* serde methods

* refinements, tests and release note

* Apply suggestions from code review

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>

---------

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
2025-06-24 14:54:20 +02:00

214 lines
8.0 KiB
Python

# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import re
import pytest
from typing import Any, Dict, List
from unittest.mock import Mock
from haystack.components.routers.llm_messages_router import LLMMessagesRouter
from haystack.dataclasses import ChatMessage
from haystack.components.generators.chat.openai import OpenAIChatGenerator
from haystack.core.serialization import default_to_dict, default_from_dict
import os
class MockChatGenerator:
def __init__(self, return_text: str = "safe"):
self.return_text = return_text
def run(self, messages: List[ChatMessage]) -> Dict[str, Any]:
return {"replies": [ChatMessage.from_assistant(self.return_text)]}
def to_dict(self) -> Dict[str, Any]:
return default_to_dict(self, return_text=self.return_text)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MockChatGenerator":
return default_from_dict(cls, data)
class TestLLMMessagesRouter:
def test_init(self):
system_prompt = "Classify the messages as safe or unsafe."
chat_generator = MockChatGenerator()
router = LLMMessagesRouter(
chat_generator=chat_generator,
system_prompt=system_prompt,
output_names=["safe", "unsafe"],
output_patterns=["safe", "unsafe"],
)
assert router._chat_generator is chat_generator
assert router._system_prompt == system_prompt
assert router._output_names == ["safe", "unsafe"]
assert router._output_patterns == ["safe", "unsafe"]
assert router._compiled_patterns == [re.compile(pattern) for pattern in ["safe", "unsafe"]]
assert router._is_warmed_up is False
def test_init_errors(self):
chat_generator = MockChatGenerator()
with pytest.raises(ValueError):
LLMMessagesRouter(chat_generator=chat_generator, output_names=[], output_patterns=["pattern1", "pattern2"])
with pytest.raises(ValueError):
LLMMessagesRouter(chat_generator=chat_generator, output_names=["name1", "name2"], output_patterns=[])
with pytest.raises(ValueError):
LLMMessagesRouter(
chat_generator=chat_generator, output_names=["name1", "name2"], output_patterns=["pattern1"]
)
def test_warm_up_with_unwarmable_chat_generator(self):
chat_generator = MockChatGenerator()
router = LLMMessagesRouter(
chat_generator=chat_generator, output_names=["safe", "unsafe"], output_patterns=["safe", "unsafe"]
)
router.warm_up()
assert router._is_warmed_up is True
def test_warm_up_with_warmable_chat_generator(self):
chat_generator = Mock()
router = LLMMessagesRouter(
chat_generator=chat_generator, output_names=["safe", "unsafe"], output_patterns=["safe", "unsafe"]
)
router.warm_up()
assert router._is_warmed_up is True
assert router._chat_generator.warm_up.call_count == 1
def test_run_input_errors(self):
router = LLMMessagesRouter(
chat_generator=MockChatGenerator(), output_names=["safe", "unsafe"], output_patterns=["safe", "unsafe"]
)
with pytest.raises(ValueError):
router.run([])
with pytest.raises(ValueError):
router.run([ChatMessage.from_system("You are a helpful assistant.")])
def test_run_no_warm_up_with_unwarmable_chat_generator(self):
router = LLMMessagesRouter(
chat_generator=MockChatGenerator(), output_names=["safe", "unsafe"], output_patterns=["safe", "unsafe"]
)
router.run([ChatMessage.from_user("Hello")])
def test_run_no_warm_up_with_warmable_chat_generator(self):
chat_generator = Mock()
router = LLMMessagesRouter(
chat_generator=chat_generator, output_names=["safe", "unsafe"], output_patterns=["safe", "unsafe"]
)
with pytest.raises(RuntimeError):
router.run([ChatMessage.from_user("Hello")])
def test_run(self):
router = LLMMessagesRouter(
chat_generator=MockChatGenerator(return_text="safe"),
output_names=["safe", "unsafe"],
output_patterns=["safe", "unsafe"],
)
messages = [ChatMessage.from_user("Hello")]
result = router.run(messages)
assert result["chat_generator_text"] == "safe"
assert result["safe"] == messages
assert "unsafe" not in result
assert "unmatched" not in result
def test_run_with_system_prompt(self):
chat_generator = Mock()
chat_generator.run.return_value = {"replies": [ChatMessage.from_assistant("safe")]}
system_prompt = "Classify the messages as safe or unsafe."
router = LLMMessagesRouter(
chat_generator=chat_generator,
output_names=["safe", "unsafe"],
output_patterns=["safe", "unsafe"],
system_prompt=system_prompt,
)
router.warm_up()
messages = [ChatMessage.from_user("Hello")]
router.run(messages)
chat_generator.run.assert_called_once_with(messages=[ChatMessage.from_system(system_prompt)] + messages)
def test_run_unmatched_output(self):
router = LLMMessagesRouter(
chat_generator=MockChatGenerator(return_text="irrelevant"),
output_names=["safe", "unsafe"],
output_patterns=["safe", "unsafe"],
)
messages = [ChatMessage.from_user("Hello")]
result = router.run(messages)
assert result["chat_generator_text"] == "irrelevant"
assert result["unmatched"] == messages
assert "safe" not in result
assert "unsafe" not in result
def test_to_dict(self):
chat_generator = MockChatGenerator(return_text="safe")
router = LLMMessagesRouter(
chat_generator=chat_generator, output_names=["safe", "unsafe"], output_patterns=["safe", "unsafe"]
)
result = router.to_dict()
assert result["type"] == "haystack.components.routers.llm_messages_router.LLMMessagesRouter"
assert result["init_parameters"]["chat_generator"] == chat_generator.to_dict()
assert result["init_parameters"]["output_names"] == ["safe", "unsafe"]
assert result["init_parameters"]["output_patterns"] == ["safe", "unsafe"]
assert result["init_parameters"]["system_prompt"] is None
def test_from_dict(self):
chat_generator = MockChatGenerator(return_text="safe")
data = {
"type": "haystack.components.routers.llm_messages_router.LLMMessagesRouter",
"init_parameters": {
"chat_generator": chat_generator.to_dict(),
"output_names": ["safe", "unsafe"],
"output_patterns": ["safe", "unsafe"],
"system_prompt": None,
},
}
router = LLMMessagesRouter.from_dict(data)
assert router._chat_generator.to_dict() == chat_generator.to_dict()
assert router._output_names == ["safe", "unsafe"]
assert router._output_patterns == ["safe", "unsafe"]
assert router._system_prompt is None
@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_live_run(self):
system_prompt = "Classify the messages into safe or unsafe. Respond with the label only, no other text."
router = LLMMessagesRouter(
chat_generator=OpenAIChatGenerator(model="gpt-4.1-mini"),
system_prompt=system_prompt,
output_names=["safe", "unsafe"],
output_patterns=[r"(?i)safe", r"(?i)unsafe"],
)
messages = [ChatMessage.from_user("Hello")]
result = router.run(messages)
print(result)
assert result["safe"] == messages
assert result["chat_generator_text"].lower() == "safe"
assert "unsafe" not in result
assert "unmatched" not in result