mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-17 18:43:58 +00:00
* llmmessagesrouter - draft * serde methods * refinements, tests and release note * Apply suggestions from code review Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> --------- Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
214 lines
8.0 KiB
Python
214 lines
8.0 KiB
Python
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import re
|
|
import pytest
|
|
from typing import Any, Dict, List
|
|
from unittest.mock import Mock
|
|
from haystack.components.routers.llm_messages_router import LLMMessagesRouter
|
|
from haystack.dataclasses import ChatMessage
|
|
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
|
from haystack.core.serialization import default_to_dict, default_from_dict
|
|
import os
|
|
|
|
|
|
class MockChatGenerator:
|
|
def __init__(self, return_text: str = "safe"):
|
|
self.return_text = return_text
|
|
|
|
def run(self, messages: List[ChatMessage]) -> Dict[str, Any]:
|
|
return {"replies": [ChatMessage.from_assistant(self.return_text)]}
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return default_to_dict(self, return_text=self.return_text)
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "MockChatGenerator":
|
|
return default_from_dict(cls, data)
|
|
|
|
|
|
class TestLLMMessagesRouter:
|
|
def test_init(self):
|
|
system_prompt = "Classify the messages as safe or unsafe."
|
|
chat_generator = MockChatGenerator()
|
|
|
|
router = LLMMessagesRouter(
|
|
chat_generator=chat_generator,
|
|
system_prompt=system_prompt,
|
|
output_names=["safe", "unsafe"],
|
|
output_patterns=["safe", "unsafe"],
|
|
)
|
|
|
|
assert router._chat_generator is chat_generator
|
|
assert router._system_prompt == system_prompt
|
|
assert router._output_names == ["safe", "unsafe"]
|
|
assert router._output_patterns == ["safe", "unsafe"]
|
|
assert router._compiled_patterns == [re.compile(pattern) for pattern in ["safe", "unsafe"]]
|
|
assert router._is_warmed_up is False
|
|
|
|
def test_init_errors(self):
|
|
chat_generator = MockChatGenerator()
|
|
|
|
with pytest.raises(ValueError):
|
|
LLMMessagesRouter(chat_generator=chat_generator, output_names=[], output_patterns=["pattern1", "pattern2"])
|
|
|
|
with pytest.raises(ValueError):
|
|
LLMMessagesRouter(chat_generator=chat_generator, output_names=["name1", "name2"], output_patterns=[])
|
|
|
|
with pytest.raises(ValueError):
|
|
LLMMessagesRouter(
|
|
chat_generator=chat_generator, output_names=["name1", "name2"], output_patterns=["pattern1"]
|
|
)
|
|
|
|
def test_warm_up_with_unwarmable_chat_generator(self):
|
|
chat_generator = MockChatGenerator()
|
|
router = LLMMessagesRouter(
|
|
chat_generator=chat_generator, output_names=["safe", "unsafe"], output_patterns=["safe", "unsafe"]
|
|
)
|
|
router.warm_up()
|
|
assert router._is_warmed_up is True
|
|
|
|
def test_warm_up_with_warmable_chat_generator(self):
|
|
chat_generator = Mock()
|
|
router = LLMMessagesRouter(
|
|
chat_generator=chat_generator, output_names=["safe", "unsafe"], output_patterns=["safe", "unsafe"]
|
|
)
|
|
router.warm_up()
|
|
assert router._is_warmed_up is True
|
|
assert router._chat_generator.warm_up.call_count == 1
|
|
|
|
def test_run_input_errors(self):
|
|
router = LLMMessagesRouter(
|
|
chat_generator=MockChatGenerator(), output_names=["safe", "unsafe"], output_patterns=["safe", "unsafe"]
|
|
)
|
|
|
|
with pytest.raises(ValueError):
|
|
router.run([])
|
|
|
|
with pytest.raises(ValueError):
|
|
router.run([ChatMessage.from_system("You are a helpful assistant.")])
|
|
|
|
def test_run_no_warm_up_with_unwarmable_chat_generator(self):
|
|
router = LLMMessagesRouter(
|
|
chat_generator=MockChatGenerator(), output_names=["safe", "unsafe"], output_patterns=["safe", "unsafe"]
|
|
)
|
|
|
|
router.run([ChatMessage.from_user("Hello")])
|
|
|
|
def test_run_no_warm_up_with_warmable_chat_generator(self):
|
|
chat_generator = Mock()
|
|
router = LLMMessagesRouter(
|
|
chat_generator=chat_generator, output_names=["safe", "unsafe"], output_patterns=["safe", "unsafe"]
|
|
)
|
|
|
|
with pytest.raises(RuntimeError):
|
|
router.run([ChatMessage.from_user("Hello")])
|
|
|
|
def test_run(self):
|
|
router = LLMMessagesRouter(
|
|
chat_generator=MockChatGenerator(return_text="safe"),
|
|
output_names=["safe", "unsafe"],
|
|
output_patterns=["safe", "unsafe"],
|
|
)
|
|
|
|
messages = [ChatMessage.from_user("Hello")]
|
|
result = router.run(messages)
|
|
|
|
assert result["chat_generator_text"] == "safe"
|
|
assert result["safe"] == messages
|
|
assert "unsafe" not in result
|
|
assert "unmatched" not in result
|
|
|
|
def test_run_with_system_prompt(self):
|
|
chat_generator = Mock()
|
|
chat_generator.run.return_value = {"replies": [ChatMessage.from_assistant("safe")]}
|
|
|
|
system_prompt = "Classify the messages as safe or unsafe."
|
|
|
|
router = LLMMessagesRouter(
|
|
chat_generator=chat_generator,
|
|
output_names=["safe", "unsafe"],
|
|
output_patterns=["safe", "unsafe"],
|
|
system_prompt=system_prompt,
|
|
)
|
|
router.warm_up()
|
|
|
|
messages = [ChatMessage.from_user("Hello")]
|
|
router.run(messages)
|
|
|
|
chat_generator.run.assert_called_once_with(messages=[ChatMessage.from_system(system_prompt)] + messages)
|
|
|
|
def test_run_unmatched_output(self):
|
|
router = LLMMessagesRouter(
|
|
chat_generator=MockChatGenerator(return_text="irrelevant"),
|
|
output_names=["safe", "unsafe"],
|
|
output_patterns=["safe", "unsafe"],
|
|
)
|
|
|
|
messages = [ChatMessage.from_user("Hello")]
|
|
result = router.run(messages)
|
|
|
|
assert result["chat_generator_text"] == "irrelevant"
|
|
assert result["unmatched"] == messages
|
|
assert "safe" not in result
|
|
assert "unsafe" not in result
|
|
|
|
def test_to_dict(self):
|
|
chat_generator = MockChatGenerator(return_text="safe")
|
|
|
|
router = LLMMessagesRouter(
|
|
chat_generator=chat_generator, output_names=["safe", "unsafe"], output_patterns=["safe", "unsafe"]
|
|
)
|
|
|
|
result = router.to_dict()
|
|
|
|
assert result["type"] == "haystack.components.routers.llm_messages_router.LLMMessagesRouter"
|
|
assert result["init_parameters"]["chat_generator"] == chat_generator.to_dict()
|
|
assert result["init_parameters"]["output_names"] == ["safe", "unsafe"]
|
|
assert result["init_parameters"]["output_patterns"] == ["safe", "unsafe"]
|
|
assert result["init_parameters"]["system_prompt"] is None
|
|
|
|
def test_from_dict(self):
|
|
chat_generator = MockChatGenerator(return_text="safe")
|
|
|
|
data = {
|
|
"type": "haystack.components.routers.llm_messages_router.LLMMessagesRouter",
|
|
"init_parameters": {
|
|
"chat_generator": chat_generator.to_dict(),
|
|
"output_names": ["safe", "unsafe"],
|
|
"output_patterns": ["safe", "unsafe"],
|
|
"system_prompt": None,
|
|
},
|
|
}
|
|
|
|
router = LLMMessagesRouter.from_dict(data)
|
|
|
|
assert router._chat_generator.to_dict() == chat_generator.to_dict()
|
|
assert router._output_names == ["safe", "unsafe"]
|
|
assert router._output_patterns == ["safe", "unsafe"]
|
|
assert router._system_prompt is None
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.skipif(
|
|
not os.environ.get("OPENAI_API_KEY", None),
|
|
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
|
)
|
|
def test_live_run(self):
|
|
system_prompt = "Classify the messages into safe or unsafe. Respond with the label only, no other text."
|
|
router = LLMMessagesRouter(
|
|
chat_generator=OpenAIChatGenerator(model="gpt-4.1-mini"),
|
|
system_prompt=system_prompt,
|
|
output_names=["safe", "unsafe"],
|
|
output_patterns=[r"(?i)safe", r"(?i)unsafe"],
|
|
)
|
|
|
|
messages = [ChatMessage.from_user("Hello")]
|
|
result = router.run(messages)
|
|
print(result)
|
|
|
|
assert result["safe"] == messages
|
|
assert result["chat_generator_text"].lower() == "safe"
|
|
assert "unsafe" not in result
|
|
assert "unmatched" not in result
|