diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/snowflake.py b/metadata-ingestion/src/datahub/ingestion/source/sql/snowflake.py index fddcb47685..b33f6d378a 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/snowflake.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/snowflake.py @@ -1,6 +1,7 @@ import json import logging from collections import defaultdict +from dataclasses import dataclass, field from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import pydantic @@ -24,12 +25,14 @@ import datahub.emitter.mce_builder as builder from datahub.configuration.common import AllowDenyPattern from datahub.configuration.time_window_config import BaseTimeWindowConfig from datahub.emitter.mcp import MetadataChangeProposalWrapper +from datahub.ingestion.api.common import PipelineContext from datahub.ingestion.api.workunit import MetadataWorkUnit from datahub.ingestion.source.aws.s3_util import make_s3_urn from datahub.ingestion.source.sql.sql_common import ( RecordTypeClass, SQLAlchemyConfig, SQLAlchemySource, + SQLSourceReport, SqlWorkUnit, TimeTypeClass, make_sqlalchemy_uri, @@ -56,6 +59,15 @@ APPLICATION_NAME = "acryl_datahub" snowdialect.ischema_names["GEOGRAPHY"] = sqltypes.NullType +@dataclass +class SnowflakeReport(SQLSourceReport): + num_table_to_table_edges_scanned: int = 0 + num_table_to_view_edges_scanned: int = 0 + num_view_to_table_edges_scanned: int = 0 + num_external_table_edges_scanned: int = 0 + upstream_lineage: Dict[str, List[str]] = field(default_factory=dict) + + class BaseSnowflakeConfig(BaseTimeWindowConfig): # Note: this config model is also used by the snowflake-usage source. @@ -179,13 +191,12 @@ class SnowflakeConfig(BaseSnowflakeConfig, SQLAlchemyConfig): class SnowflakeSource(SQLAlchemySource): - config: SnowflakeConfig - current_database: str - - def __init__(self, config, ctx): + def __init__(self, config: SnowflakeConfig, ctx: PipelineContext): super().__init__(config, ctx, "snowflake") self._lineage_map: Optional[Dict[str, List[Tuple[str, str, str]]]] = None self._external_lineage_map: Optional[Dict[str, Set[str]]] = None + self.report: SnowflakeReport = SnowflakeReport() + self.config: SnowflakeConfig = config @classmethod def create(cls, config_dict, ctx): @@ -275,6 +286,7 @@ WHERE ) assert self._lineage_map is not None + num_edges: int = 0 try: for db_row in engine.execute(view_upstream_lineage_query): @@ -287,14 +299,17 @@ WHERE # (, , ) (upstream_table, db_row[4], db_row[1]) ) + num_edges += 1 logger.debug( - f"Table->View: Lineage[{view_name}]:{self._lineage_map[view_name]}, upstream_domain={db_row[3]}" + f"Table->View: Lineage[View(Down)={view_name}]:Table(Up)={self._lineage_map[view_name]}, upstream_domain={db_row[3]}" ) except Exception as e: logger.warning( f"Extracting the upstream view lineage from Snowflake failed." f"Please check your permissions. Continuing...\nError was {e}." ) + logger.info(f"A total of {num_edges} Table->View edges found.") + self.report.num_table_to_view_edges_scanned = num_edges def _populate_view_downstream_lineage( self, engine: sqlalchemy.engine.Engine @@ -363,6 +378,8 @@ WHERE ) assert self._lineage_map is not None + num_edges: int = 0 + num_false_edges: int = 0 try: for db_row in engine.execute(view_lineage_query): @@ -374,11 +391,21 @@ WHERE upstream_table: str = db_row[2].lower().replace('"', "") downstream_table: str = db_row[5].lower().replace('"', "") # (1) Delete false direct edge between upstream_table and downstream_table + prior_edges: List[Tuple[str, str, str]] = self._lineage_map[ + downstream_table + ] self._lineage_map[downstream_table] = [ entry for entry in self._lineage_map[downstream_table] if entry[0] != upstream_table ] + for false_edge in set(prior_edges) - set( + self._lineage_map[downstream_table] + ): + logger.debug( + f"False Table->Table edge removed: Lineage[Table(Down)={downstream_table}]:Table(Up)={false_edge}." + ) + num_false_edges += 1 # (2) Add view->downstream table lineage. self._lineage_map[downstream_table].append( @@ -386,14 +413,19 @@ WHERE (view_name, db_row[1], db_row[7]) ) logger.debug( - f"View->Table: Lineage[{downstream_table}]:{self._lineage_map[downstream_table]}, downstream_domain={db_row[6]}" + f"View->Table: Lineage[Table(Down)={downstream_table}]:View(Up)={self._lineage_map[downstream_table]}, downstream_domain={db_row[6]}" ) + num_edges += 1 except Exception as e: logger.warning( f"Extracting the view lineage from Snowflake failed." f"Please check your permissions. Continuing...\nError was {e}." ) + logger.info( + f"Found {num_edges} View->Table edges. Removed {num_false_edges} false Table->Table edges." + ) + self.report.num_view_to_table_edges_scanned = num_edges def _populate_view_lineage(self) -> None: if not self.config.include_view_lineage: @@ -434,6 +466,7 @@ WHERE end_time_millis=int(self.config.end_time.timestamp() * 1000), ) + num_edges: int = 0 self._external_lineage_map = defaultdict(set) try: for db_row in engine.execute(query): @@ -441,7 +474,7 @@ WHERE key: str = db_row[1].lower().replace('"', "") self._external_lineage_map[key] |= {*json.loads(db_row[0])} logger.debug( - f"ExternalLineage[{key}]:{self._external_lineage_map[key]}" + f"ExternalLineage[Table(Down)={key}]:External(Up)={self._external_lineage_map[key]}" ) except Exception as e: logger.warning( @@ -458,13 +491,16 @@ WHERE ) self._external_lineage_map[key].add(db_row.location) logger.debug( - f"ExternalLineage[{key}]:{self._external_lineage_map[key]}" + f"ExternalLineage[Table(Down)={key}]:External(Up)={self._external_lineage_map[key]}" ) + num_edges += 1 except Exception as e: logger.warning( f"Populating external table lineage from Snowflake failed." f"Please check your premissions. Continuing...\nError was {e}." ) + logger.info(f"Found {num_edges} external lineage edges.") + self.report.num_external_table_edges_scanned = num_edges def _populate_lineage(self) -> None: url = self.config.get_sql_alchemy_url() @@ -486,7 +522,7 @@ WITH table_lineage_history AS ( t.query_start_time AS query_start_time FROM (SELECT * from snowflake.account_usage.access_history) t, - lateral flatten(input => t.BASE_OBJECTS_ACCESSED) r, + lateral flatten(input => t.DIRECT_OBJECTS_ACCESSED) r, lateral flatten(input => t.OBJECTS_MODIFIED) w WHERE r.value:"objectId" IS NOT NULL AND w.value:"objectId" IS NOT NULL @@ -501,6 +537,7 @@ QUALIFY ROW_NUMBER() OVER (PARTITION BY downstream_table_name, upstream_table_na start_time_millis=int(self.config.start_time.timestamp() * 1000), end_time_millis=int(self.config.end_time.timestamp() * 1000), ) + num_edges: int = 0 self._lineage_map = defaultdict(list) try: for db_row in engine.execute(query): @@ -510,12 +547,20 @@ QUALIFY ROW_NUMBER() OVER (PARTITION BY downstream_table_name, upstream_table_na # (, , ) (db_row[0].lower().replace('"', ""), db_row[2], db_row[3]) ) - logger.debug(f"Lineage[{key}]:{self._lineage_map[key]}") + num_edges += 1 + logger.debug( + f"Lineage[Table(Down)={key}]:Table(Up)={self._lineage_map[key]}" + ) except Exception as e: logger.warning( f"Extracting lineage from Snowflake failed." f"Please check your premissions. Continuing...\nError was {e}." ) + logger.info( + f"A total of {num_edges} Table->Table edges found" + f" for {len(self._lineage_map)} downstream tables.", + ) + self.report.num_table_to_table_edges_scanned = num_edges def _get_upstream_lineage_info( self, dataset_urn: str @@ -586,6 +631,12 @@ QUALIFY ROW_NUMBER() OVER (PARTITION BY downstream_table_name, upstream_table_na upstream_tables.append(external_upstream_table) if upstream_tables: + logger.debug( + f"Upstream lineage of '{dataset_name}': {[u.dataset for u in upstream_tables]}" + ) + self.report.upstream_lineage[dataset_name] = [ + u.dataset for u in upstream_tables + ] return UpstreamLineage(upstreams=upstream_tables), column_lineage return None @@ -594,7 +645,7 @@ QUALIFY ROW_NUMBER() OVER (PARTITION BY downstream_table_name, upstream_table_na for wu in super().get_workunits(): if ( self.config.include_table_lineage - and isinstance(wu, SqlWorkUnit) + and isinstance(wu, MetadataWorkUnit) and isinstance(wu.metadata, MetadataChangeEvent) and isinstance(wu.metadata.proposedSnapshot, DatasetSnapshot) ):