diff --git a/ingestion/src/metadata/clients/aws_client.py b/ingestion/src/metadata/clients/aws_client.py index 7ffac01a39a..c8647f361e7 100644 --- a/ingestion/src/metadata/clients/aws_client.py +++ b/ingestion/src/metadata/clients/aws_client.py @@ -12,11 +12,14 @@ Module containing AWS Client """ from enum import Enum -from typing import Any +from typing import Any, Optional import boto3 from boto3 import Session +from pydantic import BaseModel +from metadata.generated.schema.security.credentials.awsCredentials import AWSCredentials +from metadata.ingestion.models.custom_pydantic import CustomSecretStr from metadata.utils.logger import utils_logger logger = utils_logger() @@ -30,16 +33,25 @@ class AWSServices(Enum): QUICKSIGHT = "quicksight" +class AWSAssumeRoleException(Exception): + """ + Exception class to handle assume role related issues + """ + + +class AWSAssumeRoleCredentialWrapper(BaseModel): + + accessKeyId: str + secretAccessKey: CustomSecretStr + sessionToken: Optional[str] + + class AWSClient: """ AWSClient creates a boto3 Session client based on AWSCredentials. """ def __init__(self, config: "AWSCredentials"): - # local import to avoid the creation of circular dependencies with CustomSecretStr - from metadata.generated.schema.security.credentials.awsCredentials import ( # pylint: disable=import-outside-toplevel - AWSCredentials, - ) self.config = ( config @@ -47,33 +59,86 @@ class AWSClient: else (AWSCredentials.parse_obj(config) if config else config) ) - def _get_session(self) -> Session: - if ( - self.config.awsAccessKeyId - and self.config.awsSecretAccessKey - and self.config.awsSessionToken - ): - return Session( - aws_access_key_id=self.config.awsAccessKeyId, - aws_secret_access_key=self.config.awsSecretAccessKey.get_secret_value(), - aws_session_token=self.config.awsSessionToken, - region_name=self.config.awsRegion, + @staticmethod + def get_assume_role_config( + config: AWSCredentials, + ) -> Optional[AWSAssumeRoleCredentialWrapper]: + """ + Get temporary credentials from assumed role + """ + session = AWSClient._get_session( + config.awsAccessKeyId, + config.awsSecretAccessKey, + config.awsSessionToken, + config.awsRegion, + config.profileName, + ) + sts_client = session.client("sts") + resp = None + if config.assumeRoleSourceIdentity: + resp = sts_client.assume_role( + RoleArn=config.assumeRoleArn, + RoleSessionName=config.assumeRoleSessionName, + SourceIdentity=config.assumeRoleSourceIdentity, ) - if self.config.awsAccessKeyId and self.config.awsSecretAccessKey: - return Session( - aws_access_key_id=self.config.awsAccessKeyId, - aws_secret_access_key=self.config.awsSecretAccessKey.get_secret_value(), - region_name=self.config.awsRegion, + else: + resp = sts_client.assume_role( + RoleArn=config.assumeRoleArn, + RoleSessionName=config.assumeRoleSessionName, ) - if self.config.awsRegion: - return Session(region_name=self.config.awsRegion) - return Session() + + if resp: + credentials = resp.get("Credentials", {}) + return AWSAssumeRoleCredentialWrapper( + accessKeyId=credentials.get("AccessKeyId"), + secretAccessKey=credentials.get("SecretAccessKey"), + sessionToken=credentials.get("SessionToken"), + ) + return None + + @staticmethod + def _get_session( + aws_access_key_id, + aws_secret_access_key, + aws_session_token, + aws_region, + profile=None, + ) -> Session: + return Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key.get_secret_value() + if aws_secret_access_key + else None, + aws_session_token=aws_session_token, + region_name=aws_region, + profile_name=profile, + ) + + def create_session(self) -> Session: + if self.config.assumeRoleArn: + assume_creds = AWSClient.get_assume_role_config(self.config) + if assume_creds: + return AWSClient._get_session( + assume_creds.accessKeyId, + assume_creds.secretAccessKey, + assume_creds.sessionToken, + self.config.awsRegion, + self.config.profileName, + ) + + return AWSClient._get_session( + self.config.awsAccessKeyId, + self.config.awsSecretAccessKey, + self.config.awsSessionToken, + self.config.awsRegion, + self.config.profileName, + ) def get_client(self, service_name: str) -> Any: # initialize the client depending on the AWSCredentials passed if self.config is not None: logger.info(f"Getting AWS client for service [{service_name}]") - session = self._get_session() + session = self.create_session() if self.config.endPointURL is not None: return session.client( service_name=service_name, endpoint_url=self.config.endPointURL @@ -85,7 +150,7 @@ class AWSClient: return boto3.client(service_name=service_name) def get_resource(self, service_name: str) -> Any: - session = self._get_session() + session = self.create_session() if self.config.endPointURL is not None: return session.resource( service_name=service_name, endpoint_url=self.config.endPointURL diff --git a/ingestion/src/metadata/ingestion/models/custom_pydantic.py b/ingestion/src/metadata/ingestion/models/custom_pydantic.py index 66cbe6f8c4e..f0ff5a80998 100644 --- a/ingestion/src/metadata/ingestion/models/custom_pydantic.py +++ b/ingestion/src/metadata/ingestion/models/custom_pydantic.py @@ -20,7 +20,6 @@ from pydantic.utils import update_not_none from pydantic.validators import constr_length_validator, str_validator from metadata.utils.logger import ingestion_logger -from metadata.utils.secrets.secrets_manager_factory import SecretsManagerFactory logger = ingestion_logger() @@ -76,6 +75,11 @@ class CustomSecretStr(SecretStr): return str(self) def get_secret_value(self, skip_secret_manager: bool = False) -> str: + # Importing inside function to avoid circular import error + from metadata.utils.secrets.secrets_manager_factory import ( # pylint: disable=import-outside-toplevel,cyclic-import + SecretsManagerFactory, + ) + if ( not skip_secret_manager and self._secret_value.startswith("secret:") diff --git a/ingestion/src/metadata/ingestion/source/database/athena/connection.py b/ingestion/src/metadata/ingestion/source/database/athena/connection.py index 90cd60a5621..fa16859336a 100644 --- a/ingestion/src/metadata/ingestion/source/database/athena/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/athena/connection.py @@ -17,6 +17,7 @@ from urllib.parse import quote_plus from sqlalchemy.engine import Engine from sqlalchemy.inspection import inspect +from metadata.clients.aws_client import AWSClient from metadata.generated.schema.entity.services.connections.database.athenaConnection import ( AthenaConnection, ) @@ -32,11 +33,24 @@ from metadata.ingestion.connections.test_connections import ( def get_connection_url(connection: AthenaConnection) -> str: + """ + Method to get connection url + """ + aws_access_key_id = connection.awsConfig.awsAccessKeyId + aws_secret_access_key = connection.awsConfig.awsSecretAccessKey + aws_session_token = connection.awsConfig.awsSessionToken + if connection.awsConfig.assumeRoleArn: + assume_configs = AWSClient.get_assume_role_config(connection.awsConfig) + if assume_configs: + aws_access_key_id = assume_configs.accessKeyId + aws_secret_access_key = assume_configs.secretAccessKey + aws_session_token = assume_configs.sessionToken + url = f"{connection.scheme.value}://" - if connection.awsConfig.awsAccessKeyId: - url += connection.awsConfig.awsAccessKeyId - if connection.awsConfig.awsSecretAccessKey: - url += f":{connection.awsConfig.awsSecretAccessKey.get_secret_value()}" + if aws_access_key_id: + url += aws_access_key_id + if aws_secret_access_key: + url += f":{aws_secret_access_key.get_secret_value()}" else: url += ":" url += f"@athena.{connection.awsConfig.awsRegion}.amazonaws.com:443" @@ -44,8 +58,8 @@ def get_connection_url(connection: AthenaConnection) -> str: url += f"?s3_staging_dir={quote_plus(connection.s3StagingDir)}" if connection.workgroup: url += f"&work_group={connection.workgroup}" - if connection.awsConfig.awsSessionToken: - url += f"&aws_session_token={quote_plus(connection.awsConfig.awsSessionToken)}" + if aws_session_token: + url += f"&aws_session_token={quote_plus(aws_session_token)}" return url diff --git a/openmetadata-spec/src/main/resources/json/schema/security/credentials/awsCredentials.json b/openmetadata-spec/src/main/resources/json/schema/security/credentials/awsCredentials.json index 985c900827d..82600b8563d 100644 --- a/openmetadata-spec/src/main/resources/json/schema/security/credentials/awsCredentials.json +++ b/openmetadata-spec/src/main/resources/json/schema/security/credentials/awsCredentials.json @@ -32,6 +32,27 @@ "description": "EndPoint URL for the AWS", "type": "string", "format": "uri" + }, + "profileName": { + "title": "Profile Name", + "description": "The name of a profile to use with the boto session.", + "type": "string" + }, + "assumeRoleArn": { + "title": "Role Arn for Assume Role", + "description": "The Amazon Resource Name (ARN) of the role to assume. Required Field in case of Assume Role", + "type": "string" + }, + "assumeRoleSessionName": { + "title": "Role Session Name for Assume Role", + "description": "An identifier for the assumed role session. Use the role session name to uniquely identify a session when the same role is assumed by different principals or for different reasons. Required Field in case of Assume Role", + "type": "string", + "default": "OpenMetadataSession" + }, + "assumeRoleSourceIdentity": { + "title": "Source Identity for Assume Role", + "description": "The Amazon Resource Name (ARN) of the role to assume. Optional Field in case of Assume Role", + "type": "string" } }, "additionalProperties": false,