mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-07-23 09:22:18 +00:00
parent
d41878ec90
commit
c199f13ed0
@ -12,11 +12,14 @@
|
||||
Module containing AWS Client
|
||||
"""
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import boto3
|
||||
from boto3 import Session
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metadata.generated.schema.security.credentials.awsCredentials import AWSCredentials
|
||||
from metadata.ingestion.models.custom_pydantic import CustomSecretStr
|
||||
from metadata.utils.logger import utils_logger
|
||||
|
||||
logger = utils_logger()
|
||||
@ -30,16 +33,25 @@ class AWSServices(Enum):
|
||||
QUICKSIGHT = "quicksight"
|
||||
|
||||
|
||||
class AWSAssumeRoleException(Exception):
|
||||
"""
|
||||
Exception class to handle assume role related issues
|
||||
"""
|
||||
|
||||
|
||||
class AWSAssumeRoleCredentialWrapper(BaseModel):
|
||||
|
||||
accessKeyId: str
|
||||
secretAccessKey: CustomSecretStr
|
||||
sessionToken: Optional[str]
|
||||
|
||||
|
||||
class AWSClient:
|
||||
"""
|
||||
AWSClient creates a boto3 Session client based on AWSCredentials.
|
||||
"""
|
||||
|
||||
def __init__(self, config: "AWSCredentials"):
|
||||
# local import to avoid the creation of circular dependencies with CustomSecretStr
|
||||
from metadata.generated.schema.security.credentials.awsCredentials import ( # pylint: disable=import-outside-toplevel
|
||||
AWSCredentials,
|
||||
)
|
||||
|
||||
self.config = (
|
||||
config
|
||||
@ -47,33 +59,86 @@ class AWSClient:
|
||||
else (AWSCredentials.parse_obj(config) if config else config)
|
||||
)
|
||||
|
||||
def _get_session(self) -> Session:
|
||||
if (
|
||||
self.config.awsAccessKeyId
|
||||
and self.config.awsSecretAccessKey
|
||||
and self.config.awsSessionToken
|
||||
):
|
||||
return Session(
|
||||
aws_access_key_id=self.config.awsAccessKeyId,
|
||||
aws_secret_access_key=self.config.awsSecretAccessKey.get_secret_value(),
|
||||
aws_session_token=self.config.awsSessionToken,
|
||||
region_name=self.config.awsRegion,
|
||||
@staticmethod
|
||||
def get_assume_role_config(
|
||||
config: AWSCredentials,
|
||||
) -> Optional[AWSAssumeRoleCredentialWrapper]:
|
||||
"""
|
||||
Get temporary credentials from assumed role
|
||||
"""
|
||||
session = AWSClient._get_session(
|
||||
config.awsAccessKeyId,
|
||||
config.awsSecretAccessKey,
|
||||
config.awsSessionToken,
|
||||
config.awsRegion,
|
||||
config.profileName,
|
||||
)
|
||||
if self.config.awsAccessKeyId and self.config.awsSecretAccessKey:
|
||||
return Session(
|
||||
aws_access_key_id=self.config.awsAccessKeyId,
|
||||
aws_secret_access_key=self.config.awsSecretAccessKey.get_secret_value(),
|
||||
region_name=self.config.awsRegion,
|
||||
sts_client = session.client("sts")
|
||||
resp = None
|
||||
if config.assumeRoleSourceIdentity:
|
||||
resp = sts_client.assume_role(
|
||||
RoleArn=config.assumeRoleArn,
|
||||
RoleSessionName=config.assumeRoleSessionName,
|
||||
SourceIdentity=config.assumeRoleSourceIdentity,
|
||||
)
|
||||
else:
|
||||
resp = sts_client.assume_role(
|
||||
RoleArn=config.assumeRoleArn,
|
||||
RoleSessionName=config.assumeRoleSessionName,
|
||||
)
|
||||
|
||||
if resp:
|
||||
credentials = resp.get("Credentials", {})
|
||||
return AWSAssumeRoleCredentialWrapper(
|
||||
accessKeyId=credentials.get("AccessKeyId"),
|
||||
secretAccessKey=credentials.get("SecretAccessKey"),
|
||||
sessionToken=credentials.get("SessionToken"),
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_session(
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_session_token,
|
||||
aws_region,
|
||||
profile=None,
|
||||
) -> Session:
|
||||
return Session(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key.get_secret_value()
|
||||
if aws_secret_access_key
|
||||
else None,
|
||||
aws_session_token=aws_session_token,
|
||||
region_name=aws_region,
|
||||
profile_name=profile,
|
||||
)
|
||||
|
||||
def create_session(self) -> Session:
|
||||
if self.config.assumeRoleArn:
|
||||
assume_creds = AWSClient.get_assume_role_config(self.config)
|
||||
if assume_creds:
|
||||
return AWSClient._get_session(
|
||||
assume_creds.accessKeyId,
|
||||
assume_creds.secretAccessKey,
|
||||
assume_creds.sessionToken,
|
||||
self.config.awsRegion,
|
||||
self.config.profileName,
|
||||
)
|
||||
|
||||
return AWSClient._get_session(
|
||||
self.config.awsAccessKeyId,
|
||||
self.config.awsSecretAccessKey,
|
||||
self.config.awsSessionToken,
|
||||
self.config.awsRegion,
|
||||
self.config.profileName,
|
||||
)
|
||||
if self.config.awsRegion:
|
||||
return Session(region_name=self.config.awsRegion)
|
||||
return Session()
|
||||
|
||||
def get_client(self, service_name: str) -> Any:
|
||||
# initialize the client depending on the AWSCredentials passed
|
||||
if self.config is not None:
|
||||
logger.info(f"Getting AWS client for service [{service_name}]")
|
||||
session = self._get_session()
|
||||
session = self.create_session()
|
||||
if self.config.endPointURL is not None:
|
||||
return session.client(
|
||||
service_name=service_name, endpoint_url=self.config.endPointURL
|
||||
@ -85,7 +150,7 @@ class AWSClient:
|
||||
return boto3.client(service_name=service_name)
|
||||
|
||||
def get_resource(self, service_name: str) -> Any:
|
||||
session = self._get_session()
|
||||
session = self.create_session()
|
||||
if self.config.endPointURL is not None:
|
||||
return session.resource(
|
||||
service_name=service_name, endpoint_url=self.config.endPointURL
|
||||
|
@ -20,7 +20,6 @@ from pydantic.utils import update_not_none
|
||||
from pydantic.validators import constr_length_validator, str_validator
|
||||
|
||||
from metadata.utils.logger import ingestion_logger
|
||||
from metadata.utils.secrets.secrets_manager_factory import SecretsManagerFactory
|
||||
|
||||
logger = ingestion_logger()
|
||||
|
||||
@ -76,6 +75,11 @@ class CustomSecretStr(SecretStr):
|
||||
return str(self)
|
||||
|
||||
def get_secret_value(self, skip_secret_manager: bool = False) -> str:
|
||||
# Importing inside function to avoid circular import error
|
||||
from metadata.utils.secrets.secrets_manager_factory import ( # pylint: disable=import-outside-toplevel,cyclic-import
|
||||
SecretsManagerFactory,
|
||||
)
|
||||
|
||||
if (
|
||||
not skip_secret_manager
|
||||
and self._secret_value.startswith("secret:")
|
||||
|
@ -17,6 +17,7 @@ from urllib.parse import quote_plus
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.inspection import inspect
|
||||
|
||||
from metadata.clients.aws_client import AWSClient
|
||||
from metadata.generated.schema.entity.services.connections.database.athenaConnection import (
|
||||
AthenaConnection,
|
||||
)
|
||||
@ -32,11 +33,24 @@ from metadata.ingestion.connections.test_connections import (
|
||||
|
||||
|
||||
def get_connection_url(connection: AthenaConnection) -> str:
|
||||
"""
|
||||
Method to get connection url
|
||||
"""
|
||||
aws_access_key_id = connection.awsConfig.awsAccessKeyId
|
||||
aws_secret_access_key = connection.awsConfig.awsSecretAccessKey
|
||||
aws_session_token = connection.awsConfig.awsSessionToken
|
||||
if connection.awsConfig.assumeRoleArn:
|
||||
assume_configs = AWSClient.get_assume_role_config(connection.awsConfig)
|
||||
if assume_configs:
|
||||
aws_access_key_id = assume_configs.accessKeyId
|
||||
aws_secret_access_key = assume_configs.secretAccessKey
|
||||
aws_session_token = assume_configs.sessionToken
|
||||
|
||||
url = f"{connection.scheme.value}://"
|
||||
if connection.awsConfig.awsAccessKeyId:
|
||||
url += connection.awsConfig.awsAccessKeyId
|
||||
if connection.awsConfig.awsSecretAccessKey:
|
||||
url += f":{connection.awsConfig.awsSecretAccessKey.get_secret_value()}"
|
||||
if aws_access_key_id:
|
||||
url += aws_access_key_id
|
||||
if aws_secret_access_key:
|
||||
url += f":{aws_secret_access_key.get_secret_value()}"
|
||||
else:
|
||||
url += ":"
|
||||
url += f"@athena.{connection.awsConfig.awsRegion}.amazonaws.com:443"
|
||||
@ -44,8 +58,8 @@ def get_connection_url(connection: AthenaConnection) -> str:
|
||||
url += f"?s3_staging_dir={quote_plus(connection.s3StagingDir)}"
|
||||
if connection.workgroup:
|
||||
url += f"&work_group={connection.workgroup}"
|
||||
if connection.awsConfig.awsSessionToken:
|
||||
url += f"&aws_session_token={quote_plus(connection.awsConfig.awsSessionToken)}"
|
||||
if aws_session_token:
|
||||
url += f"&aws_session_token={quote_plus(aws_session_token)}"
|
||||
|
||||
return url
|
||||
|
||||
|
@ -32,6 +32,27 @@
|
||||
"description": "EndPoint URL for the AWS",
|
||||
"type": "string",
|
||||
"format": "uri"
|
||||
},
|
||||
"profileName": {
|
||||
"title": "Profile Name",
|
||||
"description": "The name of a profile to use with the boto session.",
|
||||
"type": "string"
|
||||
},
|
||||
"assumeRoleArn": {
|
||||
"title": "Role Arn for Assume Role",
|
||||
"description": "The Amazon Resource Name (ARN) of the role to assume. Required Field in case of Assume Role",
|
||||
"type": "string"
|
||||
},
|
||||
"assumeRoleSessionName": {
|
||||
"title": "Role Session Name for Assume Role",
|
||||
"description": "An identifier for the assumed role session. Use the role session name to uniquely identify a session when the same role is assumed by different principals or for different reasons. Required Field in case of Assume Role",
|
||||
"type": "string",
|
||||
"default": "OpenMetadataSession"
|
||||
},
|
||||
"assumeRoleSourceIdentity": {
|
||||
"title": "Source Identity for Assume Role",
|
||||
"description": "The Amazon Resource Name (ARN) of the role to assume. Optional Field in case of Assume Role",
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
Loading…
x
Reference in New Issue
Block a user