diff --git a/metadata-ingestion/setup.py b/metadata-ingestion/setup.py index 149bc21661..166c15249f 100644 --- a/metadata-ingestion/setup.py +++ b/metadata-ingestion/setup.py @@ -333,7 +333,7 @@ mypy_stubs = { "types-cachetools", # versions 0.1.13 and 0.1.14 seem to have issues "types-click==0.1.12", - "boto3-stubs[s3,glue,sagemaker]", + "boto3-stubs[s3,glue,sagemaker,sts]", "types-tabulate", # avrogen package requires this "types-pytz", diff --git a/metadata-ingestion/src/datahub/configuration/common.py b/metadata-ingestion/src/datahub/configuration/common.py index bb070fc0d3..6ce2a8e24f 100644 --- a/metadata-ingestion/src/datahub/configuration/common.py +++ b/metadata-ingestion/src/datahub/configuration/common.py @@ -19,6 +19,16 @@ class ConfigModel(BaseModel): ) # needed to allow cached_property to work. See https://github.com/samuelcolvin/pydantic/issues/1241 for more info. +class PermissiveConfigModel(ConfigModel): + # A permissive config model that allows extra fields. + # This is useful for cases where we want to strongly type certain fields, + # but still allow the user to pass in arbitrary fields that we don't care about. + # It is usually used for argument bags that are passed through to third-party libraries. + + class Config: + extra = Extra.allow + + class TransformerSemantics(ConfigEnum): """Describes semantics for aspect changes""" diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py b/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py index d4c123ede1..5d4659633c 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py @@ -1,4 +1,3 @@ -from functools import reduce from typing import TYPE_CHECKING, Dict, List, Optional, Union import boto3 @@ -7,21 +6,39 @@ from botocore.config import Config from botocore.utils import fix_s3_host from pydantic.fields import Field -from datahub.configuration.common import AllowDenyPattern, ConfigModel +from datahub.configuration.common import ( + AllowDenyPattern, + ConfigModel, + PermissiveConfigModel, +) from datahub.configuration.source_common import EnvBasedSourceConfigBase if TYPE_CHECKING: - from mypy_boto3_glue import GlueClient from mypy_boto3_s3 import S3Client, S3ServiceResource from mypy_boto3_sagemaker import SageMakerClient + from mypy_boto3_sts import STSClient + + +class AwsAssumeRoleConfig(PermissiveConfigModel): + # Using the PermissiveConfigModel to allow the user to pass additional arguments. + + RoleArn: str = Field( + description="ARN of the role to assume.", + ) + ExternalId: Optional[str] = Field( + None, + description="External ID to use when assuming the role.", + ) def assume_role( - role_arn: str, aws_region: str, credentials: Optional[dict] = None + role: AwsAssumeRoleConfig, + aws_region: str, + credentials: Optional[dict] = None, ) -> dict: credentials = credentials or {} - sts_client = boto3.client( + sts_client: "STSClient" = boto3.client( "sts", region_name=aws_region, aws_access_key_id=credentials.get("AccessKeyId"), @@ -29,10 +46,20 @@ def assume_role( aws_session_token=credentials.get("SessionToken"), ) + assume_role_args: dict = { + **dict( + RoleSessionName="DatahubIngestionSource", + ), + **{k: v for k, v in role.dict().items() if v is not None}, + } + assumed_role_object = sts_client.assume_role( - RoleArn=role_arn, RoleSessionName="DatahubIngestionSource" + **assume_role_args, ) - return assumed_role_object["Credentials"] + return dict(assumed_role_object["Credentials"]) + + +AUTODETECT_CREDENTIALS_DOC_LINK = "Can be auto-detected, see https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html for details." class AwsConnectionConfig(ConfigModel): @@ -47,25 +74,27 @@ class AwsConnectionConfig(ConfigModel): aws_access_key_id: Optional[str] = Field( default=None, - description="Autodetected. See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html", + description=f"AWS access key ID. {AUTODETECT_CREDENTIALS_DOC_LINK}", ) aws_secret_access_key: Optional[str] = Field( default=None, - description="Autodetected. See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html", + description=f"AWS secret access key. {AUTODETECT_CREDENTIALS_DOC_LINK}", ) aws_session_token: Optional[str] = Field( default=None, - description="Autodetected. See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html", + description=f"AWS session token. {AUTODETECT_CREDENTIALS_DOC_LINK}", ) - aws_role: Optional[Union[str, List[str]]] = Field( + aws_role: Optional[Union[str, List[Union[str, AwsAssumeRoleConfig]]]] = Field( default=None, - description="Autodetected. See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html", + description="AWS roles to assume. If using the string format, the role ARN can be specified directly. " + "If using the object format, the role can be specified in the RoleArn field and additional available arguments are documented at https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts.html?highlight=assume_role#STS.Client.assume_role", ) aws_profile: Optional[str] = Field( default=None, - description="Named AWS profile to use, if not set the default will be used", + description="Named AWS profile to use. Only used if access key / secret are unset. If not set the default will be used", ) aws_region: str = Field(description="AWS region code.") + aws_endpoint_url: Optional[str] = Field( default=None, description="Autodetected. See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html", @@ -75,43 +104,58 @@ class AwsConnectionConfig(ConfigModel): description="Autodetected. See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html", ) + def _normalized_aws_roles(self) -> List[AwsAssumeRoleConfig]: + if not self.aws_role: + return [] + elif isinstance(self.aws_role, str): + return [AwsAssumeRoleConfig(RoleArn=self.aws_role)] + else: + assert isinstance(self.aws_role, list) + return [ + AwsAssumeRoleConfig(RoleArn=role) if isinstance(role, str) else role + for role in self.aws_role + ] + def get_session(self) -> Session: - if ( - self.aws_access_key_id - and self.aws_secret_access_key - and self.aws_session_token - ): - return Session( + if self.aws_access_key_id and self.aws_secret_access_key: + session = Session( aws_access_key_id=self.aws_access_key_id, aws_secret_access_key=self.aws_secret_access_key, aws_session_token=self.aws_session_token, region_name=self.aws_region, ) - elif self.aws_access_key_id and self.aws_secret_access_key: - return Session( - aws_access_key_id=self.aws_access_key_id, - aws_secret_access_key=self.aws_secret_access_key, - region_name=self.aws_region, + elif self.aws_profile: + session = Session( + region_name=self.aws_region, profile_name=self.aws_profile ) - elif self.aws_role: - if isinstance(self.aws_role, str): - credentials = assume_role(self.aws_role, self.aws_region) - else: - credentials = reduce( - lambda new_credentials, role_arn: assume_role( - role_arn, self.aws_region, new_credentials - ), - self.aws_role, - {}, + else: + # Use boto3's credential autodetection. + session = Session(region_name=self.aws_region) + + if self._normalized_aws_roles(): + # Use existing session credentials to start the chain of role assumption. + current_credentials = session.get_credentials() + credentials = { + "AccessKeyId": current_credentials.access_key, + "SecretAccessKey": current_credentials.secret_key, + "SessionToken": current_credentials.token, + } + + for role in self._normalized_aws_roles(): + credentials = assume_role( + role, + self.aws_region, + credentials=credentials, ) - return Session( + + session = Session( aws_access_key_id=credentials["AccessKeyId"], aws_secret_access_key=credentials["SecretAccessKey"], aws_session_token=credentials["SessionToken"], region_name=self.aws_region, ) - else: - return Session(region_name=self.aws_region, profile_name=self.aws_profile) + + return session def get_credentials(self) -> Dict[str, str]: credentials = self.get_session().get_credentials()