mirror of
https://github.com/langgenius/dify.git
synced 2025-11-28 20:07:53 +00:00
fix: ensure advanced-chat workflows stop correctly (#27803)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
parent
f76a3f545c
commit
a486c47b1e
@ -17,7 +17,6 @@ from controllers.console.app.error import (
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
@ -32,6 +31,7 @@ from libs.login import current_user, login_required
|
||||
from models import Account
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -121,7 +121,13 @@ class CompletionMessageStopApi(Resource):
|
||||
def post(self, app_model, task_id):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
user_id=current_user.id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@ -220,6 +226,12 @@ class ChatMessageStopApi(Resource):
|
||||
def post(self, app_model, task_id):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
user_id=current_user.id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@ -15,7 +15,6 @@ from controllers.console.app.error import (
|
||||
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
@ -31,6 +30,7 @@ from libs.login import current_user
|
||||
from models import Account
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
from .. import console_ns
|
||||
@ -46,7 +46,7 @@ logger = logging.getLogger(__name__)
|
||||
class CompletionApi(InstalledAppResource):
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != "completion":
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = (
|
||||
@ -102,12 +102,18 @@ class CompletionApi(InstalledAppResource):
|
||||
class CompletionStopApi(InstalledAppResource):
|
||||
def post(self, installed_app, task_id):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != "completion":
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
user_id=current_user.id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@ -184,6 +190,12 @@ class ChatStopApi(InstalledAppResource):
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
user_id=current_user.id,
|
||||
app_mode=app_mode,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@ -17,7 +17,6 @@ from controllers.service_api.app.error import (
|
||||
)
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
@ -30,6 +29,7 @@ from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
@ -88,7 +88,7 @@ class CompletionApi(Resource):
|
||||
This endpoint generates a completion based on the provided inputs and query.
|
||||
Supports both blocking and streaming response modes.
|
||||
"""
|
||||
if app_model.mode != "completion":
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise AppUnavailableError()
|
||||
|
||||
args = completion_parser.parse_args()
|
||||
@ -147,10 +147,15 @@ class CompletionStopApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
"""Stop a running completion task."""
|
||||
if app_model.mode != "completion":
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise AppUnavailableError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
user_id=end_user.id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@ -244,6 +249,11 @@ class ChatStopApi(Resource):
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
user_id=end_user.id,
|
||||
app_mode=app_mode,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@ -17,7 +17,6 @@ from controllers.web.error import (
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
@ -29,6 +28,7 @@ from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -64,7 +64,7 @@ class CompletionApi(WebApiResource):
|
||||
}
|
||||
)
|
||||
def post(self, app_model, end_user):
|
||||
if app_model.mode != "completion":
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = (
|
||||
@ -125,10 +125,15 @@ class CompletionStopApi(WebApiResource):
|
||||
}
|
||||
)
|
||||
def post(self, app_model, end_user, task_id):
|
||||
if app_model.mode != "completion":
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
user_id=end_user.id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@ -234,6 +239,11 @@ class ChatStopApi(WebApiResource):
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
user_id=end_user.id,
|
||||
app_mode=app_mode,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
45
api/services/app_task_service.py
Normal file
45
api/services/app_task_service.py
Normal file
@ -0,0 +1,45 @@
|
||||
"""Service for managing application task operations.
|
||||
|
||||
This service provides centralized logic for task control operations
|
||||
like stopping tasks, handling both legacy Redis flag mechanism and
|
||||
new GraphEngine command channel mechanism.
|
||||
"""
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class AppTaskService:
|
||||
"""Service for managing application task operations."""
|
||||
|
||||
@staticmethod
|
||||
def stop_task(
|
||||
task_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
user_id: str,
|
||||
app_mode: AppMode,
|
||||
) -> None:
|
||||
"""Stop a running task.
|
||||
|
||||
This method handles stopping tasks using both mechanisms:
|
||||
1. Legacy Redis flag mechanism (for backward compatibility)
|
||||
2. New GraphEngine command channel (for workflow-based apps)
|
||||
|
||||
Args:
|
||||
task_id: The task ID to stop
|
||||
invoke_from: The source of the invoke (e.g., DEBUGGER, WEB_APP, SERVICE_API)
|
||||
user_id: The user ID requesting the stop
|
||||
app_mode: The application mode (CHAT, AGENT_CHAT, ADVANCED_CHAT, WORKFLOW, etc.)
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Legacy mechanism: Set stop flag in Redis
|
||||
AppQueueManager.set_stop_flag(task_id, invoke_from, user_id)
|
||||
|
||||
# New mechanism: Send stop command via GraphEngine for workflow-based apps
|
||||
# This ensures proper workflow status recording in the persistence layer
|
||||
if app_mode in (AppMode.ADVANCED_CHAT, AppMode.WORKFLOW):
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
106
api/tests/unit_tests/services/test_app_task_service.py
Normal file
106
api/tests/unit_tests/services/test_app_task_service.py
Normal file
@ -0,0 +1,106 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.model import AppMode
|
||||
from services.app_task_service import AppTaskService
|
||||
|
||||
|
||||
class TestAppTaskService:
|
||||
"""Test suite for AppTaskService.stop_task method."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("app_mode", "should_call_graph_engine"),
|
||||
[
|
||||
(AppMode.CHAT, False),
|
||||
(AppMode.COMPLETION, False),
|
||||
(AppMode.AGENT_CHAT, False),
|
||||
(AppMode.CHANNEL, False),
|
||||
(AppMode.RAG_PIPELINE, False),
|
||||
(AppMode.ADVANCED_CHAT, True),
|
||||
(AppMode.WORKFLOW, True),
|
||||
],
|
||||
)
|
||||
@patch("services.app_task_service.AppQueueManager")
|
||||
@patch("services.app_task_service.GraphEngineManager")
|
||||
def test_stop_task_with_different_app_modes(
|
||||
self, mock_graph_engine_manager, mock_app_queue_manager, app_mode, should_call_graph_engine
|
||||
):
|
||||
"""Test stop_task behavior with different app modes.
|
||||
|
||||
Verifies that:
|
||||
- Legacy Redis flag is always set via AppQueueManager
|
||||
- GraphEngine stop command is only sent for ADVANCED_CHAT and WORKFLOW modes
|
||||
"""
|
||||
# Arrange
|
||||
task_id = "task-123"
|
||||
invoke_from = InvokeFrom.WEB_APP
|
||||
user_id = "user-456"
|
||||
|
||||
# Act
|
||||
AppTaskService.stop_task(task_id, invoke_from, user_id, app_mode)
|
||||
|
||||
# Assert
|
||||
mock_app_queue_manager.set_stop_flag.assert_called_once_with(task_id, invoke_from, user_id)
|
||||
if should_call_graph_engine:
|
||||
mock_graph_engine_manager.send_stop_command.assert_called_once_with(task_id)
|
||||
else:
|
||||
mock_graph_engine_manager.send_stop_command.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invoke_from",
|
||||
[
|
||||
InvokeFrom.WEB_APP,
|
||||
InvokeFrom.SERVICE_API,
|
||||
InvokeFrom.DEBUGGER,
|
||||
InvokeFrom.EXPLORE,
|
||||
],
|
||||
)
|
||||
@patch("services.app_task_service.AppQueueManager")
|
||||
@patch("services.app_task_service.GraphEngineManager")
|
||||
def test_stop_task_with_different_invoke_sources(
|
||||
self, mock_graph_engine_manager, mock_app_queue_manager, invoke_from
|
||||
):
|
||||
"""Test stop_task behavior with different invoke sources.
|
||||
|
||||
Verifies that the method works correctly regardless of the invoke source.
|
||||
"""
|
||||
# Arrange
|
||||
task_id = "task-789"
|
||||
user_id = "user-999"
|
||||
app_mode = AppMode.ADVANCED_CHAT
|
||||
|
||||
# Act
|
||||
AppTaskService.stop_task(task_id, invoke_from, user_id, app_mode)
|
||||
|
||||
# Assert
|
||||
mock_app_queue_manager.set_stop_flag.assert_called_once_with(task_id, invoke_from, user_id)
|
||||
mock_graph_engine_manager.send_stop_command.assert_called_once_with(task_id)
|
||||
|
||||
@patch("services.app_task_service.GraphEngineManager")
|
||||
@patch("services.app_task_service.AppQueueManager")
|
||||
def test_stop_task_legacy_mechanism_called_even_if_graph_engine_fails(
|
||||
self, mock_app_queue_manager, mock_graph_engine_manager
|
||||
):
|
||||
"""Test that legacy Redis flag is set even if GraphEngine fails.
|
||||
|
||||
This ensures backward compatibility: the legacy mechanism should complete
|
||||
before attempting the GraphEngine command, so the stop flag is set
|
||||
regardless of GraphEngine success.
|
||||
"""
|
||||
# Arrange
|
||||
task_id = "task-123"
|
||||
invoke_from = InvokeFrom.WEB_APP
|
||||
user_id = "user-456"
|
||||
app_mode = AppMode.ADVANCED_CHAT
|
||||
|
||||
# Simulate GraphEngine failure
|
||||
mock_graph_engine_manager.send_stop_command.side_effect = Exception("GraphEngine error")
|
||||
|
||||
# Act & Assert - should raise the exception since it's not caught
|
||||
with pytest.raises(Exception, match="GraphEngine error"):
|
||||
AppTaskService.stop_task(task_id, invoke_from, user_id, app_mode)
|
||||
|
||||
# Verify legacy mechanism was still called before the exception
|
||||
mock_app_queue_manager.set_stop_flag.assert_called_once_with(task_id, invoke_from, user_id)
|
||||
Loading…
x
Reference in New Issue
Block a user