mirror of
https://github.com/datahub-project/datahub.git
synced 2025-11-05 21:29:48 +00:00
471 lines
17 KiB
Python
471 lines
17 KiB
Python
import os
|
|
from datetime import datetime, timedelta, timezone
|
|
from http import HTTPStatus
|
|
from unittest.mock import Mock, patch
|
|
|
|
import requests
|
|
|
|
from datahub.configuration.common import AllowDenyPattern
|
|
from datahub.ingestion.source.aws.aws_common import (
|
|
AwsAssumeRoleConfig,
|
|
AwsEnvironment,
|
|
AwsServicePrincipal,
|
|
AwsSourceConfig,
|
|
assume_role,
|
|
detect_aws_environment,
|
|
get_current_identity,
|
|
get_instance_metadata_token,
|
|
is_running_on_ec2,
|
|
)
|
|
|
|
|
|
class TestAwsEnvironment:
|
|
"""Tests for AwsEnvironment enum."""
|
|
|
|
def test_enum_values(self) -> None:
|
|
"""Test that all enum values are correctly defined."""
|
|
assert AwsEnvironment.EC2.value == "EC2"
|
|
assert AwsEnvironment.ECS.value == "ECS"
|
|
assert AwsEnvironment.EKS.value == "EKS"
|
|
assert AwsEnvironment.LAMBDA.value == "LAMBDA"
|
|
assert AwsEnvironment.APP_RUNNER.value == "APP_RUNNER"
|
|
assert AwsEnvironment.BEANSTALK.value == "ELASTIC_BEANSTALK"
|
|
assert AwsEnvironment.CLOUD_FORMATION.value == "CLOUD_FORMATION"
|
|
assert AwsEnvironment.UNKNOWN.value == "UNKNOWN"
|
|
|
|
|
|
class TestAwsServicePrincipal:
|
|
"""Tests for AwsServicePrincipal enum."""
|
|
|
|
def test_enum_values(self) -> None:
|
|
"""Test that all enum values are correctly defined."""
|
|
assert AwsServicePrincipal.LAMBDA.value == "lambda.amazonaws.com"
|
|
assert AwsServicePrincipal.EKS.value == "eks.amazonaws.com"
|
|
assert AwsServicePrincipal.APP_RUNNER.value == "apprunner.amazonaws.com"
|
|
assert AwsServicePrincipal.ECS.value == "ecs.amazonaws.com"
|
|
assert (
|
|
AwsServicePrincipal.ELASTIC_BEANSTALK.value
|
|
== "elasticbeanstalk.amazonaws.com"
|
|
)
|
|
assert AwsServicePrincipal.EC2.value == "ec2.amazonaws.com"
|
|
|
|
|
|
class TestAwsAssumeRoleConfig:
|
|
"""Tests for AwsAssumeRoleConfig class."""
|
|
|
|
def test_init_required_fields(self) -> None:
|
|
"""Test initialization with required fields only."""
|
|
config = AwsAssumeRoleConfig(RoleArn="arn:aws:iam::123456789012:role/TestRole")
|
|
assert config.RoleArn == "arn:aws:iam::123456789012:role/TestRole"
|
|
assert config.ExternalId is None
|
|
|
|
def test_init_all_fields(self) -> None:
|
|
"""Test initialization with all fields."""
|
|
config = AwsAssumeRoleConfig(
|
|
RoleArn="arn:aws:iam::123456789012:role/TestRole",
|
|
ExternalId="external-id-123",
|
|
)
|
|
assert config.RoleArn == "arn:aws:iam::123456789012:role/TestRole"
|
|
assert config.ExternalId == "external-id-123"
|
|
|
|
def test_dict_method(self) -> None:
|
|
"""Test dict() method returns correct values."""
|
|
config = AwsAssumeRoleConfig(
|
|
RoleArn="arn:aws:iam::123456789012:role/TestRole",
|
|
ExternalId="external-id-123",
|
|
)
|
|
config_dict = config.dict()
|
|
assert config_dict["RoleArn"] == "arn:aws:iam::123456789012:role/TestRole"
|
|
assert config_dict["ExternalId"] == "external-id-123"
|
|
|
|
|
|
class TestMetadataFunctions:
|
|
"""Tests for EC2 metadata functions."""
|
|
|
|
@patch("requests.put")
|
|
def test_get_instance_metadata_token_success(self, mock_put: Mock) -> None:
|
|
"""Test successful metadata token retrieval."""
|
|
mock_response = Mock()
|
|
mock_response.status_code = HTTPStatus.OK
|
|
mock_response.text = "test-token-123"
|
|
mock_put.return_value = mock_response
|
|
|
|
token = get_instance_metadata_token()
|
|
|
|
assert token == "test-token-123"
|
|
mock_put.assert_called_once_with(
|
|
"http://169.254.169.254/latest/api/token",
|
|
headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"},
|
|
timeout=1,
|
|
)
|
|
|
|
@patch("requests.put")
|
|
def test_get_instance_metadata_token_failure(self, mock_put: Mock) -> None:
|
|
"""Test metadata token retrieval failure."""
|
|
mock_put.side_effect = requests.exceptions.RequestException("Network error")
|
|
|
|
token = get_instance_metadata_token()
|
|
|
|
assert token is None
|
|
|
|
@patch("requests.put")
|
|
def test_get_instance_metadata_token_bad_status(self, mock_put: Mock) -> None:
|
|
"""Test metadata token retrieval with bad status code."""
|
|
mock_response = Mock()
|
|
mock_response.status_code = HTTPStatus.NOT_FOUND
|
|
mock_put.return_value = mock_response
|
|
|
|
token = get_instance_metadata_token()
|
|
|
|
assert token is None
|
|
|
|
@patch("datahub.ingestion.source.aws.aws_common.get_instance_metadata_token")
|
|
@patch("requests.get")
|
|
def test_is_running_on_ec2_success(self, mock_get: Mock, mock_token: Mock) -> None:
|
|
"""Test successful EC2 detection."""
|
|
mock_token.return_value = "test-token"
|
|
mock_response = Mock()
|
|
mock_response.status_code = HTTPStatus.OK
|
|
mock_get.return_value = mock_response
|
|
|
|
result = is_running_on_ec2()
|
|
|
|
assert result is True
|
|
mock_get.assert_called_once_with(
|
|
"http://169.254.169.254/latest/meta-data/instance-id",
|
|
headers={"X-aws-ec2-metadata-token": "test-token"},
|
|
timeout=1,
|
|
)
|
|
|
|
@patch("datahub.ingestion.source.aws.aws_common.get_instance_metadata_token")
|
|
def test_is_running_on_ec2_no_token(self, mock_token: Mock) -> None:
|
|
"""Test EC2 detection when no token is available."""
|
|
mock_token.return_value = None
|
|
|
|
result = is_running_on_ec2()
|
|
|
|
assert result is False
|
|
|
|
@patch("datahub.ingestion.source.aws.aws_common.get_instance_metadata_token")
|
|
@patch("requests.get")
|
|
def test_is_running_on_ec2_request_failure(
|
|
self, mock_get: Mock, mock_token: Mock
|
|
) -> None:
|
|
"""Test EC2 detection when request fails."""
|
|
mock_token.return_value = "test-token"
|
|
mock_get.side_effect = requests.exceptions.RequestException("Network error")
|
|
|
|
result = is_running_on_ec2()
|
|
|
|
assert result is False
|
|
|
|
|
|
class TestDetectAwsEnvironment:
|
|
"""Tests for detect_aws_environment function."""
|
|
|
|
def test_detect_lambda_environment(self) -> None:
|
|
"""Test Lambda environment detection."""
|
|
with patch.dict(
|
|
os.environ, {"AWS_LAMBDA_FUNCTION_NAME": "test-function"}, clear=True
|
|
):
|
|
result = detect_aws_environment()
|
|
assert result == AwsEnvironment.LAMBDA
|
|
|
|
def test_detect_cloud_formation_environment(self) -> None:
|
|
"""Test CloudFormation environment detection."""
|
|
with patch.dict(
|
|
os.environ,
|
|
{
|
|
"AWS_LAMBDA_FUNCTION_NAME": "test-function",
|
|
"AWS_EXECUTION_ENV": "CloudFormation-custom-resource",
|
|
},
|
|
clear=True,
|
|
):
|
|
result = detect_aws_environment()
|
|
assert result == AwsEnvironment.CLOUD_FORMATION
|
|
|
|
def test_detect_eks_environment(self) -> None:
|
|
"""Test EKS (IRSA) environment detection."""
|
|
with patch.dict(
|
|
os.environ,
|
|
{
|
|
"AWS_WEB_IDENTITY_TOKEN_FILE": "/var/run/secrets/eks.amazonaws.com/serviceaccount/token",
|
|
"AWS_ROLE_ARN": "arn:aws:iam::123456789012:role/eks-role",
|
|
},
|
|
clear=True,
|
|
):
|
|
result = detect_aws_environment()
|
|
assert result == AwsEnvironment.EKS
|
|
|
|
def test_detect_app_runner_environment(self) -> None:
|
|
"""Test App Runner environment detection."""
|
|
with patch.dict(
|
|
os.environ, {"AWS_APP_RUNNER_SERVICE_ID": "service-123"}, clear=True
|
|
):
|
|
result = detect_aws_environment()
|
|
assert result == AwsEnvironment.APP_RUNNER
|
|
|
|
def test_detect_ecs_environment_v4(self) -> None:
|
|
"""Test ECS environment detection with metadata URI v4."""
|
|
with patch.dict(
|
|
os.environ,
|
|
{"ECS_CONTAINER_METADATA_URI_V4": "http://169.254.170.2/v4/metadata"},
|
|
clear=True,
|
|
):
|
|
result = detect_aws_environment()
|
|
assert result == AwsEnvironment.ECS
|
|
|
|
def test_detect_ecs_environment_v3(self) -> None:
|
|
"""Test ECS environment detection with metadata URI v3."""
|
|
with patch.dict(
|
|
os.environ,
|
|
{"ECS_CONTAINER_METADATA_URI": "http://169.254.170.2/v3/metadata"},
|
|
clear=True,
|
|
):
|
|
result = detect_aws_environment()
|
|
assert result == AwsEnvironment.ECS
|
|
|
|
def test_detect_beanstalk_environment(self) -> None:
|
|
"""Test Elastic Beanstalk environment detection."""
|
|
with patch.dict(
|
|
os.environ, {"ELASTIC_BEANSTALK_ENVIRONMENT_NAME": "test-env"}, clear=True
|
|
):
|
|
result = detect_aws_environment()
|
|
assert result == AwsEnvironment.BEANSTALK
|
|
|
|
@patch("datahub.ingestion.source.aws.aws_common.is_running_on_ec2")
|
|
def test_detect_ec2_environment(self, mock_is_ec2: Mock) -> None:
|
|
"""Test EC2 environment detection."""
|
|
mock_is_ec2.return_value = True
|
|
|
|
with patch.dict(os.environ, {}, clear=True):
|
|
result = detect_aws_environment()
|
|
assert result == AwsEnvironment.EC2
|
|
|
|
@patch("datahub.ingestion.source.aws.aws_common.is_running_on_ec2")
|
|
def test_detect_unknown_environment(self, mock_is_ec2: Mock) -> None:
|
|
"""Test unknown environment detection."""
|
|
mock_is_ec2.return_value = False
|
|
|
|
with patch.dict(os.environ, {}, clear=True):
|
|
result = detect_aws_environment()
|
|
assert result == AwsEnvironment.UNKNOWN
|
|
|
|
|
|
class TestGetCurrentIdentity:
|
|
"""Tests for get_current_identity function."""
|
|
|
|
@patch("datahub.ingestion.source.aws.aws_common.detect_aws_environment")
|
|
@patch("datahub.ingestion.source.aws.aws_common.get_lambda_role_arn")
|
|
def test_get_lambda_identity(
|
|
self, mock_lambda_role: Mock, mock_detect: Mock
|
|
) -> None:
|
|
"""Test getting Lambda identity."""
|
|
mock_detect.return_value = AwsEnvironment.LAMBDA
|
|
mock_lambda_role.return_value = "arn:aws:iam::123456789012:role/lambda-role"
|
|
|
|
role_arn, source = get_current_identity()
|
|
|
|
assert role_arn == "arn:aws:iam::123456789012:role/lambda-role"
|
|
assert source == AwsServicePrincipal.LAMBDA.value
|
|
|
|
@patch("datahub.ingestion.source.aws.aws_common.detect_aws_environment")
|
|
def test_get_eks_identity(self, mock_detect: Mock) -> None:
|
|
"""Test getting EKS identity."""
|
|
mock_detect.return_value = AwsEnvironment.EKS
|
|
|
|
with patch.dict(
|
|
os.environ, {"AWS_ROLE_ARN": "arn:aws:iam::123456789012:role/eks-role"}
|
|
):
|
|
role_arn, source = get_current_identity()
|
|
|
|
assert role_arn == "arn:aws:iam::123456789012:role/eks-role"
|
|
assert source == AwsServicePrincipal.EKS.value
|
|
|
|
@patch("datahub.ingestion.source.aws.aws_common.detect_aws_environment")
|
|
@patch("boto3.client")
|
|
def test_get_app_runner_identity(
|
|
self, mock_boto_client: Mock, mock_detect: Mock
|
|
) -> None:
|
|
"""Test getting App Runner identity."""
|
|
mock_detect.return_value = AwsEnvironment.APP_RUNNER
|
|
mock_sts = Mock()
|
|
mock_sts.get_caller_identity.return_value = {
|
|
"Arn": "arn:aws:sts::123456789012:assumed-role/app-runner-role"
|
|
}
|
|
mock_boto_client.return_value = mock_sts
|
|
|
|
role_arn, source = get_current_identity()
|
|
|
|
assert role_arn == "arn:aws:sts::123456789012:assumed-role/app-runner-role"
|
|
assert source == AwsServicePrincipal.APP_RUNNER.value
|
|
|
|
@patch("datahub.ingestion.source.aws.aws_common.detect_aws_environment")
|
|
@patch("requests.get")
|
|
def test_get_ecs_identity(self, mock_get: Mock, mock_detect: Mock) -> None:
|
|
"""Test getting ECS identity."""
|
|
mock_detect.return_value = AwsEnvironment.ECS
|
|
mock_response = Mock()
|
|
mock_response.status_code = HTTPStatus.OK
|
|
mock_response.json.return_value = {
|
|
"TaskARN": "arn:aws:ecs:us-east-1:123456789012:task/task-id"
|
|
}
|
|
mock_get.return_value = mock_response
|
|
|
|
with patch.dict(
|
|
os.environ,
|
|
{"ECS_CONTAINER_METADATA_URI_V4": "http://169.254.170.2/v4/metadata"},
|
|
):
|
|
role_arn, source = get_current_identity()
|
|
|
|
assert role_arn == "arn:aws:ecs:us-east-1:123456789012:task/task-id"
|
|
assert source == AwsServicePrincipal.ECS.value
|
|
|
|
@patch("datahub.ingestion.source.aws.aws_common.detect_aws_environment")
|
|
def test_get_unknown_identity(self, mock_detect: Mock) -> None:
|
|
"""Test getting identity for unknown environment."""
|
|
mock_detect.return_value = AwsEnvironment.UNKNOWN
|
|
|
|
role_arn, source = get_current_identity()
|
|
|
|
assert role_arn is None
|
|
assert source is None
|
|
|
|
|
|
class TestAssumeRole:
|
|
"""Tests for assume_role function."""
|
|
|
|
@patch("boto3.client")
|
|
def test_assume_role_success(self, mock_boto_client: Mock) -> None:
|
|
"""Test successful role assumption."""
|
|
mock_sts = Mock()
|
|
mock_sts.assume_role.return_value = {
|
|
"Credentials": {
|
|
"AccessKeyId": "AKIA123456789",
|
|
"SecretAccessKey": "secret123",
|
|
"SessionToken": "token123",
|
|
"Expiration": datetime.now(timezone.utc) + timedelta(hours=1),
|
|
}
|
|
}
|
|
mock_boto_client.return_value = mock_sts
|
|
|
|
role_config = AwsAssumeRoleConfig(
|
|
RoleArn="arn:aws:iam::123456789012:role/test-role"
|
|
)
|
|
|
|
result = assume_role(role_config, "us-east-1")
|
|
|
|
assert result["AccessKeyId"] == "AKIA123456789"
|
|
assert result["SecretAccessKey"] == "secret123"
|
|
assert result["SessionToken"] == "token123"
|
|
|
|
mock_sts.assume_role.assert_called_once()
|
|
call_args = mock_sts.assume_role.call_args[1]
|
|
assert call_args["RoleArn"] == "arn:aws:iam::123456789012:role/test-role"
|
|
assert call_args["RoleSessionName"] == "DatahubIngestionSource"
|
|
|
|
@patch("boto3.client")
|
|
def test_assume_role_with_external_id(self, mock_boto_client: Mock) -> None:
|
|
"""Test role assumption with external ID."""
|
|
mock_sts = Mock()
|
|
mock_sts.assume_role.return_value = {
|
|
"Credentials": {
|
|
"AccessKeyId": "AKIA123456789",
|
|
"SecretAccessKey": "secret123",
|
|
"SessionToken": "token123",
|
|
"Expiration": datetime.now(timezone.utc) + timedelta(hours=1),
|
|
}
|
|
}
|
|
mock_boto_client.return_value = mock_sts
|
|
|
|
role_config = AwsAssumeRoleConfig(
|
|
RoleArn="arn:aws:iam::123456789012:role/test-role",
|
|
ExternalId="external-123",
|
|
)
|
|
|
|
assume_role(role_config, "us-east-1")
|
|
|
|
call_args = mock_sts.assume_role.call_args[1]
|
|
assert call_args["ExternalId"] == "external-123"
|
|
|
|
@patch("boto3.client")
|
|
def test_assume_role_with_existing_credentials(
|
|
self, mock_boto_client: Mock
|
|
) -> None:
|
|
"""Test role assumption with existing credentials."""
|
|
mock_sts = Mock()
|
|
mock_sts.assume_role.return_value = {
|
|
"Credentials": {
|
|
"AccessKeyId": "AKIA123456789",
|
|
"SecretAccessKey": "secret123",
|
|
"SessionToken": "token123",
|
|
"Expiration": datetime.now(timezone.utc) + timedelta(hours=1),
|
|
}
|
|
}
|
|
mock_boto_client.return_value = mock_sts
|
|
|
|
role_config = AwsAssumeRoleConfig(
|
|
RoleArn="arn:aws:iam::123456789012:role/test-role"
|
|
)
|
|
existing_creds = {
|
|
"AccessKeyId": "EXISTING123",
|
|
"SecretAccessKey": "existingsecret",
|
|
"SessionToken": "existingtoken",
|
|
}
|
|
|
|
assume_role(role_config, "us-east-1", existing_creds)
|
|
|
|
mock_boto_client.assert_called_once_with(
|
|
"sts",
|
|
region_name="us-east-1",
|
|
aws_access_key_id="EXISTING123",
|
|
aws_secret_access_key="existingsecret",
|
|
aws_session_token="existingtoken",
|
|
)
|
|
|
|
|
|
class TestAwsSourceConfig:
|
|
"""Tests for AwsSourceConfig class."""
|
|
|
|
def test_init_with_defaults(self) -> None:
|
|
"""Test initialization with default values."""
|
|
config = AwsSourceConfig()
|
|
|
|
# Test inherited AwsConnectionConfig fields
|
|
assert config.aws_access_key_id is None
|
|
assert config.aws_region is None
|
|
|
|
# Test new fields
|
|
assert config.database_pattern == AllowDenyPattern.allow_all()
|
|
assert config.table_pattern == AllowDenyPattern.allow_all()
|
|
|
|
def test_init_with_custom_patterns(self) -> None:
|
|
"""Test initialization with custom patterns."""
|
|
db_pattern = AllowDenyPattern(allow=["test_*"], deny=["temp_*"])
|
|
table_pattern = AllowDenyPattern(allow=["prod_*"])
|
|
|
|
config = AwsSourceConfig(
|
|
aws_region="us-east-1",
|
|
database_pattern=db_pattern,
|
|
table_pattern=table_pattern,
|
|
)
|
|
|
|
assert config.aws_region == "us-east-1"
|
|
assert config.database_pattern == db_pattern
|
|
assert config.table_pattern == table_pattern
|
|
|
|
def test_inheritance_from_aws_connection_config(self) -> None:
|
|
"""Test that AwsSourceConfig inherits from AwsConnectionConfig properly."""
|
|
config = AwsSourceConfig(
|
|
aws_access_key_id="AKIA123456789", aws_secret_access_key="secret123"
|
|
)
|
|
|
|
# Should be able to use AwsConnectionConfig methods
|
|
assert hasattr(config, "get_session")
|
|
assert hasattr(config, "get_s3_client")
|
|
assert hasattr(config, "get_glue_client")
|
|
|
|
# Should have access to connection config fields
|
|
assert config.aws_access_key_id == "AKIA123456789"
|
|
assert config.aws_secret_access_key == "secret123"
|