Fixes ISSUE-19095: auto refresh boto credentials (#19098)

When using using the assumeRole connection option with jobs involving
AWS services, the default 1 hour of boto might not cut it and the job
fails, one way of solving this is to refresh the credentials when they
expire, this ensures there is always valid credentials for the job
regardless for how long it runs.

Co-authored-by: Abdallah Serghine <abdallah.serghine@olx.pl>
Co-authored-by: IceS2 <pjt1991@gmail.com>
This commit is contained in:
Abdallah Serghine 2025-01-03 10:51:15 +01:00 committed by GitHub
parent d60327c448
commit 05c57857aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -12,11 +12,14 @@
Module containing AWS Client Module containing AWS Client
""" """
from enum import Enum from enum import Enum
from typing import Any, Optional from functools import partial
from typing import Any, Callable, Dict, Optional, Type, TypeVar
import boto3 import boto3
from boto3 import Session from boto3 import Session
from pydantic import BaseModel from botocore.credentials import RefreshableCredentials
from botocore.session import get_session
from pydantic import BaseModel, Field
from metadata.generated.schema.security.credentials.awsCredentials import AWSCredentials from metadata.generated.schema.security.credentials.awsCredentials import AWSCredentials
from metadata.ingestion.models.custom_pydantic import CustomSecretStr from metadata.ingestion.models.custom_pydantic import CustomSecretStr
@ -45,9 +48,15 @@ class AWSAssumeRoleException(Exception):
class AWSAssumeRoleCredentialWrapper(BaseModel): class AWSAssumeRoleCredentialWrapper(BaseModel):
accessKeyId: str accessKeyId: str = Field(alias="access_key")
secretAccessKey: CustomSecretStr secretAccessKey: CustomSecretStr = Field(alias="secret_key")
sessionToken: Optional[str] = None sessionToken: Optional[str] = Field(default=None, alias="token")
expiryTime: Optional[str] = Field(alias="expiry_time")
AWSAssumeRoleCredentialFormat = TypeVar(
"AWSAssumeRoleCredentialFormat", AWSAssumeRoleCredentialWrapper, Dict
)
class AWSClient: class AWSClient:
@ -65,7 +74,10 @@ class AWSClient:
@staticmethod @staticmethod
def get_assume_role_config( def get_assume_role_config(
config: AWSCredentials, config: AWSCredentials,
) -> Optional[AWSAssumeRoleCredentialWrapper]: return_type: Type[
AWSAssumeRoleCredentialFormat
] = AWSAssumeRoleCredentialWrapper,
) -> Optional[AWSAssumeRoleCredentialFormat]:
""" """
Get temporary credentials from assumed role Get temporary credentials from assumed role
""" """
@ -91,11 +103,16 @@ class AWSClient:
if resp: if resp:
credentials = resp.get("Credentials", {}) credentials = resp.get("Credentials", {})
return AWSAssumeRoleCredentialWrapper( creds_wrapper = AWSAssumeRoleCredentialWrapper(
accessKeyId=credentials.get("AccessKeyId"), accessKeyId=credentials.get("AccessKeyId"),
secretAccessKey=credentials.get("SecretAccessKey"), secretAccessKey=credentials.get("SecretAccessKey"),
sessionToken=credentials.get("SessionToken"), sessionToken=credentials.get("SessionToken"),
expiryTime=credentials.get("Expiration").isformat(),
) )
if return_type == Dict:
return creds_wrapper.model_dump(by_alias=True)
return creds_wrapper
return None return None
@staticmethod @staticmethod
@ -105,12 +122,25 @@ class AWSClient:
aws_session_token: Optional[str], aws_session_token: Optional[str],
aws_region: str, aws_region: str,
profile=None, profile=None,
refresh_using: Optional[Callable] = None,
) -> Session: ) -> Session:
""" """
The only required param for boto3 is the region. The only required param for boto3 is the region.
The rest of credentials will have fallback strategies based on The rest of credentials will have fallback strategies based on
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html#configuring-credentials https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html#configuring-credentials
""" """
if refresh_using:
refreshable_creds = RefreshableCredentials.create_from_metadata(
metadata=refresh_using(),
refresh_using=refresh_using,
method="sts-assume-role",
)
session = get_session()
session._credentials = refreshable_creds # pylint: disable=protected-access
return Session(
botocore_session=session, region_name=aws_region, profile_name=profile
)
return Session( return Session(
aws_access_key_id=aws_access_key_id, aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key.get_secret_value() aws_secret_access_key=aws_secret_access_key.get_secret_value()
@ -123,15 +153,16 @@ class AWSClient:
def create_session(self) -> Session: def create_session(self) -> Session:
if self.config.assumeRoleArn: if self.config.assumeRoleArn:
assume_creds = AWSClient.get_assume_role_config(self.config) return AWSClient._get_session(
if assume_creds: None,
return AWSClient._get_session( None,
assume_creds.accessKeyId, None,
assume_creds.secretAccessKey, self.config.awsRegion,
assume_creds.sessionToken, self.config.profileName,
self.config.awsRegion, refresh_using=partial(
self.config.profileName, AWSClient.get_assume_role_config, self.config, Dict
) ),
)
return AWSClient._get_session( return AWSClient._get_session(
self.config.awsAccessKeyId, self.config.awsAccessKeyId,