Add support for AWS IAM authentication for RDS PostgreSQL

This commit is contained in:
Brock Griffey 2025-04-28 21:32:50 -04:00
parent f986315582
commit 21b22afa2b

View File

@ -5,6 +5,9 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
# This import verifies that the dependencies are available.
import psycopg2 # noqa: F401
import sqlalchemy.dialects.postgresql as custom_types
from sqlalchemy.engine.url import make_url
import boto3
# GeoAlchemy adds support for PostGIS extensions in SQLAlchemy. In order to
# activate it, we must import it so that it can hook into SQLAlchemy. While
@ -123,6 +126,14 @@ class PostgresConfig(BasePostgresConfig):
"Note: this is not used if `database` or `sqlalchemy_uri` are provided."
),
)
use_aws_iam_auth: bool = Field(
default=False,
description="Whether to use AWS IAM authentication for PostgreSQL. When enabled, username and password are not required.",
)
aws_region: Optional[str] = Field(
default=None,
description="AWS region where the PostgreSQL instance is located. Required when use_aws_iam_auth is True.",
)
@platform_name("Postgres")
@ -140,6 +151,7 @@ class PostgresSource(SQLAlchemySource):
- Column types associated with each table
- Also supports PostGIS extensions
- Table, row, and column statistics via optional SQL profiling
- AWS IAM authentication support for RDS PostgreSQL
"""
config: PostgresConfig
@ -150,10 +162,30 @@ class PostgresSource(SQLAlchemySource):
def get_platform(self):
return "postgres"
@classmethod
def create(cls, config_dict, ctx):
config = PostgresConfig.parse_obj(config_dict)
return cls(config, ctx)
def _get_aws_iam_token(self) -> str:
"""Get an AWS IAM authentication token for PostgreSQL."""
if not self.config.aws_region:
raise ValueError("aws_region is required when use_aws_iam_auth is True")
# Create RDS client using default credential provider chain
rds_client = boto3.client("rds", region_name=self.config.aws_region)
url_obj = make_url(self.config.get_sql_alchemy_url(
database=self.config.database or self.config.initial_database
))
host = url_obj.host
port = url_obj.port
user = url_obj.username
# Generate the authentication token
token = rds_client.generate_db_auth_token(
DBHostname=host,
Port=port,
DBUsername=user,
)
return token
def get_inspectors(self) -> Iterable[Inspector]:
# Note: get_sql_alchemy_url will choose `sqlalchemy_uri` over the passed in database
@ -161,6 +193,14 @@ class PostgresSource(SQLAlchemySource):
database=self.config.database or self.config.initial_database
)
logger.debug(f"sql_alchemy_url={url}")
# If AWS IAM auth is enabled, get the token and use it as the password
if self.config.use_aws_iam_auth:
token = self._get_aws_iam_token()
url_obj = make_url(url)
url_obj = url_obj.set(password=token)
url = str(url_obj)
engine = create_engine(url, **self.config.options)
with engine.connect() as conn:
if self.config.database or self.config.sqlalchemy_uri:
@ -176,6 +216,11 @@ class PostgresSource(SQLAlchemySource):
if not self.config.database_pattern.allowed(db["datname"]):
continue
url = self.config.get_sql_alchemy_url(database=db["datname"])
if self.config.use_aws_iam_auth:
token = self._get_aws_iam_token()
url_obj = make_url(url)
url_obj = url_obj.set(password=token)
url = str(url_obj)
with create_engine(url, **self.config.options).connect() as conn:
inspector = inspect(conn)
yield inspector