Fix #10402: Add support for AssumeRole for AWS (#10417)

This commit is contained in:
Mayur Singal 2023-03-08 15:43:33 +05:30 committed by GitHub
parent d41878ec90
commit c199f13ed0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 137 additions and 33 deletions

View File

@ -12,11 +12,14 @@
Module containing AWS Client
"""
from enum import Enum
from typing import Any
from typing import Any, Optional
import boto3
from boto3 import Session
from pydantic import BaseModel
from metadata.generated.schema.security.credentials.awsCredentials import AWSCredentials
from metadata.ingestion.models.custom_pydantic import CustomSecretStr
from metadata.utils.logger import utils_logger
logger = utils_logger()
@ -30,16 +33,25 @@ class AWSServices(Enum):
QUICKSIGHT = "quicksight"
class AWSAssumeRoleException(Exception):
"""
Exception class to handle assume role related issues
"""
class AWSAssumeRoleCredentialWrapper(BaseModel):
accessKeyId: str
secretAccessKey: CustomSecretStr
sessionToken: Optional[str]
class AWSClient:
"""
AWSClient creates a boto3 Session client based on AWSCredentials.
"""
def __init__(self, config: "AWSCredentials"):
# local import to avoid the creation of circular dependencies with CustomSecretStr
from metadata.generated.schema.security.credentials.awsCredentials import ( # pylint: disable=import-outside-toplevel
AWSCredentials,
)
self.config = (
config
@ -47,33 +59,86 @@ class AWSClient:
else (AWSCredentials.parse_obj(config) if config else config)
)
def _get_session(self) -> Session:
if (
self.config.awsAccessKeyId
and self.config.awsSecretAccessKey
and self.config.awsSessionToken
):
return Session(
aws_access_key_id=self.config.awsAccessKeyId,
aws_secret_access_key=self.config.awsSecretAccessKey.get_secret_value(),
aws_session_token=self.config.awsSessionToken,
region_name=self.config.awsRegion,
@staticmethod
def get_assume_role_config(
config: AWSCredentials,
) -> Optional[AWSAssumeRoleCredentialWrapper]:
"""
Get temporary credentials from assumed role
"""
session = AWSClient._get_session(
config.awsAccessKeyId,
config.awsSecretAccessKey,
config.awsSessionToken,
config.awsRegion,
config.profileName,
)
if self.config.awsAccessKeyId and self.config.awsSecretAccessKey:
return Session(
aws_access_key_id=self.config.awsAccessKeyId,
aws_secret_access_key=self.config.awsSecretAccessKey.get_secret_value(),
region_name=self.config.awsRegion,
sts_client = session.client("sts")
resp = None
if config.assumeRoleSourceIdentity:
resp = sts_client.assume_role(
RoleArn=config.assumeRoleArn,
RoleSessionName=config.assumeRoleSessionName,
SourceIdentity=config.assumeRoleSourceIdentity,
)
else:
resp = sts_client.assume_role(
RoleArn=config.assumeRoleArn,
RoleSessionName=config.assumeRoleSessionName,
)
if resp:
credentials = resp.get("Credentials", {})
return AWSAssumeRoleCredentialWrapper(
accessKeyId=credentials.get("AccessKeyId"),
secretAccessKey=credentials.get("SecretAccessKey"),
sessionToken=credentials.get("SessionToken"),
)
return None
@staticmethod
def _get_session(
aws_access_key_id,
aws_secret_access_key,
aws_session_token,
aws_region,
profile=None,
) -> Session:
return Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key.get_secret_value()
if aws_secret_access_key
else None,
aws_session_token=aws_session_token,
region_name=aws_region,
profile_name=profile,
)
def create_session(self) -> Session:
if self.config.assumeRoleArn:
assume_creds = AWSClient.get_assume_role_config(self.config)
if assume_creds:
return AWSClient._get_session(
assume_creds.accessKeyId,
assume_creds.secretAccessKey,
assume_creds.sessionToken,
self.config.awsRegion,
self.config.profileName,
)
return AWSClient._get_session(
self.config.awsAccessKeyId,
self.config.awsSecretAccessKey,
self.config.awsSessionToken,
self.config.awsRegion,
self.config.profileName,
)
if self.config.awsRegion:
return Session(region_name=self.config.awsRegion)
return Session()
def get_client(self, service_name: str) -> Any:
# initialize the client depending on the AWSCredentials passed
if self.config is not None:
logger.info(f"Getting AWS client for service [{service_name}]")
session = self._get_session()
session = self.create_session()
if self.config.endPointURL is not None:
return session.client(
service_name=service_name, endpoint_url=self.config.endPointURL
@ -85,7 +150,7 @@ class AWSClient:
return boto3.client(service_name=service_name)
def get_resource(self, service_name: str) -> Any:
session = self._get_session()
session = self.create_session()
if self.config.endPointURL is not None:
return session.resource(
service_name=service_name, endpoint_url=self.config.endPointURL

View File

@ -20,7 +20,6 @@ from pydantic.utils import update_not_none
from pydantic.validators import constr_length_validator, str_validator
from metadata.utils.logger import ingestion_logger
from metadata.utils.secrets.secrets_manager_factory import SecretsManagerFactory
logger = ingestion_logger()
@ -76,6 +75,11 @@ class CustomSecretStr(SecretStr):
return str(self)
def get_secret_value(self, skip_secret_manager: bool = False) -> str:
# Importing inside function to avoid circular import error
from metadata.utils.secrets.secrets_manager_factory import ( # pylint: disable=import-outside-toplevel,cyclic-import
SecretsManagerFactory,
)
if (
not skip_secret_manager
and self._secret_value.startswith("secret:")

View File

@ -17,6 +17,7 @@ from urllib.parse import quote_plus
from sqlalchemy.engine import Engine
from sqlalchemy.inspection import inspect
from metadata.clients.aws_client import AWSClient
from metadata.generated.schema.entity.services.connections.database.athenaConnection import (
AthenaConnection,
)
@ -32,11 +33,24 @@ from metadata.ingestion.connections.test_connections import (
def get_connection_url(connection: AthenaConnection) -> str:
"""
Method to get connection url
"""
aws_access_key_id = connection.awsConfig.awsAccessKeyId
aws_secret_access_key = connection.awsConfig.awsSecretAccessKey
aws_session_token = connection.awsConfig.awsSessionToken
if connection.awsConfig.assumeRoleArn:
assume_configs = AWSClient.get_assume_role_config(connection.awsConfig)
if assume_configs:
aws_access_key_id = assume_configs.accessKeyId
aws_secret_access_key = assume_configs.secretAccessKey
aws_session_token = assume_configs.sessionToken
url = f"{connection.scheme.value}://"
if connection.awsConfig.awsAccessKeyId:
url += connection.awsConfig.awsAccessKeyId
if connection.awsConfig.awsSecretAccessKey:
url += f":{connection.awsConfig.awsSecretAccessKey.get_secret_value()}"
if aws_access_key_id:
url += aws_access_key_id
if aws_secret_access_key:
url += f":{aws_secret_access_key.get_secret_value()}"
else:
url += ":"
url += f"@athena.{connection.awsConfig.awsRegion}.amazonaws.com:443"
@ -44,8 +58,8 @@ def get_connection_url(connection: AthenaConnection) -> str:
url += f"?s3_staging_dir={quote_plus(connection.s3StagingDir)}"
if connection.workgroup:
url += f"&work_group={connection.workgroup}"
if connection.awsConfig.awsSessionToken:
url += f"&aws_session_token={quote_plus(connection.awsConfig.awsSessionToken)}"
if aws_session_token:
url += f"&aws_session_token={quote_plus(aws_session_token)}"
return url

View File

@ -32,6 +32,27 @@
"description": "EndPoint URL for the AWS",
"type": "string",
"format": "uri"
},
"profileName": {
"title": "Profile Name",
"description": "The name of a profile to use with the boto session.",
"type": "string"
},
"assumeRoleArn": {
"title": "Role Arn for Assume Role",
"description": "The Amazon Resource Name (ARN) of the role to assume. Required Field in case of Assume Role",
"type": "string"
},
"assumeRoleSessionName": {
"title": "Role Session Name for Assume Role",
"description": "An identifier for the assumed role session. Use the role session name to uniquely identify a session when the same role is assumed by different principals or for different reasons. Required Field in case of Assume Role",
"type": "string",
"default": "OpenMetadataSession"
},
"assumeRoleSourceIdentity": {
"title": "Source Identity for Assume Role",
"description": "The Amazon Resource Name (ARN) of the role to assume. Optional Field in case of Assume Role",
"type": "string"
}
},
"additionalProperties": false,