datahub/metadata-ingestion/tests/unit/test_postgres_source.py
Andrew Sikowitz 44406f7adf
fix(ingest/postgres): Allow specification of initial engine database; set default database to postgres (#7915)
Co-authored-by: Mayuri Nehate <33225191+mayurinehate@users.noreply.github.com>
2023-05-09 11:11:43 -07:00

106 lines
4.0 KiB
Python

from unittest import mock
from unittest.mock import patch
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.source.sql.postgres import PostgresConfig, PostgresSource
def _base_config():
return {"username": "user", "password": "password", "host_port": "host:1521"}
@patch("datahub.ingestion.source.sql.postgres.create_engine")
def test_initial_database(create_engine_mock):
config = PostgresConfig.parse_obj(_base_config())
assert config.initial_database == "postgres"
source = PostgresSource(config, PipelineContext(run_id="test"))
_ = list(source.get_inspectors())
assert create_engine_mock.call_count == 1
assert create_engine_mock.call_args[0][0].endswith("postgres")
@patch("datahub.ingestion.source.sql.postgres.create_engine")
def test_get_inspectors_multiple_databases(create_engine_mock):
execute_mock = (
create_engine_mock.return_value.connect.return_value.__enter__.return_value.execute
)
execute_mock.return_value = [{"datname": "db1"}, {"datname": "db2"}]
config = PostgresConfig.parse_obj({**_base_config(), "initial_database": "db0"})
source = PostgresSource(config, PipelineContext(run_id="test"))
_ = list(source.get_inspectors())
assert create_engine_mock.call_count == 3
assert create_engine_mock.call_args_list[0][0][0].endswith("db0")
assert create_engine_mock.call_args_list[1][0][0].endswith("db1")
assert create_engine_mock.call_args_list[2][0][0].endswith("db2")
@patch("datahub.ingestion.source.sql.postgres.create_engine")
def tests_get_inspectors_with_database_provided(create_engine_mock):
execute_mock = (
create_engine_mock.return_value.connect.return_value.__enter__.return_value.execute
)
execute_mock.return_value = [{"datname": "db1"}, {"datname": "db2"}]
config = PostgresConfig.parse_obj({**_base_config(), "database": "custom_db"})
source = PostgresSource(config, PipelineContext(run_id="test"))
_ = list(source.get_inspectors())
assert create_engine_mock.call_count == 1
assert create_engine_mock.call_args_list[0][0][0].endswith("custom_db")
@patch("datahub.ingestion.source.sql.postgres.create_engine")
def tests_get_inspectors_with_sqlalchemy_uri_provided(create_engine_mock):
execute_mock = (
create_engine_mock.return_value.connect.return_value.__enter__.return_value.execute
)
execute_mock.return_value = [{"datname": "db1"}, {"datname": "db2"}]
config = PostgresConfig.parse_obj(
{**_base_config(), "sqlalchemy_uri": "custom_url"}
)
source = PostgresSource(config, PipelineContext(run_id="test"))
_ = list(source.get_inspectors())
assert create_engine_mock.call_count == 1
assert create_engine_mock.call_args_list[0][0][0] == "custom_url"
def test_database_alias_takes_precendence():
config = PostgresConfig.parse_obj(
{
**_base_config(),
"database_alias": "ops_database",
"database": "postgres",
}
)
mock_inspector = mock.MagicMock()
assert (
PostgresSource(config, PipelineContext(run_id="test")).get_identifier(
schema="superset", entity="logs", inspector=mock_inspector
)
== "ops_database.superset.logs"
)
def test_database_in_identifier():
config = PostgresConfig.parse_obj({**_base_config(), "database": "postgres"})
mock_inspector = mock.MagicMock()
assert (
PostgresSource(config, PipelineContext(run_id="test")).get_identifier(
schema="superset", entity="logs", inspector=mock_inspector
)
== "postgres.superset.logs"
)
def test_current_sqlalchemy_database_in_identifier():
config = PostgresConfig.parse_obj({**_base_config()})
mock_inspector = mock.MagicMock()
mock_inspector.engine.url.database = "current_db"
assert (
PostgresSource(config, PipelineContext(run_id="test")).get_identifier(
schema="superset", entity="logs", inspector=mock_inspector
)
== "current_db.superset.logs"
)