feat(ingest): aws - support extra args to role config (#6031)

This commit is contained in:
Harshal Sheth 2022-09-22 17:43:32 -07:00 committed by GitHub
parent 27f28019de
commit 3a6d1d2bf1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 92 additions and 38 deletions

View File

@ -333,7 +333,7 @@ mypy_stubs = {
"types-cachetools",
# versions 0.1.13 and 0.1.14 seem to have issues
"types-click==0.1.12",
"boto3-stubs[s3,glue,sagemaker]",
"boto3-stubs[s3,glue,sagemaker,sts]",
"types-tabulate",
# avrogen package requires this
"types-pytz",

View File

@ -19,6 +19,16 @@ class ConfigModel(BaseModel):
) # needed to allow cached_property to work. See https://github.com/samuelcolvin/pydantic/issues/1241 for more info.
class PermissiveConfigModel(ConfigModel):
# A permissive config model that allows extra fields.
# This is useful for cases where we want to strongly type certain fields,
# but still allow the user to pass in arbitrary fields that we don't care about.
# It is usually used for argument bags that are passed through to third-party libraries.
class Config:
extra = Extra.allow
class TransformerSemantics(ConfigEnum):
"""Describes semantics for aspect changes"""

View File

@ -1,4 +1,3 @@
from functools import reduce
from typing import TYPE_CHECKING, Dict, List, Optional, Union
import boto3
@ -7,21 +6,39 @@ from botocore.config import Config
from botocore.utils import fix_s3_host
from pydantic.fields import Field
from datahub.configuration.common import AllowDenyPattern, ConfigModel
from datahub.configuration.common import (
AllowDenyPattern,
ConfigModel,
PermissiveConfigModel,
)
from datahub.configuration.source_common import EnvBasedSourceConfigBase
if TYPE_CHECKING:
from mypy_boto3_glue import GlueClient
from mypy_boto3_s3 import S3Client, S3ServiceResource
from mypy_boto3_sagemaker import SageMakerClient
from mypy_boto3_sts import STSClient
class AwsAssumeRoleConfig(PermissiveConfigModel):
# Using the PermissiveConfigModel to allow the user to pass additional arguments.
RoleArn: str = Field(
description="ARN of the role to assume.",
)
ExternalId: Optional[str] = Field(
None,
description="External ID to use when assuming the role.",
)
def assume_role(
role_arn: str, aws_region: str, credentials: Optional[dict] = None
role: AwsAssumeRoleConfig,
aws_region: str,
credentials: Optional[dict] = None,
) -> dict:
credentials = credentials or {}
sts_client = boto3.client(
sts_client: "STSClient" = boto3.client(
"sts",
region_name=aws_region,
aws_access_key_id=credentials.get("AccessKeyId"),
@ -29,10 +46,20 @@ def assume_role(
aws_session_token=credentials.get("SessionToken"),
)
assume_role_args: dict = {
**dict(
RoleSessionName="DatahubIngestionSource",
),
**{k: v for k, v in role.dict().items() if v is not None},
}
assumed_role_object = sts_client.assume_role(
RoleArn=role_arn, RoleSessionName="DatahubIngestionSource"
**assume_role_args,
)
return assumed_role_object["Credentials"]
return dict(assumed_role_object["Credentials"])
AUTODETECT_CREDENTIALS_DOC_LINK = "Can be auto-detected, see https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html for details."
class AwsConnectionConfig(ConfigModel):
@ -47,25 +74,27 @@ class AwsConnectionConfig(ConfigModel):
aws_access_key_id: Optional[str] = Field(
default=None,
description="Autodetected. See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html",
description=f"AWS access key ID. {AUTODETECT_CREDENTIALS_DOC_LINK}",
)
aws_secret_access_key: Optional[str] = Field(
default=None,
description="Autodetected. See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html",
description=f"AWS secret access key. {AUTODETECT_CREDENTIALS_DOC_LINK}",
)
aws_session_token: Optional[str] = Field(
default=None,
description="Autodetected. See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html",
description=f"AWS session token. {AUTODETECT_CREDENTIALS_DOC_LINK}",
)
aws_role: Optional[Union[str, List[str]]] = Field(
aws_role: Optional[Union[str, List[Union[str, AwsAssumeRoleConfig]]]] = Field(
default=None,
description="Autodetected. See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html",
description="AWS roles to assume. If using the string format, the role ARN can be specified directly. "
"If using the object format, the role can be specified in the RoleArn field and additional available arguments are documented at https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts.html?highlight=assume_role#STS.Client.assume_role",
)
aws_profile: Optional[str] = Field(
default=None,
description="Named AWS profile to use, if not set the default will be used",
description="Named AWS profile to use. Only used if access key / secret are unset. If not set the default will be used",
)
aws_region: str = Field(description="AWS region code.")
aws_endpoint_url: Optional[str] = Field(
default=None,
description="Autodetected. See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html",
@ -75,43 +104,58 @@ class AwsConnectionConfig(ConfigModel):
description="Autodetected. See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html",
)
def _normalized_aws_roles(self) -> List[AwsAssumeRoleConfig]:
if not self.aws_role:
return []
elif isinstance(self.aws_role, str):
return [AwsAssumeRoleConfig(RoleArn=self.aws_role)]
else:
assert isinstance(self.aws_role, list)
return [
AwsAssumeRoleConfig(RoleArn=role) if isinstance(role, str) else role
for role in self.aws_role
]
def get_session(self) -> Session:
if (
self.aws_access_key_id
and self.aws_secret_access_key
and self.aws_session_token
):
return Session(
if self.aws_access_key_id and self.aws_secret_access_key:
session = Session(
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
aws_session_token=self.aws_session_token,
region_name=self.aws_region,
)
elif self.aws_access_key_id and self.aws_secret_access_key:
return Session(
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
region_name=self.aws_region,
elif self.aws_profile:
session = Session(
region_name=self.aws_region, profile_name=self.aws_profile
)
elif self.aws_role:
if isinstance(self.aws_role, str):
credentials = assume_role(self.aws_role, self.aws_region)
else:
credentials = reduce(
lambda new_credentials, role_arn: assume_role(
role_arn, self.aws_region, new_credentials
),
self.aws_role,
{},
else:
# Use boto3's credential autodetection.
session = Session(region_name=self.aws_region)
if self._normalized_aws_roles():
# Use existing session credentials to start the chain of role assumption.
current_credentials = session.get_credentials()
credentials = {
"AccessKeyId": current_credentials.access_key,
"SecretAccessKey": current_credentials.secret_key,
"SessionToken": current_credentials.token,
}
for role in self._normalized_aws_roles():
credentials = assume_role(
role,
self.aws_region,
credentials=credentials,
)
return Session(
session = Session(
aws_access_key_id=credentials["AccessKeyId"],
aws_secret_access_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
region_name=self.aws_region,
)
else:
return Session(region_name=self.aws_region, profile_name=self.aws_profile)
return session
def get_credentials(self) -> Dict[str, str]:
credentials = self.get_session().get_credentials()