Vladimir Blagojevic 90edcdacee
feat: Add FallbackChatGenerator (#9859)
* Add FallbackChatGenerator

* Update licence files

* Use typing.Optional/Union for Python 3.9 compat

* Use the right logger

* Lint fix

* PR review

* Rewrite release note

* Add FallbackChatGenerator to docs

* Update haystack/components/generators/chat/fallback.py

Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com>

* Rename generator -> chat_generators

* Lint

* Rename generators -> chat_generators in meta, docs, tests

* Update haystack/components/generators/chat/fallback.py

Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com>

* Update pydocs

* Minor pydocs fix

---------

Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com>
Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com>
2025-10-17 15:46:31 +02:00

357 lines
14 KiB
Python

# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import asyncio
import time
from typing import Any, Optional, Union
from urllib.error import HTTPError as URLLibHTTPError
import pytest
from haystack import component, default_from_dict, default_to_dict
from haystack.components.generators.chat.fallback import FallbackChatGenerator
from haystack.dataclasses import ChatMessage, StreamingCallbackT
from haystack.tools import Tool, Toolset
@component
class _DummySuccessGen:
def __init__(self, text: str = "ok", delay: float = 0.0, streaming_callback: Optional[StreamingCallbackT] = None):
self.text = text
self.delay = delay
self.streaming_callback = streaming_callback
def to_dict(self) -> dict[str, Any]:
return default_to_dict(self, text=self.text, delay=self.delay, streaming_callback=None)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "_DummySuccessGen":
return default_from_dict(cls, data)
def run(
self,
messages: list[ChatMessage],
generation_kwargs: Optional[dict[str, Any]] = None,
tools: Optional[Union[list[Tool], Toolset]] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
) -> dict[str, Any]:
if self.delay:
time.sleep(self.delay)
if streaming_callback:
streaming_callback({"dummy": True}) # type: ignore[arg-type]
return {"replies": [ChatMessage.from_assistant(self.text)], "meta": {"dummy_meta": True}}
async def run_async(
self,
messages: list[ChatMessage],
generation_kwargs: Optional[dict[str, Any]] = None,
tools: Optional[Union[list[Tool], Toolset]] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
) -> dict[str, Any]:
if self.delay:
await asyncio.sleep(self.delay)
if streaming_callback:
await asyncio.sleep(0)
streaming_callback({"dummy": True}) # type: ignore[arg-type]
return {"replies": [ChatMessage.from_assistant(self.text)], "meta": {"dummy_meta": True}}
@component
class _DummyFailGen:
def __init__(self, exc: Optional[Exception] = None, delay: float = 0.0):
self.exc = exc or RuntimeError("boom")
self.delay = delay
def to_dict(self) -> dict[str, Any]:
return default_to_dict(self, exc={"message": str(self.exc)}, delay=self.delay)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "_DummyFailGen":
init = data.get("init_parameters", {})
msg = None
if isinstance(init.get("exc"), dict):
msg = init.get("exc", {}).get("message")
return cls(exc=RuntimeError(msg or "boom"), delay=init.get("delay", 0.0))
def run(
self,
messages: list[ChatMessage],
generation_kwargs: Optional[dict[str, Any]] = None,
tools: Optional[Union[list[Tool], Toolset]] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
) -> dict[str, Any]:
if self.delay:
time.sleep(self.delay)
raise self.exc
async def run_async(
self,
messages: list[ChatMessage],
generation_kwargs: Optional[dict[str, Any]] = None,
tools: Optional[Union[list[Tool], Toolset]] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
) -> dict[str, Any]:
if self.delay:
await asyncio.sleep(self.delay)
raise self.exc
def test_init_validation():
with pytest.raises(ValueError):
FallbackChatGenerator(chat_generators=[])
gen = FallbackChatGenerator(chat_generators=[_DummySuccessGen(text="A")])
assert len(gen.chat_generators) == 1
def test_sequential_first_success():
gen = FallbackChatGenerator(chat_generators=[_DummySuccessGen(text="A")])
res = gen.run([ChatMessage.from_user("hi")])
assert res["replies"][0].text == "A"
assert res["meta"]["successful_chat_generator_index"] == 0
assert res["meta"]["total_attempts"] == 1
def test_sequential_second_success_after_failure():
gen = FallbackChatGenerator(chat_generators=[_DummyFailGen(), _DummySuccessGen(text="B")])
res = gen.run([ChatMessage.from_user("hi")])
assert res["replies"][0].text == "B"
assert res["meta"]["successful_chat_generator_index"] == 1
assert res["meta"]["failed_chat_generators"]
def test_all_fail_raises():
gen = FallbackChatGenerator(chat_generators=[_DummyFailGen(), _DummyFailGen()])
with pytest.raises(RuntimeError):
gen.run([ChatMessage.from_user("hi")])
def test_timeout_handling_sync():
slow = _DummySuccessGen(text="slow", delay=0.01)
fast = _DummySuccessGen(text="fast", delay=0.0)
gen = FallbackChatGenerator(chat_generators=[slow, fast])
res = gen.run([ChatMessage.from_user("hi")])
assert res["replies"][0].text == "slow"
@pytest.mark.asyncio
async def test_timeout_handling_async():
slow = _DummySuccessGen(text="slow", delay=0.01)
fast = _DummySuccessGen(text="fast", delay=0.0)
gen = FallbackChatGenerator(chat_generators=[slow, fast])
res = await gen.run_async([ChatMessage.from_user("hi")])
assert res["replies"][0].text == "slow"
def test_streaming_callback_forwarding_sync():
calls: list[Any] = []
def cb(x: Any) -> None:
calls.append(x)
gen = FallbackChatGenerator(chat_generators=[_DummySuccessGen(text="A")])
_ = gen.run([ChatMessage.from_user("hi")], streaming_callback=cb)
assert calls
@pytest.mark.asyncio
async def test_streaming_callback_forwarding_async():
calls: list[Any] = []
def cb(x: Any) -> None:
calls.append(x)
gen = FallbackChatGenerator(chat_generators=[_DummySuccessGen(text="A")])
_ = await gen.run_async([ChatMessage.from_user("hi")], streaming_callback=cb)
assert calls
def test_serialization_roundtrip():
original = FallbackChatGenerator(chat_generators=[_DummySuccessGen(text="hello")])
data = original.to_dict()
restored = FallbackChatGenerator.from_dict(data)
assert isinstance(restored, FallbackChatGenerator)
assert len(restored.chat_generators) == 1
res = restored.run([ChatMessage.from_user("hi")])
assert res["replies"][0].text == "hello"
original = FallbackChatGenerator(chat_generators=[_DummySuccessGen(text="hello"), _DummySuccessGen(text="world")])
data = original.to_dict()
restored = FallbackChatGenerator.from_dict(data)
assert isinstance(restored, FallbackChatGenerator)
assert len(restored.chat_generators) == 2
res = restored.run([ChatMessage.from_user("hi")])
assert res["replies"][0].text == "hello"
def test_automatic_completion_mode_without_streaming():
gen = FallbackChatGenerator(chat_generators=[_DummySuccessGen(text="completion")])
res = gen.run([ChatMessage.from_user("hi")])
assert res["replies"][0].text == "completion"
assert res["meta"]["successful_chat_generator_index"] == 0
def test_automatic_ttft_mode_with_streaming():
calls: list[Any] = []
def cb(x: Any) -> None:
calls.append(x)
gen = FallbackChatGenerator(chat_generators=[_DummySuccessGen(text="streaming")])
res = gen.run([ChatMessage.from_user("hi")], streaming_callback=cb)
assert res["replies"][0].text == "streaming"
assert calls
@pytest.mark.asyncio
async def test_automatic_ttft_mode_with_streaming_async():
calls: list[Any] = []
def cb(x: Any) -> None:
calls.append(x)
gen = FallbackChatGenerator(chat_generators=[_DummySuccessGen(text="streaming_async")])
res = await gen.run_async([ChatMessage.from_user("hi")], streaming_callback=cb)
assert res["replies"][0].text == "streaming_async"
assert calls
def create_http_error(status_code: int, message: str) -> URLLibHTTPError:
return URLLibHTTPError("", status_code, message, {}, None)
@component
class _DummyHTTPErrorGen:
def __init__(self, text: str = "success", error: Optional[Exception] = None):
self.text = text
self.error = error
def to_dict(self) -> dict[str, Any]:
return default_to_dict(self, text=self.text, error=str(self.error) if self.error else None)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "_DummyHTTPErrorGen":
init = data.get("init_parameters", {})
error = None
if init.get("error"):
error = RuntimeError(init["error"])
return cls(text=init.get("text", "success"), error=error)
def run(
self,
messages: list[ChatMessage],
generation_kwargs: Optional[dict[str, Any]] = None,
tools: Optional[Union[list[Tool], Toolset]] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
) -> dict[str, Any]:
if self.error:
raise self.error
return {
"replies": [ChatMessage.from_assistant(self.text)],
"meta": {"error_type": type(self.error).__name__ if self.error else None},
}
def test_failover_trigger_429_rate_limit():
rate_limit_gen = _DummyHTTPErrorGen(text="rate_limited", error=create_http_error(429, "Rate limit exceeded"))
success_gen = _DummySuccessGen(text="success_after_rate_limit")
fallback = FallbackChatGenerator(chat_generators=[rate_limit_gen, success_gen])
result = fallback.run([ChatMessage.from_user("test")])
assert result["replies"][0].text == "success_after_rate_limit"
assert result["meta"]["successful_chat_generator_index"] == 1
assert result["meta"]["failed_chat_generators"] == ["_DummyHTTPErrorGen"]
def test_failover_trigger_401_authentication():
auth_error_gen = _DummyHTTPErrorGen(text="auth_failed", error=create_http_error(401, "Authentication failed"))
success_gen = _DummySuccessGen(text="success_after_auth")
fallback = FallbackChatGenerator(chat_generators=[auth_error_gen, success_gen])
result = fallback.run([ChatMessage.from_user("test")])
assert result["replies"][0].text == "success_after_auth"
assert result["meta"]["successful_chat_generator_index"] == 1
assert result["meta"]["failed_chat_generators"] == ["_DummyHTTPErrorGen"]
def test_failover_trigger_400_bad_request():
bad_request_gen = _DummyHTTPErrorGen(text="bad_request", error=create_http_error(400, "Context length exceeded"))
success_gen = _DummySuccessGen(text="success_after_bad_request")
fallback = FallbackChatGenerator(chat_generators=[bad_request_gen, success_gen])
result = fallback.run([ChatMessage.from_user("test")])
assert result["replies"][0].text == "success_after_bad_request"
assert result["meta"]["successful_chat_generator_index"] == 1
assert result["meta"]["failed_chat_generators"] == ["_DummyHTTPErrorGen"]
def test_failover_trigger_500_server_error():
server_error_gen = _DummyHTTPErrorGen(text="server_error", error=create_http_error(500, "Internal server error"))
success_gen = _DummySuccessGen(text="success_after_server_error")
fallback = FallbackChatGenerator(chat_generators=[server_error_gen, success_gen])
result = fallback.run([ChatMessage.from_user("test")])
assert result["replies"][0].text == "success_after_server_error"
assert result["meta"]["successful_chat_generator_index"] == 1
assert result["meta"]["failed_chat_generators"] == ["_DummyHTTPErrorGen"]
def test_failover_trigger_multiple_errors():
rate_limit_gen = _DummyHTTPErrorGen(text="rate_limited", error=create_http_error(429, "Rate limit exceeded"))
auth_error_gen = _DummyHTTPErrorGen(text="auth_failed", error=create_http_error(401, "Authentication failed"))
server_error_gen = _DummyHTTPErrorGen(text="server_error", error=create_http_error(500, "Internal server error"))
success_gen = _DummySuccessGen(text="success_after_all_errors")
fallback = FallbackChatGenerator(chat_generators=[rate_limit_gen, auth_error_gen, server_error_gen, success_gen])
result = fallback.run([ChatMessage.from_user("test")])
assert result["replies"][0].text == "success_after_all_errors"
assert result["meta"]["successful_chat_generator_index"] == 3
assert len(result["meta"]["failed_chat_generators"]) == 3
def test_failover_trigger_all_generators_fail():
rate_limit_gen = _DummyHTTPErrorGen(text="rate_limited", error=create_http_error(429, "Rate limit exceeded"))
auth_error_gen = _DummyHTTPErrorGen(text="auth_failed", error=create_http_error(401, "Authentication failed"))
server_error_gen = _DummyHTTPErrorGen(text="server_error", error=create_http_error(500, "Internal server error"))
fallback = FallbackChatGenerator(chat_generators=[rate_limit_gen, auth_error_gen, server_error_gen])
with pytest.raises(RuntimeError) as exc_info:
fallback.run([ChatMessage.from_user("test")])
error_msg = str(exc_info.value)
assert "All 3 chat generators failed" in error_msg
assert "Failed chat generators: [_DummyHTTPErrorGen, _DummyHTTPErrorGen, _DummyHTTPErrorGen]" in error_msg
@pytest.mark.asyncio
async def test_failover_trigger_429_rate_limit_async():
rate_limit_gen = _DummyHTTPErrorGen(text="rate_limited", error=create_http_error(429, "Rate limit exceeded"))
success_gen = _DummySuccessGen(text="success_after_rate_limit")
fallback = FallbackChatGenerator(chat_generators=[rate_limit_gen, success_gen])
result = await fallback.run_async([ChatMessage.from_user("test")])
assert result["replies"][0].text == "success_after_rate_limit"
assert result["meta"]["successful_chat_generator_index"] == 1
assert result["meta"]["failed_chat_generators"] == ["_DummyHTTPErrorGen"]
@pytest.mark.asyncio
async def test_failover_trigger_401_authentication_async():
auth_error_gen = _DummyHTTPErrorGen(text="auth_failed", error=create_http_error(401, "Authentication failed"))
success_gen = _DummySuccessGen(text="success_after_auth")
fallback = FallbackChatGenerator(chat_generators=[auth_error_gen, success_gen])
result = await fallback.run_async([ChatMessage.from_user("test")])
assert result["replies"][0].text == "success_after_auth"
assert result["meta"]["successful_chat_generator_index"] == 1
assert result["meta"]["failed_chat_generators"] == ["_DummyHTTPErrorGen"]