diff --git a/ingestion/src/metadata/clients/aws_client.py b/ingestion/src/metadata/clients/aws_client.py index 8e87c600f96..ca40ff1bb67 100644 --- a/ingestion/src/metadata/clients/aws_client.py +++ b/ingestion/src/metadata/clients/aws_client.py @@ -12,11 +12,14 @@ Module containing AWS Client """ from enum import Enum -from typing import Any, Optional +from functools import partial +from typing import Any, Callable, Dict, Optional, Type, TypeVar import boto3 from boto3 import Session -from pydantic import BaseModel +from botocore.credentials import RefreshableCredentials +from botocore.session import get_session +from pydantic import BaseModel, Field from metadata.generated.schema.security.credentials.awsCredentials import AWSCredentials from metadata.ingestion.models.custom_pydantic import CustomSecretStr @@ -45,9 +48,15 @@ class AWSAssumeRoleException(Exception): class AWSAssumeRoleCredentialWrapper(BaseModel): - accessKeyId: str - secretAccessKey: CustomSecretStr - sessionToken: Optional[str] = None + accessKeyId: str = Field(alias="access_key") + secretAccessKey: CustomSecretStr = Field(alias="secret_key") + sessionToken: Optional[str] = Field(default=None, alias="token") + expiryTime: Optional[str] = Field(alias="expiry_time") + + +AWSAssumeRoleCredentialFormat = TypeVar( + "AWSAssumeRoleCredentialFormat", AWSAssumeRoleCredentialWrapper, Dict +) class AWSClient: @@ -65,7 +74,10 @@ class AWSClient: @staticmethod def get_assume_role_config( config: AWSCredentials, - ) -> Optional[AWSAssumeRoleCredentialWrapper]: + return_type: Type[ + AWSAssumeRoleCredentialFormat + ] = AWSAssumeRoleCredentialWrapper, + ) -> Optional[AWSAssumeRoleCredentialFormat]: """ Get temporary credentials from assumed role """ @@ -91,11 +103,16 @@ class AWSClient: if resp: credentials = resp.get("Credentials", {}) - return AWSAssumeRoleCredentialWrapper( + creds_wrapper = AWSAssumeRoleCredentialWrapper( accessKeyId=credentials.get("AccessKeyId"), secretAccessKey=credentials.get("SecretAccessKey"), sessionToken=credentials.get("SessionToken"), + expiryTime=credentials.get("Expiration").isformat(), ) + if return_type == Dict: + return creds_wrapper.model_dump(by_alias=True) + return creds_wrapper + return None @staticmethod @@ -105,12 +122,25 @@ class AWSClient: aws_session_token: Optional[str], aws_region: str, profile=None, + refresh_using: Optional[Callable] = None, ) -> Session: """ The only required param for boto3 is the region. The rest of credentials will have fallback strategies based on https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html#configuring-credentials """ + if refresh_using: + refreshable_creds = RefreshableCredentials.create_from_metadata( + metadata=refresh_using(), + refresh_using=refresh_using, + method="sts-assume-role", + ) + session = get_session() + session._credentials = refreshable_creds # pylint: disable=protected-access + return Session( + botocore_session=session, region_name=aws_region, profile_name=profile + ) + return Session( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key.get_secret_value() @@ -123,15 +153,16 @@ class AWSClient: def create_session(self) -> Session: if self.config.assumeRoleArn: - assume_creds = AWSClient.get_assume_role_config(self.config) - if assume_creds: - return AWSClient._get_session( - assume_creds.accessKeyId, - assume_creds.secretAccessKey, - assume_creds.sessionToken, - self.config.awsRegion, - self.config.profileName, - ) + return AWSClient._get_session( + None, + None, + None, + self.config.awsRegion, + self.config.profileName, + refresh_using=partial( + AWSClient.get_assume_role_config, self.config, Dict + ), + ) return AWSClient._get_session( self.config.awsAccessKeyId,