mirror of
				https://github.com/datahub-project/datahub.git
				synced 2025-11-04 12:51:23 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			83 lines
		
	
	
		
			3.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			83 lines
		
	
	
		
			3.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from unittest import mock
 | 
						|
from unittest.mock import patch
 | 
						|
 | 
						|
from datahub.ingestion.api.common import PipelineContext
 | 
						|
from datahub.ingestion.source.sql.postgres import PostgresConfig, PostgresSource
 | 
						|
 | 
						|
 | 
						|
def _base_config():
 | 
						|
    return {"username": "user", "password": "password", "host_port": "host:1521"}
 | 
						|
 | 
						|
 | 
						|
@patch("datahub.ingestion.source.sql.postgres.create_engine")
 | 
						|
def test_initial_database(create_engine_mock):
 | 
						|
    config = PostgresConfig.parse_obj(_base_config())
 | 
						|
    assert config.initial_database == "postgres"
 | 
						|
    source = PostgresSource(config, PipelineContext(run_id="test"))
 | 
						|
    _ = list(source.get_inspectors())
 | 
						|
    assert create_engine_mock.call_count == 1
 | 
						|
    assert create_engine_mock.call_args[0][0].endswith("postgres")
 | 
						|
 | 
						|
 | 
						|
@patch("datahub.ingestion.source.sql.postgres.create_engine")
 | 
						|
def test_get_inspectors_multiple_databases(create_engine_mock):
 | 
						|
    execute_mock = create_engine_mock.return_value.connect.return_value.__enter__.return_value.execute
 | 
						|
    execute_mock.return_value = [{"datname": "db1"}, {"datname": "db2"}]
 | 
						|
 | 
						|
    config = PostgresConfig.parse_obj({**_base_config(), "initial_database": "db0"})
 | 
						|
    source = PostgresSource(config, PipelineContext(run_id="test"))
 | 
						|
    _ = list(source.get_inspectors())
 | 
						|
    assert create_engine_mock.call_count == 3
 | 
						|
    assert create_engine_mock.call_args_list[0][0][0].endswith("db0")
 | 
						|
    assert create_engine_mock.call_args_list[1][0][0].endswith("db1")
 | 
						|
    assert create_engine_mock.call_args_list[2][0][0].endswith("db2")
 | 
						|
 | 
						|
 | 
						|
@patch("datahub.ingestion.source.sql.postgres.create_engine")
 | 
						|
def tests_get_inspectors_with_database_provided(create_engine_mock):
 | 
						|
    execute_mock = create_engine_mock.return_value.connect.return_value.__enter__.return_value.execute
 | 
						|
    execute_mock.return_value = [{"datname": "db1"}, {"datname": "db2"}]
 | 
						|
 | 
						|
    config = PostgresConfig.parse_obj({**_base_config(), "database": "custom_db"})
 | 
						|
    source = PostgresSource(config, PipelineContext(run_id="test"))
 | 
						|
    _ = list(source.get_inspectors())
 | 
						|
    assert create_engine_mock.call_count == 1
 | 
						|
    assert create_engine_mock.call_args_list[0][0][0].endswith("custom_db")
 | 
						|
 | 
						|
 | 
						|
@patch("datahub.ingestion.source.sql.postgres.create_engine")
 | 
						|
def tests_get_inspectors_with_sqlalchemy_uri_provided(create_engine_mock):
 | 
						|
    execute_mock = create_engine_mock.return_value.connect.return_value.__enter__.return_value.execute
 | 
						|
    execute_mock.return_value = [{"datname": "db1"}, {"datname": "db2"}]
 | 
						|
 | 
						|
    config = PostgresConfig.parse_obj(
 | 
						|
        {**_base_config(), "sqlalchemy_uri": "custom_url"}
 | 
						|
    )
 | 
						|
    source = PostgresSource(config, PipelineContext(run_id="test"))
 | 
						|
    _ = list(source.get_inspectors())
 | 
						|
    assert create_engine_mock.call_count == 1
 | 
						|
    assert create_engine_mock.call_args_list[0][0][0] == "custom_url"
 | 
						|
 | 
						|
 | 
						|
def test_database_in_identifier():
 | 
						|
    config = PostgresConfig.parse_obj({**_base_config(), "database": "postgres"})
 | 
						|
    mock_inspector = mock.MagicMock()
 | 
						|
    assert (
 | 
						|
        PostgresSource(config, PipelineContext(run_id="test")).get_identifier(
 | 
						|
            schema="superset", entity="logs", inspector=mock_inspector
 | 
						|
        )
 | 
						|
        == "postgres.superset.logs"
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def test_current_sqlalchemy_database_in_identifier():
 | 
						|
    config = PostgresConfig.parse_obj({**_base_config()})
 | 
						|
    mock_inspector = mock.MagicMock()
 | 
						|
    mock_inspector.engine.url.database = "current_db"
 | 
						|
    assert (
 | 
						|
        PostgresSource(config, PipelineContext(run_id="test")).get_identifier(
 | 
						|
            schema="superset", entity="logs", inspector=mock_inspector
 | 
						|
        )
 | 
						|
        == "current_db.superset.logs"
 | 
						|
    )
 |