datahub/metadata-ingestion/tests/unit/test_postgres_source.py

83 lines
3.5 KiB
Python

from unittest import mock
from unittest.mock import patch
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.source.sql.postgres import PostgresConfig, PostgresSource
def _base_config():
return {"username": "user", "password": "password", "host_port": "host:1521"}
@patch("datahub.ingestion.source.sql.postgres.create_engine")
def test_initial_database(create_engine_mock):
config = PostgresConfig.parse_obj(_base_config())
assert config.initial_database == "postgres"
source = PostgresSource(config, PipelineContext(run_id="test"))
_ = list(source.get_inspectors())
assert create_engine_mock.call_count == 1
assert create_engine_mock.call_args[0][0].endswith("postgres")
@patch("datahub.ingestion.source.sql.postgres.create_engine")
def test_get_inspectors_multiple_databases(create_engine_mock):
execute_mock = create_engine_mock.return_value.connect.return_value.__enter__.return_value.execute
execute_mock.return_value = [{"datname": "db1"}, {"datname": "db2"}]
config = PostgresConfig.parse_obj({**_base_config(), "initial_database": "db0"})
source = PostgresSource(config, PipelineContext(run_id="test"))
_ = list(source.get_inspectors())
assert create_engine_mock.call_count == 3
assert create_engine_mock.call_args_list[0][0][0].endswith("db0")
assert create_engine_mock.call_args_list[1][0][0].endswith("db1")
assert create_engine_mock.call_args_list[2][0][0].endswith("db2")
@patch("datahub.ingestion.source.sql.postgres.create_engine")
def tests_get_inspectors_with_database_provided(create_engine_mock):
execute_mock = create_engine_mock.return_value.connect.return_value.__enter__.return_value.execute
execute_mock.return_value = [{"datname": "db1"}, {"datname": "db2"}]
config = PostgresConfig.parse_obj({**_base_config(), "database": "custom_db"})
source = PostgresSource(config, PipelineContext(run_id="test"))
_ = list(source.get_inspectors())
assert create_engine_mock.call_count == 1
assert create_engine_mock.call_args_list[0][0][0].endswith("custom_db")
@patch("datahub.ingestion.source.sql.postgres.create_engine")
def tests_get_inspectors_with_sqlalchemy_uri_provided(create_engine_mock):
execute_mock = create_engine_mock.return_value.connect.return_value.__enter__.return_value.execute
execute_mock.return_value = [{"datname": "db1"}, {"datname": "db2"}]
config = PostgresConfig.parse_obj(
{**_base_config(), "sqlalchemy_uri": "custom_url"}
)
source = PostgresSource(config, PipelineContext(run_id="test"))
_ = list(source.get_inspectors())
assert create_engine_mock.call_count == 1
assert create_engine_mock.call_args_list[0][0][0] == "custom_url"
def test_database_in_identifier():
config = PostgresConfig.parse_obj({**_base_config(), "database": "postgres"})
mock_inspector = mock.MagicMock()
assert (
PostgresSource(config, PipelineContext(run_id="test")).get_identifier(
schema="superset", entity="logs", inspector=mock_inspector
)
== "postgres.superset.logs"
)
def test_current_sqlalchemy_database_in_identifier():
config = PostgresConfig.parse_obj({**_base_config()})
mock_inspector = mock.MagicMock()
mock_inspector.engine.url.database = "current_db"
assert (
PostgresSource(config, PipelineContext(run_id="test")).get_identifier(
schema="superset", entity="logs", inspector=mock_inspector
)
== "current_db.superset.logs"
)