mirror of
https://github.com/datahub-project/datahub.git
synced 2025-07-04 15:50:14 +00:00
342 lines
12 KiB
Python
342 lines
12 KiB
Python
import json
|
|
import os
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import boto3
|
|
import pytest
|
|
from moto import mock_iam, mock_lambda, mock_sts
|
|
|
|
from datahub.ingestion.source.aws.aws_common import (
|
|
AwsConnectionConfig,
|
|
AwsEnvironment,
|
|
detect_aws_environment,
|
|
get_current_identity,
|
|
get_instance_metadata_token,
|
|
get_instance_role_arn,
|
|
is_running_on_ec2,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_disable_ec2_metadata():
|
|
"""Disable EC2 metadata detection"""
|
|
with patch("requests.put") as mock_put:
|
|
mock_put.return_value.status_code = 404
|
|
yield mock_put
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_aws_config():
|
|
return AwsConnectionConfig(
|
|
aws_access_key_id="test-key",
|
|
aws_secret_access_key="test-secret",
|
|
aws_region="us-east-1",
|
|
)
|
|
|
|
|
|
class TestAwsCommon:
|
|
def test_environment_detection_no_environment(self, mock_disable_ec2_metadata):
|
|
"""Test environment detection when no AWS environment is present"""
|
|
with patch.dict(os.environ, {}, clear=True):
|
|
assert detect_aws_environment() == AwsEnvironment.UNKNOWN
|
|
|
|
def test_environment_detection_lambda(self, mock_disable_ec2_metadata):
|
|
"""Test Lambda environment detection"""
|
|
with patch.dict(os.environ, {"AWS_LAMBDA_FUNCTION_NAME": "test-function"}):
|
|
assert detect_aws_environment() == AwsEnvironment.LAMBDA
|
|
|
|
def test_environment_detection_lambda_cloudformation(
|
|
self, mock_disable_ec2_metadata
|
|
):
|
|
"""Test CloudFormation Lambda environment detection"""
|
|
with patch.dict(
|
|
os.environ,
|
|
{
|
|
"AWS_LAMBDA_FUNCTION_NAME": "test-function",
|
|
"AWS_EXECUTION_ENV": "CloudFormation.xxx",
|
|
},
|
|
):
|
|
assert detect_aws_environment() == AwsEnvironment.CLOUD_FORMATION
|
|
|
|
def test_environment_detection_eks(self, mock_disable_ec2_metadata):
|
|
"""Test EKS environment detection"""
|
|
with patch.dict(
|
|
os.environ,
|
|
{
|
|
"AWS_WEB_IDENTITY_TOKEN_FILE": "/var/run/secrets/token",
|
|
"AWS_ROLE_ARN": "arn:aws:iam::123456789012:role/test-role",
|
|
},
|
|
):
|
|
assert detect_aws_environment() == AwsEnvironment.EKS
|
|
|
|
def test_environment_detection_app_runner(self, mock_disable_ec2_metadata):
|
|
"""Test App Runner environment detection"""
|
|
with patch.dict(os.environ, {"AWS_APP_RUNNER_SERVICE_ID": "service-id"}):
|
|
assert detect_aws_environment() == AwsEnvironment.APP_RUNNER
|
|
|
|
def test_environment_detection_ecs(self, mock_disable_ec2_metadata):
|
|
"""Test ECS environment detection"""
|
|
with patch.dict(
|
|
os.environ, {"ECS_CONTAINER_METADATA_URI_V4": "http://169.254.170.2/v4"}
|
|
):
|
|
assert detect_aws_environment() == AwsEnvironment.ECS
|
|
|
|
def test_environment_detection_beanstalk(self, mock_disable_ec2_metadata):
|
|
"""Test Elastic Beanstalk environment detection"""
|
|
with patch.dict(os.environ, {"ELASTIC_BEANSTALK_ENVIRONMENT_NAME": "my-env"}):
|
|
assert detect_aws_environment() == AwsEnvironment.BEANSTALK
|
|
|
|
@patch("requests.put")
|
|
def test_ec2_metadata_token(self, mock_put):
|
|
"""Test EC2 metadata token retrieval"""
|
|
mock_put.return_value.status_code = 200
|
|
mock_put.return_value.text = "token123"
|
|
|
|
token = get_instance_metadata_token()
|
|
assert token == "token123"
|
|
|
|
mock_put.assert_called_once_with(
|
|
"http://169.254.169.254/latest/api/token",
|
|
headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"},
|
|
timeout=1,
|
|
)
|
|
|
|
@patch("requests.put")
|
|
def test_ec2_metadata_token_failure(self, mock_put):
|
|
"""Test EC2 metadata token failure case"""
|
|
mock_put.return_value.status_code = 404
|
|
|
|
token = get_instance_metadata_token()
|
|
assert token is None
|
|
|
|
@patch("requests.get")
|
|
@patch("requests.put")
|
|
def test_is_running_on_ec2(self, mock_put, mock_get):
|
|
"""Test EC2 instance detection with IMDSv2"""
|
|
# Explicitly mock EC2 metadata responses
|
|
mock_put.return_value.status_code = 200
|
|
mock_put.return_value.text = "token123"
|
|
mock_get.return_value.status_code = 200
|
|
|
|
assert is_running_on_ec2() is True
|
|
|
|
mock_put.assert_called_once_with(
|
|
"http://169.254.169.254/latest/api/token",
|
|
headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"},
|
|
timeout=1,
|
|
)
|
|
mock_get.assert_called_once_with(
|
|
"http://169.254.169.254/latest/meta-data/instance-id",
|
|
headers={"X-aws-ec2-metadata-token": "token123"},
|
|
timeout=1,
|
|
)
|
|
|
|
@patch("requests.get")
|
|
@patch("requests.put")
|
|
def test_is_running_on_ec2_failure(self, mock_put, mock_get):
|
|
"""Test EC2 instance detection failure"""
|
|
mock_put.return_value.status_code = 404
|
|
assert is_running_on_ec2() is False
|
|
|
|
mock_put.return_value.status_code = 200
|
|
mock_put.return_value.text = "token123"
|
|
mock_get.return_value.status_code = 404
|
|
assert is_running_on_ec2() is False
|
|
|
|
@mock_sts
|
|
@mock_lambda
|
|
@mock_iam
|
|
def test_get_current_identity_lambda(self):
|
|
"""Test getting identity in Lambda environment"""
|
|
with patch.dict(
|
|
os.environ,
|
|
{
|
|
"AWS_LAMBDA_FUNCTION_NAME": "test-function",
|
|
"AWS_DEFAULT_REGION": "us-east-1",
|
|
},
|
|
):
|
|
# Create IAM role first with proper trust policy
|
|
iam_client = boto3.client("iam", region_name="us-east-1")
|
|
trust_policy = {
|
|
"Version": "2012-10-17",
|
|
"Statement": [
|
|
{
|
|
"Effect": "Allow",
|
|
"Principal": {"Service": "lambda.amazonaws.com"},
|
|
"Action": "sts:AssumeRole",
|
|
}
|
|
],
|
|
}
|
|
iam_client.create_role(
|
|
RoleName="test-role", AssumeRolePolicyDocument=json.dumps(trust_policy)
|
|
)
|
|
|
|
lambda_client = boto3.client("lambda", region_name="us-east-1")
|
|
lambda_client.create_function(
|
|
FunctionName="test-function",
|
|
Runtime="python3.8",
|
|
Role="arn:aws:iam::123456789012:role/test-role",
|
|
Handler="index.handler",
|
|
Code={"ZipFile": b"def handler(event, context): pass"},
|
|
)
|
|
|
|
role_arn, source = get_current_identity()
|
|
assert source == "lambda.amazonaws.com"
|
|
assert role_arn == "arn:aws:iam::123456789012:role/test-role"
|
|
|
|
@patch("requests.get")
|
|
@patch("requests.put")
|
|
@mock_sts
|
|
def test_get_instance_role_arn_success(self, mock_put, mock_get):
|
|
"""Test getting EC2 instance role ARN"""
|
|
mock_put.return_value.status_code = 200
|
|
mock_put.return_value.text = "token123"
|
|
mock_get.return_value.status_code = 200
|
|
mock_get.return_value.text = "test-role"
|
|
|
|
with patch("boto3.client") as mock_boto:
|
|
mock_sts = MagicMock()
|
|
mock_sts.get_caller_identity.return_value = {
|
|
"Arn": "arn:aws:sts::123456789012:assumed-role/test-role/instance"
|
|
}
|
|
mock_boto.return_value = mock_sts
|
|
|
|
role_arn = get_instance_role_arn()
|
|
assert (
|
|
role_arn == "arn:aws:sts::123456789012:assumed-role/test-role/instance"
|
|
)
|
|
|
|
@mock_sts
|
|
def test_aws_connection_config_basic(self, mock_aws_config):
|
|
"""Test basic AWS connection configuration"""
|
|
session = mock_aws_config.get_session()
|
|
creds = session.get_credentials()
|
|
assert creds.access_key == "test-key"
|
|
assert creds.secret_key == "test-secret"
|
|
|
|
@mock_sts
|
|
def test_aws_connection_config_with_session_token(self):
|
|
"""Test AWS connection with session token"""
|
|
config = AwsConnectionConfig(
|
|
aws_access_key_id="test-key",
|
|
aws_secret_access_key="test-secret",
|
|
aws_session_token="test-token",
|
|
aws_region="us-east-1",
|
|
)
|
|
|
|
session = config.get_session()
|
|
creds = session.get_credentials()
|
|
assert creds.token == "test-token"
|
|
|
|
@mock_sts
|
|
def test_aws_connection_config_role_assumption(self):
|
|
"""Test AWS connection with role assumption"""
|
|
config = AwsConnectionConfig(
|
|
aws_access_key_id="test-key",
|
|
aws_secret_access_key="test-secret",
|
|
aws_region="us-east-1",
|
|
aws_role="arn:aws:iam::123456789012:role/test-role",
|
|
)
|
|
|
|
with patch(
|
|
"datahub.ingestion.source.aws.aws_common.get_current_identity"
|
|
) as mock_identity:
|
|
mock_identity.return_value = (None, None)
|
|
session = config.get_session()
|
|
creds = session.get_credentials()
|
|
assert creds is not None
|
|
|
|
@mock_sts
|
|
def test_aws_connection_config_skip_role_assumption(self):
|
|
"""Test AWS connection skipping role assumption when already in role"""
|
|
config = AwsConnectionConfig(
|
|
aws_region="us-east-1",
|
|
aws_role="arn:aws:iam::123456789012:role/current-role",
|
|
)
|
|
|
|
with patch(
|
|
"datahub.ingestion.source.aws.aws_common.get_current_identity"
|
|
) as mock_identity:
|
|
mock_identity.return_value = (
|
|
"arn:aws:iam::123456789012:role/current-role",
|
|
"ec2.amazonaws.com",
|
|
)
|
|
session = config.get_session()
|
|
assert session is not None
|
|
|
|
@mock_sts
|
|
def test_aws_connection_config_multiple_roles(self):
|
|
"""Test AWS connection with multiple role assumption"""
|
|
config = AwsConnectionConfig(
|
|
aws_access_key_id="test-key",
|
|
aws_secret_access_key="test-secret",
|
|
aws_region="us-east-1",
|
|
aws_role=[
|
|
"arn:aws:iam::123456789012:role/role1",
|
|
"arn:aws:iam::123456789012:role/role2",
|
|
],
|
|
)
|
|
|
|
with patch(
|
|
"datahub.ingestion.source.aws.aws_common.get_current_identity"
|
|
) as mock_identity:
|
|
mock_identity.return_value = (None, None)
|
|
session = config.get_session()
|
|
assert session is not None
|
|
|
|
def test_aws_connection_config_validation_error(self):
|
|
"""Test AWS connection validation"""
|
|
with patch.dict(
|
|
"os.environ",
|
|
{
|
|
"AWS_ACCESS_KEY_ID": "test-key",
|
|
# Deliberately missing AWS_SECRET_ACCESS_KEY
|
|
"AWS_DEFAULT_REGION": "us-east-1",
|
|
},
|
|
clear=True,
|
|
):
|
|
config = AwsConnectionConfig() # Let it pick up from environment
|
|
session = config.get_session()
|
|
with pytest.raises(
|
|
Exception,
|
|
match="Partial credentials found in env, missing: AWS_SECRET_ACCESS_KEY",
|
|
):
|
|
session.get_credentials()
|
|
|
|
@pytest.mark.parametrize(
|
|
"env_vars,expected_environment",
|
|
[
|
|
({}, AwsEnvironment.UNKNOWN),
|
|
({"AWS_LAMBDA_FUNCTION_NAME": "test"}, AwsEnvironment.LAMBDA),
|
|
(
|
|
{
|
|
"AWS_LAMBDA_FUNCTION_NAME": "test",
|
|
"AWS_EXECUTION_ENV": "CloudFormation",
|
|
},
|
|
AwsEnvironment.CLOUD_FORMATION,
|
|
),
|
|
(
|
|
{
|
|
"AWS_WEB_IDENTITY_TOKEN_FILE": "/token",
|
|
"AWS_ROLE_ARN": "arn:aws:iam::123:role/test",
|
|
},
|
|
AwsEnvironment.EKS,
|
|
),
|
|
({"AWS_APP_RUNNER_SERVICE_ID": "service-123"}, AwsEnvironment.APP_RUNNER),
|
|
(
|
|
{"ECS_CONTAINER_METADATA_URI_V4": "http://169.254.170.2"},
|
|
AwsEnvironment.ECS,
|
|
),
|
|
(
|
|
{"ELASTIC_BEANSTALK_ENVIRONMENT_NAME": "my-env"},
|
|
AwsEnvironment.BEANSTALK,
|
|
),
|
|
],
|
|
)
|
|
def test_environment_detection_parametrized(
|
|
self, mock_disable_ec2_metadata, env_vars, expected_environment
|
|
):
|
|
"""Parametrized test for environment detection with different configurations"""
|
|
with patch.dict(os.environ, env_vars, clear=True):
|
|
assert detect_aws_environment() == expected_environment
|