datahub/metadata-ingestion/tests/unit/snowflake/test_snowflake_queries.py
Sergio Gómez Villamor b3fafc38be
feat(snowflake): ingest views from information schema (#14444)
Co-authored-by: Claude <noreply@anthropic.com>
2025-08-22 11:13:04 +02:00

392 lines
20 KiB
Python

import datetime
import pytest
import sqlglot
from sqlglot.dialects.snowflake import Snowflake
from datahub.configuration.common import AllowDenyPattern
from datahub.configuration.time_window_config import BucketDuration
from datahub.ingestion.source.snowflake.snowflake_config import QueryDedupStrategyType
from datahub.ingestion.source.snowflake.snowflake_queries import QueryLogQueryBuilder
from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery
class TestBuildAccessHistoryDatabaseFilterCondition:
@pytest.mark.parametrize(
"database_pattern,additional_database_names,expected",
[
pytest.param(
None,
None,
"TRUE",
id="no_pattern_no_additional_dbs",
),
pytest.param(
AllowDenyPattern(),
None,
"TRUE",
id="empty_pattern_no_additional_dbs",
),
pytest.param(
None,
[],
"TRUE",
id="no_pattern_empty_additional_dbs",
),
pytest.param(
AllowDenyPattern(allow=["PROD_.*"]),
None,
"(ARRAY_SIZE(FILTER(direct_objects_accessed, o -> SPLIT_PART(UPPER(o:objectName), '.', 1) RLIKE 'PROD_.*')) > 0 OR ARRAY_SIZE(FILTER(objects_modified, o -> SPLIT_PART(UPPER(o:objectName), '.', 1) RLIKE 'PROD_.*')) > 0 OR (SPLIT_PART(UPPER(object_modified_by_ddl:objectName), '.', 1) RLIKE 'PROD_.*'))",
id="allow_pattern_only",
),
pytest.param(
AllowDenyPattern(deny=[".*_TEMP"]),
None,
"(ARRAY_SIZE(FILTER(direct_objects_accessed, o -> SPLIT_PART(UPPER(o:objectName), '.', 1) NOT RLIKE '.*_TEMP')) > 0 OR ARRAY_SIZE(FILTER(objects_modified, o -> SPLIT_PART(UPPER(o:objectName), '.', 1) NOT RLIKE '.*_TEMP')) > 0 OR (SPLIT_PART(UPPER(object_modified_by_ddl:objectName), '.', 1) NOT RLIKE '.*_TEMP'))",
id="deny_pattern_only",
),
pytest.param(
AllowDenyPattern(allow=["PROD_.*"], deny=[".*_TEMP"]),
None,
"(ARRAY_SIZE(FILTER(direct_objects_accessed, o -> (SPLIT_PART(UPPER(o:objectName), '.', 1) RLIKE 'PROD_.*' AND SPLIT_PART(UPPER(o:objectName), '.', 1) NOT RLIKE '.*_TEMP'))) > 0 OR ARRAY_SIZE(FILTER(objects_modified, o -> (SPLIT_PART(UPPER(o:objectName), '.', 1) RLIKE 'PROD_.*' AND SPLIT_PART(UPPER(o:objectName), '.', 1) NOT RLIKE '.*_TEMP'))) > 0 OR ((SPLIT_PART(UPPER(object_modified_by_ddl:objectName), '.', 1) RLIKE 'PROD_.*' AND SPLIT_PART(UPPER(object_modified_by_ddl:objectName), '.', 1) NOT RLIKE '.*_TEMP')))",
id="allow_and_deny_patterns",
),
pytest.param(
AllowDenyPattern(allow=["PROD_.*", "DEV_.*"]),
None,
"(ARRAY_SIZE(FILTER(direct_objects_accessed, o -> (SPLIT_PART(UPPER(o:objectName), '.', 1) RLIKE 'PROD_.*' OR SPLIT_PART(UPPER(o:objectName), '.', 1) RLIKE 'DEV_.*'))) > 0 OR ARRAY_SIZE(FILTER(objects_modified, o -> (SPLIT_PART(UPPER(o:objectName), '.', 1) RLIKE 'PROD_.*' OR SPLIT_PART(UPPER(o:objectName), '.', 1) RLIKE 'DEV_.*'))) > 0 OR ((SPLIT_PART(UPPER(object_modified_by_ddl:objectName), '.', 1) RLIKE 'PROD_.*' OR SPLIT_PART(UPPER(object_modified_by_ddl:objectName), '.', 1) RLIKE 'DEV_.*')))",
id="multiple_allow_patterns",
),
pytest.param(
AllowDenyPattern(deny=[".*_TEMP", ".*_STAGING"]),
None,
"(ARRAY_SIZE(FILTER(direct_objects_accessed, o -> ((SPLIT_PART(UPPER(o:objectName), '.', 1) NOT RLIKE '.*_TEMP' AND SPLIT_PART(UPPER(o:objectName), '.', 1) NOT RLIKE '.*_STAGING')))) > 0 OR ARRAY_SIZE(FILTER(objects_modified, o -> ((SPLIT_PART(UPPER(o:objectName), '.', 1) NOT RLIKE '.*_TEMP' AND SPLIT_PART(UPPER(o:objectName), '.', 1) NOT RLIKE '.*_STAGING')))) > 0 OR (((SPLIT_PART(UPPER(object_modified_by_ddl:objectName), '.', 1) NOT RLIKE '.*_TEMP' AND SPLIT_PART(UPPER(object_modified_by_ddl:objectName), '.', 1) NOT RLIKE '.*_STAGING'))))",
id="multiple_deny_patterns",
),
pytest.param(
None,
["DB1", "DB2"],
"(ARRAY_SIZE(FILTER(direct_objects_accessed, o -> (SPLIT_PART(UPPER(o:objectName), '.', 1) = 'DB1' OR SPLIT_PART(UPPER(o:objectName), '.', 1) = 'DB2'))) > 0 OR ARRAY_SIZE(FILTER(objects_modified, o -> (SPLIT_PART(UPPER(o:objectName), '.', 1) = 'DB1' OR SPLIT_PART(UPPER(o:objectName), '.', 1) = 'DB2'))) > 0 OR ((SPLIT_PART(UPPER(object_modified_by_ddl:objectName), '.', 1) = 'DB1' OR SPLIT_PART(UPPER(object_modified_by_ddl:objectName), '.', 1) = 'DB2')))",
id="multiple_additional_database_names",
),
pytest.param(
None,
["SPECIAL_DB"],
"(ARRAY_SIZE(FILTER(direct_objects_accessed, o -> SPLIT_PART(UPPER(o:objectName), '.', 1) = 'SPECIAL_DB')) > 0 OR ARRAY_SIZE(FILTER(objects_modified, o -> SPLIT_PART(UPPER(o:objectName), '.', 1) = 'SPECIAL_DB')) > 0 OR (SPLIT_PART(UPPER(object_modified_by_ddl:objectName), '.', 1) = 'SPECIAL_DB'))",
id="single_additional_database_name",
),
pytest.param(
AllowDenyPattern(allow=["PROD_.*"]),
["DB1", "DB2"],
"(ARRAY_SIZE(FILTER(direct_objects_accessed, o -> (SPLIT_PART(UPPER(o:objectName), '.', 1) = 'DB1' OR SPLIT_PART(UPPER(o:objectName), '.', 1) = 'DB2') OR SPLIT_PART(UPPER(o:objectName), '.', 1) RLIKE 'PROD_.*')) > 0 OR ARRAY_SIZE(FILTER(objects_modified, o -> (SPLIT_PART(UPPER(o:objectName), '.', 1) = 'DB1' OR SPLIT_PART(UPPER(o:objectName), '.', 1) = 'DB2') OR SPLIT_PART(UPPER(o:objectName), '.', 1) RLIKE 'PROD_.*')) > 0 OR ((SPLIT_PART(UPPER(object_modified_by_ddl:objectName), '.', 1) = 'DB1' OR SPLIT_PART(UPPER(object_modified_by_ddl:objectName), '.', 1) = 'DB2') OR SPLIT_PART(UPPER(object_modified_by_ddl:objectName), '.', 1) RLIKE 'PROD_.*'))",
id="pattern_with_additional_database_names",
),
pytest.param(
AllowDenyPattern(allow=["PROD_.*", "DEV_.*"], deny=[".*_TEMP"]),
["SPECIAL_DB"],
"(ARRAY_SIZE(FILTER(direct_objects_accessed, o -> SPLIT_PART(UPPER(o:objectName), '.', 1) = 'SPECIAL_DB' OR ((SPLIT_PART(UPPER(o:objectName), '.', 1) RLIKE 'PROD_.*' OR SPLIT_PART(UPPER(o:objectName), '.', 1) RLIKE 'DEV_.*') AND SPLIT_PART(UPPER(o:objectName), '.', 1) NOT RLIKE '.*_TEMP'))) > 0 OR ARRAY_SIZE(FILTER(objects_modified, o -> SPLIT_PART(UPPER(o:objectName), '.', 1) = 'SPECIAL_DB' OR ((SPLIT_PART(UPPER(o:objectName), '.', 1) RLIKE 'PROD_.*' OR SPLIT_PART(UPPER(o:objectName), '.', 1) RLIKE 'DEV_.*') AND SPLIT_PART(UPPER(o:objectName), '.', 1) NOT RLIKE '.*_TEMP'))) > 0 OR (SPLIT_PART(UPPER(object_modified_by_ddl:objectName), '.', 1) = 'SPECIAL_DB' OR ((SPLIT_PART(UPPER(object_modified_by_ddl:objectName), '.', 1) RLIKE 'PROD_.*' OR SPLIT_PART(UPPER(object_modified_by_ddl:objectName), '.', 1) RLIKE 'DEV_.*') AND SPLIT_PART(UPPER(object_modified_by_ddl:objectName), '.', 1) NOT RLIKE '.*_TEMP')))",
id="complex_pattern_with_additional_database_names",
),
pytest.param(
AllowDenyPattern(allow=[".*"]),
None,
"TRUE",
id="default_allow_pattern_ignored",
),
pytest.param(
AllowDenyPattern(allow=[".*"], deny=[".*_TEMP"]),
None,
"(ARRAY_SIZE(FILTER(direct_objects_accessed, o -> SPLIT_PART(UPPER(o:objectName), '.', 1) NOT RLIKE '.*_TEMP')) > 0 OR ARRAY_SIZE(FILTER(objects_modified, o -> SPLIT_PART(UPPER(o:objectName), '.', 1) NOT RLIKE '.*_TEMP')) > 0 OR (SPLIT_PART(UPPER(object_modified_by_ddl:objectName), '.', 1) NOT RLIKE '.*_TEMP'))",
id="default_allow_pattern_with_deny_pattern",
),
pytest.param(
AllowDenyPattern(allow=["PROD'_.*"]),
None,
"(ARRAY_SIZE(FILTER(direct_objects_accessed, o -> SPLIT_PART(UPPER(o:objectName), '.', 1) RLIKE 'PROD''_.*')) > 0 OR ARRAY_SIZE(FILTER(objects_modified, o -> SPLIT_PART(UPPER(o:objectName), '.', 1) RLIKE 'PROD''_.*')) > 0 OR (SPLIT_PART(UPPER(object_modified_by_ddl:objectName), '.', 1) RLIKE 'PROD''_.*'))",
id="sql_injection_protection_allow_pattern",
),
pytest.param(
AllowDenyPattern(deny=[".*'_TEMP"]),
None,
"(ARRAY_SIZE(FILTER(direct_objects_accessed, o -> SPLIT_PART(UPPER(o:objectName), '.', 1) NOT RLIKE '.*''_TEMP')) > 0 OR ARRAY_SIZE(FILTER(objects_modified, o -> SPLIT_PART(UPPER(o:objectName), '.', 1) NOT RLIKE '.*''_TEMP')) > 0 OR (SPLIT_PART(UPPER(object_modified_by_ddl:objectName), '.', 1) NOT RLIKE '.*''_TEMP'))",
id="sql_injection_protection_deny_pattern",
),
pytest.param(
None,
["DB'1", "DB'2"],
"(ARRAY_SIZE(FILTER(direct_objects_accessed, o -> (SPLIT_PART(UPPER(o:objectName), '.', 1) = 'DB''1' OR SPLIT_PART(UPPER(o:objectName), '.', 1) = 'DB''2'))) > 0 OR ARRAY_SIZE(FILTER(objects_modified, o -> (SPLIT_PART(UPPER(o:objectName), '.', 1) = 'DB''1' OR SPLIT_PART(UPPER(o:objectName), '.', 1) = 'DB''2'))) > 0 OR ((SPLIT_PART(UPPER(object_modified_by_ddl:objectName), '.', 1) = 'DB''1' OR SPLIT_PART(UPPER(object_modified_by_ddl:objectName), '.', 1) = 'DB''2')))",
id="sql_injection_protection_additional_database_names",
),
pytest.param(
AllowDenyPattern(allow=[]),
None,
"TRUE",
id="empty_allow_pattern_list",
),
pytest.param(
AllowDenyPattern(deny=[]),
None,
"TRUE",
id="empty_deny_pattern_list",
),
pytest.param(
AllowDenyPattern(allow=[], deny=[]),
None,
"TRUE",
id="both_empty_pattern_lists",
),
],
)
def test_build_access_history_database_filter_condition(
self, database_pattern, additional_database_names, expected
):
"""Test the _build_access_history_database_filter_condition method with various inputs."""
# Create a QueryLogQueryBuilder instance to test the method
builder = QueryLogQueryBuilder(
start_time=datetime.datetime(year=2021, month=1, day=1),
end_time=datetime.datetime(year=2021, month=1, day=2),
bucket_duration=BucketDuration.HOUR,
dedup_strategy=QueryDedupStrategyType.STANDARD,
database_pattern=database_pattern,
additional_database_names=additional_database_names,
)
result = builder._build_access_history_database_filter_condition(
database_pattern, additional_database_names
)
assert result == expected
class TestQueryLogQueryBuilder:
def test_non_implemented_strategy(self):
with pytest.raises(NotImplementedError):
QueryLogQueryBuilder(
start_time=datetime.datetime(year=2021, month=1, day=1),
end_time=datetime.datetime(year=2021, month=1, day=1),
bucket_duration=BucketDuration.HOUR,
deny_usernames=None,
dedup_strategy="DUMMY", # type: ignore[arg-type]
).build_enriched_query_log_query()
def test_fetch_query_for_all_strategies(self):
for strategy in QueryDedupStrategyType:
query = QueryLogQueryBuilder(
start_time=datetime.datetime(year=2021, month=1, day=1),
end_time=datetime.datetime(year=2021, month=1, day=1),
bucket_duration=BucketDuration.HOUR,
dedup_strategy=strategy,
).build_enriched_query_log_query()
# SQL parsing should succeed
sqlglot.parse(query, dialect=Snowflake)
def test_query_with_database_pattern_filtering(self):
"""Test that database pattern filtering generates valid SQL."""
database_pattern = AllowDenyPattern(allow=["PROD_.*"], deny=[".*_TEMP"])
query = QueryLogQueryBuilder(
start_time=datetime.datetime(year=2021, month=1, day=1),
end_time=datetime.datetime(year=2021, month=1, day=2),
bucket_duration=BucketDuration.HOUR,
deny_usernames=None,
dedup_strategy=QueryDedupStrategyType.STANDARD,
database_pattern=database_pattern,
).build_enriched_query_log_query()
# SQL parsing should succeed
sqlglot.parse(query, dialect=Snowflake)
def test_query_with_additional_database_names(self):
"""Test that additional database names generate valid SQL."""
additional_database_names = ["SPECIAL_DB", "ANALYTICS_DB"]
query = QueryLogQueryBuilder(
start_time=datetime.datetime(year=2021, month=1, day=1),
end_time=datetime.datetime(year=2021, month=1, day=2),
bucket_duration=BucketDuration.HOUR,
dedup_strategy=QueryDedupStrategyType.NONE,
additional_database_names=additional_database_names,
).build_enriched_query_log_query()
# SQL parsing should succeed
sqlglot.parse(query, dialect=Snowflake)
def test_query_with_combined_database_filtering(self):
"""Test that both database patterns and additional database names generate valid SQL."""
database_pattern = AllowDenyPattern(allow=["PROD_.*"])
additional_database_names = ["SPECIAL_DB"]
query = QueryLogQueryBuilder(
start_time=datetime.datetime(year=2021, month=1, day=1),
end_time=datetime.datetime(year=2021, month=1, day=2),
bucket_duration=BucketDuration.HOUR,
deny_usernames=None,
dedup_strategy=QueryDedupStrategyType.STANDARD,
database_pattern=database_pattern,
additional_database_names=additional_database_names,
).build_enriched_query_log_query()
# SQL parsing should succeed
sqlglot.parse(query, dialect=Snowflake)
class TestBuildUserFilter:
@pytest.mark.parametrize(
"deny_usernames,allow_usernames,expected",
[
pytest.param(
None,
None,
"TRUE",
id="no_filters",
),
pytest.param(
[],
[],
"TRUE",
id="empty_lists",
),
pytest.param(
None,
[],
"TRUE",
id="none_deny_empty_allow",
),
pytest.param(
[],
None,
"TRUE",
id="empty_deny_none_allow",
),
pytest.param(
["SERVICE_USER"],
None,
"(user_name NOT ILIKE 'SERVICE_USER')",
id="single_deny_exact",
),
pytest.param(
["SERVICE_%"],
None,
"(user_name NOT ILIKE 'SERVICE_%')",
id="single_deny_pattern",
),
pytest.param(
["SERVICE_%", "ADMIN_%"],
None,
"(user_name NOT ILIKE 'SERVICE_%' AND user_name NOT ILIKE 'ADMIN_%')",
id="multiple_deny_patterns",
),
pytest.param(
None,
["ANALYST_USER"],
"(user_name ILIKE 'ANALYST_USER')",
id="single_allow_exact",
),
pytest.param(
None,
["ANALYST_%"],
"(user_name ILIKE 'ANALYST_%')",
id="single_allow_pattern",
),
pytest.param(
None,
["ANALYST_%", "%_USER"],
"(user_name ILIKE 'ANALYST_%' OR user_name ILIKE '%_USER')",
id="multiple_allow_patterns",
),
pytest.param(
["SERVICE_%"],
["ANALYST_%"],
"(user_name NOT ILIKE 'SERVICE_%') AND (user_name ILIKE 'ANALYST_%')",
id="single_deny_and_single_allow",
),
pytest.param(
["SERVICE_%", "ADMIN_%"],
["ANALYST_%", "%_USER"],
"(user_name NOT ILIKE 'SERVICE_%' AND user_name NOT ILIKE 'ADMIN_%') AND (user_name ILIKE 'ANALYST_%' OR user_name ILIKE '%_USER')",
id="multiple_deny_and_multiple_allow",
),
pytest.param(
["TEST_ANALYST_%"],
["TEST_%"],
"(user_name NOT ILIKE 'TEST_ANALYST_%') AND (user_name ILIKE 'TEST_%')",
id="overlapping_deny_and_allow_patterns",
),
pytest.param(
["'SPECIAL_USER'"],
None,
"(user_name NOT ILIKE '''SPECIAL_USER''')",
id="sql_injection_protection_deny",
),
pytest.param(
None,
["'SPECIAL_USER'"],
"(user_name ILIKE '''SPECIAL_USER''')",
id="sql_injection_protection_allow",
),
pytest.param(
["USER_O'CONNOR"],
["ANALYST_O'BRIEN"],
"(user_name NOT ILIKE 'USER_O''CONNOR') AND (user_name ILIKE 'ANALYST_O''BRIEN')",
id="sql_injection_protection_both",
),
],
)
def test_build_user_filter(self, deny_usernames, allow_usernames, expected):
"""Test the _build_user_filter method with various combinations of deny and allow patterns."""
# Create a QueryLogQueryBuilder instance to test the method
builder = QueryLogQueryBuilder(
start_time=datetime.datetime(year=2021, month=1, day=1),
end_time=datetime.datetime(year=2021, month=1, day=2),
bucket_duration=BucketDuration.HOUR,
deny_usernames=deny_usernames,
allow_usernames=allow_usernames,
dedup_strategy=QueryDedupStrategyType.STANDARD,
)
result = builder._build_user_filter(deny_usernames, allow_usernames)
assert result == expected
class TestSnowflakeViewQueries:
def test_get_views_for_database_query_syntax(self):
query = SnowflakeQuery.get_views_for_database("TEST_DB")
# Should be parseable by sqlglot
parsed = sqlglot.parse(query, dialect=Snowflake)
assert len(parsed) == 1
# Validate SQL structure
statement = parsed[0]
assert statement is not None
assert statement.find(sqlglot.exp.Select) is not None
# Check that it's selecting from information_schema.views in the correct database
from_clause = statement.find(sqlglot.exp.From)
assert from_clause is not None
table_name = str(from_clause.this).replace('"', "")
assert table_name == "TEST_DB.information_schema.views"
def test_get_views_for_schema_query_syntax(self):
query = SnowflakeQuery.get_views_for_schema("TEST_DB", "PUBLIC")
# Should be parseable by sqlglot
parsed = sqlglot.parse(query, dialect=Snowflake)
assert len(parsed) == 1
# Validate SQL structure
statement = parsed[0]
assert statement is not None
assert statement.find(sqlglot.exp.Select) is not None
# Check that it's selecting from information_schema.views in the correct database
from_clause = statement.find(sqlglot.exp.From)
assert from_clause is not None
table_name = str(from_clause.this).replace('"', "")
assert table_name == "TEST_DB.information_schema.views"
# Check that it has a WHERE clause filtering by schema
where_clause = statement.find(sqlglot.exp.Where)
assert where_clause is not None
where_str = str(where_clause).upper()
assert "TABLE_SCHEMA" in where_str and "PUBLIC" in where_str