refactor: Refactor _save_pipeline_snapshot and _create_pipeline_snapshot to handle more exceptions (#9871)

* Refactor saving pipeline snapshot to handle the try-except inside and to cover more cases (e.g. try-excepts around our serialization logic)

* Add reno

* Fix

* Adding tests

* More tests

* small change

* fix test

* update docstrings
This commit is contained in:
Sebastian Husch Lee 2025-10-15 11:30:26 +02:00 committed by GitHub
parent 0cd297adc8
commit fe60c765d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 222 additions and 104 deletions

View File

@ -144,117 +144,131 @@ def load_pipeline_snapshot(file_path: Union[str, Path]) -> PipelineSnapshot:
return pipeline_snapshot
def _save_pipeline_snapshot_to_file(
*, pipeline_snapshot: PipelineSnapshot, snapshot_file_path: Union[str, Path], dt: datetime
) -> None:
def _save_pipeline_snapshot(pipeline_snapshot: PipelineSnapshot, raise_on_failure: bool = True) -> None:
"""
Save the pipeline snapshot dictionary to a JSON file.
- The filename is generated based on the component name, visit count, and timestamp.
- The component name is taken from the break point's `component_name`.
- The visit count is taken from the pipeline state's `component_visits` for the component name.
- The timestamp is taken from the pipeline snapshot's `timestamp` or the current time if not available.
- The file path is taken from the break point's `snapshot_file_path`.
- If the `snapshot_file_path` is None, the function will return without saving.
:param pipeline_snapshot: The pipeline snapshot to save.
:param snapshot_file_path: The path where to save the file.
:param dt: The datetime object for timestamping.
:param raise_on_failure: If True, raises an exception if saving fails. If False, logs the error and returns.
:raises:
ValueError: If the snapshot_file_path is not a string or a Path object.
Exception: If saving the JSON snapshot fails.
"""
snapshot_file_path = Path(snapshot_file_path) if isinstance(snapshot_file_path, str) else snapshot_file_path
if not isinstance(snapshot_file_path, Path):
raise ValueError("Debug path must be a string or a Path object.")
break_point = pipeline_snapshot.break_point
snapshot_file_path = (
break_point.break_point.snapshot_file_path
if isinstance(break_point, AgentBreakpoint)
else break_point.snapshot_file_path
)
snapshot_file_path.mkdir(exist_ok=True)
if snapshot_file_path is None:
return
dt = pipeline_snapshot.timestamp or datetime.now()
snapshot_dir = Path(snapshot_file_path)
# Generate filename
# We check if the agent_name is provided to differentiate between agent and non-agent breakpoints
if isinstance(pipeline_snapshot.break_point, AgentBreakpoint):
agent_name = pipeline_snapshot.break_point.agent_name
component_name = pipeline_snapshot.break_point.break_point.component_name
visit_nr = pipeline_snapshot.pipeline_state.component_visits.get(component_name, 0)
file_name = f"{agent_name}_{component_name}_{visit_nr}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json"
if isinstance(break_point, AgentBreakpoint):
agent_name = break_point.agent_name
component_name = break_point.break_point.component_name
else:
component_name = pipeline_snapshot.break_point.component_name
visit_nr = pipeline_snapshot.pipeline_state.component_visits.get(component_name, 0)
file_name = f"{component_name}_{visit_nr}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json"
component_name = break_point.component_name
agent_name = None
visit_nr = pipeline_snapshot.pipeline_state.component_visits.get(component_name, 0)
timestamp = dt.strftime("%Y_%m_%d_%H_%M_%S")
file_name = f"{agent_name + '_' if agent_name else ''}{component_name}_{visit_nr}_{timestamp}.json"
full_path = snapshot_dir / file_name
try:
with open(snapshot_file_path / file_name, "w") as f_out:
snapshot_dir.mkdir(parents=True, exist_ok=True)
with open(full_path, "w") as f_out:
json.dump(pipeline_snapshot.to_dict(), f_out, indent=2)
logger.info(f"Pipeline snapshot saved at: {file_name}")
except Exception as e:
logger.error(f"Failed to save pipeline snapshot: {str(e)}")
raise
logger.info(
"Pipeline snapshot saved to '{full_path}'. You can use this file to debug or resume the pipeline.",
full_path=full_path,
)
except Exception as error:
logger.error("Failed to save pipeline snapshot to '{full_path}'. Error: {e}", full_path=full_path, e=error)
if raise_on_failure:
raise
def _create_pipeline_snapshot(
*,
inputs: dict[str, Any],
component_inputs: dict[str, Any],
break_point: Union[AgentBreakpoint, Breakpoint],
component_visits: dict[str, int],
original_input_data: Optional[dict[str, Any]] = None,
ordered_component_names: Optional[list[str]] = None,
include_outputs_from: Optional[set[str]] = None,
pipeline_outputs: Optional[dict[str, Any]] = None,
original_input_data: dict[str, Any],
ordered_component_names: list[str],
include_outputs_from: set[str],
pipeline_outputs: dict[str, Any],
) -> PipelineSnapshot:
"""
Create a snapshot of the pipeline at the point where the breakpoint was triggered.
:param inputs: The current pipeline snapshot inputs.
:param component_inputs: The inputs to the component that triggered the breakpoint.
:param break_point: The breakpoint that triggered the snapshot, can be AgentBreakpoint or Breakpoint.
:param component_visits: The visit count of the component that triggered the breakpoint.
:param original_input_data: The original input data.
:param ordered_component_names: The ordered component names.
:param include_outputs_from: Set of component names whose outputs should be included in the pipeline results.
:param pipeline_outputs: The current outputs of the pipeline.
:returns:
A PipelineSnapshot containing the state of the pipeline at the point of the breakpoint.
"""
dt = datetime.now()
if isinstance(break_point, AgentBreakpoint):
component_name = break_point.agent_name
else:
component_name = break_point.component_name
transformed_original_input_data = _transform_json_structure(original_input_data)
transformed_inputs = _transform_json_structure(inputs)
transformed_inputs = _transform_json_structure({**inputs, component_name: component_inputs})
try:
serialized_inputs = _serialize_value_with_schema(transformed_inputs)
except Exception as error:
logger.warning(
"Failed to serialize the inputs of the current pipeline state. "
"The inputs in the snapshot will be replaced with an empty dictionary. Error: {e}",
e=error,
)
serialized_inputs = {}
try:
serialized_original_input_data = _serialize_value_with_schema(transformed_original_input_data)
except Exception as error:
logger.warning(
"Failed to serialize original input data for `pipeline.run`. "
"This likely occurred due to non-serializable object types. "
"The snapshot will store an empty dictionary instead. Error: {e}",
e=error,
)
serialized_original_input_data = {}
pipeline_snapshot = PipelineSnapshot(
pipeline_state=PipelineState(
inputs=_serialize_value_with_schema(transformed_inputs), # current pipeline inputs
component_visits=component_visits,
pipeline_outputs=pipeline_outputs or {},
inputs=serialized_inputs, component_visits=component_visits, pipeline_outputs=pipeline_outputs
),
timestamp=dt,
timestamp=datetime.now(),
break_point=break_point,
original_input_data=_serialize_value_with_schema(transformed_original_input_data),
ordered_component_names=ordered_component_names or [],
include_outputs_from=include_outputs_from or set(),
original_input_data=serialized_original_input_data,
ordered_component_names=ordered_component_names,
include_outputs_from=include_outputs_from,
)
return pipeline_snapshot
def _save_pipeline_snapshot(pipeline_snapshot: PipelineSnapshot) -> PipelineSnapshot:
"""
Save the pipeline snapshot to a file.
:param pipeline_snapshot: The pipeline snapshot to save.
:returns:
The dictionary containing the snapshot of the pipeline containing the following keys:
- input_data: The original input data passed to the pipeline.
- timestamp: The timestamp of the breakpoint.
- pipeline_breakpoint: The component name and visit count that triggered the breakpoint.
- pipeline_state: The state of the pipeline when the breakpoint was triggered containing the following keys:
- inputs: The current state of inputs for pipeline components.
- component_visits: The visit count of the components when the breakpoint was triggered.
- ordered_component_names: The order of components in the pipeline.
"""
break_point = pipeline_snapshot.break_point
if isinstance(break_point, AgentBreakpoint):
snapshot_file_path = break_point.break_point.snapshot_file_path
else:
snapshot_file_path = break_point.snapshot_file_path
if snapshot_file_path is not None:
dt = pipeline_snapshot.timestamp or datetime.now()
_save_pipeline_snapshot_to_file(
pipeline_snapshot=pipeline_snapshot, snapshot_file_path=snapshot_file_path, dt=dt
)
return pipeline_snapshot
def _transform_json_structure(data: Union[dict[str, Any], list[Any], Any]) -> Any:
"""
Transforms a JSON structure by removing the 'sender' key and moving the 'value' to the top level.

View File

@ -360,10 +360,9 @@ class Pipeline(PipelineBase):
and component_name == break_point.agent_name
)
if break_point and (component_break_point_triggered or agent_break_point_triggered):
pipeline_snapshot_inputs_serialised = deepcopy(inputs)
pipeline_snapshot_inputs_serialised[component_name] = deepcopy(component_inputs)
new_pipeline_snapshot = _create_pipeline_snapshot(
inputs=pipeline_snapshot_inputs_serialised,
inputs=deepcopy(inputs),
component_inputs=deepcopy(component_inputs),
break_point=break_point,
component_visits=component_visits,
original_input_data=data,
@ -378,7 +377,7 @@ class Pipeline(PipelineBase):
component_inputs["break_point"] = break_point
component_inputs["parent_snapshot"] = new_pipeline_snapshot
# trigger the breakpoint if needed
# trigger the break point if needed
if component_break_point_triggered:
_trigger_break_point(pipeline_snapshot=new_pipeline_snapshot)
@ -400,11 +399,10 @@ class Pipeline(PipelineBase):
snapshot_file_path=out_dir,
)
# Create a snapshot of the last good state of the pipeline before the error occurred.
pipeline_snapshot_inputs_serialised = deepcopy(inputs)
pipeline_snapshot_inputs_serialised[component_name] = deepcopy(component_inputs)
last_good_state_snapshot = _create_pipeline_snapshot(
inputs=pipeline_snapshot_inputs_serialised,
# Create a snapshot of the state of the pipeline before the error occurred.
pipeline_snapshot = _create_pipeline_snapshot(
inputs=deepcopy(inputs),
component_inputs=deepcopy(component_inputs),
break_point=break_point,
component_visits=component_visits,
original_input_data=data,
@ -417,23 +415,12 @@ class Pipeline(PipelineBase):
# We take the agent snapshot and attach it to the pipeline snapshot we create here.
# We also update the break_point to be an AgentBreakpoint.
if error.pipeline_snapshot and error.pipeline_snapshot.agent_snapshot:
last_good_state_snapshot.agent_snapshot = error.pipeline_snapshot.agent_snapshot
last_good_state_snapshot.break_point = error.pipeline_snapshot.agent_snapshot.break_point
pipeline_snapshot.agent_snapshot = error.pipeline_snapshot.agent_snapshot
pipeline_snapshot.break_point = error.pipeline_snapshot.agent_snapshot.break_point
# Attach the last good state snapshot to the error before re-raising it and saving to disk
error.pipeline_snapshot = last_good_state_snapshot
try:
_save_pipeline_snapshot(pipeline_snapshot=last_good_state_snapshot)
logger.info(
"Saved a snapshot of the pipeline's last valid state to '{out_path}'. "
"Review this snapshot to debug the error and resume the pipeline from here.",
out_path=out_dir,
)
except Exception as save_error:
logger.error(
"Failed to save a snapshot of the pipeline's last valid state with error: {e}", e=save_error
)
# Attach the pipeline snapshot to the error before re-raising
error.pipeline_snapshot = pipeline_snapshot
_save_pipeline_snapshot(pipeline_snapshot=pipeline_snapshot, raise_on_failure=False)
raise error
# Updates global input state with component outputs and returns outputs that should go to

View File

@ -189,12 +189,12 @@ class PipelineSnapshot:
"""
A dataclass to hold a snapshot of the pipeline at a specific point in time.
:param original_input_data: The original input data provided to the pipeline.
:param ordered_component_names: A list of component names in the order they were visited.
:param pipeline_state: The state of the pipeline at the time of the snapshot.
:param break_point: The breakpoint that triggered the snapshot.
:param agent_snapshot: Optional agent snapshot if the breakpoint is an agent breakpoint.
:param timestamp: A timestamp indicating when the snapshot was taken.
:param original_input_data: The original input data provided to the pipeline.
:param ordered_component_names: A list of component names in the order they were visited.
:param include_outputs_from: Set of component names whose outputs should be included in the pipeline results.
"""

View File

@ -0,0 +1,5 @@
---
enhancements:
- |
Refactored `_save_pipeline_snapshot` to consolidate try-except logic and added a `raise_on_failure` option to control whether save failures raise an exception or are logged.
`_create_pipeline_snapshot` now wraps `_serialize_value_with_schema` in try-except blocks to prevent failures from non-serializable pipeline inputs.

View File

@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
import json
import logging
import pytest
@ -10,12 +11,14 @@ from haystack import component
from haystack.core.errors import BreakpointException
from haystack.core.pipeline import Pipeline
from haystack.core.pipeline.breakpoint import (
_create_pipeline_snapshot,
_save_pipeline_snapshot,
_transform_json_structure,
_trigger_chat_generator_breakpoint,
_trigger_tool_invoker_breakpoint,
load_pipeline_snapshot,
)
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.dataclasses import ByteStream, ChatMessage, Document, ToolCall
from haystack.dataclasses.breakpoints import (
AgentBreakpoint,
AgentSnapshot,
@ -27,7 +30,7 @@ from haystack.dataclasses.breakpoints import (
@pytest.fixture
def make_pipeline_snapshot():
def make_pipeline_snapshot_with_agent_snapshot():
def _make(break_point: AgentBreakpoint) -> PipelineSnapshot:
return PipelineSnapshot(
break_point=break_point,
@ -144,8 +147,8 @@ def test_breakpoint_saves_intermediate_outputs(tmp_path):
assert loaded_snapshot.break_point.visit_count == 0
def test_trigger_tool_invoker_breakpoint(make_pipeline_snapshot):
pipeline_snapshot_with_agent_breakpoint = make_pipeline_snapshot(
def test_trigger_tool_invoker_breakpoint(make_pipeline_snapshot_with_agent_snapshot):
pipeline_snapshot_with_agent_breakpoint = make_pipeline_snapshot_with_agent_snapshot(
break_point=AgentBreakpoint("agent", ToolBreakpoint(component_name="tool_invoker"))
)
with pytest.raises(BreakpointException):
@ -155,8 +158,8 @@ def test_trigger_tool_invoker_breakpoint(make_pipeline_snapshot):
)
def test_trigger_tool_invoker_breakpoint_no_raise(make_pipeline_snapshot):
pipeline_snapshot_with_agent_breakpoint = make_pipeline_snapshot(
def test_trigger_tool_invoker_breakpoint_no_raise(make_pipeline_snapshot_with_agent_snapshot):
pipeline_snapshot_with_agent_breakpoint = make_pipeline_snapshot_with_agent_snapshot(
break_point=AgentBreakpoint("agent", ToolBreakpoint(component_name="tool_invoker", tool_name="tool2"))
)
# This should not raise since the tool call is for "tool1", not "tool2"
@ -166,12 +169,12 @@ def test_trigger_tool_invoker_breakpoint_no_raise(make_pipeline_snapshot):
)
def test_trigger_tool_invoker_breakpoint_specific_tool(make_pipeline_snapshot):
def test_trigger_tool_invoker_breakpoint_specific_tool(make_pipeline_snapshot_with_agent_snapshot):
"""
This is to test if a specific tool is set in the ToolBreakpoint, the BreakpointException is raised even when
there are multiple tool calls in the message.
"""
pipeline_snapshot_with_agent_breakpoint = make_pipeline_snapshot(
pipeline_snapshot_with_agent_breakpoint = make_pipeline_snapshot_with_agent_snapshot(
break_point=AgentBreakpoint("agent", ToolBreakpoint(component_name="tool_invoker", tool_name="tool2"))
)
with pytest.raises(BreakpointException):
@ -185,9 +188,118 @@ def test_trigger_tool_invoker_breakpoint_specific_tool(make_pipeline_snapshot):
)
def test_trigger_chat_generator_breakpoint(make_pipeline_snapshot):
pipeline_snapshot_with_agent_breakpoint = make_pipeline_snapshot(
def test_trigger_chat_generator_breakpoint(make_pipeline_snapshot_with_agent_snapshot):
pipeline_snapshot_with_agent_breakpoint = make_pipeline_snapshot_with_agent_snapshot(
break_point=AgentBreakpoint("agent", Breakpoint(component_name="chat_generator"))
)
with pytest.raises(BreakpointException):
_trigger_chat_generator_breakpoint(pipeline_snapshot=pipeline_snapshot_with_agent_breakpoint)
class TestCreatePipelineSnapshot:
def test_create_pipeline_snapshot_all_fields(self):
break_point = Breakpoint(component_name="comp2")
ordered_component_names = ["comp1", "comp2"]
include_outputs_from = {"comp1"}
snapshot = _create_pipeline_snapshot(
inputs={"comp1": {"input_value": [{"sender": None, "value": "test"}]}, "comp2": {}},
component_inputs={"input_value": "processed_test"},
break_point=break_point,
component_visits={"comp1": 1, "comp2": 0},
original_input_data={"comp1": {"input_value": "test"}},
ordered_component_names=ordered_component_names,
include_outputs_from=include_outputs_from,
pipeline_outputs={"comp1": {"result": "processed_test"}},
)
assert snapshot.original_input_data == {
"serialization_schema": {
"type": "object",
"properties": {"comp1": {"type": "object", "properties": {"input_value": {"type": "string"}}}},
},
"serialized_data": {"comp1": {"input_value": "test"}},
}
assert snapshot.ordered_component_names == ordered_component_names
assert snapshot.break_point == break_point
assert snapshot.agent_snapshot is None
assert snapshot.include_outputs_from == include_outputs_from
assert snapshot.pipeline_state == PipelineState(
inputs={
"serialization_schema": {
"type": "object",
"properties": {
"comp1": {"type": "object", "properties": {"input_value": {"type": "string"}}},
"comp2": {"type": "object", "properties": {"input_value": {"type": "string"}}},
},
},
"serialized_data": {"comp1": {"input_value": "test"}, "comp2": {"input_value": "processed_test"}},
},
component_visits={"comp1": 1, "comp2": 0},
pipeline_outputs={"comp1": {"result": "processed_test"}},
)
def test_create_pipeline_snapshot_with_dataclasses_in_pipeline_outputs(self):
snapshot = _create_pipeline_snapshot(
inputs={},
component_inputs={},
break_point=Breakpoint(component_name="comp2"),
component_visits={"comp1": 1, "comp2": 0},
original_input_data={},
ordered_component_names=["comp1", "comp2"],
include_outputs_from={"comp1"},
pipeline_outputs={"comp1": {"result": ChatMessage.from_user("hello")}},
)
assert snapshot.pipeline_state == PipelineState(
inputs={
"serialization_schema": {
"type": "object",
"properties": {"comp2": {"type": "object", "properties": {}}},
},
"serialized_data": {"comp2": {}},
},
component_visits={"comp1": 1, "comp2": 0},
pipeline_outputs={"comp1": {"result": ChatMessage.from_user("hello")}},
)
def test_create_pipeline_snapshot_non_serializable_inputs(self, caplog):
class NonSerializable:
def to_dict(self):
raise TypeError("Cannot serialize")
with caplog.at_level(logging.WARNING):
_create_pipeline_snapshot(
inputs={"comp1": {"input_value": [{"sender": None, "value": NonSerializable()}]}, "comp2": {}},
component_inputs={},
break_point=Breakpoint(component_name="comp2"),
component_visits={"comp1": 1, "comp2": 0},
original_input_data={"comp1": {"input_value": NonSerializable()}},
ordered_component_names=["comp1", "comp2"],
include_outputs_from={"comp1"},
pipeline_outputs={},
)
assert any("Failed to serialize the inputs of the current pipeline state" in msg for msg in caplog.messages)
assert any("Failed to serialize original input data for `pipeline.run`." in msg for msg in caplog.messages)
def test_save_pipeline_snapshot_raises_on_failure(tmp_path, caplog):
snapshot = _create_pipeline_snapshot(
inputs={},
component_inputs={},
break_point=Breakpoint(component_name="comp2", snapshot_file_path=str(tmp_path)),
component_visits={"comp1": 1, "comp2": 0},
original_input_data={},
ordered_component_names=["comp1", "comp2"],
include_outputs_from={"comp1"},
# We use a non-serializable type (bytes) directly in pipeline outputs to trigger the error
pipeline_outputs={"comp1": {"result": b"test"}},
)
with pytest.raises(TypeError):
_save_pipeline_snapshot(snapshot)
with caplog.at_level(logging.ERROR):
_save_pipeline_snapshot(snapshot, raise_on_failure=False)
assert any("Failed to save pipeline snapshot to" in msg for msg in caplog.messages)