from unittest.mock import MagicMock, patch import pytest from flask import Flask from flask_login import LoginManager, UserMixin from controllers.console.error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout from controllers.console.workspace.error import AccountNotInitializedError from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_rate_limit_check, cloud_edition_billing_resource_check, enterprise_license_required, only_edition_cloud, only_edition_enterprise, only_edition_self_hosted, setup_required, ) from models.account import AccountStatus from services.feature_service import LicenseStatus class MockUser(UserMixin): """Simple User class for testing.""" def __init__(self, user_id: str): self.id = user_id self.current_tenant_id = "tenant123" def get_id(self) -> str: return self.id def create_app_with_login(): """Create a Flask app with LoginManager configured.""" app = Flask(__name__) app.config["SECRET_KEY"] = "test-secret-key" login_manager = LoginManager() login_manager.init_app(app) @login_manager.user_loader def load_user(user_id: str): return MockUser(user_id) return app class TestAccountInitialization: """Test account initialization decorator""" def test_should_allow_initialized_account(self): """Test that initialized accounts can access protected views""" # Arrange mock_user = MagicMock() mock_user.status = AccountStatus.ACTIVE @account_initialization_required def protected_view(): return "success" # Act with patch("controllers.console.wraps.current_user", mock_user): result = protected_view() # Assert assert result == "success" def test_should_reject_uninitialized_account(self): """Test that uninitialized accounts raise AccountNotInitializedError""" # Arrange mock_user = MagicMock() mock_user.status = AccountStatus.UNINITIALIZED @account_initialization_required def protected_view(): return "success" # Act & Assert with patch("controllers.console.wraps.current_user", mock_user): with pytest.raises(AccountNotInitializedError): protected_view() class TestEditionChecks: """Test edition-specific decorators""" def test_only_edition_cloud_allows_cloud_edition(self): """Test cloud edition decorator allows CLOUD edition""" # Arrange @only_edition_cloud def cloud_view(): return "cloud_success" # Act with patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"): result = cloud_view() # Assert assert result == "cloud_success" def test_only_edition_cloud_rejects_other_editions(self): """Test cloud edition decorator rejects non-CLOUD editions""" # Arrange app = Flask(__name__) @only_edition_cloud def cloud_view(): return "cloud_success" # Act & Assert with app.test_request_context(): with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"): with pytest.raises(Exception) as exc_info: cloud_view() assert exc_info.value.code == 404 def test_only_edition_enterprise_allows_when_enabled(self): """Test enterprise edition decorator allows when ENTERPRISE_ENABLED is True""" # Arrange @only_edition_enterprise def enterprise_view(): return "enterprise_success" # Act with patch("controllers.console.wraps.dify_config.ENTERPRISE_ENABLED", True): result = enterprise_view() # Assert assert result == "enterprise_success" def test_only_edition_self_hosted_allows_self_hosted(self): """Test self-hosted edition decorator allows SELF_HOSTED edition""" # Arrange @only_edition_self_hosted def self_hosted_view(): return "self_hosted_success" # Act with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"): result = self_hosted_view() # Assert assert result == "self_hosted_success" class TestBillingResourceLimits: """Test billing resource limit decorators""" def test_should_allow_when_under_resource_limit(self): """Test that requests are allowed when under resource limits""" # Arrange mock_features = MagicMock() mock_features.billing.enabled = True mock_features.members.limit = 10 mock_features.members.size = 5 @cloud_edition_billing_resource_check("members") def add_member(): return "member_added" # Act with patch("controllers.console.wraps.current_user"): with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features): result = add_member() # Assert assert result == "member_added" def test_should_reject_when_over_resource_limit(self): """Test that requests are rejected when over resource limits""" # Arrange app = create_app_with_login() mock_features = MagicMock() mock_features.billing.enabled = True mock_features.members.limit = 10 mock_features.members.size = 10 @cloud_edition_billing_resource_check("members") def add_member(): return "member_added" # Act & Assert with app.test_request_context(): with patch("controllers.console.wraps.current_user", MockUser("test_user")): with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features): with pytest.raises(Exception) as exc_info: add_member() assert exc_info.value.code == 403 assert "members has reached the limit" in str(exc_info.value.description) def test_should_check_source_for_documents_limit(self): """Test document limit checks request source""" # Arrange app = create_app_with_login() mock_features = MagicMock() mock_features.billing.enabled = True mock_features.documents_upload_quota.limit = 100 mock_features.documents_upload_quota.size = 100 @cloud_edition_billing_resource_check("documents") def upload_document(): return "document_uploaded" # Test 1: Should reject when source is datasets with app.test_request_context("/?source=datasets"): with patch("controllers.console.wraps.current_user", MockUser("test_user")): with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features): with pytest.raises(Exception) as exc_info: upload_document() assert exc_info.value.code == 403 # Test 2: Should allow when source is not datasets with app.test_request_context("/?source=other"): with patch("controllers.console.wraps.current_user", MockUser("test_user")): with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features): result = upload_document() assert result == "document_uploaded" class TestRateLimiting: """Test rate limiting decorator""" @patch("controllers.console.wraps.redis_client") @patch("controllers.console.wraps.db") def test_should_allow_requests_within_rate_limit(self, mock_db, mock_redis): """Test that requests within rate limit are allowed""" # Arrange mock_rate_limit = MagicMock() mock_rate_limit.enabled = True mock_rate_limit.limit = 10 mock_redis.zcard.return_value = 5 # 5 requests in window @cloud_edition_billing_rate_limit_check("knowledge") def knowledge_request(): return "knowledge_success" # Act with patch("controllers.console.wraps.current_user"): with patch( "controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit ): result = knowledge_request() # Assert assert result == "knowledge_success" mock_redis.zadd.assert_called_once() mock_redis.zremrangebyscore.assert_called_once() @patch("controllers.console.wraps.redis_client") @patch("controllers.console.wraps.db") def test_should_reject_requests_over_rate_limit(self, mock_db, mock_redis): """Test that requests over rate limit are rejected and logged""" # Arrange app = create_app_with_login() mock_rate_limit = MagicMock() mock_rate_limit.enabled = True mock_rate_limit.limit = 10 mock_rate_limit.subscription_plan = "pro" mock_redis.zcard.return_value = 11 # Over limit mock_session = MagicMock() mock_db.session = mock_session @cloud_edition_billing_rate_limit_check("knowledge") def knowledge_request(): return "knowledge_success" # Act & Assert with app.test_request_context(): with patch("controllers.console.wraps.current_user", MockUser("test_user")): with patch( "controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit ): with pytest.raises(Exception) as exc_info: knowledge_request() # Verify error assert exc_info.value.code == 403 assert "rate limit" in str(exc_info.value.description) # Verify rate limit log was created mock_session.add.assert_called_once() mock_session.commit.assert_called_once() class TestSystemSetup: """Test system setup decorator""" @patch("controllers.console.wraps.db") def test_should_allow_when_setup_complete(self, mock_db): """Test that requests are allowed when setup is complete""" # Arrange mock_db.session.query.return_value.first.return_value = MagicMock() # Setup exists @setup_required def admin_view(): return "admin_success" # Act with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"): result = admin_view() # Assert assert result == "admin_success" @patch("controllers.console.wraps.db") @patch("controllers.console.wraps.os.environ.get") def test_should_raise_not_init_validate_error_with_init_password(self, mock_environ_get, mock_db): """Test NotInitValidateError when INIT_PASSWORD is set but setup not complete""" # Arrange mock_db.session.query.return_value.first.return_value = None # No setup mock_environ_get.return_value = "some_password" @setup_required def admin_view(): return "admin_success" # Act & Assert with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"): with pytest.raises(NotInitValidateError): admin_view() @patch("controllers.console.wraps.db") @patch("controllers.console.wraps.os.environ.get") def test_should_raise_not_setup_error_without_init_password(self, mock_environ_get, mock_db): """Test NotSetupError when no INIT_PASSWORD and setup not complete""" # Arrange mock_db.session.query.return_value.first.return_value = None # No setup mock_environ_get.return_value = None # No INIT_PASSWORD @setup_required def admin_view(): return "admin_success" # Act & Assert with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"): with pytest.raises(NotSetupError): admin_view() class TestEnterpriseLicense: """Test enterprise license decorator""" def test_should_allow_with_valid_license(self): """Test that valid licenses allow access""" # Arrange mock_settings = MagicMock() mock_settings.license.status = LicenseStatus.ACTIVE @enterprise_license_required def enterprise_feature(): return "enterprise_success" # Act with patch("controllers.console.wraps.FeatureService.get_system_features", return_value=mock_settings): result = enterprise_feature() # Assert assert result == "enterprise_success" @pytest.mark.parametrize("invalid_status", [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]) def test_should_reject_with_invalid_license(self, invalid_status): """Test that invalid licenses raise UnauthorizedAndForceLogout""" # Arrange mock_settings = MagicMock() mock_settings.license.status = invalid_status @enterprise_license_required def enterprise_feature(): return "enterprise_success" # Act & Assert with patch("controllers.console.wraps.FeatureService.get_system_features", return_value=mock_settings): with pytest.raises(UnauthorizedAndForceLogout) as exc_info: enterprise_feature() assert "license is invalid" in str(exc_info.value)