mirror of
				https://github.com/datahub-project/datahub.git
				synced 2025-11-03 20:27:50 +00:00 
			
		
		
		
	Add support for AWS IAM authentication for RDS PostgreSQL
This commit is contained in:
		
							parent
							
								
									f986315582
								
							
						
					
					
						commit
						21b22afa2b
					
				@ -5,6 +5,9 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
 | 
			
		||||
# This import verifies that the dependencies are available.
 | 
			
		||||
import psycopg2  # noqa: F401
 | 
			
		||||
import sqlalchemy.dialects.postgresql as custom_types
 | 
			
		||||
from sqlalchemy.engine.url import make_url
 | 
			
		||||
 | 
			
		||||
import boto3
 | 
			
		||||
 | 
			
		||||
# GeoAlchemy adds support for PostGIS extensions in SQLAlchemy. In order to
 | 
			
		||||
# activate it, we must import it so that it can hook into SQLAlchemy. While
 | 
			
		||||
@ -123,6 +126,14 @@ class PostgresConfig(BasePostgresConfig):
 | 
			
		||||
            "Note: this is not used if `database` or `sqlalchemy_uri` are provided."
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    use_aws_iam_auth: bool = Field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        description="Whether to use AWS IAM authentication for PostgreSQL. When enabled, username and password are not required.",
 | 
			
		||||
    )
 | 
			
		||||
    aws_region: Optional[str] = Field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        description="AWS region where the PostgreSQL instance is located. Required when use_aws_iam_auth is True.",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@platform_name("Postgres")
 | 
			
		||||
@ -140,6 +151,7 @@ class PostgresSource(SQLAlchemySource):
 | 
			
		||||
    - Column types associated with each table
 | 
			
		||||
    - Also supports PostGIS extensions
 | 
			
		||||
    - Table, row, and column statistics via optional SQL profiling
 | 
			
		||||
    - AWS IAM authentication support for RDS PostgreSQL
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    config: PostgresConfig
 | 
			
		||||
@ -150,10 +162,30 @@ class PostgresSource(SQLAlchemySource):
 | 
			
		||||
    def get_platform(self):
 | 
			
		||||
        return "postgres"
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def create(cls, config_dict, ctx):
 | 
			
		||||
        config = PostgresConfig.parse_obj(config_dict)
 | 
			
		||||
        return cls(config, ctx)
 | 
			
		||||
    def _get_aws_iam_token(self) -> str:
 | 
			
		||||
        """Get an AWS IAM authentication token for PostgreSQL."""
 | 
			
		||||
        if not self.config.aws_region:
 | 
			
		||||
            raise ValueError("aws_region is required when use_aws_iam_auth is True")
 | 
			
		||||
 | 
			
		||||
        # Create RDS client using default credential provider chain
 | 
			
		||||
        rds_client = boto3.client("rds", region_name=self.config.aws_region)
 | 
			
		||||
 | 
			
		||||
        url_obj = make_url(self.config.get_sql_alchemy_url(
 | 
			
		||||
            database=self.config.database or self.config.initial_database
 | 
			
		||||
        )) 
 | 
			
		||||
        
 | 
			
		||||
        host = url_obj.host
 | 
			
		||||
        port = url_obj.port
 | 
			
		||||
        user = url_obj.username
 | 
			
		||||
        
 | 
			
		||||
        # Generate the authentication token
 | 
			
		||||
        token = rds_client.generate_db_auth_token(
 | 
			
		||||
            DBHostname=host,
 | 
			
		||||
            Port=port,
 | 
			
		||||
            DBUsername=user,
 | 
			
		||||
        )
 | 
			
		||||
        
 | 
			
		||||
        return token
 | 
			
		||||
 | 
			
		||||
    def get_inspectors(self) -> Iterable[Inspector]:
 | 
			
		||||
        # Note: get_sql_alchemy_url will choose `sqlalchemy_uri` over the passed in database
 | 
			
		||||
@ -161,6 +193,14 @@ class PostgresSource(SQLAlchemySource):
 | 
			
		||||
            database=self.config.database or self.config.initial_database
 | 
			
		||||
        )
 | 
			
		||||
        logger.debug(f"sql_alchemy_url={url}")
 | 
			
		||||
 | 
			
		||||
        # If AWS IAM auth is enabled, get the token and use it as the password
 | 
			
		||||
        if self.config.use_aws_iam_auth:
 | 
			
		||||
            token = self._get_aws_iam_token()
 | 
			
		||||
            url_obj = make_url(url)
 | 
			
		||||
            url_obj = url_obj.set(password=token)
 | 
			
		||||
            url = str(url_obj)
 | 
			
		||||
 | 
			
		||||
        engine = create_engine(url, **self.config.options)
 | 
			
		||||
        with engine.connect() as conn:
 | 
			
		||||
            if self.config.database or self.config.sqlalchemy_uri:
 | 
			
		||||
@ -176,6 +216,11 @@ class PostgresSource(SQLAlchemySource):
 | 
			
		||||
                    if not self.config.database_pattern.allowed(db["datname"]):
 | 
			
		||||
                        continue
 | 
			
		||||
                    url = self.config.get_sql_alchemy_url(database=db["datname"])
 | 
			
		||||
                    if self.config.use_aws_iam_auth:
 | 
			
		||||
                        token = self._get_aws_iam_token()
 | 
			
		||||
                        url_obj = make_url(url)
 | 
			
		||||
                        url_obj = url_obj.set(password=token)
 | 
			
		||||
                        url = str(url_obj)
 | 
			
		||||
                    with create_engine(url, **self.config.options).connect() as conn:
 | 
			
		||||
                        inspector = inspect(conn)
 | 
			
		||||
                        yield inspector
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user