Fix #10402: Add support for AssumeRole for AWS (#10417)

This commit is contained in:
Mayur Singal 2023-03-08 15:43:33 +05:30 committed by GitHub
parent d41878ec90
commit c199f13ed0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 137 additions and 33 deletions

View File

@ -12,11 +12,14 @@
Module containing AWS Client Module containing AWS Client
""" """
from enum import Enum from enum import Enum
from typing import Any from typing import Any, Optional
import boto3 import boto3
from boto3 import Session from boto3 import Session
from pydantic import BaseModel
from metadata.generated.schema.security.credentials.awsCredentials import AWSCredentials
from metadata.ingestion.models.custom_pydantic import CustomSecretStr
from metadata.utils.logger import utils_logger from metadata.utils.logger import utils_logger
logger = utils_logger() logger = utils_logger()
@ -30,16 +33,25 @@ class AWSServices(Enum):
QUICKSIGHT = "quicksight" QUICKSIGHT = "quicksight"
class AWSAssumeRoleException(Exception):
"""
Exception class to handle assume role related issues
"""
class AWSAssumeRoleCredentialWrapper(BaseModel):
accessKeyId: str
secretAccessKey: CustomSecretStr
sessionToken: Optional[str]
class AWSClient: class AWSClient:
""" """
AWSClient creates a boto3 Session client based on AWSCredentials. AWSClient creates a boto3 Session client based on AWSCredentials.
""" """
def __init__(self, config: "AWSCredentials"): def __init__(self, config: "AWSCredentials"):
# local import to avoid the creation of circular dependencies with CustomSecretStr
from metadata.generated.schema.security.credentials.awsCredentials import ( # pylint: disable=import-outside-toplevel
AWSCredentials,
)
self.config = ( self.config = (
config config
@ -47,33 +59,86 @@ class AWSClient:
else (AWSCredentials.parse_obj(config) if config else config) else (AWSCredentials.parse_obj(config) if config else config)
) )
def _get_session(self) -> Session: @staticmethod
if ( def get_assume_role_config(
self.config.awsAccessKeyId config: AWSCredentials,
and self.config.awsSecretAccessKey ) -> Optional[AWSAssumeRoleCredentialWrapper]:
and self.config.awsSessionToken """
): Get temporary credentials from assumed role
return Session( """
aws_access_key_id=self.config.awsAccessKeyId, session = AWSClient._get_session(
aws_secret_access_key=self.config.awsSecretAccessKey.get_secret_value(), config.awsAccessKeyId,
aws_session_token=self.config.awsSessionToken, config.awsSecretAccessKey,
region_name=self.config.awsRegion, config.awsSessionToken,
config.awsRegion,
config.profileName,
)
sts_client = session.client("sts")
resp = None
if config.assumeRoleSourceIdentity:
resp = sts_client.assume_role(
RoleArn=config.assumeRoleArn,
RoleSessionName=config.assumeRoleSessionName,
SourceIdentity=config.assumeRoleSourceIdentity,
) )
if self.config.awsAccessKeyId and self.config.awsSecretAccessKey: else:
return Session( resp = sts_client.assume_role(
aws_access_key_id=self.config.awsAccessKeyId, RoleArn=config.assumeRoleArn,
aws_secret_access_key=self.config.awsSecretAccessKey.get_secret_value(), RoleSessionName=config.assumeRoleSessionName,
region_name=self.config.awsRegion,
) )
if self.config.awsRegion:
return Session(region_name=self.config.awsRegion) if resp:
return Session() credentials = resp.get("Credentials", {})
return AWSAssumeRoleCredentialWrapper(
accessKeyId=credentials.get("AccessKeyId"),
secretAccessKey=credentials.get("SecretAccessKey"),
sessionToken=credentials.get("SessionToken"),
)
return None
@staticmethod
def _get_session(
aws_access_key_id,
aws_secret_access_key,
aws_session_token,
aws_region,
profile=None,
) -> Session:
return Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key.get_secret_value()
if aws_secret_access_key
else None,
aws_session_token=aws_session_token,
region_name=aws_region,
profile_name=profile,
)
def create_session(self) -> Session:
if self.config.assumeRoleArn:
assume_creds = AWSClient.get_assume_role_config(self.config)
if assume_creds:
return AWSClient._get_session(
assume_creds.accessKeyId,
assume_creds.secretAccessKey,
assume_creds.sessionToken,
self.config.awsRegion,
self.config.profileName,
)
return AWSClient._get_session(
self.config.awsAccessKeyId,
self.config.awsSecretAccessKey,
self.config.awsSessionToken,
self.config.awsRegion,
self.config.profileName,
)
def get_client(self, service_name: str) -> Any: def get_client(self, service_name: str) -> Any:
# initialize the client depending on the AWSCredentials passed # initialize the client depending on the AWSCredentials passed
if self.config is not None: if self.config is not None:
logger.info(f"Getting AWS client for service [{service_name}]") logger.info(f"Getting AWS client for service [{service_name}]")
session = self._get_session() session = self.create_session()
if self.config.endPointURL is not None: if self.config.endPointURL is not None:
return session.client( return session.client(
service_name=service_name, endpoint_url=self.config.endPointURL service_name=service_name, endpoint_url=self.config.endPointURL
@ -85,7 +150,7 @@ class AWSClient:
return boto3.client(service_name=service_name) return boto3.client(service_name=service_name)
def get_resource(self, service_name: str) -> Any: def get_resource(self, service_name: str) -> Any:
session = self._get_session() session = self.create_session()
if self.config.endPointURL is not None: if self.config.endPointURL is not None:
return session.resource( return session.resource(
service_name=service_name, endpoint_url=self.config.endPointURL service_name=service_name, endpoint_url=self.config.endPointURL

View File

@ -20,7 +20,6 @@ from pydantic.utils import update_not_none
from pydantic.validators import constr_length_validator, str_validator from pydantic.validators import constr_length_validator, str_validator
from metadata.utils.logger import ingestion_logger from metadata.utils.logger import ingestion_logger
from metadata.utils.secrets.secrets_manager_factory import SecretsManagerFactory
logger = ingestion_logger() logger = ingestion_logger()
@ -76,6 +75,11 @@ class CustomSecretStr(SecretStr):
return str(self) return str(self)
def get_secret_value(self, skip_secret_manager: bool = False) -> str: def get_secret_value(self, skip_secret_manager: bool = False) -> str:
# Importing inside function to avoid circular import error
from metadata.utils.secrets.secrets_manager_factory import ( # pylint: disable=import-outside-toplevel,cyclic-import
SecretsManagerFactory,
)
if ( if (
not skip_secret_manager not skip_secret_manager
and self._secret_value.startswith("secret:") and self._secret_value.startswith("secret:")

View File

@ -17,6 +17,7 @@ from urllib.parse import quote_plus
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
from sqlalchemy.inspection import inspect from sqlalchemy.inspection import inspect
from metadata.clients.aws_client import AWSClient
from metadata.generated.schema.entity.services.connections.database.athenaConnection import ( from metadata.generated.schema.entity.services.connections.database.athenaConnection import (
AthenaConnection, AthenaConnection,
) )
@ -32,11 +33,24 @@ from metadata.ingestion.connections.test_connections import (
def get_connection_url(connection: AthenaConnection) -> str: def get_connection_url(connection: AthenaConnection) -> str:
"""
Method to get connection url
"""
aws_access_key_id = connection.awsConfig.awsAccessKeyId
aws_secret_access_key = connection.awsConfig.awsSecretAccessKey
aws_session_token = connection.awsConfig.awsSessionToken
if connection.awsConfig.assumeRoleArn:
assume_configs = AWSClient.get_assume_role_config(connection.awsConfig)
if assume_configs:
aws_access_key_id = assume_configs.accessKeyId
aws_secret_access_key = assume_configs.secretAccessKey
aws_session_token = assume_configs.sessionToken
url = f"{connection.scheme.value}://" url = f"{connection.scheme.value}://"
if connection.awsConfig.awsAccessKeyId: if aws_access_key_id:
url += connection.awsConfig.awsAccessKeyId url += aws_access_key_id
if connection.awsConfig.awsSecretAccessKey: if aws_secret_access_key:
url += f":{connection.awsConfig.awsSecretAccessKey.get_secret_value()}" url += f":{aws_secret_access_key.get_secret_value()}"
else: else:
url += ":" url += ":"
url += f"@athena.{connection.awsConfig.awsRegion}.amazonaws.com:443" url += f"@athena.{connection.awsConfig.awsRegion}.amazonaws.com:443"
@ -44,8 +58,8 @@ def get_connection_url(connection: AthenaConnection) -> str:
url += f"?s3_staging_dir={quote_plus(connection.s3StagingDir)}" url += f"?s3_staging_dir={quote_plus(connection.s3StagingDir)}"
if connection.workgroup: if connection.workgroup:
url += f"&work_group={connection.workgroup}" url += f"&work_group={connection.workgroup}"
if connection.awsConfig.awsSessionToken: if aws_session_token:
url += f"&aws_session_token={quote_plus(connection.awsConfig.awsSessionToken)}" url += f"&aws_session_token={quote_plus(aws_session_token)}"
return url return url

View File

@ -32,6 +32,27 @@
"description": "EndPoint URL for the AWS", "description": "EndPoint URL for the AWS",
"type": "string", "type": "string",
"format": "uri" "format": "uri"
},
"profileName": {
"title": "Profile Name",
"description": "The name of a profile to use with the boto session.",
"type": "string"
},
"assumeRoleArn": {
"title": "Role Arn for Assume Role",
"description": "The Amazon Resource Name (ARN) of the role to assume. Required Field in case of Assume Role",
"type": "string"
},
"assumeRoleSessionName": {
"title": "Role Session Name for Assume Role",
"description": "An identifier for the assumed role session. Use the role session name to uniquely identify a session when the same role is assumed by different principals or for different reasons. Required Field in case of Assume Role",
"type": "string",
"default": "OpenMetadataSession"
},
"assumeRoleSourceIdentity": {
"title": "Source Identity for Assume Role",
"description": "The Amazon Resource Name (ARN) of the role to assume. Optional Field in case of Assume Role",
"type": "string"
} }
}, },
"additionalProperties": false, "additionalProperties": false,