mirror of
https://github.com/langgenius/dify.git
synced 2025-08-03 23:08:06 +00:00
381 lines
13 KiB
Python
381 lines
13 KiB
Python
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
from flask import Flask
|
|
from flask_login import LoginManager, UserMixin
|
|
|
|
from controllers.console.error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
|
|
from controllers.console.workspace.error import AccountNotInitializedError
|
|
from controllers.console.wraps import (
|
|
account_initialization_required,
|
|
cloud_edition_billing_rate_limit_check,
|
|
cloud_edition_billing_resource_check,
|
|
enterprise_license_required,
|
|
only_edition_cloud,
|
|
only_edition_enterprise,
|
|
only_edition_self_hosted,
|
|
setup_required,
|
|
)
|
|
from models.account import AccountStatus
|
|
from services.feature_service import LicenseStatus
|
|
|
|
|
|
class MockUser(UserMixin):
|
|
"""Simple User class for testing."""
|
|
|
|
def __init__(self, user_id: str):
|
|
self.id = user_id
|
|
self.current_tenant_id = "tenant123"
|
|
|
|
def get_id(self) -> str:
|
|
return self.id
|
|
|
|
|
|
def create_app_with_login():
|
|
"""Create a Flask app with LoginManager configured."""
|
|
app = Flask(__name__)
|
|
app.config["SECRET_KEY"] = "test-secret-key"
|
|
|
|
login_manager = LoginManager()
|
|
login_manager.init_app(app)
|
|
|
|
@login_manager.user_loader
|
|
def load_user(user_id: str):
|
|
return MockUser(user_id)
|
|
|
|
return app
|
|
|
|
|
|
class TestAccountInitialization:
|
|
"""Test account initialization decorator"""
|
|
|
|
def test_should_allow_initialized_account(self):
|
|
"""Test that initialized accounts can access protected views"""
|
|
# Arrange
|
|
mock_user = MagicMock()
|
|
mock_user.status = AccountStatus.ACTIVE
|
|
|
|
@account_initialization_required
|
|
def protected_view():
|
|
return "success"
|
|
|
|
# Act
|
|
with patch("controllers.console.wraps.current_user", mock_user):
|
|
result = protected_view()
|
|
|
|
# Assert
|
|
assert result == "success"
|
|
|
|
def test_should_reject_uninitialized_account(self):
|
|
"""Test that uninitialized accounts raise AccountNotInitializedError"""
|
|
# Arrange
|
|
mock_user = MagicMock()
|
|
mock_user.status = AccountStatus.UNINITIALIZED
|
|
|
|
@account_initialization_required
|
|
def protected_view():
|
|
return "success"
|
|
|
|
# Act & Assert
|
|
with patch("controllers.console.wraps.current_user", mock_user):
|
|
with pytest.raises(AccountNotInitializedError):
|
|
protected_view()
|
|
|
|
|
|
class TestEditionChecks:
|
|
"""Test edition-specific decorators"""
|
|
|
|
def test_only_edition_cloud_allows_cloud_edition(self):
|
|
"""Test cloud edition decorator allows CLOUD edition"""
|
|
|
|
# Arrange
|
|
@only_edition_cloud
|
|
def cloud_view():
|
|
return "cloud_success"
|
|
|
|
# Act
|
|
with patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"):
|
|
result = cloud_view()
|
|
|
|
# Assert
|
|
assert result == "cloud_success"
|
|
|
|
def test_only_edition_cloud_rejects_other_editions(self):
|
|
"""Test cloud edition decorator rejects non-CLOUD editions"""
|
|
# Arrange
|
|
app = Flask(__name__)
|
|
|
|
@only_edition_cloud
|
|
def cloud_view():
|
|
return "cloud_success"
|
|
|
|
# Act & Assert
|
|
with app.test_request_context():
|
|
with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
|
|
with pytest.raises(Exception) as exc_info:
|
|
cloud_view()
|
|
assert exc_info.value.code == 404
|
|
|
|
def test_only_edition_enterprise_allows_when_enabled(self):
|
|
"""Test enterprise edition decorator allows when ENTERPRISE_ENABLED is True"""
|
|
|
|
# Arrange
|
|
@only_edition_enterprise
|
|
def enterprise_view():
|
|
return "enterprise_success"
|
|
|
|
# Act
|
|
with patch("controllers.console.wraps.dify_config.ENTERPRISE_ENABLED", True):
|
|
result = enterprise_view()
|
|
|
|
# Assert
|
|
assert result == "enterprise_success"
|
|
|
|
def test_only_edition_self_hosted_allows_self_hosted(self):
|
|
"""Test self-hosted edition decorator allows SELF_HOSTED edition"""
|
|
|
|
# Arrange
|
|
@only_edition_self_hosted
|
|
def self_hosted_view():
|
|
return "self_hosted_success"
|
|
|
|
# Act
|
|
with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
|
|
result = self_hosted_view()
|
|
|
|
# Assert
|
|
assert result == "self_hosted_success"
|
|
|
|
|
|
class TestBillingResourceLimits:
|
|
"""Test billing resource limit decorators"""
|
|
|
|
def test_should_allow_when_under_resource_limit(self):
|
|
"""Test that requests are allowed when under resource limits"""
|
|
# Arrange
|
|
mock_features = MagicMock()
|
|
mock_features.billing.enabled = True
|
|
mock_features.members.limit = 10
|
|
mock_features.members.size = 5
|
|
|
|
@cloud_edition_billing_resource_check("members")
|
|
def add_member():
|
|
return "member_added"
|
|
|
|
# Act
|
|
with patch("controllers.console.wraps.current_user"):
|
|
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
|
result = add_member()
|
|
|
|
# Assert
|
|
assert result == "member_added"
|
|
|
|
def test_should_reject_when_over_resource_limit(self):
|
|
"""Test that requests are rejected when over resource limits"""
|
|
# Arrange
|
|
app = create_app_with_login()
|
|
mock_features = MagicMock()
|
|
mock_features.billing.enabled = True
|
|
mock_features.members.limit = 10
|
|
mock_features.members.size = 10
|
|
|
|
@cloud_edition_billing_resource_check("members")
|
|
def add_member():
|
|
return "member_added"
|
|
|
|
# Act & Assert
|
|
with app.test_request_context():
|
|
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
|
|
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
|
with pytest.raises(Exception) as exc_info:
|
|
add_member()
|
|
assert exc_info.value.code == 403
|
|
assert "members has reached the limit" in str(exc_info.value.description)
|
|
|
|
def test_should_check_source_for_documents_limit(self):
|
|
"""Test document limit checks request source"""
|
|
# Arrange
|
|
app = create_app_with_login()
|
|
mock_features = MagicMock()
|
|
mock_features.billing.enabled = True
|
|
mock_features.documents_upload_quota.limit = 100
|
|
mock_features.documents_upload_quota.size = 100
|
|
|
|
@cloud_edition_billing_resource_check("documents")
|
|
def upload_document():
|
|
return "document_uploaded"
|
|
|
|
# Test 1: Should reject when source is datasets
|
|
with app.test_request_context("/?source=datasets"):
|
|
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
|
|
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
|
with pytest.raises(Exception) as exc_info:
|
|
upload_document()
|
|
assert exc_info.value.code == 403
|
|
|
|
# Test 2: Should allow when source is not datasets
|
|
with app.test_request_context("/?source=other"):
|
|
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
|
|
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
|
result = upload_document()
|
|
assert result == "document_uploaded"
|
|
|
|
|
|
class TestRateLimiting:
|
|
"""Test rate limiting decorator"""
|
|
|
|
@patch("controllers.console.wraps.redis_client")
|
|
@patch("controllers.console.wraps.db")
|
|
def test_should_allow_requests_within_rate_limit(self, mock_db, mock_redis):
|
|
"""Test that requests within rate limit are allowed"""
|
|
# Arrange
|
|
mock_rate_limit = MagicMock()
|
|
mock_rate_limit.enabled = True
|
|
mock_rate_limit.limit = 10
|
|
mock_redis.zcard.return_value = 5 # 5 requests in window
|
|
|
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
|
def knowledge_request():
|
|
return "knowledge_success"
|
|
|
|
# Act
|
|
with patch("controllers.console.wraps.current_user"):
|
|
with patch(
|
|
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
|
|
):
|
|
result = knowledge_request()
|
|
|
|
# Assert
|
|
assert result == "knowledge_success"
|
|
mock_redis.zadd.assert_called_once()
|
|
mock_redis.zremrangebyscore.assert_called_once()
|
|
|
|
@patch("controllers.console.wraps.redis_client")
|
|
@patch("controllers.console.wraps.db")
|
|
def test_should_reject_requests_over_rate_limit(self, mock_db, mock_redis):
|
|
"""Test that requests over rate limit are rejected and logged"""
|
|
# Arrange
|
|
app = create_app_with_login()
|
|
mock_rate_limit = MagicMock()
|
|
mock_rate_limit.enabled = True
|
|
mock_rate_limit.limit = 10
|
|
mock_rate_limit.subscription_plan = "pro"
|
|
mock_redis.zcard.return_value = 11 # Over limit
|
|
|
|
mock_session = MagicMock()
|
|
mock_db.session = mock_session
|
|
|
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
|
def knowledge_request():
|
|
return "knowledge_success"
|
|
|
|
# Act & Assert
|
|
with app.test_request_context():
|
|
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
|
|
with patch(
|
|
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
|
|
):
|
|
with pytest.raises(Exception) as exc_info:
|
|
knowledge_request()
|
|
|
|
# Verify error
|
|
assert exc_info.value.code == 403
|
|
assert "rate limit" in str(exc_info.value.description)
|
|
|
|
# Verify rate limit log was created
|
|
mock_session.add.assert_called_once()
|
|
mock_session.commit.assert_called_once()
|
|
|
|
|
|
class TestSystemSetup:
|
|
"""Test system setup decorator"""
|
|
|
|
@patch("controllers.console.wraps.db")
|
|
def test_should_allow_when_setup_complete(self, mock_db):
|
|
"""Test that requests are allowed when setup is complete"""
|
|
# Arrange
|
|
mock_db.session.query.return_value.first.return_value = MagicMock() # Setup exists
|
|
|
|
@setup_required
|
|
def admin_view():
|
|
return "admin_success"
|
|
|
|
# Act
|
|
with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
|
|
result = admin_view()
|
|
|
|
# Assert
|
|
assert result == "admin_success"
|
|
|
|
@patch("controllers.console.wraps.db")
|
|
@patch("controllers.console.wraps.os.environ.get")
|
|
def test_should_raise_not_init_validate_error_with_init_password(self, mock_environ_get, mock_db):
|
|
"""Test NotInitValidateError when INIT_PASSWORD is set but setup not complete"""
|
|
# Arrange
|
|
mock_db.session.query.return_value.first.return_value = None # No setup
|
|
mock_environ_get.return_value = "some_password"
|
|
|
|
@setup_required
|
|
def admin_view():
|
|
return "admin_success"
|
|
|
|
# Act & Assert
|
|
with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
|
|
with pytest.raises(NotInitValidateError):
|
|
admin_view()
|
|
|
|
@patch("controllers.console.wraps.db")
|
|
@patch("controllers.console.wraps.os.environ.get")
|
|
def test_should_raise_not_setup_error_without_init_password(self, mock_environ_get, mock_db):
|
|
"""Test NotSetupError when no INIT_PASSWORD and setup not complete"""
|
|
# Arrange
|
|
mock_db.session.query.return_value.first.return_value = None # No setup
|
|
mock_environ_get.return_value = None # No INIT_PASSWORD
|
|
|
|
@setup_required
|
|
def admin_view():
|
|
return "admin_success"
|
|
|
|
# Act & Assert
|
|
with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
|
|
with pytest.raises(NotSetupError):
|
|
admin_view()
|
|
|
|
|
|
class TestEnterpriseLicense:
|
|
"""Test enterprise license decorator"""
|
|
|
|
def test_should_allow_with_valid_license(self):
|
|
"""Test that valid licenses allow access"""
|
|
# Arrange
|
|
mock_settings = MagicMock()
|
|
mock_settings.license.status = LicenseStatus.ACTIVE
|
|
|
|
@enterprise_license_required
|
|
def enterprise_feature():
|
|
return "enterprise_success"
|
|
|
|
# Act
|
|
with patch("controllers.console.wraps.FeatureService.get_system_features", return_value=mock_settings):
|
|
result = enterprise_feature()
|
|
|
|
# Assert
|
|
assert result == "enterprise_success"
|
|
|
|
@pytest.mark.parametrize("invalid_status", [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST])
|
|
def test_should_reject_with_invalid_license(self, invalid_status):
|
|
"""Test that invalid licenses raise UnauthorizedAndForceLogout"""
|
|
# Arrange
|
|
mock_settings = MagicMock()
|
|
mock_settings.license.status = invalid_status
|
|
|
|
@enterprise_license_required
|
|
def enterprise_feature():
|
|
return "enterprise_success"
|
|
|
|
# Act & Assert
|
|
with patch("controllers.console.wraps.FeatureService.get_system_features", return_value=mock_settings):
|
|
with pytest.raises(UnauthorizedAndForceLogout) as exc_info:
|
|
enterprise_feature()
|
|
assert "license is invalid" in str(exc_info.value)
|