mirror of
https://github.com/datahub-project/datahub.git
synced 2025-12-28 02:17:53 +00:00
fix(ingest/sagemaker): ensure consistent STS token usage with refresh mechanism (#11170)
Co-authored-by: Aseem Bansal <asmbansal2@gmail.com>
This commit is contained in:
parent
dc30c0a0b7
commit
50ed448861
@ -1,3 +1,4 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
import boto3
|
||||
@ -73,6 +74,8 @@ class AwsConnectionConfig(ConfigModel):
|
||||
- dbt source
|
||||
"""
|
||||
|
||||
_credentials_expiration: Optional[datetime] = None
|
||||
|
||||
aws_access_key_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description=f"AWS access key ID. {AUTODETECT_CREDENTIALS_DOC_LINK}",
|
||||
@ -115,6 +118,11 @@ class AwsConnectionConfig(ConfigModel):
|
||||
description="Advanced AWS configuration options. These are passed directly to [botocore.config.Config](https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html).",
|
||||
)
|
||||
|
||||
def allowed_cred_refresh(self) -> bool:
|
||||
if self._normalized_aws_roles():
|
||||
return True
|
||||
return False
|
||||
|
||||
def _normalized_aws_roles(self) -> List[AwsAssumeRoleConfig]:
|
||||
if not self.aws_role:
|
||||
return []
|
||||
@ -153,11 +161,14 @@ class AwsConnectionConfig(ConfigModel):
|
||||
}
|
||||
|
||||
for role in self._normalized_aws_roles():
|
||||
credentials = assume_role(
|
||||
role,
|
||||
self.aws_region,
|
||||
credentials=credentials,
|
||||
)
|
||||
if self._should_refresh_credentials():
|
||||
credentials = assume_role(
|
||||
role,
|
||||
self.aws_region,
|
||||
credentials=credentials,
|
||||
)
|
||||
if isinstance(credentials["Expiration"], datetime):
|
||||
self._credentials_expiration = credentials["Expiration"]
|
||||
|
||||
session = Session(
|
||||
aws_access_key_id=credentials["AccessKeyId"],
|
||||
@ -168,6 +179,12 @@ class AwsConnectionConfig(ConfigModel):
|
||||
|
||||
return session
|
||||
|
||||
def _should_refresh_credentials(self) -> bool:
|
||||
if self._credentials_expiration is None:
|
||||
return True
|
||||
remaining_time = self._credentials_expiration - datetime.now(timezone.utc)
|
||||
return remaining_time < timedelta(minutes=5)
|
||||
|
||||
def get_credentials(self) -> Dict[str, Optional[str]]:
|
||||
credentials = self.get_session().get_credentials()
|
||||
if credentials is not None:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from typing import DefaultDict, Dict, Iterable, List, Optional
|
||||
from typing import TYPE_CHECKING, DefaultDict, Dict, Iterable, List, Optional
|
||||
|
||||
from datahub.ingestion.api.common import PipelineContext
|
||||
from datahub.ingestion.api.decorators import (
|
||||
@ -33,6 +33,9 @@ from datahub.ingestion.source.state.stateful_ingestion_base import (
|
||||
StatefulIngestionSourceBase,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mypy_boto3_sagemaker import SageMakerClient
|
||||
|
||||
|
||||
@platform_name("SageMaker")
|
||||
@config_class(SagemakerSourceConfig)
|
||||
@ -56,6 +59,7 @@ class SagemakerSource(StatefulIngestionSourceBase):
|
||||
self.report = SagemakerSourceReport()
|
||||
self.sagemaker_client = config.sagemaker_client
|
||||
self.env = config.env
|
||||
self.client_factory = ClientFactory(config)
|
||||
|
||||
@classmethod
|
||||
def create(cls, config_dict, ctx):
|
||||
@ -92,7 +96,7 @@ class SagemakerSource(StatefulIngestionSourceBase):
|
||||
# extract jobs if specified
|
||||
if self.source_config.extract_jobs is not False:
|
||||
job_processor = JobProcessor(
|
||||
sagemaker_client=self.sagemaker_client,
|
||||
sagemaker_client=self.client_factory.get_client,
|
||||
env=self.env,
|
||||
report=self.report,
|
||||
job_type_filter=self.source_config.extract_jobs,
|
||||
@ -118,3 +122,15 @@ class SagemakerSource(StatefulIngestionSourceBase):
|
||||
|
||||
def get_report(self):
|
||||
return self.report
|
||||
|
||||
|
||||
class ClientFactory:
|
||||
def __init__(self, config: SagemakerSourceConfig):
|
||||
self.config = config
|
||||
self._cached_client = self.config.sagemaker_client
|
||||
|
||||
def get_client(self) -> "SageMakerClient":
|
||||
if self.config.allowed_cred_refresh():
|
||||
# Always fetch the client dynamically with auto-refresh logic
|
||||
return self.config.sagemaker_client
|
||||
return self._cached_client
|
||||
|
||||
@ -4,6 +4,7 @@ from enum import Enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
Iterable,
|
||||
@ -147,7 +148,7 @@ class JobProcessor:
|
||||
"""
|
||||
|
||||
# boto3 SageMaker client
|
||||
sagemaker_client: "SageMakerClient"
|
||||
sagemaker_client: Callable[[], "SageMakerClient"]
|
||||
env: str
|
||||
report: SagemakerSourceReport
|
||||
# config filter for specific job types to ingest (see metadata-ingestion README)
|
||||
@ -170,8 +171,7 @@ class JobProcessor:
|
||||
|
||||
def get_jobs(self, job_type: JobType, job_spec: JobInfo) -> List[Any]:
|
||||
jobs = []
|
||||
|
||||
paginator = self.sagemaker_client.get_paginator(job_spec.list_command)
|
||||
paginator = self.sagemaker_client().get_paginator(job_spec.list_command)
|
||||
for page in paginator.paginate():
|
||||
page_jobs: List[Any] = page[job_spec.list_key]
|
||||
|
||||
@ -269,7 +269,7 @@ class JobProcessor:
|
||||
describe_command = job_type_to_info[job_type].describe_command
|
||||
describe_name_key = job_type_to_info[job_type].describe_name_key
|
||||
|
||||
return getattr(self.sagemaker_client, describe_command)(
|
||||
return getattr(self.sagemaker_client(), describe_command)(
|
||||
**{describe_name_key: job_name}
|
||||
)
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from botocore.stub import Stubber
|
||||
from freezegun import freeze_time
|
||||
|
||||
@ -220,8 +222,17 @@ def test_sagemaker_ingest(tmp_path, pytestconfig):
|
||||
{"ModelName": "the-second-model"},
|
||||
)
|
||||
|
||||
mce_objects = [wu.metadata for wu in sagemaker_source_instance.get_workunits()]
|
||||
write_metadata_file(tmp_path / "sagemaker_mces.json", mce_objects)
|
||||
# Patch the client factory's get_client method to return the stubbed client for jobs
|
||||
with patch.object(
|
||||
sagemaker_source_instance.client_factory,
|
||||
"get_client",
|
||||
return_value=sagemaker_source_instance.sagemaker_client,
|
||||
):
|
||||
# Run the test and generate the MCEs
|
||||
mce_objects = [
|
||||
wu.metadata for wu in sagemaker_source_instance.get_workunits()
|
||||
]
|
||||
write_metadata_file(tmp_path / "sagemaker_mces.json", mce_objects)
|
||||
|
||||
# Verify the output.
|
||||
test_resources_dir = pytestconfig.rootpath / "tests/unit/sagemaker"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user