mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-30 16:47:19 +00:00
feat: Add FallbackChatGenerator (#9859)
* Add FallbackChatGenerator * Update licence files * Use typing.Optional/Union for Python 3.9 compat * Use the right logger * Lint fix * PR review * Rewrite release note * Add FallbackChatGenerator to docs * Update haystack/components/generators/chat/fallback.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> * Rename generator -> chat_generators * Lint * Rename generators -> chat_generators in meta, docs, tests * Update haystack/components/generators/chat/fallback.py Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com> * Update pydocs * Minor pydocs fix --------- Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com>
This commit is contained in:
parent
a43c47b635
commit
90edcdacee
@ -12,6 +12,7 @@ loaders:
|
||||
"chat/hugging_face_local",
|
||||
"chat/hugging_face_api",
|
||||
"chat/openai",
|
||||
"chat/fallback",
|
||||
]
|
||||
ignore_when_discovered: ["__init__"]
|
||||
processors:
|
||||
|
||||
@ -12,6 +12,7 @@ loaders:
|
||||
"chat/hugging_face_local",
|
||||
"chat/hugging_face_api",
|
||||
"chat/openai",
|
||||
"chat/fallback",
|
||||
]
|
||||
ignore_when_discovered: ["__init__"]
|
||||
processors:
|
||||
|
||||
@ -12,10 +12,12 @@ _import_structure = {
|
||||
"azure": ["AzureOpenAIChatGenerator"],
|
||||
"hugging_face_local": ["HuggingFaceLocalChatGenerator"],
|
||||
"hugging_face_api": ["HuggingFaceAPIChatGenerator"],
|
||||
"fallback": ["FallbackChatGenerator"],
|
||||
}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .azure import AzureOpenAIChatGenerator as AzureOpenAIChatGenerator
|
||||
from .fallback import FallbackChatGenerator as FallbackChatGenerator
|
||||
from .hugging_face_api import HuggingFaceAPIChatGenerator as HuggingFaceAPIChatGenerator
|
||||
from .hugging_face_local import HuggingFaceLocalChatGenerator as HuggingFaceLocalChatGenerator
|
||||
from .openai import OpenAIChatGenerator as OpenAIChatGenerator
|
||||
|
||||
223
haystack/components/generators/chat/fallback.py
Normal file
223
haystack/components/generators/chat/fallback.py
Normal file
@ -0,0 +1,223 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Union
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.components.generators.chat.types import ChatGenerator
|
||||
from haystack.dataclasses import ChatMessage, StreamingCallbackT
|
||||
from haystack.tools import Tool, Toolset
|
||||
from haystack.utils.deserialization import deserialize_component_inplace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@component
|
||||
class FallbackChatGenerator:
|
||||
"""
|
||||
A chat generator wrapper that tries multiple chat generators sequentially.
|
||||
|
||||
It forwards all parameters transparently to the underlying chat generators and returns the first successful result.
|
||||
Calls chat generators sequentially until one succeeds. Falls back on any exception raised by a generator.
|
||||
If all chat generators fail, it raises a RuntimeError with details.
|
||||
|
||||
Timeout enforcement is fully delegated to the underlying chat generators. The fallback mechanism will only
|
||||
work correctly if the underlying chat generators implement proper timeout handling and raise exceptions
|
||||
when timeouts occur. For predictable latency guarantees, ensure your chat generators:
|
||||
- Support a `timeout` parameter in their initialization
|
||||
- Implement timeout as total wall-clock time (shared deadline for both streaming and non-streaming)
|
||||
- Raise timeout exceptions (e.g., TimeoutError, asyncio.TimeoutError, httpx.TimeoutException) when exceeded
|
||||
|
||||
Note: Most well-implemented chat generators (OpenAI, Anthropic, Cohere, etc.) support timeout parameters
|
||||
with consistent semantics. For HTTP-based LLM providers, a single timeout value (e.g., `timeout=30`)
|
||||
typically applies to all connection phases: connection setup, read, write, and pool. For streaming
|
||||
responses, read timeout is the maximum gap between chunks. For non-streaming, it's the time limit for
|
||||
receiving the complete response.
|
||||
|
||||
Failover is automatically triggered when a generator raises any exception, including:
|
||||
- Timeout errors (if the generator implements and raises them)
|
||||
- Rate limit errors (429)
|
||||
- Authentication errors (401)
|
||||
- Context length errors (400)
|
||||
- Server errors (500+)
|
||||
- Any other exception
|
||||
"""
|
||||
|
||||
def __init__(self, chat_generators: list[ChatGenerator]):
|
||||
"""
|
||||
Creates an instance of FallbackChatGenerator.
|
||||
|
||||
:param chat_generators: A non-empty list of chat generator components to try in order.
|
||||
"""
|
||||
if not chat_generators:
|
||||
msg = "'chat_generators' must be a non-empty list"
|
||||
raise ValueError(msg)
|
||||
|
||||
self.chat_generators = list(chat_generators)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Serialize the component, including nested chat generators when they support serialization."""
|
||||
return default_to_dict(
|
||||
self, chat_generators=[gen.to_dict() for gen in self.chat_generators if hasattr(gen, "to_dict")]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> FallbackChatGenerator:
|
||||
"""Rebuild the component from a serialized representation, restoring nested chat generators."""
|
||||
# Reconstruct nested chat generators from their serialized dicts
|
||||
init_params = data.get("init_parameters", {})
|
||||
serialized = init_params.get("chat_generators") or []
|
||||
deserialized: list[Any] = []
|
||||
for g in serialized:
|
||||
# Use the generic component deserializer available in Haystack
|
||||
holder = {"component": g}
|
||||
deserialize_component_inplace(holder, key="component")
|
||||
deserialized.append(holder["component"])
|
||||
init_params["chat_generators"] = deserialized
|
||||
data["init_parameters"] = init_params
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def _run_single_sync( # pylint: disable=too-many-positional-arguments
|
||||
self,
|
||||
gen: Any,
|
||||
messages: list[ChatMessage],
|
||||
generation_kwargs: Union[dict[str, Any], None],
|
||||
tools: Union[list[Tool], Toolset, None],
|
||||
streaming_callback: Union[StreamingCallbackT, None],
|
||||
) -> dict[str, Any]:
|
||||
return gen.run(
|
||||
messages=messages, generation_kwargs=generation_kwargs, tools=tools, streaming_callback=streaming_callback
|
||||
)
|
||||
|
||||
async def _run_single_async( # pylint: disable=too-many-positional-arguments
|
||||
self,
|
||||
gen: Any,
|
||||
messages: list[ChatMessage],
|
||||
generation_kwargs: Union[dict[str, Any], None],
|
||||
tools: Union[list[Tool], Toolset, None],
|
||||
streaming_callback: Union[StreamingCallbackT, None],
|
||||
) -> dict[str, Any]:
|
||||
if hasattr(gen, "run_async") and callable(gen.run_async):
|
||||
return await gen.run_async(
|
||||
messages=messages,
|
||||
generation_kwargs=generation_kwargs,
|
||||
tools=tools,
|
||||
streaming_callback=streaming_callback,
|
||||
)
|
||||
return await asyncio.to_thread(
|
||||
gen.run,
|
||||
messages=messages,
|
||||
generation_kwargs=generation_kwargs,
|
||||
tools=tools,
|
||||
streaming_callback=streaming_callback,
|
||||
)
|
||||
|
||||
@component.output_types(replies=list[ChatMessage], meta=dict[str, Any])
|
||||
def run(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
generation_kwargs: Union[dict[str, Any], None] = None,
|
||||
tools: Union[list[Tool], Toolset, None] = None,
|
||||
streaming_callback: Union[StreamingCallbackT, None] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Execute chat generators sequentially until one succeeds.
|
||||
|
||||
:param messages: The conversation history as a list of ChatMessage instances.
|
||||
:param generation_kwargs: Optional parameters for the chat generator (e.g., temperature, max_tokens).
|
||||
:param tools: Optional Tool instances or Toolset for function calling capabilities.
|
||||
:param streaming_callback: Optional callable for handling streaming responses.
|
||||
:returns: A dictionary with:
|
||||
- "replies": Generated ChatMessage instances from the first successful generator.
|
||||
- "meta": Execution metadata including successful_chat_generator_index, successful_chat_generator_class,
|
||||
total_attempts, failed_chat_generators, plus any metadata from the successful generator.
|
||||
:raises RuntimeError: If all chat generators fail.
|
||||
"""
|
||||
failed: list[str] = []
|
||||
last_error: Union[BaseException, None] = None
|
||||
|
||||
for idx, gen in enumerate(self.chat_generators):
|
||||
gen_name = gen.__class__.__name__
|
||||
try:
|
||||
result = self._run_single_sync(gen, messages, generation_kwargs, tools, streaming_callback)
|
||||
replies = result.get("replies", [])
|
||||
meta = dict(result.get("meta", {}))
|
||||
meta.update(
|
||||
{
|
||||
"successful_chat_generator_index": idx,
|
||||
"successful_chat_generator_class": gen_name,
|
||||
"total_attempts": idx + 1,
|
||||
"failed_chat_generators": failed,
|
||||
}
|
||||
)
|
||||
return {"replies": replies, "meta": meta}
|
||||
except Exception as e: # noqa: BLE001 - fallback logic should handle any exception
|
||||
logger.warning(
|
||||
"ChatGenerator {chat_generator} failed with error: {error}", chat_generator=gen_name, error=e
|
||||
)
|
||||
failed.append(gen_name)
|
||||
last_error = e
|
||||
|
||||
failed_names = ", ".join(failed)
|
||||
msg = (
|
||||
f"All {len(self.chat_generators)} chat generators failed. "
|
||||
f"Last error: {last_error}. Failed chat generators: [{failed_names}]"
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
@component.output_types(replies=list[ChatMessage], meta=dict[str, Any])
|
||||
async def run_async(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
generation_kwargs: Union[dict[str, Any], None] = None,
|
||||
tools: Union[list[Tool], Toolset, None] = None,
|
||||
streaming_callback: Union[StreamingCallbackT, None] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Asynchronously execute chat generators sequentially until one succeeds.
|
||||
|
||||
:param messages: The conversation history as a list of ChatMessage instances.
|
||||
:param generation_kwargs: Optional parameters for the chat generator (e.g., temperature, max_tokens).
|
||||
:param tools: Optional Tool instances or Toolset for function calling capabilities.
|
||||
:param streaming_callback: Optional callable for handling streaming responses.
|
||||
:returns: A dictionary with:
|
||||
- "replies": Generated ChatMessage instances from the first successful generator.
|
||||
- "meta": Execution metadata including successful_chat_generator_index, successful_chat_generator_class,
|
||||
total_attempts, failed_chat_generators, plus any metadata from the successful generator.
|
||||
:raises RuntimeError: If all chat generators fail.
|
||||
"""
|
||||
failed: list[str] = []
|
||||
last_error: Union[BaseException, None] = None
|
||||
|
||||
for idx, gen in enumerate(self.chat_generators):
|
||||
gen_name = gen.__class__.__name__
|
||||
try:
|
||||
result = await self._run_single_async(gen, messages, generation_kwargs, tools, streaming_callback)
|
||||
replies = result.get("replies", [])
|
||||
meta = dict(result.get("meta", {}))
|
||||
meta.update(
|
||||
{
|
||||
"successful_chat_generator_index": idx,
|
||||
"successful_chat_generator_class": gen_name,
|
||||
"total_attempts": idx + 1,
|
||||
"failed_chat_generators": failed,
|
||||
}
|
||||
)
|
||||
return {"replies": replies, "meta": meta}
|
||||
except Exception as e: # noqa: BLE001 - fallback logic should handle any exception
|
||||
logger.warning(
|
||||
"ChatGenerator {chat_generator} failed with error: {error}", chat_generator=gen_name, error=e
|
||||
)
|
||||
failed.append(gen_name)
|
||||
last_error = e
|
||||
|
||||
failed_names = ", ".join(failed)
|
||||
msg = (
|
||||
f"All {len(self.chat_generators)} chat generators failed. "
|
||||
f"Last error: {last_error}. Failed chat generators: [{failed_names}]"
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
@ -0,0 +1,6 @@
|
||||
---
|
||||
highlights: >
|
||||
Introduced `FallbackChatGenerator` that tries multiple chat providers one by one, improving reliability in production and making sure you get answers even when some provider fails.
|
||||
features:
|
||||
- |
|
||||
Added `FallbackChatGenerator` that automatically retries different chat generators and returns first successful response with detailed information about which providers were tried.
|
||||
356
test/components/generators/chat/test_fallback.py
Normal file
356
test/components/generators/chat/test_fallback.py
Normal file
@ -0,0 +1,356 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, Optional, Union
|
||||
from urllib.error import HTTPError as URLLibHTTPError
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict
|
||||
from haystack.components.generators.chat.fallback import FallbackChatGenerator
|
||||
from haystack.dataclasses import ChatMessage, StreamingCallbackT
|
||||
from haystack.tools import Tool, Toolset
|
||||
|
||||
|
||||
@component
|
||||
class _DummySuccessGen:
|
||||
def __init__(self, text: str = "ok", delay: float = 0.0, streaming_callback: Optional[StreamingCallbackT] = None):
|
||||
self.text = text
|
||||
self.delay = delay
|
||||
self.streaming_callback = streaming_callback
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return default_to_dict(self, text=self.text, delay=self.delay, streaming_callback=None)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "_DummySuccessGen":
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
generation_kwargs: Optional[dict[str, Any]] = None,
|
||||
tools: Optional[Union[list[Tool], Toolset]] = None,
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
) -> dict[str, Any]:
|
||||
if self.delay:
|
||||
time.sleep(self.delay)
|
||||
if streaming_callback:
|
||||
streaming_callback({"dummy": True}) # type: ignore[arg-type]
|
||||
return {"replies": [ChatMessage.from_assistant(self.text)], "meta": {"dummy_meta": True}}
|
||||
|
||||
async def run_async(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
generation_kwargs: Optional[dict[str, Any]] = None,
|
||||
tools: Optional[Union[list[Tool], Toolset]] = None,
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
) -> dict[str, Any]:
|
||||
if self.delay:
|
||||
await asyncio.sleep(self.delay)
|
||||
if streaming_callback:
|
||||
await asyncio.sleep(0)
|
||||
streaming_callback({"dummy": True}) # type: ignore[arg-type]
|
||||
return {"replies": [ChatMessage.from_assistant(self.text)], "meta": {"dummy_meta": True}}
|
||||
|
||||
|
||||
@component
|
||||
class _DummyFailGen:
|
||||
def __init__(self, exc: Optional[Exception] = None, delay: float = 0.0):
|
||||
self.exc = exc or RuntimeError("boom")
|
||||
self.delay = delay
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return default_to_dict(self, exc={"message": str(self.exc)}, delay=self.delay)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "_DummyFailGen":
|
||||
init = data.get("init_parameters", {})
|
||||
msg = None
|
||||
if isinstance(init.get("exc"), dict):
|
||||
msg = init.get("exc", {}).get("message")
|
||||
return cls(exc=RuntimeError(msg or "boom"), delay=init.get("delay", 0.0))
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
generation_kwargs: Optional[dict[str, Any]] = None,
|
||||
tools: Optional[Union[list[Tool], Toolset]] = None,
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
) -> dict[str, Any]:
|
||||
if self.delay:
|
||||
time.sleep(self.delay)
|
||||
raise self.exc
|
||||
|
||||
async def run_async(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
generation_kwargs: Optional[dict[str, Any]] = None,
|
||||
tools: Optional[Union[list[Tool], Toolset]] = None,
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
) -> dict[str, Any]:
|
||||
if self.delay:
|
||||
await asyncio.sleep(self.delay)
|
||||
raise self.exc
|
||||
|
||||
|
||||
def test_init_validation():
|
||||
with pytest.raises(ValueError):
|
||||
FallbackChatGenerator(chat_generators=[])
|
||||
|
||||
gen = FallbackChatGenerator(chat_generators=[_DummySuccessGen(text="A")])
|
||||
assert len(gen.chat_generators) == 1
|
||||
|
||||
|
||||
def test_sequential_first_success():
|
||||
gen = FallbackChatGenerator(chat_generators=[_DummySuccessGen(text="A")])
|
||||
res = gen.run([ChatMessage.from_user("hi")])
|
||||
assert res["replies"][0].text == "A"
|
||||
assert res["meta"]["successful_chat_generator_index"] == 0
|
||||
assert res["meta"]["total_attempts"] == 1
|
||||
|
||||
|
||||
def test_sequential_second_success_after_failure():
|
||||
gen = FallbackChatGenerator(chat_generators=[_DummyFailGen(), _DummySuccessGen(text="B")])
|
||||
res = gen.run([ChatMessage.from_user("hi")])
|
||||
assert res["replies"][0].text == "B"
|
||||
assert res["meta"]["successful_chat_generator_index"] == 1
|
||||
assert res["meta"]["failed_chat_generators"]
|
||||
|
||||
|
||||
def test_all_fail_raises():
|
||||
gen = FallbackChatGenerator(chat_generators=[_DummyFailGen(), _DummyFailGen()])
|
||||
with pytest.raises(RuntimeError):
|
||||
gen.run([ChatMessage.from_user("hi")])
|
||||
|
||||
|
||||
def test_timeout_handling_sync():
|
||||
slow = _DummySuccessGen(text="slow", delay=0.01)
|
||||
fast = _DummySuccessGen(text="fast", delay=0.0)
|
||||
gen = FallbackChatGenerator(chat_generators=[slow, fast])
|
||||
res = gen.run([ChatMessage.from_user("hi")])
|
||||
assert res["replies"][0].text == "slow"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_handling_async():
|
||||
slow = _DummySuccessGen(text="slow", delay=0.01)
|
||||
fast = _DummySuccessGen(text="fast", delay=0.0)
|
||||
gen = FallbackChatGenerator(chat_generators=[slow, fast])
|
||||
res = await gen.run_async([ChatMessage.from_user("hi")])
|
||||
assert res["replies"][0].text == "slow"
|
||||
|
||||
|
||||
def test_streaming_callback_forwarding_sync():
|
||||
calls: list[Any] = []
|
||||
|
||||
def cb(x: Any) -> None:
|
||||
calls.append(x)
|
||||
|
||||
gen = FallbackChatGenerator(chat_generators=[_DummySuccessGen(text="A")])
|
||||
_ = gen.run([ChatMessage.from_user("hi")], streaming_callback=cb)
|
||||
assert calls
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_callback_forwarding_async():
|
||||
calls: list[Any] = []
|
||||
|
||||
def cb(x: Any) -> None:
|
||||
calls.append(x)
|
||||
|
||||
gen = FallbackChatGenerator(chat_generators=[_DummySuccessGen(text="A")])
|
||||
_ = await gen.run_async([ChatMessage.from_user("hi")], streaming_callback=cb)
|
||||
assert calls
|
||||
|
||||
|
||||
def test_serialization_roundtrip():
|
||||
original = FallbackChatGenerator(chat_generators=[_DummySuccessGen(text="hello")])
|
||||
data = original.to_dict()
|
||||
restored = FallbackChatGenerator.from_dict(data)
|
||||
assert isinstance(restored, FallbackChatGenerator)
|
||||
assert len(restored.chat_generators) == 1
|
||||
res = restored.run([ChatMessage.from_user("hi")])
|
||||
assert res["replies"][0].text == "hello"
|
||||
|
||||
original = FallbackChatGenerator(chat_generators=[_DummySuccessGen(text="hello"), _DummySuccessGen(text="world")])
|
||||
data = original.to_dict()
|
||||
restored = FallbackChatGenerator.from_dict(data)
|
||||
assert isinstance(restored, FallbackChatGenerator)
|
||||
assert len(restored.chat_generators) == 2
|
||||
res = restored.run([ChatMessage.from_user("hi")])
|
||||
assert res["replies"][0].text == "hello"
|
||||
|
||||
|
||||
def test_automatic_completion_mode_without_streaming():
|
||||
gen = FallbackChatGenerator(chat_generators=[_DummySuccessGen(text="completion")])
|
||||
res = gen.run([ChatMessage.from_user("hi")])
|
||||
assert res["replies"][0].text == "completion"
|
||||
assert res["meta"]["successful_chat_generator_index"] == 0
|
||||
|
||||
|
||||
def test_automatic_ttft_mode_with_streaming():
|
||||
calls: list[Any] = []
|
||||
|
||||
def cb(x: Any) -> None:
|
||||
calls.append(x)
|
||||
|
||||
gen = FallbackChatGenerator(chat_generators=[_DummySuccessGen(text="streaming")])
|
||||
res = gen.run([ChatMessage.from_user("hi")], streaming_callback=cb)
|
||||
assert res["replies"][0].text == "streaming"
|
||||
assert calls
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_automatic_ttft_mode_with_streaming_async():
|
||||
calls: list[Any] = []
|
||||
|
||||
def cb(x: Any) -> None:
|
||||
calls.append(x)
|
||||
|
||||
gen = FallbackChatGenerator(chat_generators=[_DummySuccessGen(text="streaming_async")])
|
||||
res = await gen.run_async([ChatMessage.from_user("hi")], streaming_callback=cb)
|
||||
assert res["replies"][0].text == "streaming_async"
|
||||
assert calls
|
||||
|
||||
|
||||
def create_http_error(status_code: int, message: str) -> URLLibHTTPError:
|
||||
return URLLibHTTPError("", status_code, message, {}, None)
|
||||
|
||||
|
||||
@component
|
||||
class _DummyHTTPErrorGen:
|
||||
def __init__(self, text: str = "success", error: Optional[Exception] = None):
|
||||
self.text = text
|
||||
self.error = error
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return default_to_dict(self, text=self.text, error=str(self.error) if self.error else None)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "_DummyHTTPErrorGen":
|
||||
init = data.get("init_parameters", {})
|
||||
error = None
|
||||
if init.get("error"):
|
||||
error = RuntimeError(init["error"])
|
||||
return cls(text=init.get("text", "success"), error=error)
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
generation_kwargs: Optional[dict[str, Any]] = None,
|
||||
tools: Optional[Union[list[Tool], Toolset]] = None,
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
) -> dict[str, Any]:
|
||||
if self.error:
|
||||
raise self.error
|
||||
return {
|
||||
"replies": [ChatMessage.from_assistant(self.text)],
|
||||
"meta": {"error_type": type(self.error).__name__ if self.error else None},
|
||||
}
|
||||
|
||||
|
||||
def test_failover_trigger_429_rate_limit():
|
||||
rate_limit_gen = _DummyHTTPErrorGen(text="rate_limited", error=create_http_error(429, "Rate limit exceeded"))
|
||||
success_gen = _DummySuccessGen(text="success_after_rate_limit")
|
||||
|
||||
fallback = FallbackChatGenerator(chat_generators=[rate_limit_gen, success_gen])
|
||||
result = fallback.run([ChatMessage.from_user("test")])
|
||||
|
||||
assert result["replies"][0].text == "success_after_rate_limit"
|
||||
assert result["meta"]["successful_chat_generator_index"] == 1
|
||||
assert result["meta"]["failed_chat_generators"] == ["_DummyHTTPErrorGen"]
|
||||
|
||||
|
||||
def test_failover_trigger_401_authentication():
|
||||
auth_error_gen = _DummyHTTPErrorGen(text="auth_failed", error=create_http_error(401, "Authentication failed"))
|
||||
success_gen = _DummySuccessGen(text="success_after_auth")
|
||||
|
||||
fallback = FallbackChatGenerator(chat_generators=[auth_error_gen, success_gen])
|
||||
result = fallback.run([ChatMessage.from_user("test")])
|
||||
|
||||
assert result["replies"][0].text == "success_after_auth"
|
||||
assert result["meta"]["successful_chat_generator_index"] == 1
|
||||
assert result["meta"]["failed_chat_generators"] == ["_DummyHTTPErrorGen"]
|
||||
|
||||
|
||||
def test_failover_trigger_400_bad_request():
|
||||
bad_request_gen = _DummyHTTPErrorGen(text="bad_request", error=create_http_error(400, "Context length exceeded"))
|
||||
success_gen = _DummySuccessGen(text="success_after_bad_request")
|
||||
|
||||
fallback = FallbackChatGenerator(chat_generators=[bad_request_gen, success_gen])
|
||||
result = fallback.run([ChatMessage.from_user("test")])
|
||||
|
||||
assert result["replies"][0].text == "success_after_bad_request"
|
||||
assert result["meta"]["successful_chat_generator_index"] == 1
|
||||
assert result["meta"]["failed_chat_generators"] == ["_DummyHTTPErrorGen"]
|
||||
|
||||
|
||||
def test_failover_trigger_500_server_error():
|
||||
server_error_gen = _DummyHTTPErrorGen(text="server_error", error=create_http_error(500, "Internal server error"))
|
||||
success_gen = _DummySuccessGen(text="success_after_server_error")
|
||||
|
||||
fallback = FallbackChatGenerator(chat_generators=[server_error_gen, success_gen])
|
||||
result = fallback.run([ChatMessage.from_user("test")])
|
||||
|
||||
assert result["replies"][0].text == "success_after_server_error"
|
||||
assert result["meta"]["successful_chat_generator_index"] == 1
|
||||
assert result["meta"]["failed_chat_generators"] == ["_DummyHTTPErrorGen"]
|
||||
|
||||
|
||||
def test_failover_trigger_multiple_errors():
|
||||
rate_limit_gen = _DummyHTTPErrorGen(text="rate_limited", error=create_http_error(429, "Rate limit exceeded"))
|
||||
auth_error_gen = _DummyHTTPErrorGen(text="auth_failed", error=create_http_error(401, "Authentication failed"))
|
||||
server_error_gen = _DummyHTTPErrorGen(text="server_error", error=create_http_error(500, "Internal server error"))
|
||||
success_gen = _DummySuccessGen(text="success_after_all_errors")
|
||||
|
||||
fallback = FallbackChatGenerator(chat_generators=[rate_limit_gen, auth_error_gen, server_error_gen, success_gen])
|
||||
result = fallback.run([ChatMessage.from_user("test")])
|
||||
|
||||
assert result["replies"][0].text == "success_after_all_errors"
|
||||
assert result["meta"]["successful_chat_generator_index"] == 3
|
||||
assert len(result["meta"]["failed_chat_generators"]) == 3
|
||||
|
||||
|
||||
def test_failover_trigger_all_generators_fail():
|
||||
rate_limit_gen = _DummyHTTPErrorGen(text="rate_limited", error=create_http_error(429, "Rate limit exceeded"))
|
||||
auth_error_gen = _DummyHTTPErrorGen(text="auth_failed", error=create_http_error(401, "Authentication failed"))
|
||||
server_error_gen = _DummyHTTPErrorGen(text="server_error", error=create_http_error(500, "Internal server error"))
|
||||
|
||||
fallback = FallbackChatGenerator(chat_generators=[rate_limit_gen, auth_error_gen, server_error_gen])
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
fallback.run([ChatMessage.from_user("test")])
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "All 3 chat generators failed" in error_msg
|
||||
assert "Failed chat generators: [_DummyHTTPErrorGen, _DummyHTTPErrorGen, _DummyHTTPErrorGen]" in error_msg
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failover_trigger_429_rate_limit_async():
|
||||
rate_limit_gen = _DummyHTTPErrorGen(text="rate_limited", error=create_http_error(429, "Rate limit exceeded"))
|
||||
success_gen = _DummySuccessGen(text="success_after_rate_limit")
|
||||
|
||||
fallback = FallbackChatGenerator(chat_generators=[rate_limit_gen, success_gen])
|
||||
result = await fallback.run_async([ChatMessage.from_user("test")])
|
||||
|
||||
assert result["replies"][0].text == "success_after_rate_limit"
|
||||
assert result["meta"]["successful_chat_generator_index"] == 1
|
||||
assert result["meta"]["failed_chat_generators"] == ["_DummyHTTPErrorGen"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failover_trigger_401_authentication_async():
|
||||
auth_error_gen = _DummyHTTPErrorGen(text="auth_failed", error=create_http_error(401, "Authentication failed"))
|
||||
success_gen = _DummySuccessGen(text="success_after_auth")
|
||||
|
||||
fallback = FallbackChatGenerator(chat_generators=[auth_error_gen, success_gen])
|
||||
result = await fallback.run_async([ChatMessage.from_user("test")])
|
||||
|
||||
assert result["replies"][0].text == "success_after_auth"
|
||||
assert result["meta"]["successful_chat_generator_index"] == 1
|
||||
assert result["meta"]["failed_chat_generators"] == ["_DummyHTTPErrorGen"]
|
||||
Loading…
x
Reference in New Issue
Block a user