mirror of
https://github.com/datahub-project/datahub.git
synced 2025-07-04 15:50:14 +00:00
83 lines
3.5 KiB
Python
83 lines
3.5 KiB
Python
from unittest import mock
|
|
from unittest.mock import patch
|
|
|
|
from datahub.ingestion.api.common import PipelineContext
|
|
from datahub.ingestion.source.sql.postgres import PostgresConfig, PostgresSource
|
|
|
|
|
|
def _base_config():
|
|
return {"username": "user", "password": "password", "host_port": "host:1521"}
|
|
|
|
|
|
@patch("datahub.ingestion.source.sql.postgres.create_engine")
|
|
def test_initial_database(create_engine_mock):
|
|
config = PostgresConfig.parse_obj(_base_config())
|
|
assert config.initial_database == "postgres"
|
|
source = PostgresSource(config, PipelineContext(run_id="test"))
|
|
_ = list(source.get_inspectors())
|
|
assert create_engine_mock.call_count == 1
|
|
assert create_engine_mock.call_args[0][0].endswith("postgres")
|
|
|
|
|
|
@patch("datahub.ingestion.source.sql.postgres.create_engine")
|
|
def test_get_inspectors_multiple_databases(create_engine_mock):
|
|
execute_mock = create_engine_mock.return_value.connect.return_value.__enter__.return_value.execute
|
|
execute_mock.return_value = [{"datname": "db1"}, {"datname": "db2"}]
|
|
|
|
config = PostgresConfig.parse_obj({**_base_config(), "initial_database": "db0"})
|
|
source = PostgresSource(config, PipelineContext(run_id="test"))
|
|
_ = list(source.get_inspectors())
|
|
assert create_engine_mock.call_count == 3
|
|
assert create_engine_mock.call_args_list[0][0][0].endswith("db0")
|
|
assert create_engine_mock.call_args_list[1][0][0].endswith("db1")
|
|
assert create_engine_mock.call_args_list[2][0][0].endswith("db2")
|
|
|
|
|
|
@patch("datahub.ingestion.source.sql.postgres.create_engine")
|
|
def tests_get_inspectors_with_database_provided(create_engine_mock):
|
|
execute_mock = create_engine_mock.return_value.connect.return_value.__enter__.return_value.execute
|
|
execute_mock.return_value = [{"datname": "db1"}, {"datname": "db2"}]
|
|
|
|
config = PostgresConfig.parse_obj({**_base_config(), "database": "custom_db"})
|
|
source = PostgresSource(config, PipelineContext(run_id="test"))
|
|
_ = list(source.get_inspectors())
|
|
assert create_engine_mock.call_count == 1
|
|
assert create_engine_mock.call_args_list[0][0][0].endswith("custom_db")
|
|
|
|
|
|
@patch("datahub.ingestion.source.sql.postgres.create_engine")
|
|
def tests_get_inspectors_with_sqlalchemy_uri_provided(create_engine_mock):
|
|
execute_mock = create_engine_mock.return_value.connect.return_value.__enter__.return_value.execute
|
|
execute_mock.return_value = [{"datname": "db1"}, {"datname": "db2"}]
|
|
|
|
config = PostgresConfig.parse_obj(
|
|
{**_base_config(), "sqlalchemy_uri": "custom_url"}
|
|
)
|
|
source = PostgresSource(config, PipelineContext(run_id="test"))
|
|
_ = list(source.get_inspectors())
|
|
assert create_engine_mock.call_count == 1
|
|
assert create_engine_mock.call_args_list[0][0][0] == "custom_url"
|
|
|
|
|
|
def test_database_in_identifier():
|
|
config = PostgresConfig.parse_obj({**_base_config(), "database": "postgres"})
|
|
mock_inspector = mock.MagicMock()
|
|
assert (
|
|
PostgresSource(config, PipelineContext(run_id="test")).get_identifier(
|
|
schema="superset", entity="logs", inspector=mock_inspector
|
|
)
|
|
== "postgres.superset.logs"
|
|
)
|
|
|
|
|
|
def test_current_sqlalchemy_database_in_identifier():
|
|
config = PostgresConfig.parse_obj({**_base_config()})
|
|
mock_inspector = mock.MagicMock()
|
|
mock_inspector.engine.url.database = "current_db"
|
|
assert (
|
|
PostgresSource(config, PipelineContext(run_id="test")).get_identifier(
|
|
schema="superset", entity="logs", inspector=mock_inspector
|
|
)
|
|
== "current_db.superset.logs"
|
|
)
|