mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-07 20:46:31 +00:00
fix: Use deepcopy with exceptions instead of deepcopy (#10029)
* Fix deepcopy issue * Update test * Fix spelling
This commit is contained in:
parent
b88493ac79
commit
21502f4970
@ -3,7 +3,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from dataclasses import replace
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@ -13,6 +12,7 @@ from networkx import MultiDiGraph
|
||||
|
||||
from haystack import logging
|
||||
from haystack.core.errors import BreakpointException, PipelineInvalidPipelineSnapshotError
|
||||
from haystack.core.pipeline.utils import _deepcopy_with_exceptions
|
||||
from haystack.dataclasses import ChatMessage
|
||||
from haystack.dataclasses.breakpoints import (
|
||||
AgentBreakpoint,
|
||||
@ -338,8 +338,10 @@ def _create_agent_snapshot(
|
||||
"""
|
||||
return AgentSnapshot(
|
||||
component_inputs={
|
||||
"chat_generator": _serialize_value_with_schema(deepcopy(component_inputs["chat_generator"])),
|
||||
"tool_invoker": _serialize_value_with_schema(deepcopy(component_inputs["tool_invoker"])),
|
||||
"chat_generator": _serialize_value_with_schema(
|
||||
_deepcopy_with_exceptions(component_inputs["chat_generator"])
|
||||
),
|
||||
"tool_invoker": _serialize_value_with_schema(_deepcopy_with_exceptions(component_inputs["tool_invoker"])),
|
||||
},
|
||||
component_visits=component_visits,
|
||||
break_point=agent_breakpoint,
|
||||
|
||||
@ -361,8 +361,8 @@ class Pipeline(PipelineBase):
|
||||
)
|
||||
if break_point and (component_break_point_triggered or agent_break_point_triggered):
|
||||
new_pipeline_snapshot = _create_pipeline_snapshot(
|
||||
inputs=deepcopy(inputs),
|
||||
component_inputs=deepcopy(component_inputs),
|
||||
inputs=_deepcopy_with_exceptions(inputs),
|
||||
component_inputs=_deepcopy_with_exceptions(component_inputs),
|
||||
break_point=break_point,
|
||||
component_visits=component_visits,
|
||||
original_input_data=data,
|
||||
@ -399,8 +399,8 @@ class Pipeline(PipelineBase):
|
||||
|
||||
# Create a snapshot of the state of the pipeline before the error occurred.
|
||||
pipeline_snapshot = _create_pipeline_snapshot(
|
||||
inputs=deepcopy(inputs),
|
||||
component_inputs=deepcopy(component_inputs),
|
||||
inputs=_deepcopy_with_exceptions(inputs),
|
||||
component_inputs=_deepcopy_with_exceptions(component_inputs),
|
||||
break_point=break_point,
|
||||
component_visits=component_visits,
|
||||
original_input_data=data,
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
fixes:
|
||||
- |
|
||||
When creating a pipeline snapshot we make sure to use _deepcopy_with_exceptions when copying component inputs to avoid deep copies of items like components and tools since they often contain attributes that are not deep-copyable. For example, the LinkContentFetcher has httpx.Client as an attribute which throws an error if we try to deep copy it.
|
||||
@ -8,13 +8,14 @@ import pytest
|
||||
|
||||
from haystack import Document, Pipeline, component
|
||||
from haystack.components.agents import Agent
|
||||
from haystack.components.fetchers import LinkContentFetcher
|
||||
from haystack.components.tools import ToolInvoker
|
||||
from haystack.components.writers import DocumentWriter
|
||||
from haystack.core.errors import PipelineRuntimeError
|
||||
from haystack.dataclasses import ChatMessage, ToolCall
|
||||
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||
from haystack.document_stores.types import DuplicatePolicy
|
||||
from haystack.tools import Tool, Toolset, create_tool_from_function
|
||||
from haystack.tools import ComponentTool, Tool, Toolset, create_tool_from_function
|
||||
|
||||
|
||||
def calculate(expression: str) -> dict:
|
||||
@ -39,6 +40,8 @@ calculator_tool = create_tool_from_function(
|
||||
function=calculate, name="calculator", outputs_to_state={"calc_result": {"source": "result"}}
|
||||
)
|
||||
|
||||
link_fetcher_tool = ComponentTool(component=LinkContentFetcher())
|
||||
|
||||
|
||||
@component
|
||||
class MockChatGenerator:
|
||||
@ -85,7 +88,9 @@ def test_pipeline_with_chat_generator_crash():
|
||||
"""Test pipeline crash handling when chat generator fails."""
|
||||
pipe = build_pipeline(
|
||||
agent=Agent(
|
||||
chat_generator=MockChatGenerator(True), tools=[calculator_tool], state_schema={"calc_result": {"type": int}}
|
||||
chat_generator=MockChatGenerator(True),
|
||||
tools=[calculator_tool, link_fetcher_tool],
|
||||
state_schema={"calc_result": {"type": int}},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user