datahub/metadata-ingestion/tests/unit/snowflake/test_snowflake_summary.py

74 lines
2.4 KiB
Python

from datetime import datetime, timezone
from unittest import mock
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.source.snowflake.snowflake_summary import (
SnowflakeSummaryConfig,
SnowflakeSummarySource,
)
@mock.patch("snowflake.connector.connect")
def test_snowflake_summary_source_initialization(mock_sf_connect):
# Mock the connection object and its query method
mock_conn = mock.MagicMock()
mock_conn.query.return_value = []
mock_sf_connect.return_value = mock_conn
# Create a basic config
config = SnowflakeSummaryConfig(
account_id="test_account",
username="test_user",
password="test_password",
start_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
end_time=datetime(2024, 1, 2, tzinfo=timezone.utc),
)
# Create a mock context
ctx = PipelineContext(run_id="test")
# Create the source
source = SnowflakeSummarySource(ctx, config)
# Get workunits to trigger initialization
list(source.get_workunits_internal())
# Verify that SnowflakeSchemaGenerator was initialized with all required parameters
report = source.get_report()
assert isinstance(report, source.report.__class__)
assert hasattr(report, "schema_counters")
assert hasattr(report, "object_counters")
assert hasattr(report, "num_snowflake_queries")
assert hasattr(report, "num_snowflake_mutations")
@mock.patch("snowflake.connector.connect")
def test_snowflake_summary_source_missing_filters(mock_sf_connect):
# Mock the connection object and its query method
mock_conn = mock.MagicMock()
mock_conn.query.return_value = []
mock_sf_connect.return_value = mock_conn
# Create a basic config
config = SnowflakeSummaryConfig(
account_id="test_account",
username="test_user",
password="test_password",
start_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
end_time=datetime(2024, 1, 2, tzinfo=timezone.utc),
)
# Create a mock context
ctx = PipelineContext(run_id="test")
# Create the source
source = SnowflakeSummarySource(ctx, config)
# Get workunits to trigger initialization
list(source.get_workunits_internal())
# Verify that SnowflakeSchemaGenerator was initialized with filters
report = source.get_report()
assert isinstance(report, source.report.__class__)
assert hasattr(report, "filtered") # This is added by the SnowflakeFilter