feat(ingest/snowflake): generate lineage through temp views (#13517)

This commit is contained in:
Harshal Sheth 2025-05-16 21:27:13 -07:00 committed by GitHub
parent 0f227a364a
commit d3944ded93
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 74 additions and 30 deletions

View File

@ -127,6 +127,8 @@ class SnowflakeQueriesExtractorReport(Report):
sql_aggregator: Optional[SqlAggregatorReport] = None
num_ddl_queries_dropped: int = 0
num_stream_queries_observed: int = 0
num_create_temp_view_queries_observed: int = 0
num_users: int = 0
@ -373,6 +375,13 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable):
if entry:
yield entry
@classmethod
def _has_temp_keyword(cls, query_text: str) -> bool:
return (
re.search(r"\bTEMP\b", query_text, re.IGNORECASE) is not None
or re.search(r"\bTEMPORARY\b", query_text, re.IGNORECASE) is not None
)
def _parse_audit_log_row(
self, row: Dict[str, Any], users: UsersMapping
) -> Optional[Union[TableRename, TableSwap, PreparsedQuery, ObservedQuery]]:
@ -389,6 +398,15 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable):
key = key.lower()
res[key] = value
timestamp: datetime = res["query_start_time"]
timestamp = timestamp.astimezone(timezone.utc)
# TODO need to map snowflake query types to ours
query_text: str = res["query_text"]
query_type: QueryType = SNOWFLAKE_QUERY_TYPE_MAPPING.get(
res["query_type"], QueryType.UNKNOWN
)
direct_objects_accessed = res["direct_objects_accessed"]
objects_modified = res["objects_modified"]
object_modified_by_ddl = res["object_modified_by_ddl"]
@ -399,9 +417,9 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable):
"Error fetching ddl lineage from Snowflake"
):
known_ddl_entry = self.parse_ddl_query(
res["query_text"],
query_text,
res["session_id"],
res["query_start_time"],
timestamp,
object_modified_by_ddl,
res["query_type"],
)
@ -419,24 +437,38 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable):
)
)
# Use direct_objects_accessed instead objects_modified
# objects_modified returns $SYS_VIEW_X with no mapping
# There are a couple cases when we'd want to prefer our own SQL parsing
# over Snowflake's metadata.
# 1. For queries that use a stream, objects_modified returns $SYS_VIEW_X with no mapping.
# We can check direct_objects_accessed to see if there is a stream used, and if so,
# prefer doing SQL parsing over Snowflake's metadata.
# 2. For queries that create a view, objects_modified is empty and object_modified_by_ddl
# contains the view name and columns. Because `object_modified_by_ddl` doesn't contain
# source columns e.g. lineage information, we must do our own SQL parsing. We're mainly
# focused on temporary views. It's fine if we parse a couple extra views, but in general
# we want view definitions to come from Snowflake's schema metadata and not from query logs.
has_stream_objects = any(
obj.get("objectDomain") == "Stream" for obj in direct_objects_accessed
)
is_create_view = query_type == QueryType.CREATE_VIEW
is_create_temp_view = is_create_view and self._has_temp_keyword(query_text)
if has_stream_objects or is_create_temp_view:
if has_stream_objects:
self.report.num_stream_queries_observed += 1
elif is_create_temp_view:
self.report.num_create_temp_view_queries_observed += 1
# If a stream is used, default to query parsing.
if has_stream_objects:
logger.debug("Found matching stream object")
return ObservedQuery(
query=res["query_text"],
query=query_text,
session_id=res["session_id"],
timestamp=res["query_start_time"].astimezone(timezone.utc),
timestamp=timestamp,
user=user,
default_db=res["default_db"],
default_schema=res["default_schema"],
query_hash=get_query_fingerprint(
res["query_text"], self.identifiers.platform, fast=True
query_text, self.identifiers.platform, fast=True
),
)
@ -502,25 +534,17 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable):
)
)
timestamp: datetime = res["query_start_time"]
timestamp = timestamp.astimezone(timezone.utc)
# TODO need to map snowflake query types to ours
query_type = SNOWFLAKE_QUERY_TYPE_MAPPING.get(
res["query_type"], QueryType.UNKNOWN
)
entry = PreparsedQuery(
# Despite having Snowflake's fingerprints available, our own fingerprinting logic does a better
# job at eliminating redundant / repetitive queries. As such, we include the fast fingerprint
# here
query_id=get_query_fingerprint(
res["query_text"],
query_text,
self.identifiers.platform,
fast=True,
secondary_id=res["query_secondary_fingerprint"],
),
query_text=res["query_text"],
query_text=query_text,
upstreams=upstreams,
downstream=downstream,
column_lineage=column_lineage,
@ -543,7 +567,6 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable):
object_modified_by_ddl: dict,
query_type: str,
) -> Optional[Union[TableRename, TableSwap]]:
timestamp = timestamp.astimezone(timezone.utc)
if (
object_modified_by_ddl["operationType"] == "ALTER"
and query_type == "RENAME_TABLE"

View File

@ -43,13 +43,6 @@ class SnowflakeQuery:
ACCESS_HISTORY_TABLE_VIEW_DOMAINS_FILTER = "({})".format(
",".join(f"'{domain}'" for domain in ACCESS_HISTORY_TABLE_VIEW_DOMAINS)
)
ACCESS_HISTORY_TABLE_DOMAINS_FILTER = (
"("
f"'{SnowflakeObjectDomain.TABLE.capitalize()}',"
f"'{SnowflakeObjectDomain.VIEW.capitalize()}',"
f"'{SnowflakeObjectDomain.STREAM.capitalize()}',"
")"
)
@staticmethod
def current_account() -> str:

View File

@ -109,7 +109,7 @@ class ObservedQuery:
query_hash: Optional[str] = None
usage_multiplier: int = 1
# Use this to store addtitional key-value information about query for debugging
# Use this to store additional key-value information about the query for debugging.
extra_info: Optional[dict] = None

View File

@ -2,7 +2,10 @@ import os
from unittest.mock import patch
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.source.snowflake.snowflake_queries import SnowflakeQueriesSource
from datahub.ingestion.source.snowflake.snowflake_queries import (
SnowflakeQueriesExtractor,
SnowflakeQueriesSource,
)
@patch("snowflake.connector.connect")
@ -77,3 +80,28 @@ def test_user_identifiers_username_as_identifier(snowflake_connect, tmp_path):
== "username"
)
assert source.identifiers.get_user_identifier("username", None) == "username"
def test_snowflake_has_temp_keyword():
cases = [
("CREATE TEMP VIEW my_table__dbt_tmp ...", True),
("CREATE TEMPORARY VIEW my_table__dbt_tmp ...", True),
("CREATE VIEW my_table__dbt_tmp ...", False),
# Test case sensitivity
("create TEMP view test", True),
("CREATE temporary VIEW test", True),
("create temp view test", True),
# Test with whitespace variations
("CREATE\nTEMP\tVIEW test", True),
("CREATE TEMPORARY VIEW test", True),
# Test with partial matches that should be false
("SELECT * FROM my_template_table", False),
("CREATE TEMPERATURE VIEW test", False),
("SELECT * FROM TEMPDB.table", False),
("CREATE VIEW temporary_table", False),
# Note that this method has some edge cases that don't quite work.
# But it's good enough for our purposes.
# ("SELECT 'TEMPORARY' FROM table", False),
]
for query, expected in cases:
assert SnowflakeQueriesExtractor._has_temp_keyword(query) == expected