From fe60c765d9a3c069f1ab9dfed236aa8e64b8c9b1 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> Date: Wed, 15 Oct 2025 11:30:26 +0200 Subject: [PATCH] refactor: Refactor `_save_pipeline_snapshot` and `_create_pipeline_snapshot` to handle more exceptions (#9871) * Refactor saving pipeline snapshot to handle the try-except inside and to cover more cases (e.g. try-excepts around our serialization logic) * Add reno * Fix * Adding tests * More tests * small change * fix test * update docstrings --- haystack/core/pipeline/breakpoint.py | 148 ++++++++++-------- haystack/core/pipeline/pipeline.py | 37 ++--- haystack/dataclasses/breakpoints.py | 4 +- ...ng-pipeline-snapshot-6083f7b85d4a927d.yaml | 5 + test/core/pipeline/test_breakpoint.py | 132 ++++++++++++++-- 5 files changed, 222 insertions(+), 104 deletions(-) create mode 100644 releasenotes/notes/refactor-saving-pipeline-snapshot-6083f7b85d4a927d.yaml diff --git a/haystack/core/pipeline/breakpoint.py b/haystack/core/pipeline/breakpoint.py index be4317d16..54b8b81b8 100644 --- a/haystack/core/pipeline/breakpoint.py +++ b/haystack/core/pipeline/breakpoint.py @@ -144,117 +144,131 @@ def load_pipeline_snapshot(file_path: Union[str, Path]) -> PipelineSnapshot: return pipeline_snapshot -def _save_pipeline_snapshot_to_file( - *, pipeline_snapshot: PipelineSnapshot, snapshot_file_path: Union[str, Path], dt: datetime -) -> None: +def _save_pipeline_snapshot(pipeline_snapshot: PipelineSnapshot, raise_on_failure: bool = True) -> None: """ Save the pipeline snapshot dictionary to a JSON file. + - The filename is generated based on the component name, visit count, and timestamp. + - The component name is taken from the break point's `component_name`. + - The visit count is taken from the pipeline state's `component_visits` for the component name. + - The timestamp is taken from the pipeline snapshot's `timestamp` or the current time if not available. + - The file path is taken from the break point's `snapshot_file_path`. + - If the `snapshot_file_path` is None, the function will return without saving. + :param pipeline_snapshot: The pipeline snapshot to save. - :param snapshot_file_path: The path where to save the file. - :param dt: The datetime object for timestamping. + :param raise_on_failure: If True, raises an exception if saving fails. If False, logs the error and returns. + :raises: - ValueError: If the snapshot_file_path is not a string or a Path object. Exception: If saving the JSON snapshot fails. """ - snapshot_file_path = Path(snapshot_file_path) if isinstance(snapshot_file_path, str) else snapshot_file_path - if not isinstance(snapshot_file_path, Path): - raise ValueError("Debug path must be a string or a Path object.") + break_point = pipeline_snapshot.break_point + snapshot_file_path = ( + break_point.break_point.snapshot_file_path + if isinstance(break_point, AgentBreakpoint) + else break_point.snapshot_file_path + ) - snapshot_file_path.mkdir(exist_ok=True) + if snapshot_file_path is None: + return + + dt = pipeline_snapshot.timestamp or datetime.now() + snapshot_dir = Path(snapshot_file_path) # Generate filename # We check if the agent_name is provided to differentiate between agent and non-agent breakpoints - if isinstance(pipeline_snapshot.break_point, AgentBreakpoint): - agent_name = pipeline_snapshot.break_point.agent_name - component_name = pipeline_snapshot.break_point.break_point.component_name - visit_nr = pipeline_snapshot.pipeline_state.component_visits.get(component_name, 0) - file_name = f"{agent_name}_{component_name}_{visit_nr}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json" + if isinstance(break_point, AgentBreakpoint): + agent_name = break_point.agent_name + component_name = break_point.break_point.component_name else: - component_name = pipeline_snapshot.break_point.component_name - visit_nr = pipeline_snapshot.pipeline_state.component_visits.get(component_name, 0) - file_name = f"{component_name}_{visit_nr}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json" + component_name = break_point.component_name + agent_name = None + + visit_nr = pipeline_snapshot.pipeline_state.component_visits.get(component_name, 0) + timestamp = dt.strftime("%Y_%m_%d_%H_%M_%S") + file_name = f"{agent_name + '_' if agent_name else ''}{component_name}_{visit_nr}_{timestamp}.json" + full_path = snapshot_dir / file_name try: - with open(snapshot_file_path / file_name, "w") as f_out: + snapshot_dir.mkdir(parents=True, exist_ok=True) + with open(full_path, "w") as f_out: json.dump(pipeline_snapshot.to_dict(), f_out, indent=2) - logger.info(f"Pipeline snapshot saved at: {file_name}") - except Exception as e: - logger.error(f"Failed to save pipeline snapshot: {str(e)}") - raise + logger.info( + "Pipeline snapshot saved to '{full_path}'. You can use this file to debug or resume the pipeline.", + full_path=full_path, + ) + except Exception as error: + logger.error("Failed to save pipeline snapshot to '{full_path}'. Error: {e}", full_path=full_path, e=error) + if raise_on_failure: + raise def _create_pipeline_snapshot( *, inputs: dict[str, Any], + component_inputs: dict[str, Any], break_point: Union[AgentBreakpoint, Breakpoint], component_visits: dict[str, int], - original_input_data: Optional[dict[str, Any]] = None, - ordered_component_names: Optional[list[str]] = None, - include_outputs_from: Optional[set[str]] = None, - pipeline_outputs: Optional[dict[str, Any]] = None, + original_input_data: dict[str, Any], + ordered_component_names: list[str], + include_outputs_from: set[str], + pipeline_outputs: dict[str, Any], ) -> PipelineSnapshot: """ Create a snapshot of the pipeline at the point where the breakpoint was triggered. :param inputs: The current pipeline snapshot inputs. + :param component_inputs: The inputs to the component that triggered the breakpoint. :param break_point: The breakpoint that triggered the snapshot, can be AgentBreakpoint or Breakpoint. :param component_visits: The visit count of the component that triggered the breakpoint. :param original_input_data: The original input data. :param ordered_component_names: The ordered component names. :param include_outputs_from: Set of component names whose outputs should be included in the pipeline results. + :param pipeline_outputs: The current outputs of the pipeline. + :returns: + A PipelineSnapshot containing the state of the pipeline at the point of the breakpoint. """ - dt = datetime.now() + if isinstance(break_point, AgentBreakpoint): + component_name = break_point.agent_name + else: + component_name = break_point.component_name transformed_original_input_data = _transform_json_structure(original_input_data) - transformed_inputs = _transform_json_structure(inputs) + transformed_inputs = _transform_json_structure({**inputs, component_name: component_inputs}) + + try: + serialized_inputs = _serialize_value_with_schema(transformed_inputs) + except Exception as error: + logger.warning( + "Failed to serialize the inputs of the current pipeline state. " + "The inputs in the snapshot will be replaced with an empty dictionary. Error: {e}", + e=error, + ) + serialized_inputs = {} + + try: + serialized_original_input_data = _serialize_value_with_schema(transformed_original_input_data) + except Exception as error: + logger.warning( + "Failed to serialize original input data for `pipeline.run`. " + "This likely occurred due to non-serializable object types. " + "The snapshot will store an empty dictionary instead. Error: {e}", + e=error, + ) + serialized_original_input_data = {} pipeline_snapshot = PipelineSnapshot( pipeline_state=PipelineState( - inputs=_serialize_value_with_schema(transformed_inputs), # current pipeline inputs - component_visits=component_visits, - pipeline_outputs=pipeline_outputs or {}, + inputs=serialized_inputs, component_visits=component_visits, pipeline_outputs=pipeline_outputs ), - timestamp=dt, + timestamp=datetime.now(), break_point=break_point, - original_input_data=_serialize_value_with_schema(transformed_original_input_data), - ordered_component_names=ordered_component_names or [], - include_outputs_from=include_outputs_from or set(), + original_input_data=serialized_original_input_data, + ordered_component_names=ordered_component_names, + include_outputs_from=include_outputs_from, ) return pipeline_snapshot -def _save_pipeline_snapshot(pipeline_snapshot: PipelineSnapshot) -> PipelineSnapshot: - """ - Save the pipeline snapshot to a file. - - :param pipeline_snapshot: The pipeline snapshot to save. - - :returns: - The dictionary containing the snapshot of the pipeline containing the following keys: - - input_data: The original input data passed to the pipeline. - - timestamp: The timestamp of the breakpoint. - - pipeline_breakpoint: The component name and visit count that triggered the breakpoint. - - pipeline_state: The state of the pipeline when the breakpoint was triggered containing the following keys: - - inputs: The current state of inputs for pipeline components. - - component_visits: The visit count of the components when the breakpoint was triggered. - - ordered_component_names: The order of components in the pipeline. - """ - break_point = pipeline_snapshot.break_point - if isinstance(break_point, AgentBreakpoint): - snapshot_file_path = break_point.break_point.snapshot_file_path - else: - snapshot_file_path = break_point.snapshot_file_path - - if snapshot_file_path is not None: - dt = pipeline_snapshot.timestamp or datetime.now() - _save_pipeline_snapshot_to_file( - pipeline_snapshot=pipeline_snapshot, snapshot_file_path=snapshot_file_path, dt=dt - ) - - return pipeline_snapshot - - def _transform_json_structure(data: Union[dict[str, Any], list[Any], Any]) -> Any: """ Transforms a JSON structure by removing the 'sender' key and moving the 'value' to the top level. diff --git a/haystack/core/pipeline/pipeline.py b/haystack/core/pipeline/pipeline.py index cfa822038..655f1d2aa 100644 --- a/haystack/core/pipeline/pipeline.py +++ b/haystack/core/pipeline/pipeline.py @@ -360,10 +360,9 @@ class Pipeline(PipelineBase): and component_name == break_point.agent_name ) if break_point and (component_break_point_triggered or agent_break_point_triggered): - pipeline_snapshot_inputs_serialised = deepcopy(inputs) - pipeline_snapshot_inputs_serialised[component_name] = deepcopy(component_inputs) new_pipeline_snapshot = _create_pipeline_snapshot( - inputs=pipeline_snapshot_inputs_serialised, + inputs=deepcopy(inputs), + component_inputs=deepcopy(component_inputs), break_point=break_point, component_visits=component_visits, original_input_data=data, @@ -378,7 +377,7 @@ class Pipeline(PipelineBase): component_inputs["break_point"] = break_point component_inputs["parent_snapshot"] = new_pipeline_snapshot - # trigger the breakpoint if needed + # trigger the break point if needed if component_break_point_triggered: _trigger_break_point(pipeline_snapshot=new_pipeline_snapshot) @@ -400,11 +399,10 @@ class Pipeline(PipelineBase): snapshot_file_path=out_dir, ) - # Create a snapshot of the last good state of the pipeline before the error occurred. - pipeline_snapshot_inputs_serialised = deepcopy(inputs) - pipeline_snapshot_inputs_serialised[component_name] = deepcopy(component_inputs) - last_good_state_snapshot = _create_pipeline_snapshot( - inputs=pipeline_snapshot_inputs_serialised, + # Create a snapshot of the state of the pipeline before the error occurred. + pipeline_snapshot = _create_pipeline_snapshot( + inputs=deepcopy(inputs), + component_inputs=deepcopy(component_inputs), break_point=break_point, component_visits=component_visits, original_input_data=data, @@ -417,23 +415,12 @@ class Pipeline(PipelineBase): # We take the agent snapshot and attach it to the pipeline snapshot we create here. # We also update the break_point to be an AgentBreakpoint. if error.pipeline_snapshot and error.pipeline_snapshot.agent_snapshot: - last_good_state_snapshot.agent_snapshot = error.pipeline_snapshot.agent_snapshot - last_good_state_snapshot.break_point = error.pipeline_snapshot.agent_snapshot.break_point + pipeline_snapshot.agent_snapshot = error.pipeline_snapshot.agent_snapshot + pipeline_snapshot.break_point = error.pipeline_snapshot.agent_snapshot.break_point - # Attach the last good state snapshot to the error before re-raising it and saving to disk - error.pipeline_snapshot = last_good_state_snapshot - - try: - _save_pipeline_snapshot(pipeline_snapshot=last_good_state_snapshot) - logger.info( - "Saved a snapshot of the pipeline's last valid state to '{out_path}'. " - "Review this snapshot to debug the error and resume the pipeline from here.", - out_path=out_dir, - ) - except Exception as save_error: - logger.error( - "Failed to save a snapshot of the pipeline's last valid state with error: {e}", e=save_error - ) + # Attach the pipeline snapshot to the error before re-raising + error.pipeline_snapshot = pipeline_snapshot + _save_pipeline_snapshot(pipeline_snapshot=pipeline_snapshot, raise_on_failure=False) raise error # Updates global input state with component outputs and returns outputs that should go to diff --git a/haystack/dataclasses/breakpoints.py b/haystack/dataclasses/breakpoints.py index f76c7e4b4..923f35382 100644 --- a/haystack/dataclasses/breakpoints.py +++ b/haystack/dataclasses/breakpoints.py @@ -189,12 +189,12 @@ class PipelineSnapshot: """ A dataclass to hold a snapshot of the pipeline at a specific point in time. + :param original_input_data: The original input data provided to the pipeline. + :param ordered_component_names: A list of component names in the order they were visited. :param pipeline_state: The state of the pipeline at the time of the snapshot. :param break_point: The breakpoint that triggered the snapshot. :param agent_snapshot: Optional agent snapshot if the breakpoint is an agent breakpoint. :param timestamp: A timestamp indicating when the snapshot was taken. - :param original_input_data: The original input data provided to the pipeline. - :param ordered_component_names: A list of component names in the order they were visited. :param include_outputs_from: Set of component names whose outputs should be included in the pipeline results. """ diff --git a/releasenotes/notes/refactor-saving-pipeline-snapshot-6083f7b85d4a927d.yaml b/releasenotes/notes/refactor-saving-pipeline-snapshot-6083f7b85d4a927d.yaml new file mode 100644 index 000000000..f1dbc1219 --- /dev/null +++ b/releasenotes/notes/refactor-saving-pipeline-snapshot-6083f7b85d4a927d.yaml @@ -0,0 +1,5 @@ +--- +enhancements: + - | + Refactored `_save_pipeline_snapshot` to consolidate try-except logic and added a `raise_on_failure` option to control whether save failures raise an exception or are logged. + `_create_pipeline_snapshot` now wraps `_serialize_value_with_schema` in try-except blocks to prevent failures from non-serializable pipeline inputs. diff --git a/test/core/pipeline/test_breakpoint.py b/test/core/pipeline/test_breakpoint.py index 099e8124a..54eb80a4e 100644 --- a/test/core/pipeline/test_breakpoint.py +++ b/test/core/pipeline/test_breakpoint.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import json +import logging import pytest @@ -10,12 +11,14 @@ from haystack import component from haystack.core.errors import BreakpointException from haystack.core.pipeline import Pipeline from haystack.core.pipeline.breakpoint import ( + _create_pipeline_snapshot, + _save_pipeline_snapshot, _transform_json_structure, _trigger_chat_generator_breakpoint, _trigger_tool_invoker_breakpoint, load_pipeline_snapshot, ) -from haystack.dataclasses import ChatMessage, ToolCall +from haystack.dataclasses import ByteStream, ChatMessage, Document, ToolCall from haystack.dataclasses.breakpoints import ( AgentBreakpoint, AgentSnapshot, @@ -27,7 +30,7 @@ from haystack.dataclasses.breakpoints import ( @pytest.fixture -def make_pipeline_snapshot(): +def make_pipeline_snapshot_with_agent_snapshot(): def _make(break_point: AgentBreakpoint) -> PipelineSnapshot: return PipelineSnapshot( break_point=break_point, @@ -144,8 +147,8 @@ def test_breakpoint_saves_intermediate_outputs(tmp_path): assert loaded_snapshot.break_point.visit_count == 0 -def test_trigger_tool_invoker_breakpoint(make_pipeline_snapshot): - pipeline_snapshot_with_agent_breakpoint = make_pipeline_snapshot( +def test_trigger_tool_invoker_breakpoint(make_pipeline_snapshot_with_agent_snapshot): + pipeline_snapshot_with_agent_breakpoint = make_pipeline_snapshot_with_agent_snapshot( break_point=AgentBreakpoint("agent", ToolBreakpoint(component_name="tool_invoker")) ) with pytest.raises(BreakpointException): @@ -155,8 +158,8 @@ def test_trigger_tool_invoker_breakpoint(make_pipeline_snapshot): ) -def test_trigger_tool_invoker_breakpoint_no_raise(make_pipeline_snapshot): - pipeline_snapshot_with_agent_breakpoint = make_pipeline_snapshot( +def test_trigger_tool_invoker_breakpoint_no_raise(make_pipeline_snapshot_with_agent_snapshot): + pipeline_snapshot_with_agent_breakpoint = make_pipeline_snapshot_with_agent_snapshot( break_point=AgentBreakpoint("agent", ToolBreakpoint(component_name="tool_invoker", tool_name="tool2")) ) # This should not raise since the tool call is for "tool1", not "tool2" @@ -166,12 +169,12 @@ def test_trigger_tool_invoker_breakpoint_no_raise(make_pipeline_snapshot): ) -def test_trigger_tool_invoker_breakpoint_specific_tool(make_pipeline_snapshot): +def test_trigger_tool_invoker_breakpoint_specific_tool(make_pipeline_snapshot_with_agent_snapshot): """ This is to test if a specific tool is set in the ToolBreakpoint, the BreakpointException is raised even when there are multiple tool calls in the message. """ - pipeline_snapshot_with_agent_breakpoint = make_pipeline_snapshot( + pipeline_snapshot_with_agent_breakpoint = make_pipeline_snapshot_with_agent_snapshot( break_point=AgentBreakpoint("agent", ToolBreakpoint(component_name="tool_invoker", tool_name="tool2")) ) with pytest.raises(BreakpointException): @@ -185,9 +188,118 @@ def test_trigger_tool_invoker_breakpoint_specific_tool(make_pipeline_snapshot): ) -def test_trigger_chat_generator_breakpoint(make_pipeline_snapshot): - pipeline_snapshot_with_agent_breakpoint = make_pipeline_snapshot( +def test_trigger_chat_generator_breakpoint(make_pipeline_snapshot_with_agent_snapshot): + pipeline_snapshot_with_agent_breakpoint = make_pipeline_snapshot_with_agent_snapshot( break_point=AgentBreakpoint("agent", Breakpoint(component_name="chat_generator")) ) with pytest.raises(BreakpointException): _trigger_chat_generator_breakpoint(pipeline_snapshot=pipeline_snapshot_with_agent_breakpoint) + + +class TestCreatePipelineSnapshot: + def test_create_pipeline_snapshot_all_fields(self): + break_point = Breakpoint(component_name="comp2") + ordered_component_names = ["comp1", "comp2"] + include_outputs_from = {"comp1"} + + snapshot = _create_pipeline_snapshot( + inputs={"comp1": {"input_value": [{"sender": None, "value": "test"}]}, "comp2": {}}, + component_inputs={"input_value": "processed_test"}, + break_point=break_point, + component_visits={"comp1": 1, "comp2": 0}, + original_input_data={"comp1": {"input_value": "test"}}, + ordered_component_names=ordered_component_names, + include_outputs_from=include_outputs_from, + pipeline_outputs={"comp1": {"result": "processed_test"}}, + ) + + assert snapshot.original_input_data == { + "serialization_schema": { + "type": "object", + "properties": {"comp1": {"type": "object", "properties": {"input_value": {"type": "string"}}}}, + }, + "serialized_data": {"comp1": {"input_value": "test"}}, + } + assert snapshot.ordered_component_names == ordered_component_names + assert snapshot.break_point == break_point + assert snapshot.agent_snapshot is None + assert snapshot.include_outputs_from == include_outputs_from + assert snapshot.pipeline_state == PipelineState( + inputs={ + "serialization_schema": { + "type": "object", + "properties": { + "comp1": {"type": "object", "properties": {"input_value": {"type": "string"}}}, + "comp2": {"type": "object", "properties": {"input_value": {"type": "string"}}}, + }, + }, + "serialized_data": {"comp1": {"input_value": "test"}, "comp2": {"input_value": "processed_test"}}, + }, + component_visits={"comp1": 1, "comp2": 0}, + pipeline_outputs={"comp1": {"result": "processed_test"}}, + ) + + def test_create_pipeline_snapshot_with_dataclasses_in_pipeline_outputs(self): + snapshot = _create_pipeline_snapshot( + inputs={}, + component_inputs={}, + break_point=Breakpoint(component_name="comp2"), + component_visits={"comp1": 1, "comp2": 0}, + original_input_data={}, + ordered_component_names=["comp1", "comp2"], + include_outputs_from={"comp1"}, + pipeline_outputs={"comp1": {"result": ChatMessage.from_user("hello")}}, + ) + + assert snapshot.pipeline_state == PipelineState( + inputs={ + "serialization_schema": { + "type": "object", + "properties": {"comp2": {"type": "object", "properties": {}}}, + }, + "serialized_data": {"comp2": {}}, + }, + component_visits={"comp1": 1, "comp2": 0}, + pipeline_outputs={"comp1": {"result": ChatMessage.from_user("hello")}}, + ) + + def test_create_pipeline_snapshot_non_serializable_inputs(self, caplog): + class NonSerializable: + def to_dict(self): + raise TypeError("Cannot serialize") + + with caplog.at_level(logging.WARNING): + _create_pipeline_snapshot( + inputs={"comp1": {"input_value": [{"sender": None, "value": NonSerializable()}]}, "comp2": {}}, + component_inputs={}, + break_point=Breakpoint(component_name="comp2"), + component_visits={"comp1": 1, "comp2": 0}, + original_input_data={"comp1": {"input_value": NonSerializable()}}, + ordered_component_names=["comp1", "comp2"], + include_outputs_from={"comp1"}, + pipeline_outputs={}, + ) + + assert any("Failed to serialize the inputs of the current pipeline state" in msg for msg in caplog.messages) + assert any("Failed to serialize original input data for `pipeline.run`." in msg for msg in caplog.messages) + + +def test_save_pipeline_snapshot_raises_on_failure(tmp_path, caplog): + snapshot = _create_pipeline_snapshot( + inputs={}, + component_inputs={}, + break_point=Breakpoint(component_name="comp2", snapshot_file_path=str(tmp_path)), + component_visits={"comp1": 1, "comp2": 0}, + original_input_data={}, + ordered_component_names=["comp1", "comp2"], + include_outputs_from={"comp1"}, + # We use a non-serializable type (bytes) directly in pipeline outputs to trigger the error + pipeline_outputs={"comp1": {"result": b"test"}}, + ) + + with pytest.raises(TypeError): + _save_pipeline_snapshot(snapshot) + + with caplog.at_level(logging.ERROR): + _save_pipeline_snapshot(snapshot, raise_on_failure=False) + assert any("Failed to save pipeline snapshot to" in msg for msg in caplog.messages)