Tamas Nemeth 6787d83fa6
fix(ingest/unity-catalog): Speed up tag extraction by caching platform resources (#14376)
The pr reworks the external entity api for platform resources
2025-08-11 22:51:25 +02:00

471 lines
17 KiB
Python

import os
from datetime import datetime, timedelta, timezone
from http import HTTPStatus
from unittest.mock import Mock, patch
import requests
from datahub.configuration.common import AllowDenyPattern
from datahub.ingestion.source.aws.aws_common import (
AwsAssumeRoleConfig,
AwsEnvironment,
AwsServicePrincipal,
AwsSourceConfig,
assume_role,
detect_aws_environment,
get_current_identity,
get_instance_metadata_token,
is_running_on_ec2,
)
class TestAwsEnvironment:
"""Tests for AwsEnvironment enum."""
def test_enum_values(self) -> None:
"""Test that all enum values are correctly defined."""
assert AwsEnvironment.EC2.value == "EC2"
assert AwsEnvironment.ECS.value == "ECS"
assert AwsEnvironment.EKS.value == "EKS"
assert AwsEnvironment.LAMBDA.value == "LAMBDA"
assert AwsEnvironment.APP_RUNNER.value == "APP_RUNNER"
assert AwsEnvironment.BEANSTALK.value == "ELASTIC_BEANSTALK"
assert AwsEnvironment.CLOUD_FORMATION.value == "CLOUD_FORMATION"
assert AwsEnvironment.UNKNOWN.value == "UNKNOWN"
class TestAwsServicePrincipal:
"""Tests for AwsServicePrincipal enum."""
def test_enum_values(self) -> None:
"""Test that all enum values are correctly defined."""
assert AwsServicePrincipal.LAMBDA.value == "lambda.amazonaws.com"
assert AwsServicePrincipal.EKS.value == "eks.amazonaws.com"
assert AwsServicePrincipal.APP_RUNNER.value == "apprunner.amazonaws.com"
assert AwsServicePrincipal.ECS.value == "ecs.amazonaws.com"
assert (
AwsServicePrincipal.ELASTIC_BEANSTALK.value
== "elasticbeanstalk.amazonaws.com"
)
assert AwsServicePrincipal.EC2.value == "ec2.amazonaws.com"
class TestAwsAssumeRoleConfig:
"""Tests for AwsAssumeRoleConfig class."""
def test_init_required_fields(self) -> None:
"""Test initialization with required fields only."""
config = AwsAssumeRoleConfig(RoleArn="arn:aws:iam::123456789012:role/TestRole")
assert config.RoleArn == "arn:aws:iam::123456789012:role/TestRole"
assert config.ExternalId is None
def test_init_all_fields(self) -> None:
"""Test initialization with all fields."""
config = AwsAssumeRoleConfig(
RoleArn="arn:aws:iam::123456789012:role/TestRole",
ExternalId="external-id-123",
)
assert config.RoleArn == "arn:aws:iam::123456789012:role/TestRole"
assert config.ExternalId == "external-id-123"
def test_dict_method(self) -> None:
"""Test dict() method returns correct values."""
config = AwsAssumeRoleConfig(
RoleArn="arn:aws:iam::123456789012:role/TestRole",
ExternalId="external-id-123",
)
config_dict = config.dict()
assert config_dict["RoleArn"] == "arn:aws:iam::123456789012:role/TestRole"
assert config_dict["ExternalId"] == "external-id-123"
class TestMetadataFunctions:
"""Tests for EC2 metadata functions."""
@patch("requests.put")
def test_get_instance_metadata_token_success(self, mock_put: Mock) -> None:
"""Test successful metadata token retrieval."""
mock_response = Mock()
mock_response.status_code = HTTPStatus.OK
mock_response.text = "test-token-123"
mock_put.return_value = mock_response
token = get_instance_metadata_token()
assert token == "test-token-123"
mock_put.assert_called_once_with(
"http://169.254.169.254/latest/api/token",
headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"},
timeout=1,
)
@patch("requests.put")
def test_get_instance_metadata_token_failure(self, mock_put: Mock) -> None:
"""Test metadata token retrieval failure."""
mock_put.side_effect = requests.exceptions.RequestException("Network error")
token = get_instance_metadata_token()
assert token is None
@patch("requests.put")
def test_get_instance_metadata_token_bad_status(self, mock_put: Mock) -> None:
"""Test metadata token retrieval with bad status code."""
mock_response = Mock()
mock_response.status_code = HTTPStatus.NOT_FOUND
mock_put.return_value = mock_response
token = get_instance_metadata_token()
assert token is None
@patch("datahub.ingestion.source.aws.aws_common.get_instance_metadata_token")
@patch("requests.get")
def test_is_running_on_ec2_success(self, mock_get: Mock, mock_token: Mock) -> None:
"""Test successful EC2 detection."""
mock_token.return_value = "test-token"
mock_response = Mock()
mock_response.status_code = HTTPStatus.OK
mock_get.return_value = mock_response
result = is_running_on_ec2()
assert result is True
mock_get.assert_called_once_with(
"http://169.254.169.254/latest/meta-data/instance-id",
headers={"X-aws-ec2-metadata-token": "test-token"},
timeout=1,
)
@patch("datahub.ingestion.source.aws.aws_common.get_instance_metadata_token")
def test_is_running_on_ec2_no_token(self, mock_token: Mock) -> None:
"""Test EC2 detection when no token is available."""
mock_token.return_value = None
result = is_running_on_ec2()
assert result is False
@patch("datahub.ingestion.source.aws.aws_common.get_instance_metadata_token")
@patch("requests.get")
def test_is_running_on_ec2_request_failure(
self, mock_get: Mock, mock_token: Mock
) -> None:
"""Test EC2 detection when request fails."""
mock_token.return_value = "test-token"
mock_get.side_effect = requests.exceptions.RequestException("Network error")
result = is_running_on_ec2()
assert result is False
class TestDetectAwsEnvironment:
"""Tests for detect_aws_environment function."""
def test_detect_lambda_environment(self) -> None:
"""Test Lambda environment detection."""
with patch.dict(
os.environ, {"AWS_LAMBDA_FUNCTION_NAME": "test-function"}, clear=True
):
result = detect_aws_environment()
assert result == AwsEnvironment.LAMBDA
def test_detect_cloud_formation_environment(self) -> None:
"""Test CloudFormation environment detection."""
with patch.dict(
os.environ,
{
"AWS_LAMBDA_FUNCTION_NAME": "test-function",
"AWS_EXECUTION_ENV": "CloudFormation-custom-resource",
},
clear=True,
):
result = detect_aws_environment()
assert result == AwsEnvironment.CLOUD_FORMATION
def test_detect_eks_environment(self) -> None:
"""Test EKS (IRSA) environment detection."""
with patch.dict(
os.environ,
{
"AWS_WEB_IDENTITY_TOKEN_FILE": "/var/run/secrets/eks.amazonaws.com/serviceaccount/token",
"AWS_ROLE_ARN": "arn:aws:iam::123456789012:role/eks-role",
},
clear=True,
):
result = detect_aws_environment()
assert result == AwsEnvironment.EKS
def test_detect_app_runner_environment(self) -> None:
"""Test App Runner environment detection."""
with patch.dict(
os.environ, {"AWS_APP_RUNNER_SERVICE_ID": "service-123"}, clear=True
):
result = detect_aws_environment()
assert result == AwsEnvironment.APP_RUNNER
def test_detect_ecs_environment_v4(self) -> None:
"""Test ECS environment detection with metadata URI v4."""
with patch.dict(
os.environ,
{"ECS_CONTAINER_METADATA_URI_V4": "http://169.254.170.2/v4/metadata"},
clear=True,
):
result = detect_aws_environment()
assert result == AwsEnvironment.ECS
def test_detect_ecs_environment_v3(self) -> None:
"""Test ECS environment detection with metadata URI v3."""
with patch.dict(
os.environ,
{"ECS_CONTAINER_METADATA_URI": "http://169.254.170.2/v3/metadata"},
clear=True,
):
result = detect_aws_environment()
assert result == AwsEnvironment.ECS
def test_detect_beanstalk_environment(self) -> None:
"""Test Elastic Beanstalk environment detection."""
with patch.dict(
os.environ, {"ELASTIC_BEANSTALK_ENVIRONMENT_NAME": "test-env"}, clear=True
):
result = detect_aws_environment()
assert result == AwsEnvironment.BEANSTALK
@patch("datahub.ingestion.source.aws.aws_common.is_running_on_ec2")
def test_detect_ec2_environment(self, mock_is_ec2: Mock) -> None:
"""Test EC2 environment detection."""
mock_is_ec2.return_value = True
with patch.dict(os.environ, {}, clear=True):
result = detect_aws_environment()
assert result == AwsEnvironment.EC2
@patch("datahub.ingestion.source.aws.aws_common.is_running_on_ec2")
def test_detect_unknown_environment(self, mock_is_ec2: Mock) -> None:
"""Test unknown environment detection."""
mock_is_ec2.return_value = False
with patch.dict(os.environ, {}, clear=True):
result = detect_aws_environment()
assert result == AwsEnvironment.UNKNOWN
class TestGetCurrentIdentity:
"""Tests for get_current_identity function."""
@patch("datahub.ingestion.source.aws.aws_common.detect_aws_environment")
@patch("datahub.ingestion.source.aws.aws_common.get_lambda_role_arn")
def test_get_lambda_identity(
self, mock_lambda_role: Mock, mock_detect: Mock
) -> None:
"""Test getting Lambda identity."""
mock_detect.return_value = AwsEnvironment.LAMBDA
mock_lambda_role.return_value = "arn:aws:iam::123456789012:role/lambda-role"
role_arn, source = get_current_identity()
assert role_arn == "arn:aws:iam::123456789012:role/lambda-role"
assert source == AwsServicePrincipal.LAMBDA.value
@patch("datahub.ingestion.source.aws.aws_common.detect_aws_environment")
def test_get_eks_identity(self, mock_detect: Mock) -> None:
"""Test getting EKS identity."""
mock_detect.return_value = AwsEnvironment.EKS
with patch.dict(
os.environ, {"AWS_ROLE_ARN": "arn:aws:iam::123456789012:role/eks-role"}
):
role_arn, source = get_current_identity()
assert role_arn == "arn:aws:iam::123456789012:role/eks-role"
assert source == AwsServicePrincipal.EKS.value
@patch("datahub.ingestion.source.aws.aws_common.detect_aws_environment")
@patch("boto3.client")
def test_get_app_runner_identity(
self, mock_boto_client: Mock, mock_detect: Mock
) -> None:
"""Test getting App Runner identity."""
mock_detect.return_value = AwsEnvironment.APP_RUNNER
mock_sts = Mock()
mock_sts.get_caller_identity.return_value = {
"Arn": "arn:aws:sts::123456789012:assumed-role/app-runner-role"
}
mock_boto_client.return_value = mock_sts
role_arn, source = get_current_identity()
assert role_arn == "arn:aws:sts::123456789012:assumed-role/app-runner-role"
assert source == AwsServicePrincipal.APP_RUNNER.value
@patch("datahub.ingestion.source.aws.aws_common.detect_aws_environment")
@patch("requests.get")
def test_get_ecs_identity(self, mock_get: Mock, mock_detect: Mock) -> None:
"""Test getting ECS identity."""
mock_detect.return_value = AwsEnvironment.ECS
mock_response = Mock()
mock_response.status_code = HTTPStatus.OK
mock_response.json.return_value = {
"TaskARN": "arn:aws:ecs:us-east-1:123456789012:task/task-id"
}
mock_get.return_value = mock_response
with patch.dict(
os.environ,
{"ECS_CONTAINER_METADATA_URI_V4": "http://169.254.170.2/v4/metadata"},
):
role_arn, source = get_current_identity()
assert role_arn == "arn:aws:ecs:us-east-1:123456789012:task/task-id"
assert source == AwsServicePrincipal.ECS.value
@patch("datahub.ingestion.source.aws.aws_common.detect_aws_environment")
def test_get_unknown_identity(self, mock_detect: Mock) -> None:
"""Test getting identity for unknown environment."""
mock_detect.return_value = AwsEnvironment.UNKNOWN
role_arn, source = get_current_identity()
assert role_arn is None
assert source is None
class TestAssumeRole:
"""Tests for assume_role function."""
@patch("boto3.client")
def test_assume_role_success(self, mock_boto_client: Mock) -> None:
"""Test successful role assumption."""
mock_sts = Mock()
mock_sts.assume_role.return_value = {
"Credentials": {
"AccessKeyId": "AKIA123456789",
"SecretAccessKey": "secret123",
"SessionToken": "token123",
"Expiration": datetime.now(timezone.utc) + timedelta(hours=1),
}
}
mock_boto_client.return_value = mock_sts
role_config = AwsAssumeRoleConfig(
RoleArn="arn:aws:iam::123456789012:role/test-role"
)
result = assume_role(role_config, "us-east-1")
assert result["AccessKeyId"] == "AKIA123456789"
assert result["SecretAccessKey"] == "secret123"
assert result["SessionToken"] == "token123"
mock_sts.assume_role.assert_called_once()
call_args = mock_sts.assume_role.call_args[1]
assert call_args["RoleArn"] == "arn:aws:iam::123456789012:role/test-role"
assert call_args["RoleSessionName"] == "DatahubIngestionSource"
@patch("boto3.client")
def test_assume_role_with_external_id(self, mock_boto_client: Mock) -> None:
"""Test role assumption with external ID."""
mock_sts = Mock()
mock_sts.assume_role.return_value = {
"Credentials": {
"AccessKeyId": "AKIA123456789",
"SecretAccessKey": "secret123",
"SessionToken": "token123",
"Expiration": datetime.now(timezone.utc) + timedelta(hours=1),
}
}
mock_boto_client.return_value = mock_sts
role_config = AwsAssumeRoleConfig(
RoleArn="arn:aws:iam::123456789012:role/test-role",
ExternalId="external-123",
)
assume_role(role_config, "us-east-1")
call_args = mock_sts.assume_role.call_args[1]
assert call_args["ExternalId"] == "external-123"
@patch("boto3.client")
def test_assume_role_with_existing_credentials(
self, mock_boto_client: Mock
) -> None:
"""Test role assumption with existing credentials."""
mock_sts = Mock()
mock_sts.assume_role.return_value = {
"Credentials": {
"AccessKeyId": "AKIA123456789",
"SecretAccessKey": "secret123",
"SessionToken": "token123",
"Expiration": datetime.now(timezone.utc) + timedelta(hours=1),
}
}
mock_boto_client.return_value = mock_sts
role_config = AwsAssumeRoleConfig(
RoleArn="arn:aws:iam::123456789012:role/test-role"
)
existing_creds = {
"AccessKeyId": "EXISTING123",
"SecretAccessKey": "existingsecret",
"SessionToken": "existingtoken",
}
assume_role(role_config, "us-east-1", existing_creds)
mock_boto_client.assert_called_once_with(
"sts",
region_name="us-east-1",
aws_access_key_id="EXISTING123",
aws_secret_access_key="existingsecret",
aws_session_token="existingtoken",
)
class TestAwsSourceConfig:
"""Tests for AwsSourceConfig class."""
def test_init_with_defaults(self) -> None:
"""Test initialization with default values."""
config = AwsSourceConfig()
# Test inherited AwsConnectionConfig fields
assert config.aws_access_key_id is None
assert config.aws_region is None
# Test new fields
assert config.database_pattern == AllowDenyPattern.allow_all()
assert config.table_pattern == AllowDenyPattern.allow_all()
def test_init_with_custom_patterns(self) -> None:
"""Test initialization with custom patterns."""
db_pattern = AllowDenyPattern(allow=["test_*"], deny=["temp_*"])
table_pattern = AllowDenyPattern(allow=["prod_*"])
config = AwsSourceConfig(
aws_region="us-east-1",
database_pattern=db_pattern,
table_pattern=table_pattern,
)
assert config.aws_region == "us-east-1"
assert config.database_pattern == db_pattern
assert config.table_pattern == table_pattern
def test_inheritance_from_aws_connection_config(self) -> None:
"""Test that AwsSourceConfig inherits from AwsConnectionConfig properly."""
config = AwsSourceConfig(
aws_access_key_id="AKIA123456789", aws_secret_access_key="secret123"
)
# Should be able to use AwsConnectionConfig methods
assert hasattr(config, "get_session")
assert hasattr(config, "get_s3_client")
assert hasattr(config, "get_glue_client")
# Should have access to connection config fields
assert config.aws_access_key_id == "AKIA123456789"
assert config.aws_secret_access_key == "secret123"