diff --git a/ingestion/src/metadata/clients/aws_client.py b/ingestion/src/metadata/clients/aws_client.py index ca40ff1bb67..e550c41e22a 100644 --- a/ingestion/src/metadata/clients/aws_client.py +++ b/ingestion/src/metadata/clients/aws_client.py @@ -11,6 +11,7 @@ """ Module containing AWS Client """ +import datetime from enum import Enum from functools import partial from typing import Any, Callable, Dict, Optional, Type, TypeVar @@ -47,11 +48,22 @@ class AWSAssumeRoleException(Exception): """ +class AWSAssumeRoleCredentialResponse(BaseModel): + AccessKeyId: str = Field() + SecretAccessKey: str = Field() + SessionToken: Optional[str] = Field( + default=None, + ) + Expiration: Optional[datetime.datetime] = None + + class AWSAssumeRoleCredentialWrapper(BaseModel): - accessKeyId: str = Field(alias="access_key") - secretAccessKey: CustomSecretStr = Field(alias="secret_key") - sessionToken: Optional[str] = Field(default=None, alias="token") - expiryTime: Optional[str] = Field(alias="expiry_time") + accessKeyId: str = Field() + secretAccessKey: CustomSecretStr = Field() + sessionToken: Optional[str] = Field( + default=None, + ) + expiryTime: Optional[str] = Field() AWSAssumeRoleCredentialFormat = TypeVar( @@ -102,12 +114,14 @@ class AWSClient: ) if resp: - credentials = resp.get("Credentials", {}) + credentials: AWSAssumeRoleCredentialResponse = ( + AWSAssumeRoleCredentialResponse(**resp.get("Credentials", {})) + ) creds_wrapper = AWSAssumeRoleCredentialWrapper( - accessKeyId=credentials.get("AccessKeyId"), - secretAccessKey=credentials.get("SecretAccessKey"), - sessionToken=credentials.get("SessionToken"), - expiryTime=credentials.get("Expiration").isformat(), + accessKeyId=credentials.AccessKeyId, + secretAccessKey=credentials.SecretAccessKey, + sessionToken=credentials.SessionToken, + expiryTime=credentials.Expiration.isoformat(), ) if return_type == Dict: return creds_wrapper.model_dump(by_alias=True) @@ -143,9 +157,11 @@ class AWSClient: return Session( aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key.get_secret_value() - if aws_secret_access_key - else None, + aws_secret_access_key=( + aws_secret_access_key.get_secret_value() + if aws_secret_access_key + else None + ), aws_session_token=aws_session_token, region_name=aws_region, profile_name=profile,