From 05c57857aaf6e70896df99a23b5eac805a3b8ccd Mon Sep 17 00:00:00 2001 From: Abdallah Serghine <76706155+KylixSerg@users.noreply.github.com> Date: Fri, 3 Jan 2025 10:51:15 +0100 Subject: [PATCH] Fixes ISSUE-19095: auto refresh boto credentials (#19098) When using using the assumeRole connection option with jobs involving AWS services, the default 1 hour of boto might not cut it and the job fails, one way of solving this is to refresh the credentials when they expire, this ensures there is always valid credentials for the job regardless for how long it runs. Co-authored-by: Abdallah Serghine Co-authored-by: IceS2 --- ingestion/src/metadata/clients/aws_client.py | 63 +++++++++++++++----- 1 file changed, 47 insertions(+), 16 deletions(-) diff --git a/ingestion/src/metadata/clients/aws_client.py b/ingestion/src/metadata/clients/aws_client.py index 8e87c600f96..ca40ff1bb67 100644 --- a/ingestion/src/metadata/clients/aws_client.py +++ b/ingestion/src/metadata/clients/aws_client.py @@ -12,11 +12,14 @@ Module containing AWS Client """ from enum import Enum -from typing import Any, Optional +from functools import partial +from typing import Any, Callable, Dict, Optional, Type, TypeVar import boto3 from boto3 import Session -from pydantic import BaseModel +from botocore.credentials import RefreshableCredentials +from botocore.session import get_session +from pydantic import BaseModel, Field from metadata.generated.schema.security.credentials.awsCredentials import AWSCredentials from metadata.ingestion.models.custom_pydantic import CustomSecretStr @@ -45,9 +48,15 @@ class AWSAssumeRoleException(Exception): class AWSAssumeRoleCredentialWrapper(BaseModel): - accessKeyId: str - secretAccessKey: CustomSecretStr - sessionToken: Optional[str] = None + accessKeyId: str = Field(alias="access_key") + secretAccessKey: CustomSecretStr = Field(alias="secret_key") + sessionToken: Optional[str] = Field(default=None, alias="token") + expiryTime: Optional[str] = Field(alias="expiry_time") + + +AWSAssumeRoleCredentialFormat = TypeVar( + "AWSAssumeRoleCredentialFormat", AWSAssumeRoleCredentialWrapper, Dict +) class AWSClient: @@ -65,7 +74,10 @@ class AWSClient: @staticmethod def get_assume_role_config( config: AWSCredentials, - ) -> Optional[AWSAssumeRoleCredentialWrapper]: + return_type: Type[ + AWSAssumeRoleCredentialFormat + ] = AWSAssumeRoleCredentialWrapper, + ) -> Optional[AWSAssumeRoleCredentialFormat]: """ Get temporary credentials from assumed role """ @@ -91,11 +103,16 @@ class AWSClient: if resp: credentials = resp.get("Credentials", {}) - return AWSAssumeRoleCredentialWrapper( + creds_wrapper = AWSAssumeRoleCredentialWrapper( accessKeyId=credentials.get("AccessKeyId"), secretAccessKey=credentials.get("SecretAccessKey"), sessionToken=credentials.get("SessionToken"), + expiryTime=credentials.get("Expiration").isformat(), ) + if return_type == Dict: + return creds_wrapper.model_dump(by_alias=True) + return creds_wrapper + return None @staticmethod @@ -105,12 +122,25 @@ class AWSClient: aws_session_token: Optional[str], aws_region: str, profile=None, + refresh_using: Optional[Callable] = None, ) -> Session: """ The only required param for boto3 is the region. The rest of credentials will have fallback strategies based on https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html#configuring-credentials """ + if refresh_using: + refreshable_creds = RefreshableCredentials.create_from_metadata( + metadata=refresh_using(), + refresh_using=refresh_using, + method="sts-assume-role", + ) + session = get_session() + session._credentials = refreshable_creds # pylint: disable=protected-access + return Session( + botocore_session=session, region_name=aws_region, profile_name=profile + ) + return Session( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key.get_secret_value() @@ -123,15 +153,16 @@ class AWSClient: def create_session(self) -> Session: if self.config.assumeRoleArn: - assume_creds = AWSClient.get_assume_role_config(self.config) - if assume_creds: - return AWSClient._get_session( - assume_creds.accessKeyId, - assume_creds.secretAccessKey, - assume_creds.sessionToken, - self.config.awsRegion, - self.config.profileName, - ) + return AWSClient._get_session( + None, + None, + None, + self.config.awsRegion, + self.config.profileName, + refresh_using=partial( + AWSClient.get_assume_role_config, self.config, Dict + ), + ) return AWSClient._get_session( self.config.awsAccessKeyId,