mirror of
https://github.com/datahub-project/datahub.git
synced 2025-08-07 00:37:56 +00:00
74 lines
2.4 KiB
Python
74 lines
2.4 KiB
Python
from datetime import datetime, timezone
|
|
from unittest import mock
|
|
|
|
from datahub.ingestion.api.common import PipelineContext
|
|
from datahub.ingestion.source.snowflake.snowflake_summary import (
|
|
SnowflakeSummaryConfig,
|
|
SnowflakeSummarySource,
|
|
)
|
|
|
|
|
|
@mock.patch("snowflake.connector.connect")
|
|
def test_snowflake_summary_source_initialization(mock_sf_connect):
|
|
# Mock the connection object and its query method
|
|
mock_conn = mock.MagicMock()
|
|
mock_conn.query.return_value = []
|
|
mock_sf_connect.return_value = mock_conn
|
|
|
|
# Create a basic config
|
|
config = SnowflakeSummaryConfig(
|
|
account_id="test_account",
|
|
username="test_user",
|
|
password="test_password",
|
|
start_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
|
|
end_time=datetime(2024, 1, 2, tzinfo=timezone.utc),
|
|
)
|
|
|
|
# Create a mock context
|
|
ctx = PipelineContext(run_id="test")
|
|
|
|
# Create the source
|
|
source = SnowflakeSummarySource(ctx, config)
|
|
|
|
# Get workunits to trigger initialization
|
|
list(source.get_workunits_internal())
|
|
|
|
# Verify that SnowflakeSchemaGenerator was initialized with all required parameters
|
|
report = source.get_report()
|
|
assert isinstance(report, source.report.__class__)
|
|
assert hasattr(report, "schema_counters")
|
|
assert hasattr(report, "object_counters")
|
|
assert hasattr(report, "num_snowflake_queries")
|
|
assert hasattr(report, "num_snowflake_mutations")
|
|
|
|
|
|
@mock.patch("snowflake.connector.connect")
|
|
def test_snowflake_summary_source_missing_filters(mock_sf_connect):
|
|
# Mock the connection object and its query method
|
|
mock_conn = mock.MagicMock()
|
|
mock_conn.query.return_value = []
|
|
mock_sf_connect.return_value = mock_conn
|
|
|
|
# Create a basic config
|
|
config = SnowflakeSummaryConfig(
|
|
account_id="test_account",
|
|
username="test_user",
|
|
password="test_password",
|
|
start_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
|
|
end_time=datetime(2024, 1, 2, tzinfo=timezone.utc),
|
|
)
|
|
|
|
# Create a mock context
|
|
ctx = PipelineContext(run_id="test")
|
|
|
|
# Create the source
|
|
source = SnowflakeSummarySource(ctx, config)
|
|
|
|
# Get workunits to trigger initialization
|
|
list(source.get_workunits_internal())
|
|
|
|
# Verify that SnowflakeSchemaGenerator was initialized with filters
|
|
report = source.get_report()
|
|
assert isinstance(report, source.report.__class__)
|
|
assert hasattr(report, "filtered") # This is added by the SnowflakeFilter
|