mirror of
https://github.com/datahub-project/datahub.git
synced 2025-11-03 04:10:43 +00:00
fix(ingest/postgres): Allow specification of initial engine database; set default database to postgres (#7915)
Co-authored-by: Mayuri Nehate <33225191+mayurinehate@users.noreply.github.com>
This commit is contained in:
parent
94e7e51175
commit
44406f7adf
@ -8,7 +8,7 @@ from datahub.configuration import ConfigModel
|
||||
from datahub.configuration.pydantic_field_deprecation import pydantic_field_deprecated
|
||||
from datahub.configuration.source_common import DatasetLineageProviderConfigBase
|
||||
from datahub.ingestion.source.data_lake_common.path_spec import PathSpec
|
||||
from datahub.ingestion.source.sql.postgres import PostgresConfig
|
||||
from datahub.ingestion.source.sql.postgres import BasePostgresConfig
|
||||
from datahub.ingestion.source.state.stateful_ingestion_base import (
|
||||
StatefulLineageConfigMixin,
|
||||
StatefulProfilingConfigMixin,
|
||||
@ -60,7 +60,7 @@ class RedshiftUsageConfig(BaseUsageConfig, StatefulUsageConfigMixin):
|
||||
|
||||
|
||||
class RedshiftConfig(
|
||||
PostgresConfig,
|
||||
BasePostgresConfig,
|
||||
DatasetLineageProviderConfigBase,
|
||||
S3DatasetLineageProviderConfigBase,
|
||||
RedshiftUsageConfig,
|
||||
|
||||
@ -97,22 +97,34 @@ class ViewLineageEntry(BaseModel):
|
||||
dependent_schema: str
|
||||
|
||||
|
||||
class PostgresConfig(BasicSQLAlchemyConfig):
|
||||
# defaults
|
||||
class BasePostgresConfig(BasicSQLAlchemyConfig):
|
||||
scheme = Field(default="postgresql+psycopg2", description="database scheme")
|
||||
schema_pattern = Field(default=AllowDenyPattern(deny=["information_schema"]))
|
||||
|
||||
|
||||
class PostgresConfig(BasePostgresConfig):
|
||||
include_view_lineage = Field(
|
||||
default=False, description="Include table lineage for views"
|
||||
)
|
||||
|
||||
database_pattern: AllowDenyPattern = Field(
|
||||
default=AllowDenyPattern.allow_all(),
|
||||
description="Regex patterns for databases to filter in ingestion.",
|
||||
description=(
|
||||
"Regex patterns for databases to filter in ingestion. "
|
||||
"Note: this is not used if `database` or `sqlalchemy_uri` are provided."
|
||||
),
|
||||
)
|
||||
database: Optional[str] = Field(
|
||||
default=None,
|
||||
description="database (catalog). If set to Null, all databases will be considered for ingestion.",
|
||||
)
|
||||
initial_database: Optional[str] = Field(
|
||||
default="postgres",
|
||||
description=(
|
||||
"Initial database used to query for the list of databases, when ingesting multiple databases. "
|
||||
"Note: this is not used if `database` or `sqlalchemy_uri` are provided."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@platform_name("Postgres")
|
||||
@ -144,13 +156,14 @@ class PostgresSource(SQLAlchemySource):
|
||||
return cls(config, ctx)
|
||||
|
||||
def get_inspectors(self) -> Iterable[Inspector]:
|
||||
# This method can be overridden in the case that you want to dynamically
|
||||
# run on multiple databases.
|
||||
url = self.config.get_sql_alchemy_url()
|
||||
# Note: get_sql_alchemy_url will choose `sqlalchemy_uri` over the passed in database
|
||||
url = self.config.get_sql_alchemy_url(
|
||||
database=self.config.database or self.config.initial_database
|
||||
)
|
||||
logger.debug(f"sql_alchemy_url={url}")
|
||||
engine = create_engine(url, **self.config.options)
|
||||
with engine.connect() as conn:
|
||||
if self.config.database:
|
||||
if self.config.database or self.config.sqlalchemy_uri:
|
||||
inspector = inspect(conn)
|
||||
yield inspector
|
||||
else:
|
||||
@ -160,13 +173,12 @@ class PostgresSource(SQLAlchemySource):
|
||||
"SELECT datname from pg_database where datname not in ('template0', 'template1')"
|
||||
)
|
||||
for db in databases:
|
||||
if self.config.database_pattern.allowed(db["datname"]):
|
||||
url = self.config.get_sql_alchemy_url(database=db["datname"])
|
||||
with create_engine(
|
||||
url, **self.config.options
|
||||
).connect() as conn:
|
||||
inspector = inspect(conn)
|
||||
yield inspector
|
||||
if not self.config.database_pattern.allowed(db["datname"]):
|
||||
continue
|
||||
url = self.config.get_sql_alchemy_url(database=db["datname"])
|
||||
with create_engine(url, **self.config.options).connect() as conn:
|
||||
inspector = inspect(conn)
|
||||
yield inspector
|
||||
|
||||
def get_workunits(self) -> Iterable[Union[MetadataWorkUnit, SqlWorkUnit]]:
|
||||
yield from super().get_workunits()
|
||||
|
||||
@ -34,7 +34,7 @@ from datahub.ingestion.api.decorators import (
|
||||
from datahub.ingestion.api.workunit import MetadataWorkUnit
|
||||
from datahub.ingestion.source.aws.s3_util import strip_s3_prefix
|
||||
from datahub.ingestion.source.data_lake_common.path_spec import PathSpec
|
||||
from datahub.ingestion.source.sql.postgres import PostgresConfig
|
||||
from datahub.ingestion.source.sql.postgres import BasePostgresConfig
|
||||
from datahub.ingestion.source.sql.sql_common import (
|
||||
SQLAlchemySource,
|
||||
SQLSourceReport,
|
||||
@ -124,7 +124,7 @@ class DatasetS3LineageProviderConfigBase(ConfigModel):
|
||||
|
||||
|
||||
class RedshiftConfig(
|
||||
PostgresConfig,
|
||||
BasePostgresConfig,
|
||||
BaseTimeWindowConfig,
|
||||
DatasetLineageProviderConfigBase,
|
||||
DatasetS3LineageProviderConfigBase,
|
||||
|
||||
@ -125,7 +125,7 @@ class BasicSQLAlchemyConfig(SQLAlchemyConfig):
|
||||
self.username,
|
||||
self.password.get_secret_value() if self.password is not None else None,
|
||||
self.host_port,
|
||||
self.database or database,
|
||||
database or self.database,
|
||||
uri_opts=uri_opts,
|
||||
)
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ source:
|
||||
type: postgres
|
||||
config:
|
||||
host_port: 'localhost:5432'
|
||||
initial_database: 'postgrestest'
|
||||
username: 'postgres'
|
||||
password: 'example'
|
||||
include_view_lineage: True
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from unittest import mock
|
||||
from unittest.mock import patch
|
||||
|
||||
from datahub.ingestion.api.common import PipelineContext
|
||||
from datahub.ingestion.source.sql.postgres import PostgresConfig, PostgresSource
|
||||
@ -8,6 +9,62 @@ def _base_config():
|
||||
return {"username": "user", "password": "password", "host_port": "host:1521"}
|
||||
|
||||
|
||||
@patch("datahub.ingestion.source.sql.postgres.create_engine")
|
||||
def test_initial_database(create_engine_mock):
|
||||
config = PostgresConfig.parse_obj(_base_config())
|
||||
assert config.initial_database == "postgres"
|
||||
source = PostgresSource(config, PipelineContext(run_id="test"))
|
||||
_ = list(source.get_inspectors())
|
||||
assert create_engine_mock.call_count == 1
|
||||
assert create_engine_mock.call_args[0][0].endswith("postgres")
|
||||
|
||||
|
||||
@patch("datahub.ingestion.source.sql.postgres.create_engine")
|
||||
def test_get_inspectors_multiple_databases(create_engine_mock):
|
||||
execute_mock = (
|
||||
create_engine_mock.return_value.connect.return_value.__enter__.return_value.execute
|
||||
)
|
||||
execute_mock.return_value = [{"datname": "db1"}, {"datname": "db2"}]
|
||||
|
||||
config = PostgresConfig.parse_obj({**_base_config(), "initial_database": "db0"})
|
||||
source = PostgresSource(config, PipelineContext(run_id="test"))
|
||||
_ = list(source.get_inspectors())
|
||||
assert create_engine_mock.call_count == 3
|
||||
assert create_engine_mock.call_args_list[0][0][0].endswith("db0")
|
||||
assert create_engine_mock.call_args_list[1][0][0].endswith("db1")
|
||||
assert create_engine_mock.call_args_list[2][0][0].endswith("db2")
|
||||
|
||||
|
||||
@patch("datahub.ingestion.source.sql.postgres.create_engine")
|
||||
def tests_get_inspectors_with_database_provided(create_engine_mock):
|
||||
execute_mock = (
|
||||
create_engine_mock.return_value.connect.return_value.__enter__.return_value.execute
|
||||
)
|
||||
execute_mock.return_value = [{"datname": "db1"}, {"datname": "db2"}]
|
||||
|
||||
config = PostgresConfig.parse_obj({**_base_config(), "database": "custom_db"})
|
||||
source = PostgresSource(config, PipelineContext(run_id="test"))
|
||||
_ = list(source.get_inspectors())
|
||||
assert create_engine_mock.call_count == 1
|
||||
assert create_engine_mock.call_args_list[0][0][0].endswith("custom_db")
|
||||
|
||||
|
||||
@patch("datahub.ingestion.source.sql.postgres.create_engine")
|
||||
def tests_get_inspectors_with_sqlalchemy_uri_provided(create_engine_mock):
|
||||
execute_mock = (
|
||||
create_engine_mock.return_value.connect.return_value.__enter__.return_value.execute
|
||||
)
|
||||
execute_mock.return_value = [{"datname": "db1"}, {"datname": "db2"}]
|
||||
|
||||
config = PostgresConfig.parse_obj(
|
||||
{**_base_config(), "sqlalchemy_uri": "custom_url"}
|
||||
)
|
||||
source = PostgresSource(config, PipelineContext(run_id="test"))
|
||||
_ = list(source.get_inspectors())
|
||||
assert create_engine_mock.call_count == 1
|
||||
assert create_engine_mock.call_args_list[0][0][0] == "custom_url"
|
||||
|
||||
|
||||
def test_database_alias_takes_precendence():
|
||||
config = PostgresConfig.parse_obj(
|
||||
{
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user