""" Tests for edge cases and error paths in the masking framework to improve coverage. """ import logging import threading import pytest from datahub.masking.bootstrap import ( get_bootstrap_error, initialize_secret_masking, is_bootstrapped, shutdown_secret_masking, ) from datahub.masking.logging_utils import ( get_masking_safe_logger, reset_masking_safe_loggers, ) from datahub.masking.masking_filter import SecretMaskingFilter from datahub.masking.secret_registry import SecretRegistry def test_imports_from_init(): """Test that all exports from __init__.py are accessible.""" # Import entire module to ensure __init__.py is executed import datahub.masking as masking_module # Verify all exports exist assert hasattr(masking_module, "SecretMaskingFilter") assert hasattr(masking_module, "StreamMaskingWrapper") assert hasattr(masking_module, "install_masking_filter") assert hasattr(masking_module, "uninstall_masking_filter") assert hasattr(masking_module, "SecretRegistry") assert hasattr(masking_module, "is_masking_enabled") assert hasattr(masking_module, "initialize_secret_masking") assert hasattr(masking_module, "get_masking_safe_logger") class TestMaskingFilterEdgeCases: """Test edge cases in masking filter.""" def setup_method(self): shutdown_secret_masking() SecretRegistry.reset_instance() def teardown_method(self): shutdown_secret_masking() SecretRegistry.reset_instance() def test_pattern_rebuild_with_concurrent_modifications(self): """Test pattern rebuild when secrets are modified during rebuild.""" registry = SecretRegistry.get_instance() masking_filter = SecretMaskingFilter(registry) # Add many secrets for i in range(100): registry.register_secret(f"SECRET_{i}", f"value_{i}") # Force pattern rebuild masking_filter._check_and_rebuild_pattern() # Verify pattern was rebuilt assert masking_filter._last_version > 0 def test_masking_with_very_long_message(self): """Test masking with messages exceeding max_message_size.""" registry = SecretRegistry.get_instance() registry.register_secret("PASSWORD", "secret123") masking_filter = SecretMaskingFilter(registry, max_message_size=100) # Create a very long message long_message = "x" * 200 + "secret123" + "y" * 200 record = logging.LogRecord( name="test", level=logging.INFO, pathname="", lineno=0, msg=long_message, args=(), exc_info=None, ) masking_filter.filter(record) # Message should be truncated and masked masked = record.getMessage() assert "secret123" not in masked assert len(masked) <= 100 + 50 # max_message_size + some buffer for redaction def test_masking_with_formatted_args(self): """Test masking with % formatting args.""" registry = SecretRegistry.get_instance() registry.register_secret("PASSWORD", "secret123") masking_filter = SecretMaskingFilter(registry) # Test with tuple args record = logging.LogRecord( name="test", level=logging.INFO, pathname="", lineno=0, msg="Password: %s, User: %s", args=("secret123", "admin"), exc_info=None, ) masking_filter.filter(record) # Both msg and args should be processed assert "secret123" not in str(record.args) assert "***REDACTED:PASSWORD***" in record.getMessage() class TestBootstrapEdgeCases: """Test edge cases in bootstrap.""" def setup_method(self): shutdown_secret_masking() def teardown_method(self): shutdown_secret_masking() def test_double_initialization(self): """Test that double initialization is handled gracefully.""" initialize_secret_masking() assert is_bootstrapped() # Second initialization should be no-op initialize_secret_masking() assert is_bootstrapped() def test_force_reinitialization(self): """Test force re-initialization.""" initialize_secret_masking() assert is_bootstrapped() # Force re-init initialize_secret_masking(force=True) assert is_bootstrapped() def test_bootstrap_error_cleared_on_success(self): """Test that bootstrap error is cleared after successful init.""" initialize_secret_masking() assert get_bootstrap_error() is None class TestLoggingUtils: """Test logging utilities.""" def test_masking_safe_logger_multiple_calls(self): """Test that getting the same logger multiple times doesn't add duplicate handlers.""" logger1 = get_masking_safe_logger("test.logger") handler_count_1 = len(logger1.handlers) logger2 = get_masking_safe_logger("test.logger") handler_count_2 = len(logger2.handlers) # Should be the same logger assert logger1 is logger2 # Should not have duplicate handlers assert handler_count_1 == handler_count_2 def test_reset_masking_safe_loggers(self): """Test resetting masking-safe loggers.""" # Create a masking-safe logger logger = get_masking_safe_logger("datahub.masking.test") assert not logger.propagate assert len(logger.handlers) > 0 # Reset reset_masking_safe_loggers() # Logger should be reset assert logger.propagate assert len(logger.handlers) == 0 class TestSecretRegistryEdgeCases: """Test edge cases in secret registry.""" def setup_method(self): SecretRegistry.reset_instance() def teardown_method(self): SecretRegistry.reset_instance() def test_register_empty_secret_name(self): """Test registering secret with empty name.""" registry = SecretRegistry.get_instance() # Empty name should be handled gracefully registry.register_secret("", "longsecretvalue123") # Empty names are actually accepted by the registry assert registry.has_secret("") assert registry.get_count() == 1 def test_register_very_short_secret(self): """Test registering very short secret values.""" registry = SecretRegistry.get_instance() # Very short secrets (< 8 chars) are still registered registry.register_secret("SHORT", "abc") # Short secrets are registered (no minimum length enforced) assert registry.has_secret("SHORT") assert registry.get_count() == 1 def test_concurrent_registration(self): """Test concurrent secret registration.""" registry = SecretRegistry.get_instance() errors = [] def register_secrets(start_idx: int) -> None: try: for i in range(start_idx, start_idx + 100): registry.register_secret(f"SECRET_{i}", f"secret_value_{i}") except Exception as e: errors.append(e) # Start multiple threads registering secrets concurrently threads = [] for i in range(5): t = threading.Thread(target=register_secrets, args=(i * 100,)) threads.append(t) t.start() # Wait for all threads for t in threads: t.join() # Should not have any errors assert len(errors) == 0 # All secrets should be registered assert registry.get_count() == 500 if __name__ == "__main__": pytest.main([__file__, "-v"])