diff --git a/metadata-ingestion/tests/unit/test_aws_common.py b/metadata-ingestion/tests/unit/test_aws_common.py index 9291fb9113..dd1f06cf9b 100644 --- a/metadata-ingestion/tests/unit/test_aws_common.py +++ b/metadata-ingestion/tests/unit/test_aws_common.py @@ -17,6 +17,14 @@ from datahub.ingestion.source.aws.aws_common import ( ) +@pytest.fixture +def mock_disable_ec2_metadata(): + """Disable EC2 metadata detection""" + with patch("requests.put") as mock_put: + mock_put.return_value.status_code = 404 + yield mock_put + + @pytest.fixture def mock_aws_config(): return AwsConnectionConfig( @@ -27,17 +35,19 @@ def mock_aws_config(): class TestAwsCommon: - def test_environment_detection_no_environment(self): + def test_environment_detection_no_environment(self, mock_disable_ec2_metadata): """Test environment detection when no AWS environment is present""" with patch.dict(os.environ, {}, clear=True): assert detect_aws_environment() == AwsEnvironment.UNKNOWN - def test_environment_detection_lambda(self): + def test_environment_detection_lambda(self, mock_disable_ec2_metadata): """Test Lambda environment detection""" with patch.dict(os.environ, {"AWS_LAMBDA_FUNCTION_NAME": "test-function"}): assert detect_aws_environment() == AwsEnvironment.LAMBDA - def test_environment_detection_lambda_cloudformation(self): + def test_environment_detection_lambda_cloudformation( + self, mock_disable_ec2_metadata + ): """Test CloudFormation Lambda environment detection""" with patch.dict( os.environ, @@ -48,7 +58,7 @@ class TestAwsCommon: ): assert detect_aws_environment() == AwsEnvironment.CLOUD_FORMATION - def test_environment_detection_eks(self): + def test_environment_detection_eks(self, mock_disable_ec2_metadata): """Test EKS environment detection""" with patch.dict( os.environ, @@ -59,19 +69,19 @@ class TestAwsCommon: ): assert detect_aws_environment() == AwsEnvironment.EKS - def test_environment_detection_app_runner(self): + def test_environment_detection_app_runner(self, mock_disable_ec2_metadata): """Test App Runner environment detection""" with patch.dict(os.environ, {"AWS_APP_RUNNER_SERVICE_ID": "service-id"}): assert detect_aws_environment() == AwsEnvironment.APP_RUNNER - def test_environment_detection_ecs(self): + def test_environment_detection_ecs(self, mock_disable_ec2_metadata): """Test ECS environment detection""" with patch.dict( os.environ, {"ECS_CONTAINER_METADATA_URI_V4": "http://169.254.170.2/v4"} ): assert detect_aws_environment() == AwsEnvironment.ECS - def test_environment_detection_beanstalk(self): + def test_environment_detection_beanstalk(self, mock_disable_ec2_metadata): """Test Elastic Beanstalk environment detection""" with patch.dict(os.environ, {"ELASTIC_BEANSTALK_ENVIRONMENT_NAME": "my-env"}): assert detect_aws_environment() == AwsEnvironment.BEANSTALK @@ -103,6 +113,7 @@ class TestAwsCommon: @patch("requests.put") def test_is_running_on_ec2(self, mock_put, mock_get): """Test EC2 instance detection with IMDSv2""" + # Explicitly mock EC2 metadata responses mock_put.return_value.status_code = 200 mock_put.return_value.text = "token123" mock_get.return_value.status_code = 200 @@ -322,7 +333,9 @@ class TestAwsCommon: ), ], ) - def test_environment_detection_parametrized(self, env_vars, expected_environment): + def test_environment_detection_parametrized( + self, mock_disable_ec2_metadata, env_vars, expected_environment + ): """Parametrized test for environment detection with different configurations""" with patch.dict(os.environ, env_vars, clear=True): assert detect_aws_environment() == expected_environment