mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-04 10:58:45 +00:00
refactor: Refactor _save_pipeline_snapshot and _create_pipeline_snapshot to handle more exceptions (#9871)
* Refactor saving pipeline snapshot to handle the try-except inside and to cover more cases (e.g. try-excepts around our serialization logic) * Add reno * Fix * Adding tests * More tests * small change * fix test * update docstrings
This commit is contained in:
parent
0cd297adc8
commit
fe60c765d9
@ -144,117 +144,131 @@ def load_pipeline_snapshot(file_path: Union[str, Path]) -> PipelineSnapshot:
|
||||
return pipeline_snapshot
|
||||
|
||||
|
||||
def _save_pipeline_snapshot_to_file(
|
||||
*, pipeline_snapshot: PipelineSnapshot, snapshot_file_path: Union[str, Path], dt: datetime
|
||||
) -> None:
|
||||
def _save_pipeline_snapshot(pipeline_snapshot: PipelineSnapshot, raise_on_failure: bool = True) -> None:
|
||||
"""
|
||||
Save the pipeline snapshot dictionary to a JSON file.
|
||||
|
||||
- The filename is generated based on the component name, visit count, and timestamp.
|
||||
- The component name is taken from the break point's `component_name`.
|
||||
- The visit count is taken from the pipeline state's `component_visits` for the component name.
|
||||
- The timestamp is taken from the pipeline snapshot's `timestamp` or the current time if not available.
|
||||
- The file path is taken from the break point's `snapshot_file_path`.
|
||||
- If the `snapshot_file_path` is None, the function will return without saving.
|
||||
|
||||
:param pipeline_snapshot: The pipeline snapshot to save.
|
||||
:param snapshot_file_path: The path where to save the file.
|
||||
:param dt: The datetime object for timestamping.
|
||||
:param raise_on_failure: If True, raises an exception if saving fails. If False, logs the error and returns.
|
||||
|
||||
:raises:
|
||||
ValueError: If the snapshot_file_path is not a string or a Path object.
|
||||
Exception: If saving the JSON snapshot fails.
|
||||
"""
|
||||
snapshot_file_path = Path(snapshot_file_path) if isinstance(snapshot_file_path, str) else snapshot_file_path
|
||||
if not isinstance(snapshot_file_path, Path):
|
||||
raise ValueError("Debug path must be a string or a Path object.")
|
||||
break_point = pipeline_snapshot.break_point
|
||||
snapshot_file_path = (
|
||||
break_point.break_point.snapshot_file_path
|
||||
if isinstance(break_point, AgentBreakpoint)
|
||||
else break_point.snapshot_file_path
|
||||
)
|
||||
|
||||
snapshot_file_path.mkdir(exist_ok=True)
|
||||
if snapshot_file_path is None:
|
||||
return
|
||||
|
||||
dt = pipeline_snapshot.timestamp or datetime.now()
|
||||
snapshot_dir = Path(snapshot_file_path)
|
||||
|
||||
# Generate filename
|
||||
# We check if the agent_name is provided to differentiate between agent and non-agent breakpoints
|
||||
if isinstance(pipeline_snapshot.break_point, AgentBreakpoint):
|
||||
agent_name = pipeline_snapshot.break_point.agent_name
|
||||
component_name = pipeline_snapshot.break_point.break_point.component_name
|
||||
visit_nr = pipeline_snapshot.pipeline_state.component_visits.get(component_name, 0)
|
||||
file_name = f"{agent_name}_{component_name}_{visit_nr}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json"
|
||||
if isinstance(break_point, AgentBreakpoint):
|
||||
agent_name = break_point.agent_name
|
||||
component_name = break_point.break_point.component_name
|
||||
else:
|
||||
component_name = pipeline_snapshot.break_point.component_name
|
||||
visit_nr = pipeline_snapshot.pipeline_state.component_visits.get(component_name, 0)
|
||||
file_name = f"{component_name}_{visit_nr}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json"
|
||||
component_name = break_point.component_name
|
||||
agent_name = None
|
||||
|
||||
visit_nr = pipeline_snapshot.pipeline_state.component_visits.get(component_name, 0)
|
||||
timestamp = dt.strftime("%Y_%m_%d_%H_%M_%S")
|
||||
file_name = f"{agent_name + '_' if agent_name else ''}{component_name}_{visit_nr}_{timestamp}.json"
|
||||
full_path = snapshot_dir / file_name
|
||||
|
||||
try:
|
||||
with open(snapshot_file_path / file_name, "w") as f_out:
|
||||
snapshot_dir.mkdir(parents=True, exist_ok=True)
|
||||
with open(full_path, "w") as f_out:
|
||||
json.dump(pipeline_snapshot.to_dict(), f_out, indent=2)
|
||||
logger.info(f"Pipeline snapshot saved at: {file_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save pipeline snapshot: {str(e)}")
|
||||
raise
|
||||
logger.info(
|
||||
"Pipeline snapshot saved to '{full_path}'. You can use this file to debug or resume the pipeline.",
|
||||
full_path=full_path,
|
||||
)
|
||||
except Exception as error:
|
||||
logger.error("Failed to save pipeline snapshot to '{full_path}'. Error: {e}", full_path=full_path, e=error)
|
||||
if raise_on_failure:
|
||||
raise
|
||||
|
||||
|
||||
def _create_pipeline_snapshot(
|
||||
*,
|
||||
inputs: dict[str, Any],
|
||||
component_inputs: dict[str, Any],
|
||||
break_point: Union[AgentBreakpoint, Breakpoint],
|
||||
component_visits: dict[str, int],
|
||||
original_input_data: Optional[dict[str, Any]] = None,
|
||||
ordered_component_names: Optional[list[str]] = None,
|
||||
include_outputs_from: Optional[set[str]] = None,
|
||||
pipeline_outputs: Optional[dict[str, Any]] = None,
|
||||
original_input_data: dict[str, Any],
|
||||
ordered_component_names: list[str],
|
||||
include_outputs_from: set[str],
|
||||
pipeline_outputs: dict[str, Any],
|
||||
) -> PipelineSnapshot:
|
||||
"""
|
||||
Create a snapshot of the pipeline at the point where the breakpoint was triggered.
|
||||
|
||||
:param inputs: The current pipeline snapshot inputs.
|
||||
:param component_inputs: The inputs to the component that triggered the breakpoint.
|
||||
:param break_point: The breakpoint that triggered the snapshot, can be AgentBreakpoint or Breakpoint.
|
||||
:param component_visits: The visit count of the component that triggered the breakpoint.
|
||||
:param original_input_data: The original input data.
|
||||
:param ordered_component_names: The ordered component names.
|
||||
:param include_outputs_from: Set of component names whose outputs should be included in the pipeline results.
|
||||
:param pipeline_outputs: The current outputs of the pipeline.
|
||||
:returns:
|
||||
A PipelineSnapshot containing the state of the pipeline at the point of the breakpoint.
|
||||
"""
|
||||
dt = datetime.now()
|
||||
if isinstance(break_point, AgentBreakpoint):
|
||||
component_name = break_point.agent_name
|
||||
else:
|
||||
component_name = break_point.component_name
|
||||
|
||||
transformed_original_input_data = _transform_json_structure(original_input_data)
|
||||
transformed_inputs = _transform_json_structure(inputs)
|
||||
transformed_inputs = _transform_json_structure({**inputs, component_name: component_inputs})
|
||||
|
||||
try:
|
||||
serialized_inputs = _serialize_value_with_schema(transformed_inputs)
|
||||
except Exception as error:
|
||||
logger.warning(
|
||||
"Failed to serialize the inputs of the current pipeline state. "
|
||||
"The inputs in the snapshot will be replaced with an empty dictionary. Error: {e}",
|
||||
e=error,
|
||||
)
|
||||
serialized_inputs = {}
|
||||
|
||||
try:
|
||||
serialized_original_input_data = _serialize_value_with_schema(transformed_original_input_data)
|
||||
except Exception as error:
|
||||
logger.warning(
|
||||
"Failed to serialize original input data for `pipeline.run`. "
|
||||
"This likely occurred due to non-serializable object types. "
|
||||
"The snapshot will store an empty dictionary instead. Error: {e}",
|
||||
e=error,
|
||||
)
|
||||
serialized_original_input_data = {}
|
||||
|
||||
pipeline_snapshot = PipelineSnapshot(
|
||||
pipeline_state=PipelineState(
|
||||
inputs=_serialize_value_with_schema(transformed_inputs), # current pipeline inputs
|
||||
component_visits=component_visits,
|
||||
pipeline_outputs=pipeline_outputs or {},
|
||||
inputs=serialized_inputs, component_visits=component_visits, pipeline_outputs=pipeline_outputs
|
||||
),
|
||||
timestamp=dt,
|
||||
timestamp=datetime.now(),
|
||||
break_point=break_point,
|
||||
original_input_data=_serialize_value_with_schema(transformed_original_input_data),
|
||||
ordered_component_names=ordered_component_names or [],
|
||||
include_outputs_from=include_outputs_from or set(),
|
||||
original_input_data=serialized_original_input_data,
|
||||
ordered_component_names=ordered_component_names,
|
||||
include_outputs_from=include_outputs_from,
|
||||
)
|
||||
return pipeline_snapshot
|
||||
|
||||
|
||||
def _save_pipeline_snapshot(pipeline_snapshot: PipelineSnapshot) -> PipelineSnapshot:
|
||||
"""
|
||||
Save the pipeline snapshot to a file.
|
||||
|
||||
:param pipeline_snapshot: The pipeline snapshot to save.
|
||||
|
||||
:returns:
|
||||
The dictionary containing the snapshot of the pipeline containing the following keys:
|
||||
- input_data: The original input data passed to the pipeline.
|
||||
- timestamp: The timestamp of the breakpoint.
|
||||
- pipeline_breakpoint: The component name and visit count that triggered the breakpoint.
|
||||
- pipeline_state: The state of the pipeline when the breakpoint was triggered containing the following keys:
|
||||
- inputs: The current state of inputs for pipeline components.
|
||||
- component_visits: The visit count of the components when the breakpoint was triggered.
|
||||
- ordered_component_names: The order of components in the pipeline.
|
||||
"""
|
||||
break_point = pipeline_snapshot.break_point
|
||||
if isinstance(break_point, AgentBreakpoint):
|
||||
snapshot_file_path = break_point.break_point.snapshot_file_path
|
||||
else:
|
||||
snapshot_file_path = break_point.snapshot_file_path
|
||||
|
||||
if snapshot_file_path is not None:
|
||||
dt = pipeline_snapshot.timestamp or datetime.now()
|
||||
_save_pipeline_snapshot_to_file(
|
||||
pipeline_snapshot=pipeline_snapshot, snapshot_file_path=snapshot_file_path, dt=dt
|
||||
)
|
||||
|
||||
return pipeline_snapshot
|
||||
|
||||
|
||||
def _transform_json_structure(data: Union[dict[str, Any], list[Any], Any]) -> Any:
|
||||
"""
|
||||
Transforms a JSON structure by removing the 'sender' key and moving the 'value' to the top level.
|
||||
|
||||
@ -360,10 +360,9 @@ class Pipeline(PipelineBase):
|
||||
and component_name == break_point.agent_name
|
||||
)
|
||||
if break_point and (component_break_point_triggered or agent_break_point_triggered):
|
||||
pipeline_snapshot_inputs_serialised = deepcopy(inputs)
|
||||
pipeline_snapshot_inputs_serialised[component_name] = deepcopy(component_inputs)
|
||||
new_pipeline_snapshot = _create_pipeline_snapshot(
|
||||
inputs=pipeline_snapshot_inputs_serialised,
|
||||
inputs=deepcopy(inputs),
|
||||
component_inputs=deepcopy(component_inputs),
|
||||
break_point=break_point,
|
||||
component_visits=component_visits,
|
||||
original_input_data=data,
|
||||
@ -378,7 +377,7 @@ class Pipeline(PipelineBase):
|
||||
component_inputs["break_point"] = break_point
|
||||
component_inputs["parent_snapshot"] = new_pipeline_snapshot
|
||||
|
||||
# trigger the breakpoint if needed
|
||||
# trigger the break point if needed
|
||||
if component_break_point_triggered:
|
||||
_trigger_break_point(pipeline_snapshot=new_pipeline_snapshot)
|
||||
|
||||
@ -400,11 +399,10 @@ class Pipeline(PipelineBase):
|
||||
snapshot_file_path=out_dir,
|
||||
)
|
||||
|
||||
# Create a snapshot of the last good state of the pipeline before the error occurred.
|
||||
pipeline_snapshot_inputs_serialised = deepcopy(inputs)
|
||||
pipeline_snapshot_inputs_serialised[component_name] = deepcopy(component_inputs)
|
||||
last_good_state_snapshot = _create_pipeline_snapshot(
|
||||
inputs=pipeline_snapshot_inputs_serialised,
|
||||
# Create a snapshot of the state of the pipeline before the error occurred.
|
||||
pipeline_snapshot = _create_pipeline_snapshot(
|
||||
inputs=deepcopy(inputs),
|
||||
component_inputs=deepcopy(component_inputs),
|
||||
break_point=break_point,
|
||||
component_visits=component_visits,
|
||||
original_input_data=data,
|
||||
@ -417,23 +415,12 @@ class Pipeline(PipelineBase):
|
||||
# We take the agent snapshot and attach it to the pipeline snapshot we create here.
|
||||
# We also update the break_point to be an AgentBreakpoint.
|
||||
if error.pipeline_snapshot and error.pipeline_snapshot.agent_snapshot:
|
||||
last_good_state_snapshot.agent_snapshot = error.pipeline_snapshot.agent_snapshot
|
||||
last_good_state_snapshot.break_point = error.pipeline_snapshot.agent_snapshot.break_point
|
||||
pipeline_snapshot.agent_snapshot = error.pipeline_snapshot.agent_snapshot
|
||||
pipeline_snapshot.break_point = error.pipeline_snapshot.agent_snapshot.break_point
|
||||
|
||||
# Attach the last good state snapshot to the error before re-raising it and saving to disk
|
||||
error.pipeline_snapshot = last_good_state_snapshot
|
||||
|
||||
try:
|
||||
_save_pipeline_snapshot(pipeline_snapshot=last_good_state_snapshot)
|
||||
logger.info(
|
||||
"Saved a snapshot of the pipeline's last valid state to '{out_path}'. "
|
||||
"Review this snapshot to debug the error and resume the pipeline from here.",
|
||||
out_path=out_dir,
|
||||
)
|
||||
except Exception as save_error:
|
||||
logger.error(
|
||||
"Failed to save a snapshot of the pipeline's last valid state with error: {e}", e=save_error
|
||||
)
|
||||
# Attach the pipeline snapshot to the error before re-raising
|
||||
error.pipeline_snapshot = pipeline_snapshot
|
||||
_save_pipeline_snapshot(pipeline_snapshot=pipeline_snapshot, raise_on_failure=False)
|
||||
raise error
|
||||
|
||||
# Updates global input state with component outputs and returns outputs that should go to
|
||||
|
||||
@ -189,12 +189,12 @@ class PipelineSnapshot:
|
||||
"""
|
||||
A dataclass to hold a snapshot of the pipeline at a specific point in time.
|
||||
|
||||
:param original_input_data: The original input data provided to the pipeline.
|
||||
:param ordered_component_names: A list of component names in the order they were visited.
|
||||
:param pipeline_state: The state of the pipeline at the time of the snapshot.
|
||||
:param break_point: The breakpoint that triggered the snapshot.
|
||||
:param agent_snapshot: Optional agent snapshot if the breakpoint is an agent breakpoint.
|
||||
:param timestamp: A timestamp indicating when the snapshot was taken.
|
||||
:param original_input_data: The original input data provided to the pipeline.
|
||||
:param ordered_component_names: A list of component names in the order they were visited.
|
||||
:param include_outputs_from: Set of component names whose outputs should be included in the pipeline results.
|
||||
"""
|
||||
|
||||
|
||||
@ -0,0 +1,5 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Refactored `_save_pipeline_snapshot` to consolidate try-except logic and added a `raise_on_failure` option to control whether save failures raise an exception or are logged.
|
||||
`_create_pipeline_snapshot` now wraps `_serialize_value_with_schema` in try-except blocks to prevent failures from non-serializable pipeline inputs.
|
||||
@ -3,6 +3,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
@ -10,12 +11,14 @@ from haystack import component
|
||||
from haystack.core.errors import BreakpointException
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.core.pipeline.breakpoint import (
|
||||
_create_pipeline_snapshot,
|
||||
_save_pipeline_snapshot,
|
||||
_transform_json_structure,
|
||||
_trigger_chat_generator_breakpoint,
|
||||
_trigger_tool_invoker_breakpoint,
|
||||
load_pipeline_snapshot,
|
||||
)
|
||||
from haystack.dataclasses import ChatMessage, ToolCall
|
||||
from haystack.dataclasses import ByteStream, ChatMessage, Document, ToolCall
|
||||
from haystack.dataclasses.breakpoints import (
|
||||
AgentBreakpoint,
|
||||
AgentSnapshot,
|
||||
@ -27,7 +30,7 @@ from haystack.dataclasses.breakpoints import (
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_pipeline_snapshot():
|
||||
def make_pipeline_snapshot_with_agent_snapshot():
|
||||
def _make(break_point: AgentBreakpoint) -> PipelineSnapshot:
|
||||
return PipelineSnapshot(
|
||||
break_point=break_point,
|
||||
@ -144,8 +147,8 @@ def test_breakpoint_saves_intermediate_outputs(tmp_path):
|
||||
assert loaded_snapshot.break_point.visit_count == 0
|
||||
|
||||
|
||||
def test_trigger_tool_invoker_breakpoint(make_pipeline_snapshot):
|
||||
pipeline_snapshot_with_agent_breakpoint = make_pipeline_snapshot(
|
||||
def test_trigger_tool_invoker_breakpoint(make_pipeline_snapshot_with_agent_snapshot):
|
||||
pipeline_snapshot_with_agent_breakpoint = make_pipeline_snapshot_with_agent_snapshot(
|
||||
break_point=AgentBreakpoint("agent", ToolBreakpoint(component_name="tool_invoker"))
|
||||
)
|
||||
with pytest.raises(BreakpointException):
|
||||
@ -155,8 +158,8 @@ def test_trigger_tool_invoker_breakpoint(make_pipeline_snapshot):
|
||||
)
|
||||
|
||||
|
||||
def test_trigger_tool_invoker_breakpoint_no_raise(make_pipeline_snapshot):
|
||||
pipeline_snapshot_with_agent_breakpoint = make_pipeline_snapshot(
|
||||
def test_trigger_tool_invoker_breakpoint_no_raise(make_pipeline_snapshot_with_agent_snapshot):
|
||||
pipeline_snapshot_with_agent_breakpoint = make_pipeline_snapshot_with_agent_snapshot(
|
||||
break_point=AgentBreakpoint("agent", ToolBreakpoint(component_name="tool_invoker", tool_name="tool2"))
|
||||
)
|
||||
# This should not raise since the tool call is for "tool1", not "tool2"
|
||||
@ -166,12 +169,12 @@ def test_trigger_tool_invoker_breakpoint_no_raise(make_pipeline_snapshot):
|
||||
)
|
||||
|
||||
|
||||
def test_trigger_tool_invoker_breakpoint_specific_tool(make_pipeline_snapshot):
|
||||
def test_trigger_tool_invoker_breakpoint_specific_tool(make_pipeline_snapshot_with_agent_snapshot):
|
||||
"""
|
||||
This is to test if a specific tool is set in the ToolBreakpoint, the BreakpointException is raised even when
|
||||
there are multiple tool calls in the message.
|
||||
"""
|
||||
pipeline_snapshot_with_agent_breakpoint = make_pipeline_snapshot(
|
||||
pipeline_snapshot_with_agent_breakpoint = make_pipeline_snapshot_with_agent_snapshot(
|
||||
break_point=AgentBreakpoint("agent", ToolBreakpoint(component_name="tool_invoker", tool_name="tool2"))
|
||||
)
|
||||
with pytest.raises(BreakpointException):
|
||||
@ -185,9 +188,118 @@ def test_trigger_tool_invoker_breakpoint_specific_tool(make_pipeline_snapshot):
|
||||
)
|
||||
|
||||
|
||||
def test_trigger_chat_generator_breakpoint(make_pipeline_snapshot):
|
||||
pipeline_snapshot_with_agent_breakpoint = make_pipeline_snapshot(
|
||||
def test_trigger_chat_generator_breakpoint(make_pipeline_snapshot_with_agent_snapshot):
|
||||
pipeline_snapshot_with_agent_breakpoint = make_pipeline_snapshot_with_agent_snapshot(
|
||||
break_point=AgentBreakpoint("agent", Breakpoint(component_name="chat_generator"))
|
||||
)
|
||||
with pytest.raises(BreakpointException):
|
||||
_trigger_chat_generator_breakpoint(pipeline_snapshot=pipeline_snapshot_with_agent_breakpoint)
|
||||
|
||||
|
||||
class TestCreatePipelineSnapshot:
|
||||
def test_create_pipeline_snapshot_all_fields(self):
|
||||
break_point = Breakpoint(component_name="comp2")
|
||||
ordered_component_names = ["comp1", "comp2"]
|
||||
include_outputs_from = {"comp1"}
|
||||
|
||||
snapshot = _create_pipeline_snapshot(
|
||||
inputs={"comp1": {"input_value": [{"sender": None, "value": "test"}]}, "comp2": {}},
|
||||
component_inputs={"input_value": "processed_test"},
|
||||
break_point=break_point,
|
||||
component_visits={"comp1": 1, "comp2": 0},
|
||||
original_input_data={"comp1": {"input_value": "test"}},
|
||||
ordered_component_names=ordered_component_names,
|
||||
include_outputs_from=include_outputs_from,
|
||||
pipeline_outputs={"comp1": {"result": "processed_test"}},
|
||||
)
|
||||
|
||||
assert snapshot.original_input_data == {
|
||||
"serialization_schema": {
|
||||
"type": "object",
|
||||
"properties": {"comp1": {"type": "object", "properties": {"input_value": {"type": "string"}}}},
|
||||
},
|
||||
"serialized_data": {"comp1": {"input_value": "test"}},
|
||||
}
|
||||
assert snapshot.ordered_component_names == ordered_component_names
|
||||
assert snapshot.break_point == break_point
|
||||
assert snapshot.agent_snapshot is None
|
||||
assert snapshot.include_outputs_from == include_outputs_from
|
||||
assert snapshot.pipeline_state == PipelineState(
|
||||
inputs={
|
||||
"serialization_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"comp1": {"type": "object", "properties": {"input_value": {"type": "string"}}},
|
||||
"comp2": {"type": "object", "properties": {"input_value": {"type": "string"}}},
|
||||
},
|
||||
},
|
||||
"serialized_data": {"comp1": {"input_value": "test"}, "comp2": {"input_value": "processed_test"}},
|
||||
},
|
||||
component_visits={"comp1": 1, "comp2": 0},
|
||||
pipeline_outputs={"comp1": {"result": "processed_test"}},
|
||||
)
|
||||
|
||||
def test_create_pipeline_snapshot_with_dataclasses_in_pipeline_outputs(self):
|
||||
snapshot = _create_pipeline_snapshot(
|
||||
inputs={},
|
||||
component_inputs={},
|
||||
break_point=Breakpoint(component_name="comp2"),
|
||||
component_visits={"comp1": 1, "comp2": 0},
|
||||
original_input_data={},
|
||||
ordered_component_names=["comp1", "comp2"],
|
||||
include_outputs_from={"comp1"},
|
||||
pipeline_outputs={"comp1": {"result": ChatMessage.from_user("hello")}},
|
||||
)
|
||||
|
||||
assert snapshot.pipeline_state == PipelineState(
|
||||
inputs={
|
||||
"serialization_schema": {
|
||||
"type": "object",
|
||||
"properties": {"comp2": {"type": "object", "properties": {}}},
|
||||
},
|
||||
"serialized_data": {"comp2": {}},
|
||||
},
|
||||
component_visits={"comp1": 1, "comp2": 0},
|
||||
pipeline_outputs={"comp1": {"result": ChatMessage.from_user("hello")}},
|
||||
)
|
||||
|
||||
def test_create_pipeline_snapshot_non_serializable_inputs(self, caplog):
|
||||
class NonSerializable:
|
||||
def to_dict(self):
|
||||
raise TypeError("Cannot serialize")
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
_create_pipeline_snapshot(
|
||||
inputs={"comp1": {"input_value": [{"sender": None, "value": NonSerializable()}]}, "comp2": {}},
|
||||
component_inputs={},
|
||||
break_point=Breakpoint(component_name="comp2"),
|
||||
component_visits={"comp1": 1, "comp2": 0},
|
||||
original_input_data={"comp1": {"input_value": NonSerializable()}},
|
||||
ordered_component_names=["comp1", "comp2"],
|
||||
include_outputs_from={"comp1"},
|
||||
pipeline_outputs={},
|
||||
)
|
||||
|
||||
assert any("Failed to serialize the inputs of the current pipeline state" in msg for msg in caplog.messages)
|
||||
assert any("Failed to serialize original input data for `pipeline.run`." in msg for msg in caplog.messages)
|
||||
|
||||
|
||||
def test_save_pipeline_snapshot_raises_on_failure(tmp_path, caplog):
|
||||
snapshot = _create_pipeline_snapshot(
|
||||
inputs={},
|
||||
component_inputs={},
|
||||
break_point=Breakpoint(component_name="comp2", snapshot_file_path=str(tmp_path)),
|
||||
component_visits={"comp1": 1, "comp2": 0},
|
||||
original_input_data={},
|
||||
ordered_component_names=["comp1", "comp2"],
|
||||
include_outputs_from={"comp1"},
|
||||
# We use a non-serializable type (bytes) directly in pipeline outputs to trigger the error
|
||||
pipeline_outputs={"comp1": {"result": b"test"}},
|
||||
)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
_save_pipeline_snapshot(snapshot)
|
||||
|
||||
with caplog.at_level(logging.ERROR):
|
||||
_save_pipeline_snapshot(snapshot, raise_on_failure=False)
|
||||
assert any("Failed to save pipeline snapshot to" in msg for msg in caplog.messages)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user