from datetime import datetime, timedelta, timezone from unittest.mock import MagicMock, patch import pytest from datahub.ingestion.source.aws.aws_common import ( AwsConnectionConfig, RDSIAMTokenManager, generate_rds_iam_token, ) class TestGenerateRDSIAMToken: def test_generate_token_success(self): mock_client = MagicMock() full_token = "test.rds.amazonaws.com:5432/?Action=connect&DBUser=testuser&X-Amz-Algorithm=AWS4-HMAC-SHA256" mock_client.generate_db_auth_token.return_value = full_token mock_session = MagicMock() mock_session.client.return_value = mock_client aws_config = AwsConnectionConfig(aws_region="us-west-2") with patch.object( AwsConnectionConfig, "get_session", return_value=mock_session ): token = generate_rds_iam_token( endpoint="test.rds.amazonaws.com", username="testuser", port=5432, aws_config=aws_config, ) assert token == full_token mock_client.generate_db_auth_token.assert_called_with( DBHostname="test.rds.amazonaws.com", Port=5432, DBUsername="testuser" ) def test_generate_token_no_credentials_error(self): from botocore.exceptions import NoCredentialsError mock_session = MagicMock() mock_session.client.side_effect = NoCredentialsError() aws_config = AwsConnectionConfig(aws_region="us-west-2") with ( patch.object(AwsConnectionConfig, "get_session", return_value=mock_session), pytest.raises(ValueError, match="AWS credentials not found"), ): generate_rds_iam_token( endpoint="test.rds.amazonaws.com", username="testuser", port=5432, aws_config=aws_config, ) def test_generate_token_client_error(self): from botocore.exceptions import ClientError mock_client = MagicMock() error_response = { "Error": {"Code": "InvalidParameterValue"}, "ResponseMetadata": {"HTTPStatusCode": 400}, } mock_client.generate_db_auth_token.side_effect = ClientError( error_response, # type: ignore "generate_db_auth_token", ) mock_session = MagicMock() mock_session.client.return_value = mock_client aws_config = AwsConnectionConfig(aws_region="us-west-2") with ( patch.object(AwsConnectionConfig, "get_session", return_value=mock_session), pytest.raises(ValueError, match="Failed to generate RDS IAM token"), ): generate_rds_iam_token( endpoint="test.rds.amazonaws.com", username="testuser", port=5432, aws_config=aws_config, ) def test_generate_token_returns_full_url(self): """Test that full presigned URL is returned from boto3.""" mock_client = MagicMock() full_token = ( "database-1.cluster-xxx.us-west-2.rds.amazonaws.com:3306/" "?Action=connect&DBUser=iam-user&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIOSFODNN7EXAMPLE" ) mock_client.generate_db_auth_token.return_value = full_token mock_session = MagicMock() mock_session.client.return_value = mock_client aws_config = AwsConnectionConfig(aws_region="us-west-2") with patch.object( AwsConnectionConfig, "get_session", return_value=mock_session ): token = generate_rds_iam_token( endpoint="database-1.cluster-xxx.us-west-2.rds.amazonaws.com", username="iam-user", port=3306, aws_config=aws_config, ) assert token == full_token assert "database-1.cluster-xxx.us-west-2.rds.amazonaws.com" in token assert "3306" in token assert "DBUser=iam-user" in token assert "X-Amz-Algorithm" in token class TestRDSIAMTokenManager: def test_init(self): aws_config = AwsConnectionConfig(aws_region="us-west-2") manager = RDSIAMTokenManager( endpoint="test.rds.amazonaws.com", username="testuser", port=5432, aws_config=aws_config, refresh_threshold_minutes=5, ) assert manager.refresh_threshold == timedelta(minutes=5) assert manager.endpoint == "test.rds.amazonaws.com" assert manager.username == "testuser" assert manager.port == 5432 def test_needs_refresh_no_token(self): aws_config = AwsConnectionConfig(aws_region="us-west-2") manager = RDSIAMTokenManager( endpoint="test.rds.amazonaws.com", username="testuser", port=5432, aws_config=aws_config, ) assert manager._needs_refresh() is True def test_needs_refresh_token_expired(self): aws_config = AwsConnectionConfig(aws_region="us-west-2") manager = RDSIAMTokenManager( endpoint="test.rds.amazonaws.com", username="testuser", port=5432, aws_config=aws_config, ) manager._current_token = "old-token" manager._token_expires_at = datetime.now(timezone.utc) - timedelta(minutes=1) assert manager._needs_refresh() is True def test_needs_refresh_token_valid(self): aws_config = AwsConnectionConfig(aws_region="us-west-2") manager = RDSIAMTokenManager( endpoint="test.rds.amazonaws.com", username="testuser", port=5432, aws_config=aws_config, ) manager._current_token = "valid-token" manager._token_expires_at = datetime.now(timezone.utc) + timedelta(minutes=12) assert manager._needs_refresh() is False def test_get_token_refresh_needed(self): aws_config = AwsConnectionConfig(aws_region="us-west-2") # Token with X-Amz-Date and X-Amz-Expires for parsing full_token = ( "test.rds.amazonaws.com:5432/?Action=connect&DBUser=testuser" "&X-Amz-Date=20250101T120000Z&X-Amz-Expires=900" ) with patch( "datahub.ingestion.source.aws.aws_common.generate_rds_iam_token" ) as mock_gen: mock_gen.return_value = full_token manager = RDSIAMTokenManager( endpoint="test.rds.amazonaws.com", username="testuser", port=5432, aws_config=aws_config, ) token = manager.get_token() assert token == full_token assert manager._current_token == full_token assert manager._token_expires_at is not None mock_gen.assert_called_once() def test_parse_token_expiry(self): """Test parsing X-Amz-Date and X-Amz-Expires from token URL.""" aws_config = AwsConnectionConfig(aws_region="us-west-2") manager = RDSIAMTokenManager( endpoint="test.rds.amazonaws.com", username="testuser", port=5432, aws_config=aws_config, ) # Token issued at 2025-01-01 12:00:00 UTC, expires in 900 seconds (15 minutes) token = ( "test.rds.amazonaws.com:5432/?Action=connect&DBUser=testuser" "&X-Amz-Date=20250101T120000Z&X-Amz-Expires=900" ) expiry = manager._parse_token_expiry(token) expected_expiry = datetime(2025, 1, 1, 12, 15, 0, tzinfo=timezone.utc) assert expiry == expected_expiry def test_parse_token_expiry_missing_date(self): """Test error handling when X-Amz-Date is missing.""" aws_config = AwsConnectionConfig(aws_region="us-west-2") manager = RDSIAMTokenManager( endpoint="test.rds.amazonaws.com", username="testuser", port=5432, aws_config=aws_config, ) token = "test.rds.amazonaws.com:5432/?Action=connect&DBUser=testuser&X-Amz-Expires=900" with pytest.raises(ValueError, match="Missing X-Amz-Date"): manager._parse_token_expiry(token) def test_get_token_automatically_refreshes_expired_token(self): """Test that get_token() automatically refreshes when token is expired.""" aws_config = AwsConnectionConfig(aws_region="us-west-2") old_token = ( "test.rds.amazonaws.com:5432/?Action=connect&DBUser=testuser" "&X-Amz-Date=20250101T120000Z&X-Amz-Expires=900" ) new_token = ( "test.rds.amazonaws.com:5432/?Action=connect&DBUser=testuser" "&X-Amz-Date=20250101T130000Z&X-Amz-Expires=900" ) with patch( "datahub.ingestion.source.aws.aws_common.generate_rds_iam_token" ) as mock_gen: mock_gen.return_value = new_token manager = RDSIAMTokenManager( endpoint="test.rds.amazonaws.com", username="testuser", port=5432, aws_config=aws_config, ) # Set an expired token manually manager._current_token = old_token manager._token_expires_at = datetime.now(timezone.utc) - timedelta( minutes=1 ) # Call get_token() - should automatically refresh token = manager.get_token() # Verify new token was generated assert token == new_token assert manager._current_token == new_token mock_gen.assert_called_once_with( endpoint="test.rds.amazonaws.com", username="testuser", port=5432, aws_config=aws_config, ) def test_get_token_reuses_valid_token(self): """Test that get_token() reuses token when still valid.""" aws_config = AwsConnectionConfig(aws_region="us-west-2") valid_token = ( "test.rds.amazonaws.com:5432/?Action=connect&DBUser=testuser" "&X-Amz-Date=20250101T120000Z&X-Amz-Expires=900" ) with patch( "datahub.ingestion.source.aws.aws_common.generate_rds_iam_token" ) as mock_gen: manager = RDSIAMTokenManager( endpoint="test.rds.amazonaws.com", username="testuser", port=5432, aws_config=aws_config, ) # Set a valid token that won't expire soon manager._current_token = valid_token manager._token_expires_at = datetime.now(timezone.utc) + timedelta( minutes=20 ) # Call get_token() - should reuse existing token token = manager.get_token() # Verify existing token was returned without generating new one assert token == valid_token assert manager._current_token == valid_token mock_gen.assert_not_called()