mirror of
				https://github.com/datahub-project/datahub.git
				synced 2025-11-04 12:51:23 +00:00 
			
		
		
		
	fix(ingest/postgres): Allow specification of initial engine database; set default database to postgres (#7915)
Co-authored-by: Mayuri Nehate <33225191+mayurinehate@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									94e7e51175
								
							
						
					
					
						commit
						44406f7adf
					
				@ -8,7 +8,7 @@ from datahub.configuration import ConfigModel
 | 
				
			|||||||
from datahub.configuration.pydantic_field_deprecation import pydantic_field_deprecated
 | 
					from datahub.configuration.pydantic_field_deprecation import pydantic_field_deprecated
 | 
				
			||||||
from datahub.configuration.source_common import DatasetLineageProviderConfigBase
 | 
					from datahub.configuration.source_common import DatasetLineageProviderConfigBase
 | 
				
			||||||
from datahub.ingestion.source.data_lake_common.path_spec import PathSpec
 | 
					from datahub.ingestion.source.data_lake_common.path_spec import PathSpec
 | 
				
			||||||
from datahub.ingestion.source.sql.postgres import PostgresConfig
 | 
					from datahub.ingestion.source.sql.postgres import BasePostgresConfig
 | 
				
			||||||
from datahub.ingestion.source.state.stateful_ingestion_base import (
 | 
					from datahub.ingestion.source.state.stateful_ingestion_base import (
 | 
				
			||||||
    StatefulLineageConfigMixin,
 | 
					    StatefulLineageConfigMixin,
 | 
				
			||||||
    StatefulProfilingConfigMixin,
 | 
					    StatefulProfilingConfigMixin,
 | 
				
			||||||
@ -60,7 +60,7 @@ class RedshiftUsageConfig(BaseUsageConfig, StatefulUsageConfigMixin):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class RedshiftConfig(
 | 
					class RedshiftConfig(
 | 
				
			||||||
    PostgresConfig,
 | 
					    BasePostgresConfig,
 | 
				
			||||||
    DatasetLineageProviderConfigBase,
 | 
					    DatasetLineageProviderConfigBase,
 | 
				
			||||||
    S3DatasetLineageProviderConfigBase,
 | 
					    S3DatasetLineageProviderConfigBase,
 | 
				
			||||||
    RedshiftUsageConfig,
 | 
					    RedshiftUsageConfig,
 | 
				
			||||||
 | 
				
			|||||||
@ -97,22 +97,34 @@ class ViewLineageEntry(BaseModel):
 | 
				
			|||||||
    dependent_schema: str
 | 
					    dependent_schema: str
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class PostgresConfig(BasicSQLAlchemyConfig):
 | 
					class BasePostgresConfig(BasicSQLAlchemyConfig):
 | 
				
			||||||
    # defaults
 | 
					 | 
				
			||||||
    scheme = Field(default="postgresql+psycopg2", description="database scheme")
 | 
					    scheme = Field(default="postgresql+psycopg2", description="database scheme")
 | 
				
			||||||
    schema_pattern = Field(default=AllowDenyPattern(deny=["information_schema"]))
 | 
					    schema_pattern = Field(default=AllowDenyPattern(deny=["information_schema"]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PostgresConfig(BasePostgresConfig):
 | 
				
			||||||
    include_view_lineage = Field(
 | 
					    include_view_lineage = Field(
 | 
				
			||||||
        default=False, description="Include table lineage for views"
 | 
					        default=False, description="Include table lineage for views"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    database_pattern: AllowDenyPattern = Field(
 | 
					    database_pattern: AllowDenyPattern = Field(
 | 
				
			||||||
        default=AllowDenyPattern.allow_all(),
 | 
					        default=AllowDenyPattern.allow_all(),
 | 
				
			||||||
        description="Regex patterns for databases to filter in ingestion.",
 | 
					        description=(
 | 
				
			||||||
 | 
					            "Regex patterns for databases to filter in ingestion. "
 | 
				
			||||||
 | 
					            "Note: this is not used if `database` or `sqlalchemy_uri` are provided."
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    database: Optional[str] = Field(
 | 
					    database: Optional[str] = Field(
 | 
				
			||||||
        default=None,
 | 
					        default=None,
 | 
				
			||||||
        description="database (catalog). If set to Null, all databases will be considered for ingestion.",
 | 
					        description="database (catalog). If set to Null, all databases will be considered for ingestion.",
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					    initial_database: Optional[str] = Field(
 | 
				
			||||||
 | 
					        default="postgres",
 | 
				
			||||||
 | 
					        description=(
 | 
				
			||||||
 | 
					            "Initial database used to query for the list of databases, when ingesting multiple databases. "
 | 
				
			||||||
 | 
					            "Note: this is not used if `database` or `sqlalchemy_uri` are provided."
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@platform_name("Postgres")
 | 
					@platform_name("Postgres")
 | 
				
			||||||
@ -144,13 +156,14 @@ class PostgresSource(SQLAlchemySource):
 | 
				
			|||||||
        return cls(config, ctx)
 | 
					        return cls(config, ctx)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_inspectors(self) -> Iterable[Inspector]:
 | 
					    def get_inspectors(self) -> Iterable[Inspector]:
 | 
				
			||||||
        # This method can be overridden in the case that you want to dynamically
 | 
					        # Note: get_sql_alchemy_url will choose `sqlalchemy_uri` over the passed in database
 | 
				
			||||||
        # run on multiple databases.
 | 
					        url = self.config.get_sql_alchemy_url(
 | 
				
			||||||
        url = self.config.get_sql_alchemy_url()
 | 
					            database=self.config.database or self.config.initial_database
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        logger.debug(f"sql_alchemy_url={url}")
 | 
					        logger.debug(f"sql_alchemy_url={url}")
 | 
				
			||||||
        engine = create_engine(url, **self.config.options)
 | 
					        engine = create_engine(url, **self.config.options)
 | 
				
			||||||
        with engine.connect() as conn:
 | 
					        with engine.connect() as conn:
 | 
				
			||||||
            if self.config.database:
 | 
					            if self.config.database or self.config.sqlalchemy_uri:
 | 
				
			||||||
                inspector = inspect(conn)
 | 
					                inspector = inspect(conn)
 | 
				
			||||||
                yield inspector
 | 
					                yield inspector
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
@ -160,13 +173,12 @@ class PostgresSource(SQLAlchemySource):
 | 
				
			|||||||
                    "SELECT datname from pg_database where datname not in ('template0', 'template1')"
 | 
					                    "SELECT datname from pg_database where datname not in ('template0', 'template1')"
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
                for db in databases:
 | 
					                for db in databases:
 | 
				
			||||||
                    if self.config.database_pattern.allowed(db["datname"]):
 | 
					                    if not self.config.database_pattern.allowed(db["datname"]):
 | 
				
			||||||
                        url = self.config.get_sql_alchemy_url(database=db["datname"])
 | 
					                        continue
 | 
				
			||||||
                        with create_engine(
 | 
					                    url = self.config.get_sql_alchemy_url(database=db["datname"])
 | 
				
			||||||
                            url, **self.config.options
 | 
					                    with create_engine(url, **self.config.options).connect() as conn:
 | 
				
			||||||
                        ).connect() as conn:
 | 
					                        inspector = inspect(conn)
 | 
				
			||||||
                            inspector = inspect(conn)
 | 
					                        yield inspector
 | 
				
			||||||
                            yield inspector
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_workunits(self) -> Iterable[Union[MetadataWorkUnit, SqlWorkUnit]]:
 | 
					    def get_workunits(self) -> Iterable[Union[MetadataWorkUnit, SqlWorkUnit]]:
 | 
				
			||||||
        yield from super().get_workunits()
 | 
					        yield from super().get_workunits()
 | 
				
			||||||
 | 
				
			|||||||
@ -34,7 +34,7 @@ from datahub.ingestion.api.decorators import (
 | 
				
			|||||||
from datahub.ingestion.api.workunit import MetadataWorkUnit
 | 
					from datahub.ingestion.api.workunit import MetadataWorkUnit
 | 
				
			||||||
from datahub.ingestion.source.aws.s3_util import strip_s3_prefix
 | 
					from datahub.ingestion.source.aws.s3_util import strip_s3_prefix
 | 
				
			||||||
from datahub.ingestion.source.data_lake_common.path_spec import PathSpec
 | 
					from datahub.ingestion.source.data_lake_common.path_spec import PathSpec
 | 
				
			||||||
from datahub.ingestion.source.sql.postgres import PostgresConfig
 | 
					from datahub.ingestion.source.sql.postgres import BasePostgresConfig
 | 
				
			||||||
from datahub.ingestion.source.sql.sql_common import (
 | 
					from datahub.ingestion.source.sql.sql_common import (
 | 
				
			||||||
    SQLAlchemySource,
 | 
					    SQLAlchemySource,
 | 
				
			||||||
    SQLSourceReport,
 | 
					    SQLSourceReport,
 | 
				
			||||||
@ -124,7 +124,7 @@ class DatasetS3LineageProviderConfigBase(ConfigModel):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class RedshiftConfig(
 | 
					class RedshiftConfig(
 | 
				
			||||||
    PostgresConfig,
 | 
					    BasePostgresConfig,
 | 
				
			||||||
    BaseTimeWindowConfig,
 | 
					    BaseTimeWindowConfig,
 | 
				
			||||||
    DatasetLineageProviderConfigBase,
 | 
					    DatasetLineageProviderConfigBase,
 | 
				
			||||||
    DatasetS3LineageProviderConfigBase,
 | 
					    DatasetS3LineageProviderConfigBase,
 | 
				
			||||||
 | 
				
			|||||||
@ -125,7 +125,7 @@ class BasicSQLAlchemyConfig(SQLAlchemyConfig):
 | 
				
			|||||||
            self.username,
 | 
					            self.username,
 | 
				
			||||||
            self.password.get_secret_value() if self.password is not None else None,
 | 
					            self.password.get_secret_value() if self.password is not None else None,
 | 
				
			||||||
            self.host_port,
 | 
					            self.host_port,
 | 
				
			||||||
            self.database or database,
 | 
					            database or self.database,
 | 
				
			||||||
            uri_opts=uri_opts,
 | 
					            uri_opts=uri_opts,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -4,6 +4,7 @@ source:
 | 
				
			|||||||
  type: postgres
 | 
					  type: postgres
 | 
				
			||||||
  config:
 | 
					  config:
 | 
				
			||||||
    host_port: 'localhost:5432'
 | 
					    host_port: 'localhost:5432'
 | 
				
			||||||
 | 
					    initial_database: 'postgrestest'
 | 
				
			||||||
    username: 'postgres'
 | 
					    username: 'postgres'
 | 
				
			||||||
    password: 'example'
 | 
					    password: 'example'
 | 
				
			||||||
    include_view_lineage: True
 | 
					    include_view_lineage: True
 | 
				
			||||||
 | 
				
			|||||||
@ -1,4 +1,5 @@
 | 
				
			|||||||
from unittest import mock
 | 
					from unittest import mock
 | 
				
			||||||
 | 
					from unittest.mock import patch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from datahub.ingestion.api.common import PipelineContext
 | 
					from datahub.ingestion.api.common import PipelineContext
 | 
				
			||||||
from datahub.ingestion.source.sql.postgres import PostgresConfig, PostgresSource
 | 
					from datahub.ingestion.source.sql.postgres import PostgresConfig, PostgresSource
 | 
				
			||||||
@ -8,6 +9,62 @@ def _base_config():
 | 
				
			|||||||
    return {"username": "user", "password": "password", "host_port": "host:1521"}
 | 
					    return {"username": "user", "password": "password", "host_port": "host:1521"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@patch("datahub.ingestion.source.sql.postgres.create_engine")
 | 
				
			||||||
 | 
					def test_initial_database(create_engine_mock):
 | 
				
			||||||
 | 
					    config = PostgresConfig.parse_obj(_base_config())
 | 
				
			||||||
 | 
					    assert config.initial_database == "postgres"
 | 
				
			||||||
 | 
					    source = PostgresSource(config, PipelineContext(run_id="test"))
 | 
				
			||||||
 | 
					    _ = list(source.get_inspectors())
 | 
				
			||||||
 | 
					    assert create_engine_mock.call_count == 1
 | 
				
			||||||
 | 
					    assert create_engine_mock.call_args[0][0].endswith("postgres")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@patch("datahub.ingestion.source.sql.postgres.create_engine")
 | 
				
			||||||
 | 
					def test_get_inspectors_multiple_databases(create_engine_mock):
 | 
				
			||||||
 | 
					    execute_mock = (
 | 
				
			||||||
 | 
					        create_engine_mock.return_value.connect.return_value.__enter__.return_value.execute
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    execute_mock.return_value = [{"datname": "db1"}, {"datname": "db2"}]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    config = PostgresConfig.parse_obj({**_base_config(), "initial_database": "db0"})
 | 
				
			||||||
 | 
					    source = PostgresSource(config, PipelineContext(run_id="test"))
 | 
				
			||||||
 | 
					    _ = list(source.get_inspectors())
 | 
				
			||||||
 | 
					    assert create_engine_mock.call_count == 3
 | 
				
			||||||
 | 
					    assert create_engine_mock.call_args_list[0][0][0].endswith("db0")
 | 
				
			||||||
 | 
					    assert create_engine_mock.call_args_list[1][0][0].endswith("db1")
 | 
				
			||||||
 | 
					    assert create_engine_mock.call_args_list[2][0][0].endswith("db2")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@patch("datahub.ingestion.source.sql.postgres.create_engine")
 | 
				
			||||||
 | 
					def tests_get_inspectors_with_database_provided(create_engine_mock):
 | 
				
			||||||
 | 
					    execute_mock = (
 | 
				
			||||||
 | 
					        create_engine_mock.return_value.connect.return_value.__enter__.return_value.execute
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    execute_mock.return_value = [{"datname": "db1"}, {"datname": "db2"}]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    config = PostgresConfig.parse_obj({**_base_config(), "database": "custom_db"})
 | 
				
			||||||
 | 
					    source = PostgresSource(config, PipelineContext(run_id="test"))
 | 
				
			||||||
 | 
					    _ = list(source.get_inspectors())
 | 
				
			||||||
 | 
					    assert create_engine_mock.call_count == 1
 | 
				
			||||||
 | 
					    assert create_engine_mock.call_args_list[0][0][0].endswith("custom_db")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@patch("datahub.ingestion.source.sql.postgres.create_engine")
 | 
				
			||||||
 | 
					def tests_get_inspectors_with_sqlalchemy_uri_provided(create_engine_mock):
 | 
				
			||||||
 | 
					    execute_mock = (
 | 
				
			||||||
 | 
					        create_engine_mock.return_value.connect.return_value.__enter__.return_value.execute
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    execute_mock.return_value = [{"datname": "db1"}, {"datname": "db2"}]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    config = PostgresConfig.parse_obj(
 | 
				
			||||||
 | 
					        {**_base_config(), "sqlalchemy_uri": "custom_url"}
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    source = PostgresSource(config, PipelineContext(run_id="test"))
 | 
				
			||||||
 | 
					    _ = list(source.get_inspectors())
 | 
				
			||||||
 | 
					    assert create_engine_mock.call_count == 1
 | 
				
			||||||
 | 
					    assert create_engine_mock.call_args_list[0][0][0] == "custom_url"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_database_alias_takes_precendence():
 | 
					def test_database_alias_takes_precendence():
 | 
				
			||||||
    config = PostgresConfig.parse_obj(
 | 
					    config = PostgresConfig.parse_obj(
 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user