fix(ingest/sagemaker): ensure consistent STS token usage with refresh mechanism (#11170)

Co-authored-by: Aseem Bansal <asmbansal2@gmail.com>
This commit is contained in:
sagar-salvi-apptware 2024-08-22 15:42:13 +05:30 committed by GitHub
parent dc30c0a0b7
commit 50ed448861
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 57 additions and 13 deletions

View File

@ -1,3 +1,4 @@
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import boto3
@ -73,6 +74,8 @@ class AwsConnectionConfig(ConfigModel):
- dbt source
"""
_credentials_expiration: Optional[datetime] = None
aws_access_key_id: Optional[str] = Field(
default=None,
description=f"AWS access key ID. {AUTODETECT_CREDENTIALS_DOC_LINK}",
@ -115,6 +118,11 @@ class AwsConnectionConfig(ConfigModel):
description="Advanced AWS configuration options. These are passed directly to [botocore.config.Config](https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html).",
)
def allowed_cred_refresh(self) -> bool:
if self._normalized_aws_roles():
return True
return False
def _normalized_aws_roles(self) -> List[AwsAssumeRoleConfig]:
if not self.aws_role:
return []
@ -153,11 +161,14 @@ class AwsConnectionConfig(ConfigModel):
}
for role in self._normalized_aws_roles():
credentials = assume_role(
role,
self.aws_region,
credentials=credentials,
)
if self._should_refresh_credentials():
credentials = assume_role(
role,
self.aws_region,
credentials=credentials,
)
if isinstance(credentials["Expiration"], datetime):
self._credentials_expiration = credentials["Expiration"]
session = Session(
aws_access_key_id=credentials["AccessKeyId"],
@ -168,6 +179,12 @@ class AwsConnectionConfig(ConfigModel):
return session
def _should_refresh_credentials(self) -> bool:
if self._credentials_expiration is None:
return True
remaining_time = self._credentials_expiration - datetime.now(timezone.utc)
return remaining_time < timedelta(minutes=5)
def get_credentials(self) -> Dict[str, Optional[str]]:
credentials = self.get_session().get_credentials()
if credentials is not None:

View File

@ -1,5 +1,5 @@
from collections import defaultdict
from typing import DefaultDict, Dict, Iterable, List, Optional
from typing import TYPE_CHECKING, DefaultDict, Dict, Iterable, List, Optional
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.decorators import (
@ -33,6 +33,9 @@ from datahub.ingestion.source.state.stateful_ingestion_base import (
StatefulIngestionSourceBase,
)
if TYPE_CHECKING:
from mypy_boto3_sagemaker import SageMakerClient
@platform_name("SageMaker")
@config_class(SagemakerSourceConfig)
@ -56,6 +59,7 @@ class SagemakerSource(StatefulIngestionSourceBase):
self.report = SagemakerSourceReport()
self.sagemaker_client = config.sagemaker_client
self.env = config.env
self.client_factory = ClientFactory(config)
@classmethod
def create(cls, config_dict, ctx):
@ -92,7 +96,7 @@ class SagemakerSource(StatefulIngestionSourceBase):
# extract jobs if specified
if self.source_config.extract_jobs is not False:
job_processor = JobProcessor(
sagemaker_client=self.sagemaker_client,
sagemaker_client=self.client_factory.get_client,
env=self.env,
report=self.report,
job_type_filter=self.source_config.extract_jobs,
@ -118,3 +122,15 @@ class SagemakerSource(StatefulIngestionSourceBase):
def get_report(self):
return self.report
class ClientFactory:
def __init__(self, config: SagemakerSourceConfig):
self.config = config
self._cached_client = self.config.sagemaker_client
def get_client(self) -> "SageMakerClient":
if self.config.allowed_cred_refresh():
# Always fetch the client dynamically with auto-refresh logic
return self.config.sagemaker_client
return self._cached_client

View File

@ -4,6 +4,7 @@ from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Callable,
DefaultDict,
Dict,
Iterable,
@ -147,7 +148,7 @@ class JobProcessor:
"""
# boto3 SageMaker client
sagemaker_client: "SageMakerClient"
sagemaker_client: Callable[[], "SageMakerClient"]
env: str
report: SagemakerSourceReport
# config filter for specific job types to ingest (see metadata-ingestion README)
@ -170,8 +171,7 @@ class JobProcessor:
def get_jobs(self, job_type: JobType, job_spec: JobInfo) -> List[Any]:
jobs = []
paginator = self.sagemaker_client.get_paginator(job_spec.list_command)
paginator = self.sagemaker_client().get_paginator(job_spec.list_command)
for page in paginator.paginate():
page_jobs: List[Any] = page[job_spec.list_key]
@ -269,7 +269,7 @@ class JobProcessor:
describe_command = job_type_to_info[job_type].describe_command
describe_name_key = job_type_to_info[job_type].describe_name_key
return getattr(self.sagemaker_client, describe_command)(
return getattr(self.sagemaker_client(), describe_command)(
**{describe_name_key: job_name}
)

View File

@ -1,3 +1,5 @@
from unittest.mock import patch
from botocore.stub import Stubber
from freezegun import freeze_time
@ -220,8 +222,17 @@ def test_sagemaker_ingest(tmp_path, pytestconfig):
{"ModelName": "the-second-model"},
)
mce_objects = [wu.metadata for wu in sagemaker_source_instance.get_workunits()]
write_metadata_file(tmp_path / "sagemaker_mces.json", mce_objects)
# Patch the client factory's get_client method to return the stubbed client for jobs
with patch.object(
sagemaker_source_instance.client_factory,
"get_client",
return_value=sagemaker_source_instance.sagemaker_client,
):
# Run the test and generate the MCEs
mce_objects = [
wu.metadata for wu in sagemaker_source_instance.get_workunits()
]
write_metadata_file(tmp_path / "sagemaker_mces.json", mce_objects)
# Verify the output.
test_resources_dir = pytestconfig.rootpath / "tests/unit/sagemaker"