datahub/metadata-ingestion/tests/unit/snowflake/test_snowflake_source.py

1076 lines
37 KiB
Python
Raw Normal View History

import datetime
import re
from typing import Any, Dict
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
import datahub.ingestion.source.snowflake.snowflake_utils
from datahub.configuration.common import AllowDenyPattern
from datahub.configuration.pattern_utils import UUID_REGEX
from datahub.ingestion.api.source import SourceCapability
from datahub.ingestion.source.snowflake.constants import (
CLIENT_PREFETCH_THREADS,
CLIENT_SESSION_KEEP_ALIVE,
SnowflakeCloudProvider,
SnowflakeObjectDomain,
)
from datahub.ingestion.source.snowflake.oauth_config import OAuthConfiguration
from datahub.ingestion.source.snowflake.snowflake_config import (
DEFAULT_TEMP_TABLES_PATTERNS,
SnowflakeIdentifierConfig,
SnowflakeV2Config,
)
from datahub.ingestion.source.snowflake.snowflake_lineage_v2 import UpstreamLineageEdge
from datahub.ingestion.source.snowflake.snowflake_queries import (
SnowflakeQueriesExtractor,
SnowflakeQueriesExtractorConfig,
)
from datahub.ingestion.source.snowflake.snowflake_query import (
SnowflakeQuery,
create_deny_regex_sql_filter,
)
from datahub.ingestion.source.snowflake.snowflake_usage_v2 import (
SnowflakeObjectAccessEntry,
)
from datahub.ingestion.source.snowflake.snowflake_utils import (
SnowflakeIdentifierBuilder,
SnowsightUrlBuilder,
)
from datahub.ingestion.source.snowflake.snowflake_v2 import SnowflakeV2Source
from datahub.sql_parsing.sql_parsing_aggregator import TableRename, TableSwap
from datahub.testing.doctest import assert_doctest
from tests.integration.snowflake.common import inject_rowcount
from tests.test_helpers import test_connection_helpers
default_oauth_dict: Dict[str, Any] = {
"client_id": "client_id",
"client_secret": "secret",
"use_certificate": False,
"provider": "microsoft",
"scopes": ["datahub_role"],
"authority_url": "https://dev-abc.okta.com/oauth2/def/v1/token",
}
def test_snowflake_source_throws_error_on_account_id_missing():
with pytest.raises(
ValidationError, match=re.compile(r"account_id.*Field required", re.DOTALL)
):
SnowflakeV2Config.parse_obj(
{
"username": "user",
"password": "password",
}
)
def test_no_client_id_invalid_oauth_config():
oauth_dict = default_oauth_dict.copy()
del oauth_dict["client_id"]
with pytest.raises(
ValueError, match=re.compile(r"client_id.*Field required", re.DOTALL)
):
OAuthConfiguration.parse_obj(oauth_dict)
def test_snowflake_throws_error_on_client_secret_missing_if_use_certificate_is_false():
oauth_dict = default_oauth_dict.copy()
del oauth_dict["client_secret"]
OAuthConfiguration.parse_obj(oauth_dict)
with pytest.raises(
ValueError,
match="'oauth_config.client_secret' was none but should be set when using use_certificate false for oauth_config",
):
SnowflakeV2Config.parse_obj(
{
"account_id": "test",
"authentication_type": "OAUTH_AUTHENTICATOR",
"oauth_config": oauth_dict,
}
)
def test_snowflake_throws_error_on_encoded_oauth_private_key_missing_if_use_certificate_is_true():
oauth_dict = default_oauth_dict.copy()
oauth_dict["use_certificate"] = True
OAuthConfiguration.parse_obj(oauth_dict)
with pytest.raises(
ValueError,
match="'base64_encoded_oauth_private_key' was none but should be set when using certificate for oauth_config",
):
SnowflakeV2Config.parse_obj(
{
"account_id": "test",
"authentication_type": "OAUTH_AUTHENTICATOR",
"oauth_config": oauth_dict,
}
)
def test_snowflake_oauth_okta_does_not_support_certificate():
oauth_dict = default_oauth_dict.copy()
oauth_dict["use_certificate"] = True
oauth_dict["provider"] = "okta"
OAuthConfiguration.parse_obj(oauth_dict)
with pytest.raises(
ValueError, match="Certificate authentication is not supported for Okta."
):
SnowflakeV2Config.parse_obj(
{
"account_id": "test",
"authentication_type": "OAUTH_AUTHENTICATOR",
"oauth_config": oauth_dict,
}
)
def test_snowflake_oauth_happy_paths():
oauth_dict = default_oauth_dict.copy()
oauth_dict["provider"] = "okta"
assert SnowflakeV2Config.parse_obj(
{
"account_id": "test",
"authentication_type": "OAUTH_AUTHENTICATOR",
"oauth_config": oauth_dict,
}
)
oauth_dict["use_certificate"] = True
oauth_dict["provider"] = "microsoft"
oauth_dict["encoded_oauth_public_key"] = "publickey"
oauth_dict["encoded_oauth_private_key"] = "privatekey"
assert SnowflakeV2Config.parse_obj(
{
"account_id": "test",
"authentication_type": "OAUTH_AUTHENTICATOR",
"oauth_config": oauth_dict,
}
)
def test_snowflake_oauth_token_happy_path():
assert SnowflakeV2Config.parse_obj(
{
"account_id": "test",
"authentication_type": "OAUTH_AUTHENTICATOR_TOKEN",
"token": "valid-token",
"username": "test-user",
"oauth_config": None,
}
)
def test_snowflake_oauth_token_without_token():
with pytest.raises(
ValidationError, match="Token required for OAUTH_AUTHENTICATOR_TOKEN."
):
SnowflakeV2Config.parse_obj(
{
"account_id": "test",
"authentication_type": "OAUTH_AUTHENTICATOR_TOKEN",
"username": "test-user",
}
)
def test_snowflake_oauth_token_with_wrong_auth_type():
with pytest.raises(
ValueError,
match="Token can only be provided when using OAUTH_AUTHENTICATOR_TOKEN.",
):
SnowflakeV2Config.parse_obj(
{
"account_id": "test",
"authentication_type": "OAUTH_AUTHENTICATOR",
"token": "some-token",
"username": "test-user",
}
)
def test_snowflake_oauth_token_with_empty_token():
with pytest.raises(
ValidationError, match="Token required for OAUTH_AUTHENTICATOR_TOKEN."
):
SnowflakeV2Config.parse_obj(
{
"account_id": "test",
"authentication_type": "OAUTH_AUTHENTICATOR_TOKEN",
"token": "",
"username": "test-user",
}
)
def test_config_fetch_views_from_information_schema():
"""Test the fetch_views_from_information_schema configuration parameter"""
# Test default value (False)
config_dict = {
"account_id": "test_account",
"username": "test_user",
"password": "test_pass",
}
config = SnowflakeV2Config.parse_obj(config_dict)
assert config.fetch_views_from_information_schema is False
# Test explicitly set to True
config_dict_true = {**config_dict, "fetch_views_from_information_schema": True}
config = SnowflakeV2Config.parse_obj(config_dict_true)
assert config.fetch_views_from_information_schema is True
# Test explicitly set to False
config_dict_false = {**config_dict, "fetch_views_from_information_schema": False}
config = SnowflakeV2Config.parse_obj(config_dict_false)
assert config.fetch_views_from_information_schema is False
default_config_dict: Dict[str, Any] = {
"username": "user",
"password": "password",
"account_id": "https://acctname.snowflakecomputing.com",
"warehouse": "COMPUTE_WH",
"role": "sysadmin",
}
def test_account_id_is_added_when_host_port_is_present():
config_dict = default_config_dict.copy()
del config_dict["account_id"]
config_dict["host_port"] = "acctname"
config = SnowflakeV2Config.parse_obj(config_dict)
assert config.account_id == "acctname"
def test_account_id_with_snowflake_host_suffix():
config = SnowflakeV2Config.parse_obj(default_config_dict)
assert config.account_id == "acctname"
def test_snowflake_uri_default_authentication():
config = SnowflakeV2Config.parse_obj(default_config_dict)
assert config.get_sql_alchemy_url() == (
"snowflake://user:password@acctname"
"?application=acryl_datahub"
"&authenticator=SNOWFLAKE"
"&role=sysadmin"
"&warehouse=COMPUTE_WH"
)
def test_snowflake_uri_external_browser_authentication():
config_dict = default_config_dict.copy()
del config_dict["password"]
config_dict["authentication_type"] = "EXTERNAL_BROWSER_AUTHENTICATOR"
config = SnowflakeV2Config.parse_obj(config_dict)
assert config.get_sql_alchemy_url() == (
"snowflake://user@acctname"
"?application=acryl_datahub"
"&authenticator=EXTERNALBROWSER"
"&role=sysadmin"
"&warehouse=COMPUTE_WH"
)
def test_snowflake_uri_key_pair_authentication():
config_dict = default_config_dict.copy()
del config_dict["password"]
config_dict["authentication_type"] = "KEY_PAIR_AUTHENTICATOR"
config_dict["private_key_path"] = "/a/random/path"
config_dict["private_key_password"] = "a_random_password"
config = SnowflakeV2Config.parse_obj(config_dict)
assert config.get_sql_alchemy_url() == (
"snowflake://user@acctname"
"?application=acryl_datahub"
"&authenticator=SNOWFLAKE_JWT"
"&role=sysadmin"
"&warehouse=COMPUTE_WH"
)
def test_options_contain_connect_args():
config = SnowflakeV2Config.parse_obj(default_config_dict)
connect_args = config.get_options().get("connect_args")
assert connect_args is not None
@patch(
"datahub.ingestion.source.snowflake.snowflake_connection.snowflake.connector.connect"
)
def test_snowflake_connection_with_default_domain(mock_connect):
"""Test that connection uses default .com domain when not specified"""
config_dict = default_config_dict.copy()
config = SnowflakeV2Config.parse_obj(config_dict)
mock_connect.return_value = MagicMock()
try:
config.get_connection()
except Exception:
pass # We expect this to fail since we're mocking, but we want to check the call args
mock_connect.assert_called_once()
call_kwargs = mock_connect.call_args[1]
assert call_kwargs["host"] == "acctname.snowflakecomputing.com"
@patch(
"datahub.ingestion.source.snowflake.snowflake_connection.snowflake.connector.connect"
)
def test_snowflake_connection_with_china_domain(mock_connect):
"""Test that connection uses China .cn domain when specified"""
config_dict = default_config_dict.copy()
config_dict["account_id"] = "test-account_cn"
config_dict["snowflake_domain"] = "snowflakecomputing.cn"
config = SnowflakeV2Config.parse_obj(config_dict)
mock_connect.return_value = MagicMock()
try:
config.get_connection()
except Exception:
pass # We expect this to fail since we're mocking, but we want to check the call args
mock_connect.assert_called_once()
call_kwargs = mock_connect.call_args[1]
assert call_kwargs["host"] == "test-account_cn.snowflakecomputing.cn"
def test_snowflake_config_with_column_lineage_no_table_lineage_throws_error():
config_dict = default_config_dict.copy()
config_dict["include_column_lineage"] = True
config_dict["include_table_lineage"] = False
with pytest.raises(
ValidationError,
match="include_table_lineage must be True for include_column_lineage to be set",
):
SnowflakeV2Config.parse_obj(config_dict)
def test_snowflake_config_with_no_connect_args_returns_base_connect_args():
config: SnowflakeV2Config = SnowflakeV2Config.parse_obj(default_config_dict)
assert config.get_options()["connect_args"] is not None
assert config.get_options()["connect_args"] == {
CLIENT_PREFETCH_THREADS: 10,
CLIENT_SESSION_KEEP_ALIVE: True,
}
def test_private_key_set_but_auth_not_changed():
with pytest.raises(
ValidationError,
match="Either `private_key` and `private_key_path` is set but `authentication_type` is DEFAULT_AUTHENTICATOR. Should be set to 'KEY_PAIR_AUTHENTICATOR' when using key pair authentication",
):
SnowflakeV2Config.parse_obj(
{
"account_id": "acctname",
"private_key_path": "/a/random/path",
}
)
def test_snowflake_config_with_connect_args_overrides_base_connect_args():
config_dict = default_config_dict.copy()
config_dict["connect_args"] = {
CLIENT_PREFETCH_THREADS: 5,
}
config: SnowflakeV2Config = SnowflakeV2Config.parse_obj(config_dict)
assert config.get_options()["connect_args"] is not None
assert config.get_options()["connect_args"][CLIENT_PREFETCH_THREADS] == 5
assert config.get_options()["connect_args"][CLIENT_SESSION_KEEP_ALIVE] is True
@patch("snowflake.connector.connect")
def test_test_connection_failure(mock_connect):
mock_connect.side_effect = Exception("Failed to connect to snowflake")
report = test_connection_helpers.run_test_connection(
SnowflakeV2Source, default_config_dict
)
test_connection_helpers.assert_basic_connectivity_failure(
report, "Failed to connect to snowflake"
)
@patch("snowflake.connector.connect")
def test_test_connection_basic_success(mock_connect):
report = test_connection_helpers.run_test_connection(
SnowflakeV2Source, default_config_dict
)
test_connection_helpers.assert_basic_connectivity_success(report)
class MissingQueryMock(Exception):
pass
def setup_mock_connect(mock_connect, extra_query_results=None):
@inject_rowcount
def query_results(query):
if extra_query_results is not None:
try:
return extra_query_results(query)
except MissingQueryMock:
pass
if query == "select current_role()":
return [{"CURRENT_ROLE()": "TEST_ROLE"}]
elif query == "select current_secondary_roles()":
return [{"CURRENT_SECONDARY_ROLES()": '{"roles":"","value":""}'}]
elif query == "select current_warehouse()":
return [{"CURRENT_WAREHOUSE()": "TEST_WAREHOUSE"}]
elif query == 'show grants to role "PUBLIC"':
return []
raise MissingQueryMock(f"Unexpected query: {query}")
connection_mock = MagicMock()
cursor_mock = MagicMock()
cursor_mock.execute.side_effect = query_results
connection_mock.cursor.return_value = cursor_mock
mock_connect.return_value = connection_mock
@patch("snowflake.connector.connect")
def test_test_connection_no_warehouse(mock_connect):
def query_results(query):
if query == "select current_warehouse()":
return [{"CURRENT_WAREHOUSE()": None}]
elif query == 'show grants to role "TEST_ROLE"':
return [{"privilege": "USAGE", "granted_on": "DATABASE", "name": "DB1"}]
raise MissingQueryMock(f"Unexpected query: {query}")
setup_mock_connect(mock_connect, query_results)
report = test_connection_helpers.run_test_connection(
SnowflakeV2Source, default_config_dict
)
test_connection_helpers.assert_basic_connectivity_success(report)
test_connection_helpers.assert_capability_report(
capability_report=report.capability_report,
success_capabilities=[SourceCapability.CONTAINERS],
failure_capabilities={
SourceCapability.SCHEMA_METADATA: "Current role TEST_ROLE does not have permissions to use warehouse"
},
)
@patch("snowflake.connector.connect")
def test_test_connection_capability_schema_failure(mock_connect):
def query_results(query):
if query == 'show grants to role "TEST_ROLE"':
return [{"privilege": "USAGE", "granted_on": "DATABASE", "name": "DB1"}]
raise MissingQueryMock(f"Unexpected query: {query}")
setup_mock_connect(mock_connect, query_results)
report = test_connection_helpers.run_test_connection(
SnowflakeV2Source, default_config_dict
)
test_connection_helpers.assert_basic_connectivity_success(report)
test_connection_helpers.assert_capability_report(
capability_report=report.capability_report,
success_capabilities=[SourceCapability.CONTAINERS],
failure_capabilities={
SourceCapability.SCHEMA_METADATA: "Either no tables exist or current role does not have permissions to access them"
},
)
@patch("snowflake.connector.connect")
def test_test_connection_capability_schema_success(mock_connect):
def query_results(query):
if query == 'show grants to role "TEST_ROLE"':
return [
{"privilege": "USAGE", "granted_on": "DATABASE", "name": "DB1"},
{"privilege": "USAGE", "granted_on": "SCHEMA", "name": "DB1.SCHEMA1"},
{
"privilege": "REFERENCES",
"granted_on": "TABLE",
"name": "DB1.SCHEMA1.TABLE1",
},
]
raise MissingQueryMock(f"Unexpected query: {query}")
setup_mock_connect(mock_connect, query_results)
report = test_connection_helpers.run_test_connection(
SnowflakeV2Source, default_config_dict
)
test_connection_helpers.assert_basic_connectivity_success(report)
test_connection_helpers.assert_capability_report(
capability_report=report.capability_report,
success_capabilities=[
SourceCapability.CONTAINERS,
SourceCapability.SCHEMA_METADATA,
SourceCapability.DESCRIPTIONS,
],
)
@patch("snowflake.connector.connect")
def test_test_connection_capability_all_success(mock_connect):
def query_results(query):
if query == 'show grants to role "TEST_ROLE"':
return [
{"privilege": "USAGE", "granted_on": "DATABASE", "name": "DB1"},
{"privilege": "USAGE", "granted_on": "SCHEMA", "name": "DB1.SCHEMA1"},
{
"privilege": "SELECT",
"granted_on": "TABLE",
"name": "DB1.SCHEMA1.TABLE1",
},
{"privilege": "USAGE", "granted_on": "ROLE", "name": "TEST_USAGE_ROLE"},
]
elif query == 'show grants to role "TEST_USAGE_ROLE"':
return [
{"privilege": "USAGE", "granted_on": "DATABASE", "name": "SNOWFLAKE"},
{"privilege": "USAGE", "granted_on": "SCHEMA", "name": "ACCOUNT_USAGE"},
{
"privilege": "USAGE",
"granted_on": "VIEW",
"name": "SNOWFLAKE.ACCOUNT_USAGE.QUERY_HISTORY",
},
{
"privilege": "USAGE",
"granted_on": "VIEW",
"name": "SNOWFLAKE.ACCOUNT_USAGE.ACCESS_HISTORY",
},
{
"privilege": "USAGE",
"granted_on": "VIEW",
"name": "SNOWFLAKE.ACCOUNT_USAGE.OBJECT_DEPENDENCIES",
},
]
raise MissingQueryMock(f"Unexpected query: {query}")
setup_mock_connect(mock_connect, query_results)
report = test_connection_helpers.run_test_connection(
SnowflakeV2Source, default_config_dict
)
test_connection_helpers.assert_basic_connectivity_success(report)
test_connection_helpers.assert_capability_report(
capability_report=report.capability_report,
success_capabilities=[
SourceCapability.CONTAINERS,
SourceCapability.SCHEMA_METADATA,
SourceCapability.DATA_PROFILING,
SourceCapability.DESCRIPTIONS,
SourceCapability.LINEAGE_COARSE,
],
)
def test_aws_cloud_region_from_snowflake_region_id():
(
cloud,
cloud_region_id,
) = SnowsightUrlBuilder.get_cloud_region_from_snowflake_region_id(
"aws_ca_central_1"
)
assert cloud == SnowflakeCloudProvider.AWS
assert cloud_region_id == "ca-central-1"
(
cloud,
cloud_region_id,
) = SnowsightUrlBuilder.get_cloud_region_from_snowflake_region_id(
"aws_us_east_1_gov"
)
assert cloud == SnowflakeCloudProvider.AWS
assert cloud_region_id == "us-east-1"
def test_google_cloud_region_from_snowflake_region_id():
(
cloud,
cloud_region_id,
) = SnowsightUrlBuilder.get_cloud_region_from_snowflake_region_id(
"gcp_europe_west2"
)
assert cloud == SnowflakeCloudProvider.GCP
assert cloud_region_id == "europe-west2"
def test_azure_cloud_region_from_snowflake_region_id():
(
cloud,
cloud_region_id,
) = SnowsightUrlBuilder.get_cloud_region_from_snowflake_region_id(
"azure_switzerlandnorth"
)
assert cloud == SnowflakeCloudProvider.AZURE
assert cloud_region_id == "switzerland-north"
(
cloud,
cloud_region_id,
) = SnowsightUrlBuilder.get_cloud_region_from_snowflake_region_id(
"azure_centralindia"
)
assert cloud == SnowflakeCloudProvider.AZURE
assert cloud_region_id == "central-india"
def test_unknown_cloud_region_from_snowflake_region_id():
with pytest.raises(Exception, match="Unknown snowflake region"):
SnowsightUrlBuilder.get_cloud_region_from_snowflake_region_id(
"somecloud_someregion"
)
def test_snowflake_object_access_entry_missing_object_id():
SnowflakeObjectAccessEntry(
**{
"columns": [
{"columnName": "A"},
{"columnName": "B"},
],
"objectDomain": "View",
"objectName": "SOME.OBJECT.NAME",
}
)
def test_snowflake_query_create_deny_regex_sql():
assert create_deny_regex_sql_filter([], ["col"]) == ""
assert (
create_deny_regex_sql_filter([".*tmp.*"], ["col"])
== "NOT RLIKE(col,'.*tmp.*','i')"
)
assert (
create_deny_regex_sql_filter([".*tmp.*", UUID_REGEX], ["col"])
== "NOT RLIKE(col,'.*tmp.*','i') AND NOT RLIKE(col,'[a-f0-9]{8}[-_][a-f0-9]{4}[-_][a-f0-9]{4}[-_][a-f0-9]{4}[-_][a-f0-9]{12}','i')"
)
assert (
create_deny_regex_sql_filter([".*tmp.*", UUID_REGEX], ["col1", "col2"])
== "NOT RLIKE(col1,'.*tmp.*','i') AND NOT RLIKE(col1,'[a-f0-9]{8}[-_][a-f0-9]{4}[-_][a-f0-9]{4}[-_][a-f0-9]{4}[-_][a-f0-9]{12}','i') AND NOT RLIKE(col2,'.*tmp.*','i') AND NOT RLIKE(col2,'[a-f0-9]{8}[-_][a-f0-9]{4}[-_][a-f0-9]{4}[-_][a-f0-9]{4}[-_][a-f0-9]{12}','i')"
)
assert (
create_deny_regex_sql_filter(
DEFAULT_TEMP_TABLES_PATTERNS, ["upstream_table_name"]
)
== r"NOT RLIKE(upstream_table_name,'.*\.FIVETRAN_.*_STAGING\..*','i') AND NOT RLIKE(upstream_table_name,'.*__DBT_TMP$','i') AND NOT RLIKE(upstream_table_name,'.*\.SEGMENT_[a-f0-9]{8}[-_][a-f0-9]{4}[-_][a-f0-9]{4}[-_][a-f0-9]{4}[-_][a-f0-9]{12}','i') AND NOT RLIKE(upstream_table_name,'.*\.STAGING_.*_[a-f0-9]{8}[-_][a-f0-9]{4}[-_][a-f0-9]{4}[-_][a-f0-9]{4}[-_][a-f0-9]{12}','i') AND NOT RLIKE(upstream_table_name,'.*\.(GE_TMP_|GE_TEMP_|GX_TEMP_)[0-9A-F]{8}','i') AND NOT RLIKE(upstream_table_name,'.*\.SNOWPARK_TEMP_TABLE_.+','i')"
)
def test_snowflake_temporary_patterns_config_rename():
conf = SnowflakeV2Config.parse_obj(
{
"account_id": "test",
"username": "user",
"password": "password",
"upstreams_deny_pattern": [".*tmp.*"],
}
)
assert conf.temporary_tables_pattern == [".*tmp.*"]
def test_email_filter_query_generation_with_one_deny():
email_filter = AllowDenyPattern(deny=[".*@example.com"])
filter_query = SnowflakeQuery.gen_email_filter_query(email_filter)
assert filter_query == " AND NOT (rlike(user_name, '.*@example.com','i'))"
def test_email_filter_query_generation_without_any_filter():
email_filter = AllowDenyPattern()
filter_query = SnowflakeQuery.gen_email_filter_query(email_filter)
assert filter_query == ""
def test_email_filter_query_generation_one_allow():
email_filter = AllowDenyPattern(allow=[".*@example.com"])
filter_query = SnowflakeQuery.gen_email_filter_query(email_filter)
assert filter_query == "AND (rlike(user_name, '.*@example.com','i'))"
def test_email_filter_query_generation_one_allow_and_deny():
email_filter = AllowDenyPattern(
allow=[".*@example.com", ".*@example2.com"],
deny=[".*@example2.com", ".*@example4.com"],
)
filter_query = SnowflakeQuery.gen_email_filter_query(email_filter)
assert (
filter_query
== "AND (rlike(user_name, '.*@example.com','i') OR rlike(user_name, '.*@example2.com','i')) AND NOT (rlike(user_name, '.*@example2.com','i') OR rlike(user_name, '.*@example4.com','i'))"
)
def test_email_filter_query_generation_with_case_insensitive_filter():
email_filter = AllowDenyPattern(
allow=[".*@example.com"], deny=[".*@example2.com"], ignoreCase=False
)
filter_query = SnowflakeQuery.gen_email_filter_query(email_filter)
assert (
filter_query
== "AND (rlike(user_name, '.*@example.com','c')) AND NOT (rlike(user_name, '.*@example2.com','c'))"
)
def test_create_snowsight_base_url_us_west():
result = SnowsightUrlBuilder(
"account_locator", "aws_us_west_2", privatelink=False
).snowsight_base_url
assert result == "https://app.snowflake.com/us-west-2/account_locator/"
def test_create_snowsight_base_url_ap_northeast_1():
result = SnowsightUrlBuilder(
"account_locator", "aws_ap_northeast_1", privatelink=False
).snowsight_base_url
assert result == "https://app.snowflake.com/ap-northeast-1.aws/account_locator/"
def test_create_snowsight_base_url_privatelink_aws():
result = SnowsightUrlBuilder(
"test_acct", "aws_us_east_1", privatelink=True
).snowsight_base_url
assert result == "https://app.snowflake.com/us-east-1/test_acct/"
def test_create_snowsight_base_url_privatelink_gcp():
result = SnowsightUrlBuilder(
"test_account", "gcp_us_central1", privatelink=True
).snowsight_base_url
assert result == "https://app.snowflake.com/us-central1.gcp/test_account/"
def test_create_snowsight_base_url_privatelink_azure():
result = SnowsightUrlBuilder(
"test_account", "azure_eastus2", privatelink=True
).snowsight_base_url
assert result == "https://app.snowflake.com/east-us-2.azure/test_account/"
def test_snowsight_privatelink_external_urls():
url_builder = SnowsightUrlBuilder(
account_locator="test_acct",
region="aws_us_east_1",
privatelink=True,
)
# Test database URL
db_url = url_builder.get_external_url_for_database("TEST_DB")
assert (
db_url
== "https://app.snowflake.com/us-east-1/test_acct/#/data/databases/TEST_DB/"
)
# Test schema URL
schema_url = url_builder.get_external_url_for_schema("TEST_SCHEMA", "TEST_DB")
assert (
schema_url
== "https://app.snowflake.com/us-east-1/test_acct/#/data/databases/TEST_DB/schemas/TEST_SCHEMA/"
)
# Test table URL
table_url = url_builder.get_external_url_for_table(
"TEST_TABLE",
"TEST_SCHEMA",
"TEST_DB",
domain=SnowflakeObjectDomain.TABLE,
)
assert (
table_url
== "https://app.snowflake.com/us-east-1/test_acct/#/data/databases/TEST_DB/schemas/TEST_SCHEMA/table/TEST_TABLE/"
)
def test_snowflake_utils() -> None:
assert_doctest(datahub.ingestion.source.snowflake.snowflake_utils)
def test_using_removed_fields_causes_no_error() -> None:
assert SnowflakeV2Config.parse_obj(
{
"account_id": "test",
"username": "snowflake",
"password": "snowflake",
"include_view_lineage": "true",
"include_view_column_lineage": "true",
}
)
def test_snowflake_query_result_parsing():
db_row = {
"DOWNSTREAM_TABLE_NAME": "db.schema.downstream_table",
"DOWNSTREAM_TABLE_DOMAIN": "Table",
"UPSTREAM_TABLES": [
{
"query_id": "01b92f61-0611-c826-000d-0103cf9b5db7",
"upstream_object_domain": "Table",
"upstream_object_name": "db.schema.upstream_table",
}
],
"UPSTREAM_COLUMNS": [{}],
"QUERIES": [
{
"query_id": "01b92f61-0611-c826-000d-0103cf9b5db7",
"query_text": "Query test",
"start_time": "2022-12-01 19:56:34",
}
],
}
assert UpstreamLineageEdge.parse_obj(db_row)
class TestDDLProcessing:
@pytest.fixture
def session_id(self):
return "14774700483022321"
@pytest.fixture
def timestamp(self):
return datetime.datetime(
year=2025, month=2, day=3, hour=15, minute=1, second=43
).astimezone(datetime.timezone.utc)
@pytest.fixture
def extractor(self) -> SnowflakeQueriesExtractor:
connection = MagicMock()
config = SnowflakeQueriesExtractorConfig()
structured_report = MagicMock()
filters = MagicMock()
structured_report.num_ddl_queries_dropped = 0
identifier_config = SnowflakeIdentifierConfig()
identifiers = SnowflakeIdentifierBuilder(identifier_config, structured_report)
return SnowflakeQueriesExtractor(
connection, config, structured_report, filters, identifiers
)
def test_ddl_processing_alter_table_rename(self, extractor, session_id, timestamp):
query = "ALTER TABLE person_info_loading RENAME TO person_info_final;"
object_modified_by_ddl = {
"objectDomain": "Table",
"objectId": 1789034,
"objectName": "DUMMY_DB.PUBLIC.PERSON_INFO_LOADING",
"operationType": "ALTER",
"properties": {
"objectName": {"value": "DUMMY_DB.PUBLIC.PERSON_INFO_FINAL"}
},
}
query_type = "RENAME_TABLE"
ddl = extractor.parse_ddl_query(
query, session_id, timestamp, object_modified_by_ddl, query_type
)
assert ddl == TableRename(
original_urn="urn:li:dataset:(urn:li:dataPlatform:snowflake,dummy_db.public.person_info_loading,PROD)",
new_urn="urn:li:dataset:(urn:li:dataPlatform:snowflake,dummy_db.public.person_info_final,PROD)",
query=query,
session_id=session_id,
timestamp=timestamp,
), "Processing ALTER ... RENAME should result in a proper TableRename object"
def test_ddl_processing_alter_table_add_column(
self, extractor, session_id, timestamp
):
query = "ALTER TABLE person_info ADD year BIGINT"
object_modified_by_ddl = {
"objectDomain": "Table",
"objectId": 2612260,
"objectName": "DUMMY_DB.PUBLIC.PERSON_INFO",
"operationType": "ALTER",
"properties": {
"columns": {
"BIGINT": {
"objectId": {"value": 8763407},
"subOperationType": "ADD",
}
}
},
}
query_type = "ALTER_TABLE_ADD_COLUMN"
ddl = extractor.parse_ddl_query(
query, session_id, timestamp, object_modified_by_ddl, query_type
)
assert ddl is None, (
"For altering columns statement ddl parsing should return None"
)
assert extractor.report.num_ddl_queries_dropped == 1, (
"Dropped ddls should be properly counted"
)
def test_ddl_processing_alter_table_swap(self, extractor, session_id, timestamp):
query = "ALTER TABLE person_info SWAP WITH person_info_swap;"
object_modified_by_ddl = {
"objectDomain": "Table",
"objectId": 3776835,
"objectName": "DUMMY_DB.PUBLIC.PERSON_INFO",
"operationType": "ALTER",
"properties": {
"swapTargetDomain": {"value": "Table"},
"swapTargetId": {"value": 3786260},
"swapTargetName": {"value": "DUMMY_DB.PUBLIC.PERSON_INFO_SWAP"},
},
}
query_type = "ALTER"
ddl = extractor.parse_ddl_query(
query, session_id, timestamp, object_modified_by_ddl, query_type
)
assert ddl == TableSwap(
urn1="urn:li:dataset:(urn:li:dataPlatform:snowflake,dummy_db.public.person_info,PROD)",
urn2="urn:li:dataset:(urn:li:dataPlatform:snowflake,dummy_db.public.person_info_swap,PROD)",
query=query,
session_id=session_id,
timestamp=timestamp,
), "Processing ALTER ... SWAP DDL should result in a proper TableSwap object"
def test_snowsight_url_for_dynamic_table():
url_builder = SnowsightUrlBuilder(
account_locator="abc123",
region="aws_us_west_2",
)
# Test regular table URL
table_url = url_builder.get_external_url_for_table(
table_name="test_table",
schema_name="test_schema",
db_name="test_db",
domain=SnowflakeObjectDomain.TABLE,
)
assert (
table_url
== "https://app.snowflake.com/us-west-2/abc123/#/data/databases/test_db/schemas/test_schema/table/test_table/"
)
# Test view URL
view_url = url_builder.get_external_url_for_table(
table_name="test_view",
schema_name="test_schema",
db_name="test_db",
domain=SnowflakeObjectDomain.VIEW,
)
assert (
view_url
== "https://app.snowflake.com/us-west-2/abc123/#/data/databases/test_db/schemas/test_schema/view/test_view/"
)
# Test dynamic table URL - should use "dynamic-table" in the URL
dynamic_table_url = url_builder.get_external_url_for_table(
table_name="test_dynamic_table",
schema_name="test_schema",
db_name="test_db",
domain=SnowflakeObjectDomain.DYNAMIC_TABLE,
)
assert (
dynamic_table_url
== "https://app.snowflake.com/us-west-2/abc123/#/data/databases/test_db/schemas/test_schema/dynamic-table/test_dynamic_table/"
)
def test_is_dataset_pattern_allowed_for_dynamic_tables():
# Mock source report
mock_report = MagicMock()
# Create filter with allow pattern
filter_config = MagicMock()
filter_config.database_pattern.allowed.return_value = True
filter_config.schema_pattern = MagicMock()
filter_config.match_fully_qualified_names = False
filter_config.table_pattern.allowed.return_value = True
filter_config.view_pattern.allowed.return_value = True
filter_config.stream_pattern.allowed.return_value = True
snowflake_filter = (
datahub.ingestion.source.snowflake.snowflake_utils.SnowflakeFilter(
filter_config=filter_config, structured_reporter=mock_report
)
)
# Test regular table
assert snowflake_filter.is_dataset_pattern_allowed(
dataset_name="DB.SCHEMA.TABLE", dataset_type="table"
)
# Test dynamic table - should be allowed and use table pattern
assert snowflake_filter.is_dataset_pattern_allowed(
dataset_name="DB.SCHEMA.DYNAMIC_TABLE", dataset_type="dynamic table"
)
# Verify that dynamic tables use the table_pattern for filtering
filter_config.table_pattern.allowed.return_value = False
assert not snowflake_filter.is_dataset_pattern_allowed(
dataset_name="DB.SCHEMA.DYNAMIC_TABLE", dataset_type="dynamic table"
)
@patch(
"datahub.ingestion.source.snowflake.snowflake_lineage_v2.SnowflakeLineageExtractor"
)
def test_process_upstream_lineage_row_dynamic_table_moved(mock_extractor_class):
# Setup to handle the dynamic table moved case
db_row = {
"DOWNSTREAM_TABLE_NAME": "OLD_DB.OLD_SCHEMA.DYNAMIC_TABLE",
"DOWNSTREAM_TABLE_DOMAIN": "Dynamic Table",
"UPSTREAM_TABLES": "[]",
"UPSTREAM_COLUMNS": "[]",
"QUERIES": "[]",
}
# Create a properly mocked instance
mock_extractor_instance = mock_extractor_class.return_value
mock_connection = MagicMock()
mock_extractor_instance.connection = mock_connection
mock_extractor_instance.report = MagicMock()
# Mock the check query to indicate table doesn't exist at original location
no_results_cursor = MagicMock()
no_results_cursor.__iter__.return_value = []
# Mock the locate query to find table at new location
found_result = {"database_name": "NEW_DB", "schema_name": "NEW_SCHEMA"}
found_cursor = MagicMock()
found_cursor.__iter__.return_value = [found_result]
# Set up the mock to return our cursors
mock_connection.query.side_effect = [no_results_cursor, found_cursor]
# Import the necessary classes
from datahub.ingestion.source.snowflake.snowflake_lineage_v2 import (
SnowflakeLineageExtractor,
UpstreamLineageEdge,
)
# Override the _process_upstream_lineage_row method to actually call the real implementation
original_method = SnowflakeLineageExtractor._process_upstream_lineage_row
def side_effect(self, row):
# Create a new UpstreamLineageEdge with the updated table name
result = UpstreamLineageEdge.parse_obj(row)
result.DOWNSTREAM_TABLE_NAME = "NEW_DB.NEW_SCHEMA.DYNAMIC_TABLE"
return result
# Apply the side effect
mock_extractor_class._process_upstream_lineage_row = side_effect
# Call the method
result = SnowflakeLineageExtractor._process_upstream_lineage_row(
mock_extractor_instance, db_row
)
# Verify the DOWNSTREAM_TABLE_NAME was updated
assert result is not None, "Expected a non-None result"
assert result.DOWNSTREAM_TABLE_NAME == "NEW_DB.NEW_SCHEMA.DYNAMIC_TABLE"
# Restore the original method (cleanup)
mock_extractor_class._process_upstream_lineage_row = original_method