471 lines
17 KiB
Python
Raw Permalink Normal View History

import os
from datetime import datetime, timedelta, timezone
from http import HTTPStatus
from unittest.mock import Mock, patch
import requests
from datahub.configuration.common import AllowDenyPattern
from datahub.ingestion.source.aws.aws_common import (
AwsAssumeRoleConfig,
AwsEnvironment,
AwsServicePrincipal,
AwsSourceConfig,
assume_role,
detect_aws_environment,
get_current_identity,
get_instance_metadata_token,
is_running_on_ec2,
)
class TestAwsEnvironment:
"""Tests for AwsEnvironment enum."""
def test_enum_values(self) -> None:
"""Test that all enum values are correctly defined."""
assert AwsEnvironment.EC2.value == "EC2"
assert AwsEnvironment.ECS.value == "ECS"
assert AwsEnvironment.EKS.value == "EKS"
assert AwsEnvironment.LAMBDA.value == "LAMBDA"
assert AwsEnvironment.APP_RUNNER.value == "APP_RUNNER"
assert AwsEnvironment.BEANSTALK.value == "ELASTIC_BEANSTALK"
assert AwsEnvironment.CLOUD_FORMATION.value == "CLOUD_FORMATION"
assert AwsEnvironment.UNKNOWN.value == "UNKNOWN"
class TestAwsServicePrincipal:
"""Tests for AwsServicePrincipal enum."""
def test_enum_values(self) -> None:
"""Test that all enum values are correctly defined."""
assert AwsServicePrincipal.LAMBDA.value == "lambda.amazonaws.com"
assert AwsServicePrincipal.EKS.value == "eks.amazonaws.com"
assert AwsServicePrincipal.APP_RUNNER.value == "apprunner.amazonaws.com"
assert AwsServicePrincipal.ECS.value == "ecs.amazonaws.com"
assert (
AwsServicePrincipal.ELASTIC_BEANSTALK.value
== "elasticbeanstalk.amazonaws.com"
)
assert AwsServicePrincipal.EC2.value == "ec2.amazonaws.com"
class TestAwsAssumeRoleConfig:
"""Tests for AwsAssumeRoleConfig class."""
def test_init_required_fields(self) -> None:
"""Test initialization with required fields only."""
config = AwsAssumeRoleConfig(RoleArn="arn:aws:iam::123456789012:role/TestRole")
assert config.RoleArn == "arn:aws:iam::123456789012:role/TestRole"
assert config.ExternalId is None
def test_init_all_fields(self) -> None:
"""Test initialization with all fields."""
config = AwsAssumeRoleConfig(
RoleArn="arn:aws:iam::123456789012:role/TestRole",
ExternalId="external-id-123",
)
assert config.RoleArn == "arn:aws:iam::123456789012:role/TestRole"
assert config.ExternalId == "external-id-123"
def test_dict_method(self) -> None:
"""Test dict() method returns correct values."""
config = AwsAssumeRoleConfig(
RoleArn="arn:aws:iam::123456789012:role/TestRole",
ExternalId="external-id-123",
)
config_dict = config.dict()
assert config_dict["RoleArn"] == "arn:aws:iam::123456789012:role/TestRole"
assert config_dict["ExternalId"] == "external-id-123"
class TestMetadataFunctions:
"""Tests for EC2 metadata functions."""
@patch("requests.put")
def test_get_instance_metadata_token_success(self, mock_put: Mock) -> None:
"""Test successful metadata token retrieval."""
mock_response = Mock()
mock_response.status_code = HTTPStatus.OK
mock_response.text = "test-token-123"
mock_put.return_value = mock_response
token = get_instance_metadata_token()
assert token == "test-token-123"
mock_put.assert_called_once_with(
"http://169.254.169.254/latest/api/token",
headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"},
timeout=1,
)
@patch("requests.put")
def test_get_instance_metadata_token_failure(self, mock_put: Mock) -> None:
"""Test metadata token retrieval failure."""
mock_put.side_effect = requests.exceptions.RequestException("Network error")
token = get_instance_metadata_token()
assert token is None
@patch("requests.put")
def test_get_instance_metadata_token_bad_status(self, mock_put: Mock) -> None:
"""Test metadata token retrieval with bad status code."""
mock_response = Mock()
mock_response.status_code = HTTPStatus.NOT_FOUND
mock_put.return_value = mock_response
token = get_instance_metadata_token()
assert token is None
@patch("datahub.ingestion.source.aws.aws_common.get_instance_metadata_token")
@patch("requests.get")
def test_is_running_on_ec2_success(self, mock_get: Mock, mock_token: Mock) -> None:
"""Test successful EC2 detection."""
mock_token.return_value = "test-token"
mock_response = Mock()
mock_response.status_code = HTTPStatus.OK
mock_get.return_value = mock_response
result = is_running_on_ec2()
assert result is True
mock_get.assert_called_once_with(
"http://169.254.169.254/latest/meta-data/instance-id",
headers={"X-aws-ec2-metadata-token": "test-token"},
timeout=1,
)
@patch("datahub.ingestion.source.aws.aws_common.get_instance_metadata_token")
def test_is_running_on_ec2_no_token(self, mock_token: Mock) -> None:
"""Test EC2 detection when no token is available."""
mock_token.return_value = None
result = is_running_on_ec2()
assert result is False
@patch("datahub.ingestion.source.aws.aws_common.get_instance_metadata_token")
@patch("requests.get")
def test_is_running_on_ec2_request_failure(
self, mock_get: Mock, mock_token: Mock
) -> None:
"""Test EC2 detection when request fails."""
mock_token.return_value = "test-token"
mock_get.side_effect = requests.exceptions.RequestException("Network error")
result = is_running_on_ec2()
assert result is False
class TestDetectAwsEnvironment:
"""Tests for detect_aws_environment function."""
def test_detect_lambda_environment(self) -> None:
"""Test Lambda environment detection."""
with patch.dict(
os.environ, {"AWS_LAMBDA_FUNCTION_NAME": "test-function"}, clear=True
):
result = detect_aws_environment()
assert result == AwsEnvironment.LAMBDA
def test_detect_cloud_formation_environment(self) -> None:
"""Test CloudFormation environment detection."""
with patch.dict(
os.environ,
{
"AWS_LAMBDA_FUNCTION_NAME": "test-function",
"AWS_EXECUTION_ENV": "CloudFormation-custom-resource",
},
clear=True,
):
result = detect_aws_environment()
assert result == AwsEnvironment.CLOUD_FORMATION
def test_detect_eks_environment(self) -> None:
"""Test EKS (IRSA) environment detection."""
with patch.dict(
os.environ,
{
"AWS_WEB_IDENTITY_TOKEN_FILE": "/var/run/secrets/eks.amazonaws.com/serviceaccount/token",
"AWS_ROLE_ARN": "arn:aws:iam::123456789012:role/eks-role",
},
clear=True,
):
result = detect_aws_environment()
assert result == AwsEnvironment.EKS
def test_detect_app_runner_environment(self) -> None:
"""Test App Runner environment detection."""
with patch.dict(
os.environ, {"AWS_APP_RUNNER_SERVICE_ID": "service-123"}, clear=True
):
result = detect_aws_environment()
assert result == AwsEnvironment.APP_RUNNER
def test_detect_ecs_environment_v4(self) -> None:
"""Test ECS environment detection with metadata URI v4."""
with patch.dict(
os.environ,
{"ECS_CONTAINER_METADATA_URI_V4": "http://169.254.170.2/v4/metadata"},
clear=True,
):
result = detect_aws_environment()
assert result == AwsEnvironment.ECS
def test_detect_ecs_environment_v3(self) -> None:
"""Test ECS environment detection with metadata URI v3."""
with patch.dict(
os.environ,
{"ECS_CONTAINER_METADATA_URI": "http://169.254.170.2/v3/metadata"},
clear=True,
):
result = detect_aws_environment()
assert result == AwsEnvironment.ECS
def test_detect_beanstalk_environment(self) -> None:
"""Test Elastic Beanstalk environment detection."""
with patch.dict(
os.environ, {"ELASTIC_BEANSTALK_ENVIRONMENT_NAME": "test-env"}, clear=True
):
result = detect_aws_environment()
assert result == AwsEnvironment.BEANSTALK
@patch("datahub.ingestion.source.aws.aws_common.is_running_on_ec2")
def test_detect_ec2_environment(self, mock_is_ec2: Mock) -> None:
"""Test EC2 environment detection."""
mock_is_ec2.return_value = True
with patch.dict(os.environ, {}, clear=True):
result = detect_aws_environment()
assert result == AwsEnvironment.EC2
@patch("datahub.ingestion.source.aws.aws_common.is_running_on_ec2")
def test_detect_unknown_environment(self, mock_is_ec2: Mock) -> None:
"""Test unknown environment detection."""
mock_is_ec2.return_value = False
with patch.dict(os.environ, {}, clear=True):
result = detect_aws_environment()
assert result == AwsEnvironment.UNKNOWN
class TestGetCurrentIdentity:
"""Tests for get_current_identity function."""
@patch("datahub.ingestion.source.aws.aws_common.detect_aws_environment")
@patch("datahub.ingestion.source.aws.aws_common.get_lambda_role_arn")
def test_get_lambda_identity(
self, mock_lambda_role: Mock, mock_detect: Mock
) -> None:
"""Test getting Lambda identity."""
mock_detect.return_value = AwsEnvironment.LAMBDA
mock_lambda_role.return_value = "arn:aws:iam::123456789012:role/lambda-role"
role_arn, source = get_current_identity()
assert role_arn == "arn:aws:iam::123456789012:role/lambda-role"
assert source == AwsServicePrincipal.LAMBDA.value
@patch("datahub.ingestion.source.aws.aws_common.detect_aws_environment")
def test_get_eks_identity(self, mock_detect: Mock) -> None:
"""Test getting EKS identity."""
mock_detect.return_value = AwsEnvironment.EKS
with patch.dict(
os.environ, {"AWS_ROLE_ARN": "arn:aws:iam::123456789012:role/eks-role"}
):
role_arn, source = get_current_identity()
assert role_arn == "arn:aws:iam::123456789012:role/eks-role"
assert source == AwsServicePrincipal.EKS.value
@patch("datahub.ingestion.source.aws.aws_common.detect_aws_environment")
@patch("boto3.client")
def test_get_app_runner_identity(
self, mock_boto_client: Mock, mock_detect: Mock
) -> None:
"""Test getting App Runner identity."""
mock_detect.return_value = AwsEnvironment.APP_RUNNER
mock_sts = Mock()
mock_sts.get_caller_identity.return_value = {
"Arn": "arn:aws:sts::123456789012:assumed-role/app-runner-role"
}
mock_boto_client.return_value = mock_sts
role_arn, source = get_current_identity()
assert role_arn == "arn:aws:sts::123456789012:assumed-role/app-runner-role"
assert source == AwsServicePrincipal.APP_RUNNER.value
@patch("datahub.ingestion.source.aws.aws_common.detect_aws_environment")
@patch("requests.get")
def test_get_ecs_identity(self, mock_get: Mock, mock_detect: Mock) -> None:
"""Test getting ECS identity."""
mock_detect.return_value = AwsEnvironment.ECS
mock_response = Mock()
mock_response.status_code = HTTPStatus.OK
mock_response.json.return_value = {
"TaskARN": "arn:aws:ecs:us-east-1:123456789012:task/task-id"
}
mock_get.return_value = mock_response
with patch.dict(
os.environ,
{"ECS_CONTAINER_METADATA_URI_V4": "http://169.254.170.2/v4/metadata"},
):
role_arn, source = get_current_identity()
assert role_arn == "arn:aws:ecs:us-east-1:123456789012:task/task-id"
assert source == AwsServicePrincipal.ECS.value
@patch("datahub.ingestion.source.aws.aws_common.detect_aws_environment")
def test_get_unknown_identity(self, mock_detect: Mock) -> None:
"""Test getting identity for unknown environment."""
mock_detect.return_value = AwsEnvironment.UNKNOWN
role_arn, source = get_current_identity()
assert role_arn is None
assert source is None
class TestAssumeRole:
"""Tests for assume_role function."""
@patch("boto3.client")
def test_assume_role_success(self, mock_boto_client: Mock) -> None:
"""Test successful role assumption."""
mock_sts = Mock()
mock_sts.assume_role.return_value = {
"Credentials": {
"AccessKeyId": "AKIA123456789",
"SecretAccessKey": "secret123",
"SessionToken": "token123",
"Expiration": datetime.now(timezone.utc) + timedelta(hours=1),
}
}
mock_boto_client.return_value = mock_sts
role_config = AwsAssumeRoleConfig(
RoleArn="arn:aws:iam::123456789012:role/test-role"
)
result = assume_role(role_config, "us-east-1")
assert result["AccessKeyId"] == "AKIA123456789"
assert result["SecretAccessKey"] == "secret123"
assert result["SessionToken"] == "token123"
mock_sts.assume_role.assert_called_once()
call_args = mock_sts.assume_role.call_args[1]
assert call_args["RoleArn"] == "arn:aws:iam::123456789012:role/test-role"
assert call_args["RoleSessionName"] == "DatahubIngestionSource"
@patch("boto3.client")
def test_assume_role_with_external_id(self, mock_boto_client: Mock) -> None:
"""Test role assumption with external ID."""
mock_sts = Mock()
mock_sts.assume_role.return_value = {
"Credentials": {
"AccessKeyId": "AKIA123456789",
"SecretAccessKey": "secret123",
"SessionToken": "token123",
"Expiration": datetime.now(timezone.utc) + timedelta(hours=1),
}
}
mock_boto_client.return_value = mock_sts
role_config = AwsAssumeRoleConfig(
RoleArn="arn:aws:iam::123456789012:role/test-role",
ExternalId="external-123",
)
assume_role(role_config, "us-east-1")
call_args = mock_sts.assume_role.call_args[1]
assert call_args["ExternalId"] == "external-123"
@patch("boto3.client")
def test_assume_role_with_existing_credentials(
self, mock_boto_client: Mock
) -> None:
"""Test role assumption with existing credentials."""
mock_sts = Mock()
mock_sts.assume_role.return_value = {
"Credentials": {
"AccessKeyId": "AKIA123456789",
"SecretAccessKey": "secret123",
"SessionToken": "token123",
"Expiration": datetime.now(timezone.utc) + timedelta(hours=1),
}
}
mock_boto_client.return_value = mock_sts
role_config = AwsAssumeRoleConfig(
RoleArn="arn:aws:iam::123456789012:role/test-role"
)
existing_creds = {
"AccessKeyId": "EXISTING123",
"SecretAccessKey": "existingsecret",
"SessionToken": "existingtoken",
}
assume_role(role_config, "us-east-1", existing_creds)
mock_boto_client.assert_called_once_with(
"sts",
region_name="us-east-1",
aws_access_key_id="EXISTING123",
aws_secret_access_key="existingsecret",
aws_session_token="existingtoken",
)
class TestAwsSourceConfig:
"""Tests for AwsSourceConfig class."""
def test_init_with_defaults(self) -> None:
"""Test initialization with default values."""
config = AwsSourceConfig()
# Test inherited AwsConnectionConfig fields
assert config.aws_access_key_id is None
assert config.aws_region is None
# Test new fields
assert config.database_pattern == AllowDenyPattern.allow_all()
assert config.table_pattern == AllowDenyPattern.allow_all()
def test_init_with_custom_patterns(self) -> None:
"""Test initialization with custom patterns."""
db_pattern = AllowDenyPattern(allow=["test_*"], deny=["temp_*"])
table_pattern = AllowDenyPattern(allow=["prod_*"])
config = AwsSourceConfig(
aws_region="us-east-1",
database_pattern=db_pattern,
table_pattern=table_pattern,
)
assert config.aws_region == "us-east-1"
assert config.database_pattern == db_pattern
assert config.table_pattern == table_pattern
def test_inheritance_from_aws_connection_config(self) -> None:
"""Test that AwsSourceConfig inherits from AwsConnectionConfig properly."""
config = AwsSourceConfig(
aws_access_key_id="AKIA123456789", aws_secret_access_key="secret123"
)
# Should be able to use AwsConnectionConfig methods
assert hasattr(config, "get_session")
assert hasattr(config, "get_s3_client")
assert hasattr(config, "get_glue_client")
# Should have access to connection config fields
assert config.aws_access_key_id == "AKIA123456789"
assert config.aws_secret_access_key == "secret123"