diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/config.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/config.py index 98eba17a3d..f9b9fb73de 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/config.py @@ -8,7 +8,7 @@ from datahub.configuration import ConfigModel from datahub.configuration.pydantic_field_deprecation import pydantic_field_deprecated from datahub.configuration.source_common import DatasetLineageProviderConfigBase from datahub.ingestion.source.data_lake_common.path_spec import PathSpec -from datahub.ingestion.source.sql.postgres import PostgresConfig +from datahub.ingestion.source.sql.postgres import BasePostgresConfig from datahub.ingestion.source.state.stateful_ingestion_base import ( StatefulLineageConfigMixin, StatefulProfilingConfigMixin, @@ -60,7 +60,7 @@ class RedshiftUsageConfig(BaseUsageConfig, StatefulUsageConfigMixin): class RedshiftConfig( - PostgresConfig, + BasePostgresConfig, DatasetLineageProviderConfigBase, S3DatasetLineageProviderConfigBase, RedshiftUsageConfig, diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/postgres.py b/metadata-ingestion/src/datahub/ingestion/source/sql/postgres.py index a075d93016..c65adc838c 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/postgres.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/postgres.py @@ -97,22 +97,34 @@ class ViewLineageEntry(BaseModel): dependent_schema: str -class PostgresConfig(BasicSQLAlchemyConfig): - # defaults +class BasePostgresConfig(BasicSQLAlchemyConfig): scheme = Field(default="postgresql+psycopg2", description="database scheme") schema_pattern = Field(default=AllowDenyPattern(deny=["information_schema"])) + + +class PostgresConfig(BasePostgresConfig): include_view_lineage = Field( default=False, description="Include table lineage for views" ) database_pattern: AllowDenyPattern = Field( default=AllowDenyPattern.allow_all(), - description="Regex patterns for databases to filter in ingestion.", + description=( + "Regex patterns for databases to filter in ingestion. " + "Note: this is not used if `database` or `sqlalchemy_uri` are provided." + ), ) database: Optional[str] = Field( default=None, description="database (catalog). If set to Null, all databases will be considered for ingestion.", ) + initial_database: Optional[str] = Field( + default="postgres", + description=( + "Initial database used to query for the list of databases, when ingesting multiple databases. " + "Note: this is not used if `database` or `sqlalchemy_uri` are provided." + ), + ) @platform_name("Postgres") @@ -144,13 +156,14 @@ class PostgresSource(SQLAlchemySource): return cls(config, ctx) def get_inspectors(self) -> Iterable[Inspector]: - # This method can be overridden in the case that you want to dynamically - # run on multiple databases. - url = self.config.get_sql_alchemy_url() + # Note: get_sql_alchemy_url will choose `sqlalchemy_uri` over the passed in database + url = self.config.get_sql_alchemy_url( + database=self.config.database or self.config.initial_database + ) logger.debug(f"sql_alchemy_url={url}") engine = create_engine(url, **self.config.options) with engine.connect() as conn: - if self.config.database: + if self.config.database or self.config.sqlalchemy_uri: inspector = inspect(conn) yield inspector else: @@ -160,13 +173,12 @@ class PostgresSource(SQLAlchemySource): "SELECT datname from pg_database where datname not in ('template0', 'template1')" ) for db in databases: - if self.config.database_pattern.allowed(db["datname"]): - url = self.config.get_sql_alchemy_url(database=db["datname"]) - with create_engine( - url, **self.config.options - ).connect() as conn: - inspector = inspect(conn) - yield inspector + if not self.config.database_pattern.allowed(db["datname"]): + continue + url = self.config.get_sql_alchemy_url(database=db["datname"]) + with create_engine(url, **self.config.options).connect() as conn: + inspector = inspect(conn) + yield inspector def get_workunits(self) -> Iterable[Union[MetadataWorkUnit, SqlWorkUnit]]: yield from super().get_workunits() diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/redshift.py b/metadata-ingestion/src/datahub/ingestion/source/sql/redshift.py index 9e3fe26795..21c4bb68fb 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/redshift.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/redshift.py @@ -34,7 +34,7 @@ from datahub.ingestion.api.decorators import ( from datahub.ingestion.api.workunit import MetadataWorkUnit from datahub.ingestion.source.aws.s3_util import strip_s3_prefix from datahub.ingestion.source.data_lake_common.path_spec import PathSpec -from datahub.ingestion.source.sql.postgres import PostgresConfig +from datahub.ingestion.source.sql.postgres import BasePostgresConfig from datahub.ingestion.source.sql.sql_common import ( SQLAlchemySource, SQLSourceReport, @@ -124,7 +124,7 @@ class DatasetS3LineageProviderConfigBase(ConfigModel): class RedshiftConfig( - PostgresConfig, + BasePostgresConfig, BaseTimeWindowConfig, DatasetLineageProviderConfigBase, DatasetS3LineageProviderConfigBase, diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_config.py b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_config.py index 10508ab0f1..bad59c1f71 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_config.py @@ -125,7 +125,7 @@ class BasicSQLAlchemyConfig(SQLAlchemyConfig): self.username, self.password.get_secret_value() if self.password is not None else None, self.host_port, - self.database or database, + database or self.database, uri_opts=uri_opts, ) diff --git a/metadata-ingestion/tests/integration/postgres/postgres_all_db_to_file_with_db_estimate_row_count.yml b/metadata-ingestion/tests/integration/postgres/postgres_all_db_to_file_with_db_estimate_row_count.yml index 04ec6f01d3..69c20352f6 100644 --- a/metadata-ingestion/tests/integration/postgres/postgres_all_db_to_file_with_db_estimate_row_count.yml +++ b/metadata-ingestion/tests/integration/postgres/postgres_all_db_to_file_with_db_estimate_row_count.yml @@ -4,6 +4,7 @@ source: type: postgres config: host_port: 'localhost:5432' + initial_database: 'postgrestest' username: 'postgres' password: 'example' include_view_lineage: True diff --git a/metadata-ingestion/tests/unit/test_postgres_source.py b/metadata-ingestion/tests/unit/test_postgres_source.py index 3db8d27bbc..fac491cbae 100644 --- a/metadata-ingestion/tests/unit/test_postgres_source.py +++ b/metadata-ingestion/tests/unit/test_postgres_source.py @@ -1,4 +1,5 @@ from unittest import mock +from unittest.mock import patch from datahub.ingestion.api.common import PipelineContext from datahub.ingestion.source.sql.postgres import PostgresConfig, PostgresSource @@ -8,6 +9,62 @@ def _base_config(): return {"username": "user", "password": "password", "host_port": "host:1521"} +@patch("datahub.ingestion.source.sql.postgres.create_engine") +def test_initial_database(create_engine_mock): + config = PostgresConfig.parse_obj(_base_config()) + assert config.initial_database == "postgres" + source = PostgresSource(config, PipelineContext(run_id="test")) + _ = list(source.get_inspectors()) + assert create_engine_mock.call_count == 1 + assert create_engine_mock.call_args[0][0].endswith("postgres") + + +@patch("datahub.ingestion.source.sql.postgres.create_engine") +def test_get_inspectors_multiple_databases(create_engine_mock): + execute_mock = ( + create_engine_mock.return_value.connect.return_value.__enter__.return_value.execute + ) + execute_mock.return_value = [{"datname": "db1"}, {"datname": "db2"}] + + config = PostgresConfig.parse_obj({**_base_config(), "initial_database": "db0"}) + source = PostgresSource(config, PipelineContext(run_id="test")) + _ = list(source.get_inspectors()) + assert create_engine_mock.call_count == 3 + assert create_engine_mock.call_args_list[0][0][0].endswith("db0") + assert create_engine_mock.call_args_list[1][0][0].endswith("db1") + assert create_engine_mock.call_args_list[2][0][0].endswith("db2") + + +@patch("datahub.ingestion.source.sql.postgres.create_engine") +def tests_get_inspectors_with_database_provided(create_engine_mock): + execute_mock = ( + create_engine_mock.return_value.connect.return_value.__enter__.return_value.execute + ) + execute_mock.return_value = [{"datname": "db1"}, {"datname": "db2"}] + + config = PostgresConfig.parse_obj({**_base_config(), "database": "custom_db"}) + source = PostgresSource(config, PipelineContext(run_id="test")) + _ = list(source.get_inspectors()) + assert create_engine_mock.call_count == 1 + assert create_engine_mock.call_args_list[0][0][0].endswith("custom_db") + + +@patch("datahub.ingestion.source.sql.postgres.create_engine") +def tests_get_inspectors_with_sqlalchemy_uri_provided(create_engine_mock): + execute_mock = ( + create_engine_mock.return_value.connect.return_value.__enter__.return_value.execute + ) + execute_mock.return_value = [{"datname": "db1"}, {"datname": "db2"}] + + config = PostgresConfig.parse_obj( + {**_base_config(), "sqlalchemy_uri": "custom_url"} + ) + source = PostgresSource(config, PipelineContext(run_id="test")) + _ = list(source.get_inspectors()) + assert create_engine_mock.call_count == 1 + assert create_engine_mock.call_args_list[0][0][0] == "custom_url" + + def test_database_alias_takes_precendence(): config = PostgresConfig.parse_obj( {