diff --git a/docs/pydoc/config/generators_api.yml b/docs/pydoc/config/generators_api.yml index fa1f638e9..a08743a2d 100644 --- a/docs/pydoc/config/generators_api.yml +++ b/docs/pydoc/config/generators_api.yml @@ -12,6 +12,7 @@ loaders: "chat/hugging_face_local", "chat/hugging_face_api", "chat/openai", + "chat/fallback", ] ignore_when_discovered: ["__init__"] processors: diff --git a/docs/pydoc/config_docusaurus/generators_api.yml b/docs/pydoc/config_docusaurus/generators_api.yml index 43e91c770..d7e51ae88 100644 --- a/docs/pydoc/config_docusaurus/generators_api.yml +++ b/docs/pydoc/config_docusaurus/generators_api.yml @@ -12,6 +12,7 @@ loaders: "chat/hugging_face_local", "chat/hugging_face_api", "chat/openai", + "chat/fallback", ] ignore_when_discovered: ["__init__"] processors: diff --git a/haystack/components/generators/chat/__init__.py b/haystack/components/generators/chat/__init__.py index 0f31584c5..c98bc263c 100644 --- a/haystack/components/generators/chat/__init__.py +++ b/haystack/components/generators/chat/__init__.py @@ -12,10 +12,12 @@ _import_structure = { "azure": ["AzureOpenAIChatGenerator"], "hugging_face_local": ["HuggingFaceLocalChatGenerator"], "hugging_face_api": ["HuggingFaceAPIChatGenerator"], + "fallback": ["FallbackChatGenerator"], } if TYPE_CHECKING: from .azure import AzureOpenAIChatGenerator as AzureOpenAIChatGenerator + from .fallback import FallbackChatGenerator as FallbackChatGenerator from .hugging_face_api import HuggingFaceAPIChatGenerator as HuggingFaceAPIChatGenerator from .hugging_face_local import HuggingFaceLocalChatGenerator as HuggingFaceLocalChatGenerator from .openai import OpenAIChatGenerator as OpenAIChatGenerator diff --git a/haystack/components/generators/chat/fallback.py b/haystack/components/generators/chat/fallback.py new file mode 100644 index 000000000..71ef9bef9 --- /dev/null +++ b/haystack/components/generators/chat/fallback.py @@ -0,0 +1,223 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import asyncio +from typing import Any, Union + +from haystack import component, default_from_dict, default_to_dict, logging +from haystack.components.generators.chat.types import ChatGenerator +from haystack.dataclasses import ChatMessage, StreamingCallbackT +from haystack.tools import Tool, Toolset +from haystack.utils.deserialization import deserialize_component_inplace + +logger = logging.getLogger(__name__) + + +@component +class FallbackChatGenerator: + """ + A chat generator wrapper that tries multiple chat generators sequentially. + + It forwards all parameters transparently to the underlying chat generators and returns the first successful result. + Calls chat generators sequentially until one succeeds. Falls back on any exception raised by a generator. + If all chat generators fail, it raises a RuntimeError with details. + + Timeout enforcement is fully delegated to the underlying chat generators. The fallback mechanism will only + work correctly if the underlying chat generators implement proper timeout handling and raise exceptions + when timeouts occur. For predictable latency guarantees, ensure your chat generators: + - Support a `timeout` parameter in their initialization + - Implement timeout as total wall-clock time (shared deadline for both streaming and non-streaming) + - Raise timeout exceptions (e.g., TimeoutError, asyncio.TimeoutError, httpx.TimeoutException) when exceeded + + Note: Most well-implemented chat generators (OpenAI, Anthropic, Cohere, etc.) support timeout parameters + with consistent semantics. For HTTP-based LLM providers, a single timeout value (e.g., `timeout=30`) + typically applies to all connection phases: connection setup, read, write, and pool. For streaming + responses, read timeout is the maximum gap between chunks. For non-streaming, it's the time limit for + receiving the complete response. + + Failover is automatically triggered when a generator raises any exception, including: + - Timeout errors (if the generator implements and raises them) + - Rate limit errors (429) + - Authentication errors (401) + - Context length errors (400) + - Server errors (500+) + - Any other exception + """ + + def __init__(self, chat_generators: list[ChatGenerator]): + """ + Creates an instance of FallbackChatGenerator. + + :param chat_generators: A non-empty list of chat generator components to try in order. + """ + if not chat_generators: + msg = "'chat_generators' must be a non-empty list" + raise ValueError(msg) + + self.chat_generators = list(chat_generators) + + def to_dict(self) -> dict[str, Any]: + """Serialize the component, including nested chat generators when they support serialization.""" + return default_to_dict( + self, chat_generators=[gen.to_dict() for gen in self.chat_generators if hasattr(gen, "to_dict")] + ) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> FallbackChatGenerator: + """Rebuild the component from a serialized representation, restoring nested chat generators.""" + # Reconstruct nested chat generators from their serialized dicts + init_params = data.get("init_parameters", {}) + serialized = init_params.get("chat_generators") or [] + deserialized: list[Any] = [] + for g in serialized: + # Use the generic component deserializer available in Haystack + holder = {"component": g} + deserialize_component_inplace(holder, key="component") + deserialized.append(holder["component"]) + init_params["chat_generators"] = deserialized + data["init_parameters"] = init_params + return default_from_dict(cls, data) + + def _run_single_sync( # pylint: disable=too-many-positional-arguments + self, + gen: Any, + messages: list[ChatMessage], + generation_kwargs: Union[dict[str, Any], None], + tools: Union[list[Tool], Toolset, None], + streaming_callback: Union[StreamingCallbackT, None], + ) -> dict[str, Any]: + return gen.run( + messages=messages, generation_kwargs=generation_kwargs, tools=tools, streaming_callback=streaming_callback + ) + + async def _run_single_async( # pylint: disable=too-many-positional-arguments + self, + gen: Any, + messages: list[ChatMessage], + generation_kwargs: Union[dict[str, Any], None], + tools: Union[list[Tool], Toolset, None], + streaming_callback: Union[StreamingCallbackT, None], + ) -> dict[str, Any]: + if hasattr(gen, "run_async") and callable(gen.run_async): + return await gen.run_async( + messages=messages, + generation_kwargs=generation_kwargs, + tools=tools, + streaming_callback=streaming_callback, + ) + return await asyncio.to_thread( + gen.run, + messages=messages, + generation_kwargs=generation_kwargs, + tools=tools, + streaming_callback=streaming_callback, + ) + + @component.output_types(replies=list[ChatMessage], meta=dict[str, Any]) + def run( + self, + messages: list[ChatMessage], + generation_kwargs: Union[dict[str, Any], None] = None, + tools: Union[list[Tool], Toolset, None] = None, + streaming_callback: Union[StreamingCallbackT, None] = None, + ) -> dict[str, Any]: + """ + Execute chat generators sequentially until one succeeds. + + :param messages: The conversation history as a list of ChatMessage instances. + :param generation_kwargs: Optional parameters for the chat generator (e.g., temperature, max_tokens). + :param tools: Optional Tool instances or Toolset for function calling capabilities. + :param streaming_callback: Optional callable for handling streaming responses. + :returns: A dictionary with: + - "replies": Generated ChatMessage instances from the first successful generator. + - "meta": Execution metadata including successful_chat_generator_index, successful_chat_generator_class, + total_attempts, failed_chat_generators, plus any metadata from the successful generator. + :raises RuntimeError: If all chat generators fail. + """ + failed: list[str] = [] + last_error: Union[BaseException, None] = None + + for idx, gen in enumerate(self.chat_generators): + gen_name = gen.__class__.__name__ + try: + result = self._run_single_sync(gen, messages, generation_kwargs, tools, streaming_callback) + replies = result.get("replies", []) + meta = dict(result.get("meta", {})) + meta.update( + { + "successful_chat_generator_index": idx, + "successful_chat_generator_class": gen_name, + "total_attempts": idx + 1, + "failed_chat_generators": failed, + } + ) + return {"replies": replies, "meta": meta} + except Exception as e: # noqa: BLE001 - fallback logic should handle any exception + logger.warning( + "ChatGenerator {chat_generator} failed with error: {error}", chat_generator=gen_name, error=e + ) + failed.append(gen_name) + last_error = e + + failed_names = ", ".join(failed) + msg = ( + f"All {len(self.chat_generators)} chat generators failed. " + f"Last error: {last_error}. Failed chat generators: [{failed_names}]" + ) + raise RuntimeError(msg) + + @component.output_types(replies=list[ChatMessage], meta=dict[str, Any]) + async def run_async( + self, + messages: list[ChatMessage], + generation_kwargs: Union[dict[str, Any], None] = None, + tools: Union[list[Tool], Toolset, None] = None, + streaming_callback: Union[StreamingCallbackT, None] = None, + ) -> dict[str, Any]: + """ + Asynchronously execute chat generators sequentially until one succeeds. + + :param messages: The conversation history as a list of ChatMessage instances. + :param generation_kwargs: Optional parameters for the chat generator (e.g., temperature, max_tokens). + :param tools: Optional Tool instances or Toolset for function calling capabilities. + :param streaming_callback: Optional callable for handling streaming responses. + :returns: A dictionary with: + - "replies": Generated ChatMessage instances from the first successful generator. + - "meta": Execution metadata including successful_chat_generator_index, successful_chat_generator_class, + total_attempts, failed_chat_generators, plus any metadata from the successful generator. + :raises RuntimeError: If all chat generators fail. + """ + failed: list[str] = [] + last_error: Union[BaseException, None] = None + + for idx, gen in enumerate(self.chat_generators): + gen_name = gen.__class__.__name__ + try: + result = await self._run_single_async(gen, messages, generation_kwargs, tools, streaming_callback) + replies = result.get("replies", []) + meta = dict(result.get("meta", {})) + meta.update( + { + "successful_chat_generator_index": idx, + "successful_chat_generator_class": gen_name, + "total_attempts": idx + 1, + "failed_chat_generators": failed, + } + ) + return {"replies": replies, "meta": meta} + except Exception as e: # noqa: BLE001 - fallback logic should handle any exception + logger.warning( + "ChatGenerator {chat_generator} failed with error: {error}", chat_generator=gen_name, error=e + ) + failed.append(gen_name) + last_error = e + + failed_names = ", ".join(failed) + msg = ( + f"All {len(self.chat_generators)} chat generators failed. " + f"Last error: {last_error}. Failed chat generators: [{failed_names}]" + ) + raise RuntimeError(msg) diff --git a/releasenotes/notes/add-fallback-chat-generator-ffe557ca01fcdaca.yaml b/releasenotes/notes/add-fallback-chat-generator-ffe557ca01fcdaca.yaml new file mode 100644 index 000000000..6a83ce556 --- /dev/null +++ b/releasenotes/notes/add-fallback-chat-generator-ffe557ca01fcdaca.yaml @@ -0,0 +1,6 @@ +--- +highlights: > + Introduced `FallbackChatGenerator` that tries multiple chat providers one by one, improving reliability in production and making sure you get answers even when some provider fails. +features: + - | + Added `FallbackChatGenerator` that automatically retries different chat generators and returns first successful response with detailed information about which providers were tried. diff --git a/test/components/generators/chat/test_fallback.py b/test/components/generators/chat/test_fallback.py new file mode 100644 index 000000000..6e12b9d0e --- /dev/null +++ b/test/components/generators/chat/test_fallback.py @@ -0,0 +1,356 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import time +from typing import Any, Optional, Union +from urllib.error import HTTPError as URLLibHTTPError + +import pytest + +from haystack import component, default_from_dict, default_to_dict +from haystack.components.generators.chat.fallback import FallbackChatGenerator +from haystack.dataclasses import ChatMessage, StreamingCallbackT +from haystack.tools import Tool, Toolset + + +@component +class _DummySuccessGen: + def __init__(self, text: str = "ok", delay: float = 0.0, streaming_callback: Optional[StreamingCallbackT] = None): + self.text = text + self.delay = delay + self.streaming_callback = streaming_callback + + def to_dict(self) -> dict[str, Any]: + return default_to_dict(self, text=self.text, delay=self.delay, streaming_callback=None) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "_DummySuccessGen": + return default_from_dict(cls, data) + + def run( + self, + messages: list[ChatMessage], + generation_kwargs: Optional[dict[str, Any]] = None, + tools: Optional[Union[list[Tool], Toolset]] = None, + streaming_callback: Optional[StreamingCallbackT] = None, + ) -> dict[str, Any]: + if self.delay: + time.sleep(self.delay) + if streaming_callback: + streaming_callback({"dummy": True}) # type: ignore[arg-type] + return {"replies": [ChatMessage.from_assistant(self.text)], "meta": {"dummy_meta": True}} + + async def run_async( + self, + messages: list[ChatMessage], + generation_kwargs: Optional[dict[str, Any]] = None, + tools: Optional[Union[list[Tool], Toolset]] = None, + streaming_callback: Optional[StreamingCallbackT] = None, + ) -> dict[str, Any]: + if self.delay: + await asyncio.sleep(self.delay) + if streaming_callback: + await asyncio.sleep(0) + streaming_callback({"dummy": True}) # type: ignore[arg-type] + return {"replies": [ChatMessage.from_assistant(self.text)], "meta": {"dummy_meta": True}} + + +@component +class _DummyFailGen: + def __init__(self, exc: Optional[Exception] = None, delay: float = 0.0): + self.exc = exc or RuntimeError("boom") + self.delay = delay + + def to_dict(self) -> dict[str, Any]: + return default_to_dict(self, exc={"message": str(self.exc)}, delay=self.delay) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "_DummyFailGen": + init = data.get("init_parameters", {}) + msg = None + if isinstance(init.get("exc"), dict): + msg = init.get("exc", {}).get("message") + return cls(exc=RuntimeError(msg or "boom"), delay=init.get("delay", 0.0)) + + def run( + self, + messages: list[ChatMessage], + generation_kwargs: Optional[dict[str, Any]] = None, + tools: Optional[Union[list[Tool], Toolset]] = None, + streaming_callback: Optional[StreamingCallbackT] = None, + ) -> dict[str, Any]: + if self.delay: + time.sleep(self.delay) + raise self.exc + + async def run_async( + self, + messages: list[ChatMessage], + generation_kwargs: Optional[dict[str, Any]] = None, + tools: Optional[Union[list[Tool], Toolset]] = None, + streaming_callback: Optional[StreamingCallbackT] = None, + ) -> dict[str, Any]: + if self.delay: + await asyncio.sleep(self.delay) + raise self.exc + + +def test_init_validation(): + with pytest.raises(ValueError): + FallbackChatGenerator(chat_generators=[]) + + gen = FallbackChatGenerator(chat_generators=[_DummySuccessGen(text="A")]) + assert len(gen.chat_generators) == 1 + + +def test_sequential_first_success(): + gen = FallbackChatGenerator(chat_generators=[_DummySuccessGen(text="A")]) + res = gen.run([ChatMessage.from_user("hi")]) + assert res["replies"][0].text == "A" + assert res["meta"]["successful_chat_generator_index"] == 0 + assert res["meta"]["total_attempts"] == 1 + + +def test_sequential_second_success_after_failure(): + gen = FallbackChatGenerator(chat_generators=[_DummyFailGen(), _DummySuccessGen(text="B")]) + res = gen.run([ChatMessage.from_user("hi")]) + assert res["replies"][0].text == "B" + assert res["meta"]["successful_chat_generator_index"] == 1 + assert res["meta"]["failed_chat_generators"] + + +def test_all_fail_raises(): + gen = FallbackChatGenerator(chat_generators=[_DummyFailGen(), _DummyFailGen()]) + with pytest.raises(RuntimeError): + gen.run([ChatMessage.from_user("hi")]) + + +def test_timeout_handling_sync(): + slow = _DummySuccessGen(text="slow", delay=0.01) + fast = _DummySuccessGen(text="fast", delay=0.0) + gen = FallbackChatGenerator(chat_generators=[slow, fast]) + res = gen.run([ChatMessage.from_user("hi")]) + assert res["replies"][0].text == "slow" + + +@pytest.mark.asyncio +async def test_timeout_handling_async(): + slow = _DummySuccessGen(text="slow", delay=0.01) + fast = _DummySuccessGen(text="fast", delay=0.0) + gen = FallbackChatGenerator(chat_generators=[slow, fast]) + res = await gen.run_async([ChatMessage.from_user("hi")]) + assert res["replies"][0].text == "slow" + + +def test_streaming_callback_forwarding_sync(): + calls: list[Any] = [] + + def cb(x: Any) -> None: + calls.append(x) + + gen = FallbackChatGenerator(chat_generators=[_DummySuccessGen(text="A")]) + _ = gen.run([ChatMessage.from_user("hi")], streaming_callback=cb) + assert calls + + +@pytest.mark.asyncio +async def test_streaming_callback_forwarding_async(): + calls: list[Any] = [] + + def cb(x: Any) -> None: + calls.append(x) + + gen = FallbackChatGenerator(chat_generators=[_DummySuccessGen(text="A")]) + _ = await gen.run_async([ChatMessage.from_user("hi")], streaming_callback=cb) + assert calls + + +def test_serialization_roundtrip(): + original = FallbackChatGenerator(chat_generators=[_DummySuccessGen(text="hello")]) + data = original.to_dict() + restored = FallbackChatGenerator.from_dict(data) + assert isinstance(restored, FallbackChatGenerator) + assert len(restored.chat_generators) == 1 + res = restored.run([ChatMessage.from_user("hi")]) + assert res["replies"][0].text == "hello" + + original = FallbackChatGenerator(chat_generators=[_DummySuccessGen(text="hello"), _DummySuccessGen(text="world")]) + data = original.to_dict() + restored = FallbackChatGenerator.from_dict(data) + assert isinstance(restored, FallbackChatGenerator) + assert len(restored.chat_generators) == 2 + res = restored.run([ChatMessage.from_user("hi")]) + assert res["replies"][0].text == "hello" + + +def test_automatic_completion_mode_without_streaming(): + gen = FallbackChatGenerator(chat_generators=[_DummySuccessGen(text="completion")]) + res = gen.run([ChatMessage.from_user("hi")]) + assert res["replies"][0].text == "completion" + assert res["meta"]["successful_chat_generator_index"] == 0 + + +def test_automatic_ttft_mode_with_streaming(): + calls: list[Any] = [] + + def cb(x: Any) -> None: + calls.append(x) + + gen = FallbackChatGenerator(chat_generators=[_DummySuccessGen(text="streaming")]) + res = gen.run([ChatMessage.from_user("hi")], streaming_callback=cb) + assert res["replies"][0].text == "streaming" + assert calls + + +@pytest.mark.asyncio +async def test_automatic_ttft_mode_with_streaming_async(): + calls: list[Any] = [] + + def cb(x: Any) -> None: + calls.append(x) + + gen = FallbackChatGenerator(chat_generators=[_DummySuccessGen(text="streaming_async")]) + res = await gen.run_async([ChatMessage.from_user("hi")], streaming_callback=cb) + assert res["replies"][0].text == "streaming_async" + assert calls + + +def create_http_error(status_code: int, message: str) -> URLLibHTTPError: + return URLLibHTTPError("", status_code, message, {}, None) + + +@component +class _DummyHTTPErrorGen: + def __init__(self, text: str = "success", error: Optional[Exception] = None): + self.text = text + self.error = error + + def to_dict(self) -> dict[str, Any]: + return default_to_dict(self, text=self.text, error=str(self.error) if self.error else None) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "_DummyHTTPErrorGen": + init = data.get("init_parameters", {}) + error = None + if init.get("error"): + error = RuntimeError(init["error"]) + return cls(text=init.get("text", "success"), error=error) + + def run( + self, + messages: list[ChatMessage], + generation_kwargs: Optional[dict[str, Any]] = None, + tools: Optional[Union[list[Tool], Toolset]] = None, + streaming_callback: Optional[StreamingCallbackT] = None, + ) -> dict[str, Any]: + if self.error: + raise self.error + return { + "replies": [ChatMessage.from_assistant(self.text)], + "meta": {"error_type": type(self.error).__name__ if self.error else None}, + } + + +def test_failover_trigger_429_rate_limit(): + rate_limit_gen = _DummyHTTPErrorGen(text="rate_limited", error=create_http_error(429, "Rate limit exceeded")) + success_gen = _DummySuccessGen(text="success_after_rate_limit") + + fallback = FallbackChatGenerator(chat_generators=[rate_limit_gen, success_gen]) + result = fallback.run([ChatMessage.from_user("test")]) + + assert result["replies"][0].text == "success_after_rate_limit" + assert result["meta"]["successful_chat_generator_index"] == 1 + assert result["meta"]["failed_chat_generators"] == ["_DummyHTTPErrorGen"] + + +def test_failover_trigger_401_authentication(): + auth_error_gen = _DummyHTTPErrorGen(text="auth_failed", error=create_http_error(401, "Authentication failed")) + success_gen = _DummySuccessGen(text="success_after_auth") + + fallback = FallbackChatGenerator(chat_generators=[auth_error_gen, success_gen]) + result = fallback.run([ChatMessage.from_user("test")]) + + assert result["replies"][0].text == "success_after_auth" + assert result["meta"]["successful_chat_generator_index"] == 1 + assert result["meta"]["failed_chat_generators"] == ["_DummyHTTPErrorGen"] + + +def test_failover_trigger_400_bad_request(): + bad_request_gen = _DummyHTTPErrorGen(text="bad_request", error=create_http_error(400, "Context length exceeded")) + success_gen = _DummySuccessGen(text="success_after_bad_request") + + fallback = FallbackChatGenerator(chat_generators=[bad_request_gen, success_gen]) + result = fallback.run([ChatMessage.from_user("test")]) + + assert result["replies"][0].text == "success_after_bad_request" + assert result["meta"]["successful_chat_generator_index"] == 1 + assert result["meta"]["failed_chat_generators"] == ["_DummyHTTPErrorGen"] + + +def test_failover_trigger_500_server_error(): + server_error_gen = _DummyHTTPErrorGen(text="server_error", error=create_http_error(500, "Internal server error")) + success_gen = _DummySuccessGen(text="success_after_server_error") + + fallback = FallbackChatGenerator(chat_generators=[server_error_gen, success_gen]) + result = fallback.run([ChatMessage.from_user("test")]) + + assert result["replies"][0].text == "success_after_server_error" + assert result["meta"]["successful_chat_generator_index"] == 1 + assert result["meta"]["failed_chat_generators"] == ["_DummyHTTPErrorGen"] + + +def test_failover_trigger_multiple_errors(): + rate_limit_gen = _DummyHTTPErrorGen(text="rate_limited", error=create_http_error(429, "Rate limit exceeded")) + auth_error_gen = _DummyHTTPErrorGen(text="auth_failed", error=create_http_error(401, "Authentication failed")) + server_error_gen = _DummyHTTPErrorGen(text="server_error", error=create_http_error(500, "Internal server error")) + success_gen = _DummySuccessGen(text="success_after_all_errors") + + fallback = FallbackChatGenerator(chat_generators=[rate_limit_gen, auth_error_gen, server_error_gen, success_gen]) + result = fallback.run([ChatMessage.from_user("test")]) + + assert result["replies"][0].text == "success_after_all_errors" + assert result["meta"]["successful_chat_generator_index"] == 3 + assert len(result["meta"]["failed_chat_generators"]) == 3 + + +def test_failover_trigger_all_generators_fail(): + rate_limit_gen = _DummyHTTPErrorGen(text="rate_limited", error=create_http_error(429, "Rate limit exceeded")) + auth_error_gen = _DummyHTTPErrorGen(text="auth_failed", error=create_http_error(401, "Authentication failed")) + server_error_gen = _DummyHTTPErrorGen(text="server_error", error=create_http_error(500, "Internal server error")) + + fallback = FallbackChatGenerator(chat_generators=[rate_limit_gen, auth_error_gen, server_error_gen]) + + with pytest.raises(RuntimeError) as exc_info: + fallback.run([ChatMessage.from_user("test")]) + + error_msg = str(exc_info.value) + assert "All 3 chat generators failed" in error_msg + assert "Failed chat generators: [_DummyHTTPErrorGen, _DummyHTTPErrorGen, _DummyHTTPErrorGen]" in error_msg + + +@pytest.mark.asyncio +async def test_failover_trigger_429_rate_limit_async(): + rate_limit_gen = _DummyHTTPErrorGen(text="rate_limited", error=create_http_error(429, "Rate limit exceeded")) + success_gen = _DummySuccessGen(text="success_after_rate_limit") + + fallback = FallbackChatGenerator(chat_generators=[rate_limit_gen, success_gen]) + result = await fallback.run_async([ChatMessage.from_user("test")]) + + assert result["replies"][0].text == "success_after_rate_limit" + assert result["meta"]["successful_chat_generator_index"] == 1 + assert result["meta"]["failed_chat_generators"] == ["_DummyHTTPErrorGen"] + + +@pytest.mark.asyncio +async def test_failover_trigger_401_authentication_async(): + auth_error_gen = _DummyHTTPErrorGen(text="auth_failed", error=create_http_error(401, "Authentication failed")) + success_gen = _DummySuccessGen(text="success_after_auth") + + fallback = FallbackChatGenerator(chat_generators=[auth_error_gen, success_gen]) + result = await fallback.run_async([ChatMessage.from_user("test")]) + + assert result["replies"][0].text == "success_after_auth" + assert result["meta"]["successful_chat_generator_index"] == 1 + assert result["meta"]["failed_chat_generators"] == ["_DummyHTTPErrorGen"]