fix(ingest/postgres): Allow specification of initial engine database; set default database to postgres (#7915)

Co-authored-by: Mayuri Nehate <33225191+mayurinehate@users.noreply.github.com>
This commit is contained in:
Andrew Sikowitz 2023-05-09 14:11:43 -04:00 committed by GitHub
parent 94e7e51175
commit 44406f7adf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 89 additions and 19 deletions

View File

@ -8,7 +8,7 @@ from datahub.configuration import ConfigModel
from datahub.configuration.pydantic_field_deprecation import pydantic_field_deprecated
from datahub.configuration.source_common import DatasetLineageProviderConfigBase
from datahub.ingestion.source.data_lake_common.path_spec import PathSpec
from datahub.ingestion.source.sql.postgres import PostgresConfig
from datahub.ingestion.source.sql.postgres import BasePostgresConfig
from datahub.ingestion.source.state.stateful_ingestion_base import (
StatefulLineageConfigMixin,
StatefulProfilingConfigMixin,
@ -60,7 +60,7 @@ class RedshiftUsageConfig(BaseUsageConfig, StatefulUsageConfigMixin):
class RedshiftConfig(
PostgresConfig,
BasePostgresConfig,
DatasetLineageProviderConfigBase,
S3DatasetLineageProviderConfigBase,
RedshiftUsageConfig,

View File

@ -97,22 +97,34 @@ class ViewLineageEntry(BaseModel):
dependent_schema: str
class PostgresConfig(BasicSQLAlchemyConfig):
# defaults
class BasePostgresConfig(BasicSQLAlchemyConfig):
scheme = Field(default="postgresql+psycopg2", description="database scheme")
schema_pattern = Field(default=AllowDenyPattern(deny=["information_schema"]))
class PostgresConfig(BasePostgresConfig):
include_view_lineage = Field(
default=False, description="Include table lineage for views"
)
database_pattern: AllowDenyPattern = Field(
default=AllowDenyPattern.allow_all(),
description="Regex patterns for databases to filter in ingestion.",
description=(
"Regex patterns for databases to filter in ingestion. "
"Note: this is not used if `database` or `sqlalchemy_uri` are provided."
),
)
database: Optional[str] = Field(
default=None,
description="database (catalog). If set to Null, all databases will be considered for ingestion.",
)
initial_database: Optional[str] = Field(
default="postgres",
description=(
"Initial database used to query for the list of databases, when ingesting multiple databases. "
"Note: this is not used if `database` or `sqlalchemy_uri` are provided."
),
)
@platform_name("Postgres")
@ -144,13 +156,14 @@ class PostgresSource(SQLAlchemySource):
return cls(config, ctx)
def get_inspectors(self) -> Iterable[Inspector]:
# This method can be overridden in the case that you want to dynamically
# run on multiple databases.
url = self.config.get_sql_alchemy_url()
# Note: get_sql_alchemy_url will choose `sqlalchemy_uri` over the passed in database
url = self.config.get_sql_alchemy_url(
database=self.config.database or self.config.initial_database
)
logger.debug(f"sql_alchemy_url={url}")
engine = create_engine(url, **self.config.options)
with engine.connect() as conn:
if self.config.database:
if self.config.database or self.config.sqlalchemy_uri:
inspector = inspect(conn)
yield inspector
else:
@ -160,13 +173,12 @@ class PostgresSource(SQLAlchemySource):
"SELECT datname from pg_database where datname not in ('template0', 'template1')"
)
for db in databases:
if self.config.database_pattern.allowed(db["datname"]):
url = self.config.get_sql_alchemy_url(database=db["datname"])
with create_engine(
url, **self.config.options
).connect() as conn:
inspector = inspect(conn)
yield inspector
if not self.config.database_pattern.allowed(db["datname"]):
continue
url = self.config.get_sql_alchemy_url(database=db["datname"])
with create_engine(url, **self.config.options).connect() as conn:
inspector = inspect(conn)
yield inspector
def get_workunits(self) -> Iterable[Union[MetadataWorkUnit, SqlWorkUnit]]:
yield from super().get_workunits()

View File

@ -34,7 +34,7 @@ from datahub.ingestion.api.decorators import (
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.aws.s3_util import strip_s3_prefix
from datahub.ingestion.source.data_lake_common.path_spec import PathSpec
from datahub.ingestion.source.sql.postgres import PostgresConfig
from datahub.ingestion.source.sql.postgres import BasePostgresConfig
from datahub.ingestion.source.sql.sql_common import (
SQLAlchemySource,
SQLSourceReport,
@ -124,7 +124,7 @@ class DatasetS3LineageProviderConfigBase(ConfigModel):
class RedshiftConfig(
PostgresConfig,
BasePostgresConfig,
BaseTimeWindowConfig,
DatasetLineageProviderConfigBase,
DatasetS3LineageProviderConfigBase,

View File

@ -125,7 +125,7 @@ class BasicSQLAlchemyConfig(SQLAlchemyConfig):
self.username,
self.password.get_secret_value() if self.password is not None else None,
self.host_port,
self.database or database,
database or self.database,
uri_opts=uri_opts,
)

View File

@ -4,6 +4,7 @@ source:
type: postgres
config:
host_port: 'localhost:5432'
initial_database: 'postgrestest'
username: 'postgres'
password: 'example'
include_view_lineage: True

View File

@ -1,4 +1,5 @@
from unittest import mock
from unittest.mock import patch
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.source.sql.postgres import PostgresConfig, PostgresSource
@ -8,6 +9,62 @@ def _base_config():
return {"username": "user", "password": "password", "host_port": "host:1521"}
@patch("datahub.ingestion.source.sql.postgres.create_engine")
def test_initial_database(create_engine_mock):
config = PostgresConfig.parse_obj(_base_config())
assert config.initial_database == "postgres"
source = PostgresSource(config, PipelineContext(run_id="test"))
_ = list(source.get_inspectors())
assert create_engine_mock.call_count == 1
assert create_engine_mock.call_args[0][0].endswith("postgres")
@patch("datahub.ingestion.source.sql.postgres.create_engine")
def test_get_inspectors_multiple_databases(create_engine_mock):
execute_mock = (
create_engine_mock.return_value.connect.return_value.__enter__.return_value.execute
)
execute_mock.return_value = [{"datname": "db1"}, {"datname": "db2"}]
config = PostgresConfig.parse_obj({**_base_config(), "initial_database": "db0"})
source = PostgresSource(config, PipelineContext(run_id="test"))
_ = list(source.get_inspectors())
assert create_engine_mock.call_count == 3
assert create_engine_mock.call_args_list[0][0][0].endswith("db0")
assert create_engine_mock.call_args_list[1][0][0].endswith("db1")
assert create_engine_mock.call_args_list[2][0][0].endswith("db2")
@patch("datahub.ingestion.source.sql.postgres.create_engine")
def tests_get_inspectors_with_database_provided(create_engine_mock):
execute_mock = (
create_engine_mock.return_value.connect.return_value.__enter__.return_value.execute
)
execute_mock.return_value = [{"datname": "db1"}, {"datname": "db2"}]
config = PostgresConfig.parse_obj({**_base_config(), "database": "custom_db"})
source = PostgresSource(config, PipelineContext(run_id="test"))
_ = list(source.get_inspectors())
assert create_engine_mock.call_count == 1
assert create_engine_mock.call_args_list[0][0][0].endswith("custom_db")
@patch("datahub.ingestion.source.sql.postgres.create_engine")
def tests_get_inspectors_with_sqlalchemy_uri_provided(create_engine_mock):
execute_mock = (
create_engine_mock.return_value.connect.return_value.__enter__.return_value.execute
)
execute_mock.return_value = [{"datname": "db1"}, {"datname": "db2"}]
config = PostgresConfig.parse_obj(
{**_base_config(), "sqlalchemy_uri": "custom_url"}
)
source = PostgresSource(config, PipelineContext(run_id="test"))
_ = list(source.get_inspectors())
assert create_engine_mock.call_count == 1
assert create_engine_mock.call_args_list[0][0][0] == "custom_url"
def test_database_alias_takes_precendence():
config = PostgresConfig.parse_obj(
{