From d3944ded93e1696e24d72a6ebf3c5c190f70352c Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Fri, 16 May 2025 21:27:13 -0700 Subject: [PATCH] feat(ingest/snowflake): generate lineage through temp views (#13517) --- .../source/snowflake/snowflake_queries.py | 65 +++++++++++++------ .../source/snowflake/snowflake_query.py | 7 -- .../sql_parsing/sql_parsing_aggregator.py | 2 +- .../snowflake/test_snowflake_queries.py | 30 ++++++++- 4 files changed, 74 insertions(+), 30 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py index ceb3d2a80f..0d6c070281 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py @@ -127,6 +127,8 @@ class SnowflakeQueriesExtractorReport(Report): sql_aggregator: Optional[SqlAggregatorReport] = None num_ddl_queries_dropped: int = 0 + num_stream_queries_observed: int = 0 + num_create_temp_view_queries_observed: int = 0 num_users: int = 0 @@ -373,6 +375,13 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable): if entry: yield entry + @classmethod + def _has_temp_keyword(cls, query_text: str) -> bool: + return ( + re.search(r"\bTEMP\b", query_text, re.IGNORECASE) is not None + or re.search(r"\bTEMPORARY\b", query_text, re.IGNORECASE) is not None + ) + def _parse_audit_log_row( self, row: Dict[str, Any], users: UsersMapping ) -> Optional[Union[TableRename, TableSwap, PreparsedQuery, ObservedQuery]]: @@ -389,6 +398,15 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable): key = key.lower() res[key] = value + timestamp: datetime = res["query_start_time"] + timestamp = timestamp.astimezone(timezone.utc) + + # TODO need to map snowflake query types to ours + query_text: str = res["query_text"] + query_type: QueryType = SNOWFLAKE_QUERY_TYPE_MAPPING.get( + res["query_type"], QueryType.UNKNOWN + ) + direct_objects_accessed = res["direct_objects_accessed"] objects_modified = res["objects_modified"] object_modified_by_ddl = res["object_modified_by_ddl"] @@ -399,9 +417,9 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable): "Error fetching ddl lineage from Snowflake" ): known_ddl_entry = self.parse_ddl_query( - res["query_text"], + query_text, res["session_id"], - res["query_start_time"], + timestamp, object_modified_by_ddl, res["query_type"], ) @@ -419,24 +437,38 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable): ) ) - # Use direct_objects_accessed instead objects_modified - # objects_modified returns $SYS_VIEW_X with no mapping + # There are a couple cases when we'd want to prefer our own SQL parsing + # over Snowflake's metadata. + # 1. For queries that use a stream, objects_modified returns $SYS_VIEW_X with no mapping. + # We can check direct_objects_accessed to see if there is a stream used, and if so, + # prefer doing SQL parsing over Snowflake's metadata. + # 2. For queries that create a view, objects_modified is empty and object_modified_by_ddl + # contains the view name and columns. Because `object_modified_by_ddl` doesn't contain + # source columns e.g. lineage information, we must do our own SQL parsing. We're mainly + # focused on temporary views. It's fine if we parse a couple extra views, but in general + # we want view definitions to come from Snowflake's schema metadata and not from query logs. + has_stream_objects = any( obj.get("objectDomain") == "Stream" for obj in direct_objects_accessed ) + is_create_view = query_type == QueryType.CREATE_VIEW + is_create_temp_view = is_create_view and self._has_temp_keyword(query_text) + + if has_stream_objects or is_create_temp_view: + if has_stream_objects: + self.report.num_stream_queries_observed += 1 + elif is_create_temp_view: + self.report.num_create_temp_view_queries_observed += 1 - # If a stream is used, default to query parsing. - if has_stream_objects: - logger.debug("Found matching stream object") return ObservedQuery( - query=res["query_text"], + query=query_text, session_id=res["session_id"], - timestamp=res["query_start_time"].astimezone(timezone.utc), + timestamp=timestamp, user=user, default_db=res["default_db"], default_schema=res["default_schema"], query_hash=get_query_fingerprint( - res["query_text"], self.identifiers.platform, fast=True + query_text, self.identifiers.platform, fast=True ), ) @@ -502,25 +534,17 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable): ) ) - timestamp: datetime = res["query_start_time"] - timestamp = timestamp.astimezone(timezone.utc) - - # TODO need to map snowflake query types to ours - query_type = SNOWFLAKE_QUERY_TYPE_MAPPING.get( - res["query_type"], QueryType.UNKNOWN - ) - entry = PreparsedQuery( # Despite having Snowflake's fingerprints available, our own fingerprinting logic does a better # job at eliminating redundant / repetitive queries. As such, we include the fast fingerprint # here query_id=get_query_fingerprint( - res["query_text"], + query_text, self.identifiers.platform, fast=True, secondary_id=res["query_secondary_fingerprint"], ), - query_text=res["query_text"], + query_text=query_text, upstreams=upstreams, downstream=downstream, column_lineage=column_lineage, @@ -543,7 +567,6 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable): object_modified_by_ddl: dict, query_type: str, ) -> Optional[Union[TableRename, TableSwap]]: - timestamp = timestamp.astimezone(timezone.utc) if ( object_modified_by_ddl["operationType"] == "ALTER" and query_type == "RENAME_TABLE" diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_query.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_query.py index e7338a02bd..c023ed9c75 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_query.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_query.py @@ -43,13 +43,6 @@ class SnowflakeQuery: ACCESS_HISTORY_TABLE_VIEW_DOMAINS_FILTER = "({})".format( ",".join(f"'{domain}'" for domain in ACCESS_HISTORY_TABLE_VIEW_DOMAINS) ) - ACCESS_HISTORY_TABLE_DOMAINS_FILTER = ( - "(" - f"'{SnowflakeObjectDomain.TABLE.capitalize()}'," - f"'{SnowflakeObjectDomain.VIEW.capitalize()}'," - f"'{SnowflakeObjectDomain.STREAM.capitalize()}'," - ")" - ) @staticmethod def current_account() -> str: diff --git a/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py b/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py index c9a203c495..80c38d9103 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py +++ b/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py @@ -109,7 +109,7 @@ class ObservedQuery: query_hash: Optional[str] = None usage_multiplier: int = 1 - # Use this to store addtitional key-value information about query for debugging + # Use this to store additional key-value information about the query for debugging. extra_info: Optional[dict] = None diff --git a/metadata-ingestion/tests/integration/snowflake/test_snowflake_queries.py b/metadata-ingestion/tests/integration/snowflake/test_snowflake_queries.py index ae0f23d932..f419a1a201 100644 --- a/metadata-ingestion/tests/integration/snowflake/test_snowflake_queries.py +++ b/metadata-ingestion/tests/integration/snowflake/test_snowflake_queries.py @@ -2,7 +2,10 @@ import os from unittest.mock import patch from datahub.ingestion.api.common import PipelineContext -from datahub.ingestion.source.snowflake.snowflake_queries import SnowflakeQueriesSource +from datahub.ingestion.source.snowflake.snowflake_queries import ( + SnowflakeQueriesExtractor, + SnowflakeQueriesSource, +) @patch("snowflake.connector.connect") @@ -77,3 +80,28 @@ def test_user_identifiers_username_as_identifier(snowflake_connect, tmp_path): == "username" ) assert source.identifiers.get_user_identifier("username", None) == "username" + + +def test_snowflake_has_temp_keyword(): + cases = [ + ("CREATE TEMP VIEW my_table__dbt_tmp ...", True), + ("CREATE TEMPORARY VIEW my_table__dbt_tmp ...", True), + ("CREATE VIEW my_table__dbt_tmp ...", False), + # Test case sensitivity + ("create TEMP view test", True), + ("CREATE temporary VIEW test", True), + ("create temp view test", True), + # Test with whitespace variations + ("CREATE\nTEMP\tVIEW test", True), + ("CREATE TEMPORARY VIEW test", True), + # Test with partial matches that should be false + ("SELECT * FROM my_template_table", False), + ("CREATE TEMPERATURE VIEW test", False), + ("SELECT * FROM TEMPDB.table", False), + ("CREATE VIEW temporary_table", False), + # Note that this method has some edge cases that don't quite work. + # But it's good enough for our purposes. + # ("SELECT 'TEMPORARY' FROM table", False), + ] + for query, expected in cases: + assert SnowflakeQueriesExtractor._has_temp_keyword(query) == expected