refactor(workflow): Rename workflow node execution models (#20458)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-05-30 04:56:37 +08:00 committed by GitHub
parent 32e779eef3
commit f7fb10635f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 88 additions and 87 deletions

View File

@ -45,7 +45,7 @@ from core.app.entities.task_entities import (
from core.file import FILE_MODEL_IDENTITY, File
from core.tools.tool_manager import ToolManager
from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.entities.workflow_node_execution import NodeExecution, WorkflowNodeExecutionStatus
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from models import (
@ -143,7 +143,7 @@ class WorkflowResponseConverter:
*,
event: QueueNodeStartedEvent,
task_id: str,
workflow_node_execution: NodeExecution,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeStartStreamResponse]:
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
@ -193,7 +193,7 @@ class WorkflowResponseConverter:
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
task_id: str,
workflow_node_execution: NodeExecution,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
@ -236,7 +236,7 @@ class WorkflowResponseConverter:
*,
event: QueueNodeRetryEvent,
task_id: str,
workflow_node_execution: NodeExecution,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None

View File

@ -13,7 +13,7 @@ from sqlalchemy.orm import sessionmaker
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.workflow_node_execution import (
NodeExecution,
WorkflowNodeExecution,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
@ -23,7 +23,7 @@ from models import (
Account,
CreatorUserRole,
EndUser,
WorkflowNodeExecution,
WorkflowNodeExecutionModel,
WorkflowNodeExecutionTriggeredFrom,
)
@ -86,9 +86,9 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
# Initialize in-memory cache for node executions
# Key: node_execution_id, Value: WorkflowNodeExecution (DB model)
self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
self._node_execution_cache: dict[str, WorkflowNodeExecutionModel] = {}
def _to_domain_model(self, db_model: WorkflowNodeExecution) -> NodeExecution:
def _to_domain_model(self, db_model: WorkflowNodeExecutionModel) -> WorkflowNodeExecution:
"""
Convert a database model to a domain model.
@ -107,7 +107,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
# Convert status to domain enum
status = WorkflowNodeExecutionStatus(db_model.status)
return NodeExecution(
return WorkflowNodeExecution(
id=db_model.id,
node_execution_id=db_model.node_execution_id,
workflow_id=db_model.workflow_id,
@ -128,7 +128,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
finished_at=db_model.finished_at,
)
def to_db_model(self, domain_model: NodeExecution) -> WorkflowNodeExecution:
def to_db_model(self, domain_model: WorkflowNodeExecution) -> WorkflowNodeExecutionModel:
"""
Convert a domain model to a database model.
@ -146,7 +146,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
if not self._creator_user_role:
raise ValueError("created_by_role is required in repository constructor")
db_model = WorkflowNodeExecution()
db_model = WorkflowNodeExecutionModel()
db_model.id = domain_model.id
db_model.tenant_id = self._tenant_id
if self._app_id is not None:
@ -175,7 +175,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
db_model.finished_at = domain_model.finished_at
return db_model
def save(self, execution: NodeExecution) -> None:
def save(self, execution: WorkflowNodeExecution) -> None:
"""
Save or update a NodeExecution domain entity to the database.
@ -207,7 +207,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
logger.debug(f"Updating cache for node_execution_id: {db_model.node_execution_id}")
self._node_execution_cache[db_model.node_execution_id] = db_model
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]:
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
"""
Retrieve a NodeExecution by its node_execution_id.
@ -230,13 +230,13 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
# If not in cache, query the database
logger.debug(f"Cache miss for node_execution_id: {node_execution_id}, querying database")
with self._session_factory() as session:
stmt = select(WorkflowNodeExecution).where(
WorkflowNodeExecution.node_execution_id == node_execution_id,
WorkflowNodeExecution.tenant_id == self._tenant_id,
stmt = select(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.node_execution_id == node_execution_id,
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
)
if self._app_id:
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
db_model = session.scalar(stmt)
if db_model:
@ -252,7 +252,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
self,
workflow_run_id: str,
order_config: Optional[OrderConfig] = None,
) -> Sequence[WorkflowNodeExecution]:
) -> Sequence[WorkflowNodeExecutionModel]:
"""
Retrieve all WorkflowNodeExecution database models for a specific workflow run.
@ -270,20 +270,20 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
A list of WorkflowNodeExecution database models
"""
with self._session_factory() as session:
stmt = select(WorkflowNodeExecution).where(
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
WorkflowNodeExecution.tenant_id == self._tenant_id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
stmt = select(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
WorkflowNodeExecutionModel.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
if self._app_id:
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
# Apply ordering if provided
if order_config and order_config.order_by:
order_columns: list[UnaryExpression] = []
for field in order_config.order_by:
column = getattr(WorkflowNodeExecution, field, None)
column = getattr(WorkflowNodeExecutionModel, field, None)
if not column:
continue
if order_config.order_direction == "desc":
@ -307,7 +307,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
self,
workflow_run_id: str,
order_config: Optional[OrderConfig] = None,
) -> Sequence[NodeExecution]:
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all NodeExecution instances for a specific workflow run.
@ -334,7 +334,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
return domain_models
def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]:
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all running NodeExecution instances for a specific workflow run.
@ -348,15 +348,15 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
A list of running NodeExecution instances
"""
with self._session_factory() as session:
stmt = select(WorkflowNodeExecution).where(
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
WorkflowNodeExecution.tenant_id == self._tenant_id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
stmt = select(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
WorkflowNodeExecutionModel.status == WorkflowNodeExecutionStatus.RUNNING,
WorkflowNodeExecutionModel.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
if self._app_id:
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
db_models = session.scalars(stmt).all()
domain_models = []
@ -381,10 +381,10 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
It also clears the in-memory cache.
"""
with self._session_factory() as session:
stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id)
stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.tenant_id == self._tenant_id)
if self._app_id:
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
result = session.execute(stmt)
session.commit()

View File

@ -53,7 +53,7 @@ class WorkflowNodeExecutionStatus(StrEnum):
RETRY = "retry"
class NodeExecution(BaseModel):
class WorkflowNodeExecution(BaseModel):
"""
Domain model for workflow node execution.

View File

@ -2,7 +2,7 @@ from collections.abc import Sequence
from dataclasses import dataclass
from typing import Literal, Optional, Protocol
from core.workflow.entities.workflow_node_execution import NodeExecution
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
@dataclass
@ -26,7 +26,7 @@ class WorkflowNodeExecutionRepository(Protocol):
application domains or deployment scenarios.
"""
def save(self, execution: NodeExecution) -> None:
def save(self, execution: WorkflowNodeExecution) -> None:
"""
Save or update a NodeExecution instance.
@ -39,7 +39,7 @@ class WorkflowNodeExecutionRepository(Protocol):
"""
...
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]:
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
"""
Retrieve a NodeExecution by its node_execution_id.
@ -55,7 +55,7 @@ class WorkflowNodeExecutionRepository(Protocol):
self,
workflow_run_id: str,
order_config: Optional[OrderConfig] = None,
) -> Sequence[NodeExecution]:
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all NodeExecution instances for a specific workflow run.
@ -70,7 +70,7 @@ class WorkflowNodeExecutionRepository(Protocol):
"""
...
def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]:
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all running NodeExecution instances for a specific workflow run.

View File

@ -19,7 +19,7 @@ from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
from core.workflow.entities.workflow_node_execution import (
NodeExecution,
WorkflowNodeExecution,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
@ -204,7 +204,7 @@ class WorkflowCycleManager:
*,
workflow_execution_id: str,
event: QueueNodeStartedEvent,
) -> NodeExecution:
) -> WorkflowNodeExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
# Create a domain model
@ -215,7 +215,7 @@ class WorkflowCycleManager:
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
}
domain_execution = NodeExecution(
domain_execution = WorkflowNodeExecution(
id=str(uuid4()),
workflow_id=workflow_execution.workflow_id,
workflow_execution_id=workflow_execution.id_,
@ -235,7 +235,7 @@ class WorkflowCycleManager:
return domain_execution
def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> NodeExecution:
def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
# Get the domain model from repository
domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id)
if not domain_execution:
@ -275,7 +275,7 @@ class WorkflowCycleManager:
| QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
) -> NodeExecution:
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param event: queue node failed event
@ -320,7 +320,7 @@ class WorkflowCycleManager:
def handle_workflow_node_execution_retried(
self, *, workflow_execution_id: str, event: QueueNodeRetryEvent
) -> NodeExecution:
) -> WorkflowNodeExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
created_at = event.start_at
finished_at = datetime.now(UTC).replace(tzinfo=None)
@ -344,7 +344,7 @@ class WorkflowCycleManager:
merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata
# Create a domain model
domain_execution = NodeExecution(
domain_execution = WorkflowNodeExecution(
id=str(uuid4()),
workflow_id=workflow_execution.workflow_id,
workflow_execution_id=workflow_execution.id_,

View File

@ -84,7 +84,7 @@ from .workflow import (
Workflow,
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
WorkflowNodeExecution,
WorkflowNodeExecutionModel,
WorkflowNodeExecutionTriggeredFrom,
WorkflowRun,
WorkflowType,
@ -169,7 +169,7 @@ __all__ = [
"Workflow",
"WorkflowAppLog",
"WorkflowAppLogCreatedFrom",
"WorkflowNodeExecution",
"WorkflowNodeExecutionModel",
"WorkflowNodeExecutionTriggeredFrom",
"WorkflowRun",
"WorkflowRunTriggeredFrom",

View File

@ -541,7 +541,7 @@ class WorkflowNodeExecutionTriggeredFrom(StrEnum):
WORKFLOW_RUN = "workflow-run"
class WorkflowNodeExecution(Base):
class WorkflowNodeExecutionModel(Base):
"""
Workflow Node Execution

View File

@ -14,7 +14,7 @@ from extensions.ext_database import db
from extensions.ext_storage import storage
from models.account import Tenant
from models.model import App, Conversation, Message
from models.workflow import WorkflowNodeExecution, WorkflowRun
from models.workflow import WorkflowNodeExecutionModel, WorkflowRun
from services.billing_service import BillingService
logger = logging.getLogger(__name__)
@ -108,10 +108,11 @@ class ClearFreePlanTenantExpiredLogs:
while True:
with Session(db.engine).no_autoflush as session:
workflow_node_executions = (
session.query(WorkflowNodeExecution)
session.query(WorkflowNodeExecutionModel)
.filter(
WorkflowNodeExecution.tenant_id == tenant_id,
WorkflowNodeExecution.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.created_at
< datetime.datetime.now() - datetime.timedelta(days=days),
)
.limit(batch)
.all()
@ -135,8 +136,8 @@ class ClearFreePlanTenantExpiredLogs:
]
# delete workflow node executions
session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id.in_(workflow_node_execution_ids),
session.query(WorkflowNodeExecutionModel).filter(
WorkflowNodeExecutionModel.id.in_(workflow_node_execution_ids),
).delete(synchronize_session=False)
session.commit()

View File

@ -11,7 +11,7 @@ from models import (
Account,
App,
EndUser,
WorkflowNodeExecution,
WorkflowNodeExecutionModel,
WorkflowRun,
WorkflowRunTriggeredFrom,
)
@ -125,7 +125,7 @@ class WorkflowRunService:
app_model: App,
run_id: str,
user: Account | EndUser,
) -> Sequence[WorkflowNodeExecution]:
) -> Sequence[WorkflowNodeExecutionModel]:
"""
Get workflow run node execution list
"""

View File

@ -13,7 +13,7 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables import Variable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import NodeExecution, WorkflowNodeExecutionStatus
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes import NodeType
@ -30,7 +30,7 @@ from models.model import App, AppMode
from models.tools import WorkflowToolProvider
from models.workflow import (
Workflow,
WorkflowNodeExecution,
WorkflowNodeExecutionModel,
WorkflowNodeExecutionTriggeredFrom,
WorkflowType,
)
@ -254,7 +254,7 @@ class WorkflowService:
def run_draft_workflow_node(
self, app_model: App, node_id: str, user_inputs: dict, account: Account
) -> WorkflowNodeExecution:
) -> WorkflowNodeExecutionModel:
"""
Run draft workflow node
"""
@ -296,7 +296,7 @@ class WorkflowService:
def run_free_workflow_node(
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
) -> NodeExecution:
) -> WorkflowNodeExecution:
"""
Run draft workflow node
"""
@ -322,7 +322,7 @@ class WorkflowService:
invoke_node_fn: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
start_at: float,
node_id: str,
) -> NodeExecution:
) -> WorkflowNodeExecution:
try:
node_instance, generator = invoke_node_fn()
@ -374,7 +374,7 @@ class WorkflowService:
error = e.error
# Create a NodeExecution domain model
node_execution = NodeExecution(
node_execution = WorkflowNodeExecution(
id=str(uuid4()),
workflow_id="", # This is a single-step execution, so no workflow ID
index=1,

View File

@ -30,7 +30,7 @@ from models import (
)
from models.tools import WorkflowToolProvider
from models.web import PinnedConversation, SavedMessage
from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun
from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecutionModel, WorkflowRun
@shared_task(queue="app_deletion", bind=True, max_retries=3)
@ -188,9 +188,9 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str):
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
def del_workflow_node_execution(workflow_node_execution_id: str):
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).delete(
synchronize_session=False
)
db.session.query(WorkflowNodeExecutionModel).filter(
WorkflowNodeExecutionModel.id == workflow_node_execution_id
).delete(synchronize_session=False)
_delete_records(
"""select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""",

View File

@ -14,7 +14,7 @@ from core.app.entities.queue_entities import (
)
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
from core.workflow.entities.workflow_node_execution import (
NodeExecution,
WorkflowNodeExecution,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
@ -373,7 +373,7 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager):
# Create a real node execution
node_execution = NodeExecution(
node_execution = WorkflowNodeExecution(
id="test-node-execution-record-id",
node_execution_id="test-node-execution-id",
workflow_id="test-workflow-id",
@ -451,7 +451,7 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager):
# Create a real node execution
node_execution = NodeExecution(
node_execution = WorkflowNodeExecution(
id="test-node-execution-record-id",
node_execution_id="test-node-execution-id",
workflow_id="test-workflow-id",

View File

@ -4,7 +4,7 @@ from uuid import uuid4
from constants import HIDDEN_VALUE
from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable
from models.workflow import Workflow, WorkflowNodeExecution
from models.workflow import Workflow, WorkflowNodeExecutionModel
def test_environment_variables():
@ -156,7 +156,7 @@ def test_to_dict():
class TestWorkflowNodeExecution:
def test_execution_metadata_dict(self):
node_exec = WorkflowNodeExecution()
node_exec = WorkflowNodeExecutionModel()
node_exec.execution_metadata = None
assert node_exec.execution_metadata_dict == {}

View File

@ -14,14 +14,14 @@ from sqlalchemy.orm import Session, sessionmaker
from core.model_runtime.utils.encoders import jsonable_encoder
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities.workflow_node_execution import (
NodeExecution,
WorkflowNodeExecution,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.nodes.enums import NodeType
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
from models.account import Account, Tenant
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionTriggeredFrom
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom
def configure_mock_execution(mock_execution):
@ -85,7 +85,7 @@ def test_save(repository, session):
"""Test save method."""
session_obj, _ = session
# Create a mock execution
execution = MagicMock(spec=WorkflowNodeExecution)
execution = MagicMock(spec=WorkflowNodeExecutionModel)
execution.tenant_id = None
execution.app_id = None
execution.inputs = None
@ -111,7 +111,7 @@ def test_save_with_existing_tenant_id(repository, session):
"""Test save method with existing tenant_id."""
session_obj, _ = session
# Create a mock execution with existing tenant_id
execution = MagicMock(spec=WorkflowNodeExecution)
execution = MagicMock(spec=WorkflowNodeExecutionModel)
execution.tenant_id = "existing-tenant"
execution.app_id = None
execution.inputs = None
@ -120,7 +120,7 @@ def test_save_with_existing_tenant_id(repository, session):
execution.metadata = None
# Create a modified execution that will be returned by _to_db_model
modified_execution = MagicMock(spec=WorkflowNodeExecution)
modified_execution = MagicMock(spec=WorkflowNodeExecutionModel)
modified_execution.tenant_id = "existing-tenant" # Tenant ID should not change
modified_execution.app_id = repository._app_id # App ID should be set
@ -147,7 +147,7 @@ def test_get_by_node_execution_id(repository, session, mocker: MockerFixture):
mock_stmt.where.return_value = mock_stmt
# Create a properly configured mock execution
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecution)
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecutionModel)
configure_mock_execution(mock_execution)
session_obj.scalar.return_value = mock_execution
@ -179,7 +179,7 @@ def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
mock_stmt.order_by.return_value = mock_stmt
# Create a properly configured mock execution
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecution)
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecutionModel)
configure_mock_execution(mock_execution)
session_obj.scalars.return_value.all.return_value = [mock_execution]
@ -212,7 +212,7 @@ def test_get_running_executions(repository, session, mocker: MockerFixture):
mock_stmt.where.return_value = mock_stmt
# Create a properly configured mock execution
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecution)
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecutionModel)
configure_mock_execution(mock_execution)
session_obj.scalars.return_value.all.return_value = [mock_execution]
@ -238,7 +238,7 @@ def test_update_via_save(repository, session):
"""Test updating an existing record via save method."""
session_obj, _ = session
# Create a mock execution
execution = MagicMock(spec=WorkflowNodeExecution)
execution = MagicMock(spec=WorkflowNodeExecutionModel)
execution.tenant_id = None
execution.app_id = None
execution.inputs = None
@ -278,7 +278,7 @@ def test_clear(repository, session, mocker: MockerFixture):
repository.clear()
# Assert delete was called with correct parameters
mock_delete.assert_called_once_with(WorkflowNodeExecution)
mock_delete.assert_called_once_with(WorkflowNodeExecutionModel)
mock_stmt.where.assert_called()
session_obj.execute.assert_called_once_with(mock_stmt)
session_obj.commit.assert_called_once()
@ -287,7 +287,7 @@ def test_clear(repository, session, mocker: MockerFixture):
def test_to_db_model(repository):
"""Test to_db_model method."""
# Create a domain model
domain_model = NodeExecution(
domain_model = WorkflowNodeExecution(
id="test-id",
workflow_id="test-workflow-id",
node_execution_id="test-node-execution-id",
@ -315,7 +315,7 @@ def test_to_db_model(repository):
db_model = repository.to_db_model(domain_model)
# Assert DB model has correct values
assert isinstance(db_model, WorkflowNodeExecution)
assert isinstance(db_model, WorkflowNodeExecutionModel)
assert db_model.id == domain_model.id
assert db_model.tenant_id == repository._tenant_id
assert db_model.app_id == repository._app_id
@ -352,7 +352,7 @@ def test_to_domain_model(repository):
metadata_dict = {str(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS): 100}
# Create a DB model using our custom subclass
db_model = WorkflowNodeExecution()
db_model = WorkflowNodeExecutionModel()
db_model.id = "test-id"
db_model.tenant_id = "test-tenant-id"
db_model.app_id = "test-app-id"
@ -381,7 +381,7 @@ def test_to_domain_model(repository):
domain_model = repository._to_domain_model(db_model)
# Assert domain model has correct values
assert isinstance(domain_model, NodeExecution)
assert isinstance(domain_model, WorkflowNodeExecution)
assert domain_model.id == db_model.id
assert domain_model.workflow_id == db_model.workflow_id
assert domain_model.workflow_execution_id == db_model.workflow_run_id