From dc08bedd6e96088815802e6f83f06a1e8082e659 Mon Sep 17 00:00:00 2001 From: Mayuri Nehate <33225191+mayurinehate@users.noreply.github.com> Date: Tue, 16 Aug 2022 09:24:02 +0530 Subject: [PATCH] feat(ingest): snowflake - add snowflake-beta connector (#5517) --- .../docs/sources/snowflake/README.md | 5 +- .../docs/sources/snowflake/snowflake-beta.md | 56 + .../snowflake/snowflake-beta_recipe.yml | 43 + metadata-ingestion/setup.py | 2 + .../configuration/time_window_config.py | 14 +- .../ingestion/source/snowflake/__init__.py | 0 .../source/snowflake/snowflake_config.py | 79 ++ .../source/snowflake/snowflake_lineage.py | 325 +++++ .../source/snowflake/snowflake_query.py | 509 ++++++++ .../source/snowflake/snowflake_report.py | 11 + .../source/snowflake/snowflake_schema.py | 323 +++++ .../source/snowflake/snowflake_usage_v2.py | 410 ++++++ .../source/snowflake/snowflake_utils.py | 149 +++ .../source/snowflake/snowflake_v2.py | 1139 +++++++++++++++++ .../src/datahub/utilities/sql_formatter.py | 15 + .../bigquery-usage/test_bigquery_usage.py | 2 +- .../tests/unit/test_bigquery_usage_source.py | 4 +- 17 files changed, 3073 insertions(+), 13 deletions(-) create mode 100644 metadata-ingestion/docs/sources/snowflake/snowflake-beta.md create mode 100644 metadata-ingestion/docs/sources/snowflake/snowflake-beta_recipe.yml create mode 100644 metadata-ingestion/src/datahub/ingestion/source/snowflake/__init__.py create mode 100644 metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py create mode 100644 metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage.py create mode 100644 metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_query.py create mode 100644 metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_report.py create mode 100644 metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py create mode 100644 metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py create mode 100644 metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py create mode 100644 metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py diff --git a/metadata-ingestion/docs/sources/snowflake/README.md b/metadata-ingestion/docs/sources/snowflake/README.md index 62c6c15513..1ca05f5ec4 100644 --- a/metadata-ingestion/docs/sources/snowflake/README.md +++ b/metadata-ingestion/docs/sources/snowflake/README.md @@ -1 +1,4 @@ -To get all metadata from Snowflake you need to use two plugins `snowflake` and `snowflake-usage`. Both of them are described in this page. These will require 2 separate recipes. We understand this is not ideal and we plan to make this easier in the future. +To get all metadata from Snowflake you need to use two plugins `snowflake` and `snowflake-usage`. Both of them are described in this page. These will require 2 separate recipes. + + +We encourage you to try out new `snowflake-beta` plugin as alternative to running both `snowflake` and `snowflake-usage` plugins and share feedback. `snowflake-beta` is much faster than `snowflake` for extracting metadata . Please note that, `snowflake-beta` plugin currently does not support column level profiling, unlike `snowflake` plugin. \ No newline at end of file diff --git a/metadata-ingestion/docs/sources/snowflake/snowflake-beta.md b/metadata-ingestion/docs/sources/snowflake/snowflake-beta.md new file mode 100644 index 0000000000..9f76554b22 --- /dev/null +++ b/metadata-ingestion/docs/sources/snowflake/snowflake-beta.md @@ -0,0 +1,56 @@ +### Prerequisites + +In order to execute this source, your Snowflake user will need to have specific privileges granted to it for reading metadata +from your warehouse. + +Snowflake system admin can follow this guide to create a DataHub-specific role, assign it the required privileges, and assign it to a new DataHub user by executing the following Snowflake commands from a user with the `ACCOUNTADMIN` role or `MANAGE GRANTS` privilege. + +```sql +create or replace role datahub_role; + +// Grant access to a warehouse to run queries to view metadata +grant operate, usage on warehouse "" to role datahub_role; + +// Grant access to view database and schema in which your tables/views exist +grant usage on DATABASE "" to role datahub_role; +grant usage on all schemas in database "" to role datahub_role; +grant usage on future schemas in database "" to role datahub_role; + +// If you are NOT using Snowflake Profiling feature: Grant references privileges to your tables and views +grant references on all tables in database "" to role datahub_role; +grant references on future tables in database "" to role datahub_role; +grant references on all external tables in database "" to role datahub_role; +grant references on future external tables in database "" to role datahub_role; +grant references on all views in database "" to role datahub_role; +grant references on future views in database "" to role datahub_role; + +// If you ARE using Snowflake Profiling feature: Grant select privileges to your tables and views +grant select on all tables in database "" to role datahub_role; +grant select on future tables in database "" to role datahub_role; +grant select on all external tables in database "" to role datahub_role; +grant select on future external tables in database "" to role datahub_role; +grant select on all views in database "" to role datahub_role; +grant select on future views in database "" to role datahub_role; + +// Create a new DataHub user and assign the DataHub role to it +create user datahub_user display_name = 'DataHub' password='' default_role = datahub_role default_warehouse = ''; + +// Grant the datahub_role to the new DataHub user. +grant role datahub_role to user datahub_user; +``` + +The details of each granted privilege can be viewed in [snowflake docs](https://docs.snowflake.com/en/user-guide/security-access-control-privileges.html). A summarization of each privilege, and why it is required for this connector: +- `operate` is required on warehouse to execute queries +- `usage` is required for us to run queries using the warehouse +- `usage` on `database` and `schema` are required because without it tables and views inside them are not accessible. If an admin does the required grants on `table` but misses the grants on `schema` or the `database` in which the table/view exists then we will not be able to get metadata for the table/view. +- If metadata is required only on some schemas then you can grant the usage privilieges only on a particular schema like +```sql +grant usage on schema ""."" to role datahub_role; +``` + +This represents the bare minimum privileges required to extract databases, schemas, views, tables from Snowflake. + +If you plan to enable extraction of table lineage, via the `include_table_lineage` config flag or extraction of usage statistics, via the `include_usage_stats` config, you'll also need to grant access to the [Account Usage](https://docs.snowflake.com/en/sql-reference/account-usage.html) system tables, using which the DataHub source extracts information. This can be done by granting access to the `snowflake` database. +```sql +grant imported privileges on database snowflake to role datahub_role; +``` \ No newline at end of file diff --git a/metadata-ingestion/docs/sources/snowflake/snowflake-beta_recipe.yml b/metadata-ingestion/docs/sources/snowflake/snowflake-beta_recipe.yml new file mode 100644 index 0000000000..817707eb72 --- /dev/null +++ b/metadata-ingestion/docs/sources/snowflake/snowflake-beta_recipe.yml @@ -0,0 +1,43 @@ +source: + type: snowflake-beta + config: + + # This option is recommended to be used for the first time to ingest all lineage + ignore_start_time_lineage: true + # This is an alternative option to specify the start_time for lineage + # if you don't want to look back since beginning + start_time: '2022-03-01T00:00:00Z' + + # Coordinates + account_id: "abc48144" + warehouse: "COMPUTE_WH" + + # Credentials + username: "${SNOWFLAKE_USER}" + password: "${SNOWFLAKE_PASS}" + role: "datahub_role" + + # Change these as per your database names. Remove to get all databases + database_pattern: + allow: + - "^ACCOUNTING_DB$" + - "^MARKETING_DB$" + + table_pattern: + allow: + # If you want to ingest only few tables with name revenue and sales + - ".*revenue" + - ".*sales" + + profiling: + # Change to false to disable profiling + enabled: true + profile_table_level_only: true + profile_pattern: + allow: + - 'ACCOUNTING_DB.*.*' + - 'MARKETING_DB.*.*' + + +sink: +# sink configs \ No newline at end of file diff --git a/metadata-ingestion/setup.py b/metadata-ingestion/setup.py index baf10921d8..c20cc3a4c6 100644 --- a/metadata-ingestion/setup.py +++ b/metadata-ingestion/setup.py @@ -290,6 +290,7 @@ plugins: Dict[str, Set[str]] = { | { "more-itertools>=8.12.0", }, + "snowflake-beta": snowflake_common | usage_common, "sqlalchemy": sql_common, "superset": { "requests", @@ -499,6 +500,7 @@ entry_points = { "redshift-usage = datahub.ingestion.source.usage.redshift_usage:RedshiftUsageSource", "snowflake = datahub.ingestion.source.sql.snowflake:SnowflakeSource", "snowflake-usage = datahub.ingestion.source.usage.snowflake_usage:SnowflakeUsageSource", + "snowflake-beta = datahub.ingestion.source.snowflake.snowflake_v2:SnowflakeV2Source", "superset = datahub.ingestion.source.superset:SupersetSource", "tableau = datahub.ingestion.source.tableau:TableauSource", "openapi = datahub.ingestion.source.openapi:OpenApiSource", diff --git a/metadata-ingestion/src/datahub/configuration/time_window_config.py b/metadata-ingestion/src/datahub/configuration/time_window_config.py index 2ea6663e48..ad7bfafedd 100644 --- a/metadata-ingestion/src/datahub/configuration/time_window_config.py +++ b/metadata-ingestion/src/datahub/configuration/time_window_config.py @@ -40,26 +40,22 @@ class BaseTimeWindowConfig(ConfigModel): # `start_time` and `end_time` will be populated by the pre-validators. # However, we must specify a "default" value here or pydantic will complain # if those fields are not set by the user. - end_time: datetime = Field(default=None, description="Latest date of usage to consider. Default: Last full day in UTC (or hour, depending on `bucket_duration`)") # type: ignore + end_time: datetime = Field(default=None, description="Latest date of usage to consider. Default: Current time in UTC") # type: ignore start_time: datetime = Field(default=None, description="Earliest date of usage to consider. Default: Last full day in UTC (or hour, depending on `bucket_duration`)") # type: ignore @pydantic.validator("end_time", pre=True, always=True) def default_end_time( cls, v: Any, *, values: Dict[str, Any], **kwargs: Any ) -> datetime: - return v or get_time_bucket( - datetime.now(tz=timezone.utc) - + get_bucket_duration_delta(values["bucket_duration"]), - values["bucket_duration"], - ) + return v or datetime.now(tz=timezone.utc) @pydantic.validator("start_time", pre=True, always=True) def default_start_time( cls, v: Any, *, values: Dict[str, Any], **kwargs: Any ) -> datetime: - return v or ( - values["end_time"] - - get_bucket_duration_delta(values["bucket_duration"]) * 2 + return v or get_time_bucket( + values["end_time"] - get_bucket_duration_delta(values["bucket_duration"]), + values["bucket_duration"], ) @pydantic.validator("start_time", "end_time") diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/__init__.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py new file mode 100644 index 0000000000..142d9689bf --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py @@ -0,0 +1,79 @@ +import logging +from typing import Dict, Optional, cast + +from pydantic import Field, root_validator + +from datahub.configuration.common import AllowDenyPattern +from datahub.ingestion.source.sql.sql_common import SQLAlchemyStatefulIngestionConfig +from datahub.ingestion.source_config.sql.snowflake import ( + SnowflakeConfig, + SnowflakeProvisionRoleConfig, +) +from datahub.ingestion.source_config.usage.snowflake_usage import ( + SnowflakeStatefulIngestionConfig, + SnowflakeUsageConfig, +) + + +class SnowflakeV2StatefulIngestionConfig( + SQLAlchemyStatefulIngestionConfig, SnowflakeStatefulIngestionConfig +): + pass + + +logger = logging.Logger(__name__) + + +class SnowflakeV2Config(SnowflakeConfig, SnowflakeUsageConfig): + convert_urns_to_lowercase: bool = Field( + default=True, + ) + + include_usage_stats: bool = Field( + default=True, + description="If enabled, populates the snowflake usage statistics. Requires appropriate grants given to the role.", + ) + + check_role_grants: bool = Field( + default=False, + description="Not supported", + ) + + provision_role: Optional[SnowflakeProvisionRoleConfig] = Field( + default=None, description="Not supported" + ) + + @root_validator(pre=False) + def validate_unsupported_configs(cls, values: Dict) -> Dict: + + value = values.get("provision_role") + if value is not None and value.enabled: + raise ValueError( + "Provision role is currently not supported. Set `provision_role.enabled` to False." + ) + + value = values.get("profiling") + if value is not None and value.enabled and not value.profile_table_level_only: + raise ValueError( + "Only table level profiling is supported. Set `profiling.profile_table_level_only` to True.", + ) + + value = values.get("check_role_grants") + if value is not None and value: + raise ValueError( + "Check role grants is not supported. Set `check_role_grants` to False.", + ) + + value = values.get("include_read_operational_stats") + if value is not None and value: + raise ValueError( + "include_read_operational_stats is not supported. Set `include_read_operational_stats` to False.", + ) + + # Always exclude reporting metadata for INFORMATION_SCHEMA schema + schema_pattern = values.get("schema_pattern") + if schema_pattern is not None and schema_pattern: + logger.debug("Adding deny for INFORMATION_SCHEMA to schema_pattern.") + cast(AllowDenyPattern, schema_pattern).deny.append(r"^INFORMATION_SCHEMA$") + + return values diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage.py new file mode 100644 index 0000000000..d6e46c4f3b --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage.py @@ -0,0 +1,325 @@ +import json +import logging +from collections import defaultdict +from typing import Dict, List, Optional, Set, Tuple + +from snowflake.connector import SnowflakeConnection + +import datahub.emitter.mce_builder as builder +from datahub.ingestion.source.aws.s3_util import make_s3_urn +from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config +from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery +from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report +from datahub.ingestion.source.snowflake.snowflake_utils import ( + SnowflakeCommonMixin, + SnowflakeQueryMixin, +) +from datahub.metadata.com.linkedin.pegasus2avro.dataset import UpstreamLineage +from datahub.metadata.schema_classes import DatasetLineageTypeClass, UpstreamClass + +logger: logging.Logger = logging.getLogger(__name__) + + +class SnowflakeLineageExtractor(SnowflakeQueryMixin, SnowflakeCommonMixin): + def __init__(self, config: SnowflakeV2Config, report: SnowflakeV2Report) -> None: + self._lineage_map: Optional[Dict[str, List[Tuple[str, str, str]]]] = None + self._external_lineage_map: Optional[Dict[str, Set[str]]] = None + self.config = config + self.platform = "snowflake" + self.report = report + self.logger = logger + + # Rewrite implementation for readability, efficiency and extensibility + def _get_upstream_lineage_info( + self, dataset_name: str + ) -> Optional[Tuple[UpstreamLineage, Dict[str, str]]]: + + if not self.config.include_table_lineage: + return None + + if self._lineage_map is None or self._external_lineage_map is None: + conn = self.config.get_connection() + if self._lineage_map is None: + self._populate_lineage(conn) + self._populate_view_lineage(conn) + if self._external_lineage_map is None: + self._populate_external_lineage(conn) + + assert self._lineage_map is not None + assert self._external_lineage_map is not None + + lineage = self._lineage_map[dataset_name] + external_lineage = self._external_lineage_map[dataset_name] + if not (lineage or external_lineage): + logger.debug(f"No lineage found for {dataset_name}") + return None + upstream_tables: List[UpstreamClass] = [] + column_lineage: Dict[str, str] = {} + for lineage_entry in lineage: + # Update the table-lineage + upstream_table_name = lineage_entry[0] + if not self._is_dataset_pattern_allowed(upstream_table_name, "table"): + continue + upstream_table = UpstreamClass( + dataset=builder.make_dataset_urn_with_platform_instance( + self.platform, + upstream_table_name, + self.config.platform_instance, + self.config.env, + ), + type=DatasetLineageTypeClass.TRANSFORMED, + ) + upstream_tables.append(upstream_table) + # Update column-lineage for each down-stream column. + upstream_columns = [ + self.snowflake_identifier(d["columnName"]) + for d in json.loads(lineage_entry[1]) + ] + downstream_columns = [ + self.snowflake_identifier(d["columnName"]) + for d in json.loads(lineage_entry[2]) + ] + upstream_column_str = ( + f"{upstream_table_name}({', '.join(sorted(upstream_columns))})" + ) + downstream_column_str = ( + f"{dataset_name}({', '.join(sorted(downstream_columns))})" + ) + column_lineage_key = f"column_lineage[{upstream_table_name}]" + column_lineage_value = ( + f"{{{upstream_column_str} -> {downstream_column_str}}}" + ) + column_lineage[column_lineage_key] = column_lineage_value + logger.debug(f"{column_lineage_key}:{column_lineage_value}") + + for external_lineage_entry in external_lineage: + # For now, populate only for S3 + if external_lineage_entry.startswith("s3://"): + external_upstream_table = UpstreamClass( + dataset=make_s3_urn(external_lineage_entry, self.config.env), + type=DatasetLineageTypeClass.COPY, + ) + upstream_tables.append(external_upstream_table) + + if upstream_tables: + logger.debug( + f"Upstream lineage of '{dataset_name}': {[u.dataset for u in upstream_tables]}" + ) + if self.config.upstream_lineage_in_report: + self.report.upstream_lineage[dataset_name] = [ + u.dataset for u in upstream_tables + ] + return UpstreamLineage(upstreams=upstream_tables), column_lineage + return None + + def _populate_view_lineage(self, conn: SnowflakeConnection) -> None: + if not self.config.include_view_lineage: + return + self._populate_view_upstream_lineage(conn) + self._populate_view_downstream_lineage(conn) + + def _populate_external_lineage(self, conn: SnowflakeConnection) -> None: + # Handles the case where a table is populated from an external location via copy. + # Eg: copy into category_english from 's3://acryl-snow-demo-olist/olist_raw_data/category_english'credentials=(aws_key_id='...' aws_secret_key='...') pattern='.*.csv'; + query: str = SnowflakeQuery.external_table_lineage_history( + start_time_millis=int(self.config.start_time.timestamp() * 1000) + if not self.config.ignore_start_time_lineage + else 0, + end_time_millis=int(self.config.end_time.timestamp() * 1000), + ) + + num_edges: int = 0 + self._external_lineage_map = defaultdict(set) + try: + for db_row in self.query(conn, query): + # key is the down-stream table name + key: str = self.get_dataset_identifier_from_qualified_name( + db_row["downstream_table_name"] + ) + if not self._is_dataset_pattern_allowed(key, "table"): + continue + self._external_lineage_map[key] |= { + *json.loads(db_row["upstream_locations"]) + } + logger.debug( + f"ExternalLineage[Table(Down)={key}]:External(Up)={self._external_lineage_map[key]} via access_history" + ) + except Exception as e: + logger.warning( + f"Populating table external lineage from Snowflake failed." + f"Please check your premissions. Continuing...\nError was {e}." + ) + # Handles the case for explicitly created external tables. + # NOTE: Snowflake does not log this information to the access_history table. + external_tables_query: str = SnowflakeQuery.show_external_tables() + try: + for db_row in self.query(conn, external_tables_query): + key = self.get_dataset_identifier( + db_row["name"], db_row["schema_name"], db_row["database_name"] + ) + + if not self._is_dataset_pattern_allowed(key, "table"): + continue + self._external_lineage_map[key].add(db_row["location"]) + logger.debug( + f"ExternalLineage[Table(Down)={key}]:External(Up)={self._external_lineage_map[key]} via show external tables" + ) + num_edges += 1 + except Exception as e: + self.warn( + "external_lineage", + f"Populating external table lineage from Snowflake failed." + f"Please check your premissions. Continuing...\nError was {e}.", + ) + logger.info(f"Found {num_edges} external lineage edges.") + self.report.num_external_table_edges_scanned = num_edges + + def _populate_lineage(self, conn: SnowflakeConnection) -> None: + query: str = SnowflakeQuery.table_to_table_lineage_history( + start_time_millis=int(self.config.start_time.timestamp() * 1000) + if not self.config.ignore_start_time_lineage + else 0, + end_time_millis=int(self.config.end_time.timestamp() * 1000), + ) + num_edges: int = 0 + self._lineage_map = defaultdict(list) + try: + for db_row in self.query(conn, query): + # key is the down-stream table name + key: str = self.get_dataset_identifier_from_qualified_name( + db_row["downstream_table_name"] + ) + upstream_table_name = self.get_dataset_identifier_from_qualified_name( + db_row["upstream_table_name"] + ) + if not ( + self._is_dataset_pattern_allowed(key, "table") + or self._is_dataset_pattern_allowed(upstream_table_name, "table") + ): + continue + self._lineage_map[key].append( + # (, , ) + ( + upstream_table_name, + db_row["upstream_table_columns"], + db_row["downstream_table_columns"], + ) + ) + num_edges += 1 + logger.debug( + f"Lineage[Table(Down)={key}]:Table(Up)={self._lineage_map[key]}" + ) + except Exception as e: + self.warn( + "lineage", + f"Extracting lineage from Snowflake failed." + f"Please check your premissions. Continuing...\nError was {e}.", + ) + logger.info( + f"A total of {num_edges} Table->Table edges found" + f" for {len(self._lineage_map)} downstream tables.", + ) + self.report.num_table_to_table_edges_scanned = num_edges + + def _populate_view_upstream_lineage(self, conn: SnowflakeConnection) -> None: + # NOTE: This query captures only the upstream lineage of a view (with no column lineage). + # For more details see: https://docs.snowflake.com/en/user-guide/object-dependencies.html#object-dependencies + # and also https://docs.snowflake.com/en/sql-reference/account-usage/access_history.html#usage-notes for current limitations on capturing the lineage for views. + view_upstream_lineage_query: str = SnowflakeQuery.view_dependencies() + + assert self._lineage_map is not None + num_edges: int = 0 + + try: + for db_row in self.query(conn, view_upstream_lineage_query): + # Process UpstreamTable/View/ExternalTable/Materialized View->View edge. + view_upstream: str = self.get_dataset_identifier_from_qualified_name( + db_row["view_upstream"] + ) + view_name: str = self.get_dataset_identifier_from_qualified_name( + db_row["downstream_view"] + ) + if not self._is_dataset_pattern_allowed( + dataset_name=view_name, + dataset_type=db_row["referencing_object_domain"], + ): + continue + # key is the downstream view name + self._lineage_map[view_name].append( + # (, , ) + (view_upstream, "[]", "[]") + ) + num_edges += 1 + logger.debug( + f"Upstream->View: Lineage[View(Down)={view_name}]:Upstream={view_upstream}" + ) + except Exception as e: + self.warn( + "view_upstream_lineage", + "Extracting the upstream view lineage from Snowflake failed." + + f"Please check your permissions. Continuing...\nError was {e}.", + ) + logger.info(f"A total of {num_edges} View upstream edges found.") + self.report.num_table_to_view_edges_scanned = num_edges + + def _populate_view_downstream_lineage(self, conn: SnowflakeConnection) -> None: + + # This query captures the downstream table lineage for views. + # See https://docs.snowflake.com/en/sql-reference/account-usage/access_history.html#usage-notes for current limitations on capturing the lineage for views. + # Eg: For viewA->viewB->ViewC->TableD, snowflake does not yet log intermediate view logs, resulting in only the viewA->TableD edge. + view_lineage_query: str = SnowflakeQuery.view_lineage_history( + start_time_millis=int(self.config.start_time.timestamp() * 1000) + if not self.config.ignore_start_time_lineage + else 0, + end_time_millis=int(self.config.end_time.timestamp() * 1000), + ) + + assert self._lineage_map is not None + self.report.num_view_to_table_edges_scanned = 0 + + try: + db_rows = self.query(conn, view_lineage_query) + except Exception as e: + self.warn( + "view_downstream_lineage", + f"Extracting the view lineage from Snowflake failed." + f"Please check your permissions. Continuing...\nError was {e}.", + ) + else: + for db_row in db_rows: + view_name: str = self.get_dataset_identifier_from_qualified_name( + db_row["view_name"] + ) + if not self._is_dataset_pattern_allowed( + view_name, db_row["view_domain"] + ): + continue + downstream_table: str = self.get_dataset_identifier_from_qualified_name( + db_row["downstream_table_name"] + ) + # Capture view->downstream table lineage. + self._lineage_map[downstream_table].append( + # (, , ) + ( + view_name, + db_row["view_columns"], + db_row["downstream_table_columns"], + ) + ) + self.report.num_view_to_table_edges_scanned += 1 + + logger.debug( + f"View->Table: Lineage[Table(Down)={downstream_table}]:View(Up)={self._lineage_map[downstream_table]}" + ) + + logger.info( + f"Found {self.report.num_view_to_table_edges_scanned} View->Table edges." + ) + + def warn(self, key: str, reason: str) -> None: + self.report.report_warning(key, reason) + self.logger.warning(f"{key} => {reason}") + + def error(self, key: str, reason: str) -> None: + self.report.report_failure(key, reason) + self.logger.error(f"{key} => {reason}") diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_query.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_query.py new file mode 100644 index 0000000000..e7b23dfd82 --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_query.py @@ -0,0 +1,509 @@ +from typing import Optional + + +class SnowflakeQuery: + @staticmethod + def current_version() -> str: + return "select CURRENT_VERSION()" + + @staticmethod + def current_role() -> str: + return "select CURRENT_ROLE()" + + @staticmethod + def current_warehouse() -> str: + return "select CURRENT_WAREHOUSE()" + + @staticmethod + def current_database() -> str: + return "select CURRENT_DATABASE()" + + @staticmethod + def current_schema() -> str: + return "select CURRENT_SCHEMA()" + + @staticmethod + def show_databases() -> str: + return "show databases" + + @staticmethod + def use_database(db_name: str) -> str: + return f'use database "{db_name}"' + + @staticmethod + def schemas_for_database(db_name: Optional[str]) -> str: + db_clause = f'"{db_name}".' if db_name is not None else "" + return f""" + SELECT schema_name AS "schema_name", + created AS "created", + last_altered AS "last_altered", + comment AS "comment" + from {db_clause}information_schema.schemata + WHERE schema_name != 'INFORMATION_SCHEMA' + order by schema_name""" + + @staticmethod + def tables_for_database(db_name: Optional[str]) -> str: + db_clause = f'"{db_name}".' if db_name is not None else "" + return f""" + SELECT table_catalog AS "table_catalog", + table_schema AS "table_schema", + table_name AS "table_name", + table_type AS "table_type", + created AS "created", + last_altered AS "last_altered" , + comment AS "comment", + row_count AS "row_count", + bytes AS "bytes", + clustering_key AS "clustering_key", + auto_clustering_on AS "auto_clustering_on" + FROM {db_clause}information_schema.tables t + WHERE table_schema != 'INFORMATION_SCHEMA' + and table_type in ( 'BASE TABLE', 'EXTERNAL TABLE') + order by table_schema, table_name""" + + @staticmethod + def tables_for_schema(schema_name: str, db_name: Optional[str]) -> str: + db_clause = f'"{db_name}".' if db_name is not None else "" + return f""" + SELECT table_catalog AS "table_catalog", + table_schema AS "table_schema", + table_name AS "table_name", + table_type AS "table_type", + created AS "created", + last_altered AS "last_altered" , + comment AS "comment", + row_count AS "row_count", + bytes AS "bytes", + clustering_key AS "clustering_key", + auto_clustering_on AS "auto_clustering_on" + FROM {db_clause}information_schema.tables t + where schema_name='{schema_name}' + and table_type in ('BASE TABLE', 'EXTERNAL TABLE') + order by table_schema, table_name""" + + # View definition is retrived in information_schema query only if role is owner of view. Hence this query is not used. + # https://community.snowflake.com/s/article/Is-it-possible-to-see-the-view-definition-in-information-schema-views-from-a-non-owner-role + @staticmethod + def views_for_database(db_name: Optional[str]) -> str: + db_clause = f'"{db_name}".' if db_name is not None else "" + return f""" + SELECT table_catalog AS "table_catalog", + table_schema AS "table_schema", + table_name AS "table_name", + created AS "created", + last_altered AS "last_altered", + comment AS "comment", + view_definition AS "view_definition" + FROM {db_clause}information_schema.views t + WHERE table_schema != 'INFORMATION_SCHEMA' + order by table_schema, table_name""" + + # View definition is retrived in information_schema query only if role is owner of view. Hence this query is not used. + # https://community.snowflake.com/s/article/Is-it-possible-to-see-the-view-definition-in-information-schema-views-from-a-non-owner-role + @staticmethod + def views_for_schema(schema_name: str, db_name: Optional[str]) -> str: + db_clause = f'"{db_name}".' if db_name is not None else "" + return f""" + SELECT table_catalog AS "table_catalog", + table_schema AS "table_schema", + table_name AS "table_name", + created AS "created", + last_altered AS "last_altered", + comment AS "comment", + view_definition AS "view_definition" + FROM {db_clause}information_schema.views t + where schema_name='{schema_name}' + order by table_schema, table_name""" + + @staticmethod + def show_views_for_database(db_name: str) -> str: + return f"""show views in database "{db_name}";""" + + @staticmethod + def show_views_for_schema(schema_name: str, db_name: Optional[str]) -> str: + db_clause = f'"{db_name}".' if db_name is not None else "" + return f"""show views in schema {db_clause}"{schema_name}";""" + + @staticmethod + def columns_for_schema(schema_name: str, db_name: Optional[str]) -> str: + db_clause = f'"{db_name}".' if db_name is not None else "" + return f""" + select + table_catalog AS "table_catalog", + table_schema AS "table_schema", + table_name AS "table_name", + column_name AS "column_name", + ordinal_position AS "ordinal_position", + is_nullable AS "is_nullable", + data_type AS "data_type", + comment AS "comment", + character_maximum_length AS "character_maximum_length", + numeric_precision AS "numeric_precision", + numeric_scale AS "numeric_scale", + column_default AS "column_default", + is_identity AS "is_identity" + from {db_clause}information_schema.columns + WHERE table_schema='{schema_name}' + ORDER BY ordinal_position""" + + @staticmethod + def columns_for_table( + table_name: str, schema_name: str, db_name: Optional[str] + ) -> str: + db_clause = f'"{db_name}".' if db_name is not None else "" + return f""" + select + table_catalog AS "table_catalog", + table_schema AS "table_schema", + table_name AS "table_name", + column_name AS "column_name", + ordinal_position AS "ordinal_position", + is_nullable AS "is_nullable", + data_type AS "data_type", + comment AS "comment", + character_maximum_length AS "character_maximum_length", + numeric_precision AS "numeric_precision", + numeric_scale AS "numeric_scale", + column_default AS "column_default", + is_identity AS "is_identity" + from {db_clause}information_schema.columns + WHERE table_schema='{schema_name}' and table_name='{table_name}' + ORDER BY ordinal_position""" + + @staticmethod + def show_primary_keys_for_schema(schema_name: str, db_name: str) -> str: + return f""" + show primary keys in schema "{db_name}"."{schema_name}" """ + + @staticmethod + def show_foreign_keys_for_schema(schema_name: str, db_name: str) -> str: + return f""" + show imported keys in schema "{db_name}"."{schema_name}" """ + + @staticmethod + def operational_data_for_time_window( + start_time_millis: int, end_time_millis: int + ) -> str: + return f""" + SELECT + -- access_history.query_id, -- only for debugging purposes + access_history.query_start_time AS "query_start_time", + query_history.query_text AS "query_text", + query_history.query_type AS "query_type", + query_history.rows_inserted AS "rows_inserted", + query_history.rows_updated AS "rows_updated", + query_history.rows_deleted AS "rows_deleted", + access_history.base_objects_accessed AS "base_objects_accessed", + access_history.direct_objects_accessed AS "direct_objects_accessed", -- when dealing with views, direct objects will show the view while base will show the underlying table + access_history.objects_modified AS "objects_modified", + -- query_history.execution_status, -- not really necessary, but should equal "SUCCESS" + -- query_history.warehouse_name, + access_history.user_name AS "user_name", + users.first_name AS "first_name", + users.last_name AS "last_name", + users.display_name AS "display_name", + users.email AS "email", + query_history.role_name AS "role_name" + FROM + snowflake.account_usage.access_history access_history + LEFT JOIN + snowflake.account_usage.query_history query_history + ON access_history.query_id = query_history.query_id + LEFT JOIN + snowflake.account_usage.users users + ON access_history.user_name = users.name + WHERE query_start_time >= to_timestamp_ltz({start_time_millis}, 3) + AND query_start_time < to_timestamp_ltz({end_time_millis}, 3) + AND query_history.query_type in ('INSERT', 'UPDATE', 'DELETE', 'CREATE', 'CREATE_TABLE', 'CREATE_TABLE_AS_SELECT') + ORDER BY query_start_time DESC + ;""" + + @staticmethod + def table_to_table_lineage_history( + start_time_millis: int, end_time_millis: int + ) -> str: + return f""" + WITH table_lineage_history AS ( + SELECT + r.value:"objectName"::varchar AS upstream_table_name, + r.value:"objectDomain"::varchar AS upstream_table_domain, + r.value:"columns" AS upstream_table_columns, + w.value:"objectName"::varchar AS downstream_table_name, + w.value:"objectDomain"::varchar AS downstream_table_domain, + w.value:"columns" AS downstream_table_columns, + t.query_start_time AS query_start_time + FROM + (SELECT * from snowflake.account_usage.access_history) t, + lateral flatten(input => t.DIRECT_OBJECTS_ACCESSED) r, + lateral flatten(input => t.OBJECTS_MODIFIED) w + WHERE r.value:"objectId" IS NOT NULL + AND w.value:"objectId" IS NOT NULL + AND w.value:"objectName" NOT LIKE '%.GE_TMP_%' + AND w.value:"objectName" NOT LIKE '%.GE_TEMP_%' + AND t.query_start_time >= to_timestamp_ltz({start_time_millis}, 3) + AND t.query_start_time < to_timestamp_ltz({end_time_millis}, 3)) + SELECT + upstream_table_name AS "upstream_table_name", + downstream_table_name AS "downstream_table_name", + upstream_table_columns AS "upstream_table_columns", + downstream_table_columns AS "downstream_table_columns" + FROM table_lineage_history + WHERE upstream_table_domain in ('Table', 'External table') and downstream_table_domain = 'Table' + QUALIFY ROW_NUMBER() OVER (PARTITION BY downstream_table_name, upstream_table_name ORDER BY query_start_time DESC) = 1""" + + @staticmethod + def view_dependencies() -> str: + return """ + SELECT + concat( + referenced_database, '.', referenced_schema, + '.', referenced_object_name + ) AS "view_upstream", + concat( + referencing_database, '.', referencing_schema, + '.', referencing_object_name + ) AS "downstream_view", + referencing_object_domain AS "referencing_object_domain" + FROM + snowflake.account_usage.object_dependencies + WHERE + referencing_object_domain in ('VIEW', 'MATERIALIZED VIEW') + """ + + @staticmethod + def view_lineage_history(start_time_millis: int, end_time_millis: int) -> str: + return f""" + WITH view_lineage_history AS ( + SELECT + vu.value : "objectName"::varchar AS view_name, + vu.value : "objectDomain"::varchar AS view_domain, + vu.value : "columns" AS view_columns, + w.value : "objectName"::varchar AS downstream_table_name, + w.value : "objectDomain"::varchar AS downstream_table_domain, + w.value : "columns" AS downstream_table_columns, + t.query_start_time AS query_start_time + FROM + ( + SELECT + * + FROM + snowflake.account_usage.access_history + ) t, + lateral flatten(input => t.DIRECT_OBJECTS_ACCESSED) vu, + lateral flatten(input => t.OBJECTS_MODIFIED) w + WHERE + vu.value : "objectId" IS NOT NULL + AND w.value : "objectId" IS NOT NULL + AND w.value : "objectName" NOT LIKE '%.GE_TMP_%' + AND w.value : "objectName" NOT LIKE '%.GE_TEMP_%' + AND t.query_start_time >= to_timestamp_ltz({start_time_millis}, 3) + AND t.query_start_time < to_timestamp_ltz({end_time_millis}, 3) + ) + SELECT + view_name AS "view_name", + view_domain AS "view_domain", + view_columns AS "view_columns", + downstream_table_name AS "downstream_table_name", + downstream_table_columns AS "downstream_table_columns" + FROM + view_lineage_history + WHERE + view_domain in ('View', 'Materialized view') + QUALIFY ROW_NUMBER() OVER ( + PARTITION BY view_name, + downstream_table_name + ORDER BY + query_start_time DESC + ) = 1 + """ + + @staticmethod + def show_external_tables() -> str: + return "show external tables in account" + + @staticmethod + def external_table_lineage_history( + start_time_millis: int, end_time_millis: int + ) -> str: + return f""" + WITH external_table_lineage_history AS ( + SELECT + r.value:"locations" AS upstream_locations, + w.value:"objectName"::varchar AS downstream_table_name, + w.value:"objectDomain"::varchar AS downstream_table_domain, + w.value:"columns" AS downstream_table_columns, + t.query_start_time AS query_start_time + FROM + (SELECT * from snowflake.account_usage.access_history) t, + lateral flatten(input => t.BASE_OBJECTS_ACCESSED) r, + lateral flatten(input => t.OBJECTS_MODIFIED) w + WHERE r.value:"locations" IS NOT NULL + AND w.value:"objectId" IS NOT NULL + AND t.query_start_time >= to_timestamp_ltz({start_time_millis}, 3) + AND t.query_start_time < to_timestamp_ltz({end_time_millis}, 3)) + SELECT + upstream_locations AS "upstream_locations", + downstream_table_name AS "downstream_table_name", + downstream_table_columns AS "downstream_table_columns" + FROM external_table_lineage_history + WHERE downstream_table_domain = 'Table' + QUALIFY ROW_NUMBER() OVER (PARTITION BY downstream_table_name ORDER BY query_start_time DESC) = 1""" + + @staticmethod + def usage_per_object_per_time_bucket_for_time_window( + start_time_millis: int, + end_time_millis: int, + time_bucket_size: str, + use_base_objects: bool, + top_n_queries: int, + include_top_n_queries: bool, + ) -> str: + # TODO: Do not query for top n queries if include_top_n_queries = False + # How can we make this pretty + assert time_bucket_size == "DAY" or time_bucket_size == "HOUR" + objects_column = ( + "BASE_OBJECTS_ACCESSED" if use_base_objects else "DIRECT_OBJECTS_ACCESSED" + ) + return f""" + WITH object_access_history AS + ( + select + object.value : "objectName"::varchar AS object_name, + object.value : "objectDomain"::varchar AS object_domain, + object.value : "columns" AS object_columns, + query_start_time, + query_id, + user_name + from + ( + select + query_id, + query_start_time, + user_name, + {objects_column} + from + snowflake.account_usage.access_history + WHERE + query_start_time >= to_timestamp_ltz({start_time_millis}, 3) + AND query_start_time < to_timestamp_ltz({end_time_millis}, 3) + ) + t, + lateral flatten(input => t.{objects_column}) object + ) + , + field_access_history AS + ( + select + o.*, + col.value : "columnName"::varchar AS column_name + from + object_access_history o, + lateral flatten(input => o.object_columns) col + ) + , + basic_usage_counts AS + ( + SELECT + object_name, + ANY_VALUE(object_domain) AS object_domain, + DATE_TRUNC('{time_bucket_size}', CONVERT_TIMEZONE('UTC', query_start_time)) AS bucket_start_time, + count(distinct(query_id)) AS total_queries, + count( distinct(user_name) ) AS total_users + FROM + object_access_history + GROUP BY + bucket_start_time, + object_name + ) + , + field_usage_counts AS + ( + SELECT + object_name, + column_name, + DATE_TRUNC('{time_bucket_size}', CONVERT_TIMEZONE('UTC', query_start_time)) AS bucket_start_time, + count(distinct(query_id)) AS total_queries + FROM + field_access_history + GROUP BY + bucket_start_time, + object_name, + column_name + ) + , + user_usage_counts AS + ( + SELECT + object_name, + DATE_TRUNC('{time_bucket_size}', CONVERT_TIMEZONE('UTC', query_start_time)) AS bucket_start_time, + count(distinct(query_id)) AS total_queries, + user_name, + ANY_VALUE(users.email) AS user_email + FROM + object_access_history + LEFT JOIN + snowflake.account_usage.users users + ON user_name = users.name + GROUP BY + bucket_start_time, + object_name, + user_name + ) + , + top_queries AS + ( + SELECT + object_name, + DATE_TRUNC('{time_bucket_size}', CONVERT_TIMEZONE('UTC', query_start_time)) AS bucket_start_time, + query_history.query_text AS query_text, + count(distinct(access_history.query_id)) AS total_queries + FROM + object_access_history access_history + LEFT JOIN + snowflake.account_usage.query_history query_history + ON access_history.query_id = query_history.query_id + GROUP BY + bucket_start_time, + object_name, + query_text QUALIFY row_number() over ( partition by bucket_start_time, object_name, query_text + order by + total_queries desc ) <= {top_n_queries} + ) + select + basic_usage_counts.object_name AS "object_name", + basic_usage_counts.bucket_start_time AS "bucket_start_time", + ANY_VALUE(basic_usage_counts.object_domain) AS "object_domain", + ANY_VALUE(basic_usage_counts.total_queries) AS "total_queries", + ANY_VALUE(basic_usage_counts.total_users) AS "total_users", + ARRAY_AGG( distinct top_queries.query_text) AS "top_sql_queries", + ARRAY_AGG( distinct OBJECT_CONSTRUCT( 'column_name', field_usage_counts.column_name, 'total_queries', field_usage_counts.total_queries ) ) AS "field_counts", + ARRAY_AGG( distinct OBJECT_CONSTRUCT( 'user_name', user_usage_counts.user_name, 'user_email', user_usage_counts.user_email, 'total_queries', user_usage_counts.total_queries ) ) AS "user_counts" + from + basic_usage_counts basic_usage_counts + left join + top_queries top_queries + on basic_usage_counts.bucket_start_time = top_queries.bucket_start_time + and basic_usage_counts.object_name = top_queries.object_name + left join + field_usage_counts field_usage_counts + on basic_usage_counts.bucket_start_time = field_usage_counts.bucket_start_time + and basic_usage_counts.object_name = field_usage_counts.object_name + left join + user_usage_counts user_usage_counts + on basic_usage_counts.bucket_start_time = user_usage_counts.bucket_start_time + and basic_usage_counts.object_name = user_usage_counts.object_name + where + basic_usage_counts.object_domain in + ( + 'Table', + 'View', + 'Materialized view', + 'External table' + ) + group by + basic_usage_counts.object_name, + basic_usage_counts.bucket_start_time + order by + basic_usage_counts.bucket_start_time + """ diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_report.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_report.py new file mode 100644 index 0000000000..42cd5249ae --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_report.py @@ -0,0 +1,11 @@ +from datahub.ingestion.source_report.sql.snowflake import SnowflakeReport +from datahub.ingestion.source_report.usage.snowflake_usage import SnowflakeUsageReport + + +class SnowflakeV2Report(SnowflakeReport, SnowflakeUsageReport): + include_usage_stats: bool = False + include_operational_stats: bool = False + + usage_aggregation_query_secs: float = -1 + + rows_zero_objects_modified: int = 0 diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py new file mode 100644 index 0000000000..96dd4d0c87 --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py @@ -0,0 +1,323 @@ +import logging +from dataclasses import dataclass, field +from datetime import datetime +from typing import Dict, List, Optional + +from snowflake.connector import SnowflakeConnection + +from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery +from datahub.ingestion.source.snowflake.snowflake_utils import SnowflakeQueryMixin + +logger: logging.Logger = logging.getLogger(__name__) + + +@dataclass +class SnowflakePK: + name: str + column_names: List[str] + + +@dataclass +class SnowflakeFK: + name: str + column_names: List[str] + referred_database: str + referred_schema: str + referred_table: str + referred_column_names: List[str] + + +@dataclass +class SnowflakeColumn: + name: str + ordinal_position: int + is_nullable: bool + data_type: str + comment: str + + +@dataclass +class SnowflakeTable: + name: str + created: datetime + last_altered: datetime + size_in_bytes: int + rows_count: int + comment: str + clustering_key: str + pk: Optional[SnowflakePK] = None + columns: List[SnowflakeColumn] = field(default_factory=list) + foreign_keys: List[SnowflakeFK] = field(default_factory=list) + + +@dataclass +class SnowflakeView: + name: str + created: datetime + comment: Optional[str] + view_definition: str + last_altered: Optional[datetime] = None + columns: List[SnowflakeColumn] = field(default_factory=list) + + +@dataclass +class SnowflakeSchema: + name: str + created: datetime + last_altered: datetime + comment: str + tables: List[SnowflakeTable] = field(default_factory=list) + views: List[SnowflakeView] = field(default_factory=list) + + +@dataclass +class SnowflakeDatabase: + name: str + created: datetime + comment: str + schemas: List[SnowflakeSchema] = field(default_factory=list) + + +class SnowflakeDataDictionary(SnowflakeQueryMixin): + def __init__(self) -> None: + self.logger = logger + + def get_databases(self, conn: SnowflakeConnection) -> List[SnowflakeDatabase]: + + databases: List[SnowflakeDatabase] = [] + + cur = self.query( + conn, + SnowflakeQuery.show_databases(), + ) + + for database in cur: + snowflake_db = SnowflakeDatabase( + name=database["name"], + created=database["created_on"], + comment=database["comment"], + ) + databases.append(snowflake_db) + + return databases + + def get_schemas_for_database( + self, conn: SnowflakeConnection, db_name: str + ) -> List[SnowflakeSchema]: + + snowflake_schemas = [] + + cur = self.query( + conn, + SnowflakeQuery.schemas_for_database(db_name), + ) + for schema in cur: + snowflake_schema = SnowflakeSchema( + name=schema["schema_name"], + created=schema["created"], + last_altered=schema["last_altered"], + comment=schema["comment"], + ) + snowflake_schemas.append(snowflake_schema) + return snowflake_schemas + + def get_tables_for_database( + self, conn: SnowflakeConnection, db_name: str + ) -> Optional[Dict[str, List[SnowflakeTable]]]: + tables: Dict[str, List[SnowflakeTable]] = {} + try: + cur = self.query( + conn, + SnowflakeQuery.tables_for_database(db_name), + ) + except Exception as e: + logger.debug(e) + # Error - Information schema query returned too much data. Please repeat query with more selective predicates. + return None + + for table in cur: + if table["table_schema"] not in tables: + tables[table["table_schema"]] = [] + tables[table["table_schema"]].append( + SnowflakeTable( + name=table["table_name"], + created=table["created"], + last_altered=table["last_altered"], + size_in_bytes=table["bytes"], + rows_count=table["row_count"], + comment=table["comment"], + clustering_key=table["clustering_key"], + ) + ) + return tables + + def get_tables_for_schema( + self, conn: SnowflakeConnection, schema_name: str, db_name: str + ) -> List[SnowflakeTable]: + tables: List[SnowflakeTable] = [] + + cur = self.query( + conn, + SnowflakeQuery.tables_for_schema(schema_name, db_name), + ) + + for table in cur: + tables.append( + SnowflakeTable( + name=table["table_name"], + created=table["created"], + last_altered=table["last_altered"], + size_in_bytes=table["bytes"], + rows_count=table["row_count"], + comment=table["comment"], + clustering_key=table["clustering_key"], + ) + ) + return tables + + def get_views_for_database( + self, conn: SnowflakeConnection, db_name: str + ) -> Optional[Dict[str, List[SnowflakeView]]]: + views: Dict[str, List[SnowflakeView]] = {} + try: + cur = self.query(conn, SnowflakeQuery.show_views_for_database(db_name)) + except Exception as e: + logger.debug(e) + # Error - Information schema query returned too much data. Please repeat query with more selective predicates. + return None + + for table in cur: + if table["schema_name"] not in views: + views[table["schema_name"]] = [] + views[table["schema_name"]].append( + SnowflakeView( + name=table["name"], + created=table["created_on"], + # last_altered=table["last_altered"], + comment=table["comment"], + view_definition=table["text"], + ) + ) + return views + + def get_views_for_schema( + self, conn: SnowflakeConnection, schema_name: str, db_name: str + ) -> List[SnowflakeView]: + views: List[SnowflakeView] = [] + + cur = self.query( + conn, SnowflakeQuery.show_views_for_schema(schema_name, db_name) + ) + for table in cur: + views.append( + SnowflakeView( + name=table["name"], + created=table["created_on"], + # last_altered=table["last_altered"], + comment=table["comment"], + view_definition=table["text"], + ) + ) + return views + + def get_columns_for_schema( + self, conn: SnowflakeConnection, schema_name: str, db_name: str + ) -> Optional[Dict[str, List[SnowflakeColumn]]]: + columns: Dict[str, List[SnowflakeColumn]] = {} + try: + cur = self.query( + conn, SnowflakeQuery.columns_for_schema(schema_name, db_name) + ) + except Exception as e: + logger.debug(e) + # Error - Information schema query returned too much data. + # Please repeat query with more selective predicates. + return None + + for column in cur: + if column["table_name"] not in columns: + columns[column["table_name"]] = [] + columns[column["table_name"]].append( + SnowflakeColumn( + name=column["column_name"], + ordinal_position=column["ordinal_position"], + is_nullable=column["is_nullable"] == "YES", + data_type=column["data_type"], + comment=column["comment"], + ) + ) + return columns + + def get_columns_for_table( + self, conn: SnowflakeConnection, table_name: str, schema_name: str, db_name: str + ) -> List[SnowflakeColumn]: + columns: List[SnowflakeColumn] = [] + + cur = self.query( + conn, + SnowflakeQuery.columns_for_table(table_name, schema_name, db_name), + ) + + for column in cur: + columns.append( + SnowflakeColumn( + name=column["column_name"], + ordinal_position=column["ordinal_position"], + is_nullable=column["is_nullable"] == "YES", + data_type=column["data_type"], + comment=column["comment"], + ) + ) + return columns + + def get_pk_constraints_for_schema( + self, conn: SnowflakeConnection, schema_name: str, db_name: str + ) -> Dict[str, SnowflakePK]: + constraints: Dict[str, SnowflakePK] = {} + cur = self.query( + conn, + SnowflakeQuery.show_primary_keys_for_schema(schema_name, db_name), + ) + + for row in cur: + if row["table_name"] not in constraints: + constraints[row["table_name"]] = SnowflakePK( + name=row["constraint_name"], column_names=[] + ) + constraints[row["table_name"]].column_names.append(row["column_name"]) + return constraints + + def get_fk_constraints_for_schema( + self, conn: SnowflakeConnection, schema_name: str, db_name: str + ) -> Dict[str, List[SnowflakeFK]]: + constraints: Dict[str, List[SnowflakeFK]] = {} + fk_constraints_map: Dict[str, SnowflakeFK] = {} + + cur = self.query( + conn, + SnowflakeQuery.show_foreign_keys_for_schema(schema_name, db_name), + ) + + for row in cur: + if row["fk_name"] not in constraints: + fk_constraints_map[row["fk_name"]] = SnowflakeFK( + name=row["fk_name"], + column_names=[], + referred_database=row["pk_database_name"], + referred_schema=row["pk_schema_name"], + referred_table=row["pk_table_name"], + referred_column_names=[], + ) + + if row["fk_table_name"] not in constraints: + constraints[row["fk_table_name"]] = [] + + fk_constraints_map[row["fk_name"]].column_names.append( + row["fk_column_name"] + ) + fk_constraints_map[row["fk_name"]].referred_column_names.append( + row["pk_column_name"] + ) + constraints[row["fk_table_name"]].append(fk_constraints_map[row["fk_name"]]) + + return constraints diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py new file mode 100644 index 0000000000..93f6318551 --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py @@ -0,0 +1,410 @@ +import json +import logging +import time +from datetime import datetime, timezone +from typing import Any, Dict, Iterable, List, Optional + +import pydantic +from snowflake.connector import SnowflakeConnection + +from datahub.emitter.mce_builder import ( + make_dataset_urn_with_platform_instance, + make_user_urn, +) +from datahub.emitter.mcp import MetadataChangeProposalWrapper +from datahub.ingestion.api.workunit import MetadataWorkUnit +from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config +from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery +from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report +from datahub.ingestion.source.snowflake.snowflake_utils import ( + SnowflakeCommonMixin, + SnowflakeQueryMixin, +) +from datahub.metadata.com.linkedin.pegasus2avro.dataset import ( + DatasetFieldUsageCounts, + DatasetUsageStatistics, + DatasetUserUsageCounts, +) +from datahub.metadata.com.linkedin.pegasus2avro.timeseries import TimeWindowSize +from datahub.metadata.schema_classes import ( + ChangeTypeClass, + OperationClass, + OperationTypeClass, +) +from datahub.utilities.perf_timer import PerfTimer +from datahub.utilities.sql_formatter import format_sql_query, trim_query + +logger: logging.Logger = logging.getLogger(__name__) + +OPERATION_STATEMENT_TYPES = { + "INSERT": OperationTypeClass.INSERT, + "UPDATE": OperationTypeClass.UPDATE, + "DELETE": OperationTypeClass.DELETE, + "CREATE": OperationTypeClass.CREATE, + "CREATE_TABLE": OperationTypeClass.CREATE, + "CREATE_TABLE_AS_SELECT": OperationTypeClass.CREATE, +} + + +@pydantic.dataclasses.dataclass +class SnowflakeColumnReference: + columnId: int + columnName: str + + +class PermissiveModel(pydantic.BaseModel): + class Config: + extra = "allow" + + +class SnowflakeObjectAccessEntry(PermissiveModel): + columns: Optional[List[SnowflakeColumnReference]] + objectDomain: str + objectId: int + objectName: str + stageKind: Optional[str] + + +class SnowflakeJoinedAccessEvent(PermissiveModel): + query_start_time: datetime + query_text: str + query_type: str + rows_inserted: Optional[int] + rows_updated: Optional[int] + rows_deleted: Optional[int] + base_objects_accessed: List[SnowflakeObjectAccessEntry] + direct_objects_accessed: List[SnowflakeObjectAccessEntry] + objects_modified: List[SnowflakeObjectAccessEntry] + + user_name: str + first_name: Optional[str] + last_name: Optional[str] + display_name: Optional[str] + email: Optional[str] + role_name: str + + +class SnowflakeUsageExtractor(SnowflakeQueryMixin, SnowflakeCommonMixin): + def __init__(self, config: SnowflakeV2Config, report: SnowflakeV2Report) -> None: + self.config: SnowflakeV2Config = config + self.report: SnowflakeV2Report = report + self.logger = logger + + def get_workunits(self) -> Iterable[MetadataWorkUnit]: + conn = self.config.get_connection() + + logger.info("Checking usage date ranges") + self._check_usage_date_ranges(conn) + if ( + self.report.min_access_history_time is None + or self.report.max_access_history_time is None + ): + return + + # NOTE: In earlier `snowflake-usage` connector, users with no email were not considered in usage counts as well as in operation + # Now, we report the usage as well as operation metadata even if user email is absent + + if self.config.include_usage_stats: + yield from self.get_usage_workunits(conn) + + if self.config.include_operational_stats: + # Generate the operation workunits. + access_events = self._get_snowflake_history(conn) + for event in access_events: + yield from self._get_operation_aspect_work_unit(event) + + def get_usage_workunits( + self, conn: SnowflakeConnection + ) -> Iterable[MetadataWorkUnit]: + + with PerfTimer() as timer: + logger.info("Getting aggregated usage statistics") + results = self.query( + conn, + SnowflakeQuery.usage_per_object_per_time_bucket_for_time_window( + start_time_millis=int(self.config.start_time.timestamp() * 1000), + end_time_millis=int(self.config.end_time.timestamp() * 1000), + time_bucket_size=self.config.bucket_duration, + use_base_objects=self.config.apply_view_usage_to_tables, + top_n_queries=self.config.top_n_queries, + include_top_n_queries=self.config.include_top_n_queries, + ), + ) + self.report.usage_aggregation_query_secs = timer.elapsed_seconds() + + for row in results: + assert row["object_name"] is not None, "Null objectName not allowed" + if not self._is_dataset_pattern_allowed( + row["object_name"], + row["object_domain"], + ): + continue + + stats = DatasetUsageStatistics( + timestampMillis=int(row["bucket_start_time"].timestamp() * 1000), + eventGranularity=TimeWindowSize( + unit=self.config.bucket_duration, multiple=1 + ), + totalSqlQueries=row["total_queries"], + uniqueUserCount=row["total_users"], + topSqlQueries=self._map_top_sql_queries( + json.loads(row["top_sql_queries"]) + ) + if self.config.include_top_n_queries + else None, + userCounts=self._map_user_counts(json.loads(row["user_counts"])), + fieldCounts=[ + DatasetFieldUsageCounts( + fieldPath=self.snowflake_identifier(field_count["column_name"]), + count=field_count["total_queries"], + ) + for field_count in json.loads(row["field_counts"]) + ], + ) + dataset_urn = make_dataset_urn_with_platform_instance( + "snowflake", + self.get_dataset_identifier_from_qualified_name(row["object_name"]), + self.config.platform_instance, + self.config.env, + ) + yield self.wrap_aspect_as_workunit( + "dataset", + dataset_urn, + "datasetUsageStatistics", + stats, + ) + + def _map_top_sql_queries(self, top_sql_queries: Dict) -> List[str]: + total_budget_for_query_list: int = 24000 + budget_per_query: int = int( + total_budget_for_query_list / self.config.top_n_queries + ) + return [ + trim_query(format_sql_query(query), budget_per_query) + if self.config.format_sql_queries + else query + for query in top_sql_queries + ] + + def _map_user_counts(self, user_counts: Dict) -> List[DatasetUserUsageCounts]: + filtered_user_counts = [] + for user_count in user_counts: + user_email = user_count.get( + "user_email", + "{0}@{1}".format( + user_count["user_name"], self.config.email_domain + ).lower() + if self.config.email_domain + else None, + ) + if user_email is None or not self.config.user_email_pattern.allowed( + user_email + ): + continue + + filtered_user_counts.append( + DatasetUserUsageCounts( + user=make_user_urn( + self.get_user_identifier(user_count["user_name"], user_email) + ), + count=user_count["total_queries"], + # NOTE: Generated emails may be incorrect, as email may be different than + # username@email_domain + userEmail=user_email, + ) + ) + return filtered_user_counts + + def _get_snowflake_history( + self, conn: SnowflakeConnection + ) -> Iterable[SnowflakeJoinedAccessEvent]: + + logger.info("Getting access history") + with PerfTimer() as timer: + query = self._make_operations_query() + results = self.query(conn, query) + self.report.access_history_query_secs = round(timer.elapsed_seconds(), 2) + + for row in results: + yield from self._process_snowflake_history_row(row) + + def _make_operations_query(self) -> str: + start_time = int(self.config.start_time.timestamp() * 1000) + end_time = int(self.config.end_time.timestamp() * 1000) + return SnowflakeQuery.operational_data_for_time_window(start_time, end_time) + + def _check_usage_date_ranges(self, conn: SnowflakeConnection) -> Any: + + query = """ + select + min(query_start_time) as "min_time", + max(query_start_time) as "max_time" + from snowflake.account_usage.access_history + """ + with PerfTimer() as timer: + try: + results = self.query(conn, query) + except Exception as e: + self.warn( + "check-usage-data", + f"Extracting the date range for usage data from Snowflake failed." + f"Please check your permissions. Continuing...\nError was {e}.", + ) + else: + for db_row in results: + if ( + len(db_row) < 2 + or db_row["min_time"] is None + or db_row["max_time"] is None + ): + self.warn( + "check-usage-data", + f"Missing data for access_history {db_row} - Check if using Enterprise edition of Snowflake", + ) + continue + self.report.min_access_history_time = db_row["min_time"].astimezone( + tz=timezone.utc + ) + self.report.max_access_history_time = db_row["max_time"].astimezone( + tz=timezone.utc + ) + self.report.access_history_range_query_secs = round( + timer.elapsed_seconds(), 2 + ) + + def _get_operation_aspect_work_unit( + self, event: SnowflakeJoinedAccessEvent + ) -> Iterable[MetadataWorkUnit]: + if event.query_start_time and event.query_type in OPERATION_STATEMENT_TYPES: + start_time = event.query_start_time + query_type = event.query_type + user_email = event.email + user_name = event.user_name + operation_type = OPERATION_STATEMENT_TYPES[query_type] + reported_time: int = int(time.time() * 1000) + last_updated_timestamp: int = int(start_time.timestamp() * 1000) + user_urn = make_user_urn(self.get_user_identifier(user_name, user_email)) + + # NOTE: In earlier `snowflake-usage` connector this was base_objects_accessed, which is incorrect + for obj in event.objects_modified: + + resource = obj.objectName + dataset_urn = make_dataset_urn_with_platform_instance( + "snowflake", + self.get_dataset_identifier_from_qualified_name(resource), + self.config.platform_instance, + self.config.env, + ) + operation_aspect = OperationClass( + timestampMillis=reported_time, + lastUpdatedTimestamp=last_updated_timestamp, + actor=user_urn, + operationType=operation_type, + ) + mcp = MetadataChangeProposalWrapper( + entityType="dataset", + aspectName="operation", + changeType=ChangeTypeClass.UPSERT, + entityUrn=dataset_urn, + aspect=operation_aspect, + ) + wu = MetadataWorkUnit( + id=f"{start_time.isoformat()}-operation-aspect-{resource}", + mcp=mcp, + ) + self.report.report_workunit(wu) + yield wu + + def _process_snowflake_history_row( + self, row: Any + ) -> Iterable[SnowflakeJoinedAccessEvent]: + self.report.rows_processed += 1 + # Make some minor type conversions. + if hasattr(row, "_asdict"): + # Compat with SQLAlchemy 1.3 and 1.4 + # See https://docs.sqlalchemy.org/en/14/changelog/migration_14.html#rowproxy-is-no-longer-a-proxy-is-now-called-row-and-behaves-like-an-enhanced-named-tuple. + event_dict = row._asdict() + else: + event_dict = dict(row) + + # no use processing events that don't have a query text + if not event_dict["query_text"]: + self.report.rows_missing_query_text += 1 + return + + event_dict["base_objects_accessed"] = [ + obj + for obj in json.loads(event_dict["base_objects_accessed"]) + if self._is_object_valid(obj) + ] + if len(event_dict["base_objects_accessed"]) == 0: + self.report.rows_zero_base_objects_accessed += 1 + + event_dict["direct_objects_accessed"] = [ + obj + for obj in json.loads(event_dict["direct_objects_accessed"]) + if self._is_object_valid(obj) + ] + if len(event_dict["direct_objects_accessed"]) == 0: + self.report.rows_zero_direct_objects_accessed += 1 + + event_dict["objects_modified"] = [ + obj + for obj in json.loads(event_dict["objects_modified"]) + if self._is_object_valid(obj) + ] + if len(event_dict["objects_modified"]) == 0: + self.report.rows_zero_objects_modified += 1 + + event_dict["query_start_time"] = (event_dict["query_start_time"]).astimezone( + tz=timezone.utc + ) + + if ( + not event_dict["email"] + and self.config.email_domain + and event_dict["user_name"] + ): + # NOTE: Generated emails may be incorrect, as email may be different than + # username@email_domain + event_dict[ + "email" + ] = f'{event_dict["user_name"]}@{self.config.email_domain}'.lower() + + if not event_dict["email"]: + self.report.rows_missing_email += 1 + + try: # big hammer try block to ensure we don't fail on parsing events + event = SnowflakeJoinedAccessEvent(**event_dict) + yield event + except Exception as e: + self.report.rows_parsing_error += 1 + self.warn( + "usage", + f"Failed to parse usage line {event_dict}, {e}", + ) + + def _is_unsupported_object_accessed(self, obj: Dict[str, Any]) -> bool: + unsupported_keys = ["locations"] + + if obj.get("objectDomain") in ["Stage"]: + return True + + return any([obj.get(key) is not None for key in unsupported_keys]) + + def _is_object_valid(self, obj: Dict[str, Any]) -> bool: + if self._is_unsupported_object_accessed( + obj + ) or not self._is_dataset_pattern_allowed( + obj.get("objectName"), obj.get("objectDomain") + ): + return False + return True + + def warn(self, key: str, reason: str) -> None: + self.report.report_warning(key, reason) + self.logger.warning(f"{key} => {reason}") + + def error(self, key: str, reason: str) -> None: + self.report.report_failure(key, reason) + self.logger.error(f"{key} => {reason}") diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py new file mode 100644 index 0000000000..a49b94d6f2 --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py @@ -0,0 +1,149 @@ +import logging +from typing import Any, Optional, Protocol + +from snowflake.connector import SnowflakeConnection +from snowflake.connector.cursor import DictCursor + +from datahub.emitter.mcp import MetadataChangeProposalWrapper +from datahub.ingestion.api.workunit import MetadataWorkUnit +from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config +from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report +from datahub.metadata.com.linkedin.pegasus2avro.events.metadata import ChangeType +from datahub.metadata.schema_classes import _Aspect + + +class SnowflakeLoggingProtocol(Protocol): + @property + def logger(self) -> logging.Logger: + ... + + +class SnowflakeCommonProtocol(Protocol): + @property + def logger(self) -> logging.Logger: + ... + + @property + def config(self) -> SnowflakeV2Config: + ... + + @property + def report(self) -> SnowflakeV2Report: + ... + + def get_dataset_identifier( + self, table_name: str, schema_name: str, db_name: str + ) -> str: + ... + + def snowflake_identifier(self, identifier: str) -> str: + ... + + +class SnowflakeQueryMixin: + def query( + self: SnowflakeLoggingProtocol, conn: SnowflakeConnection, query: str + ) -> Any: + self.logger.debug("Query : {}".format(query)) + resp = conn.cursor(DictCursor).execute(query) + return resp + + +class SnowflakeCommonMixin: + def _is_dataset_pattern_allowed( + self: SnowflakeCommonProtocol, + dataset_name: Optional[str], + dataset_type: Optional[str], + ) -> bool: + if not dataset_type or not dataset_name: + return True + dataset_params = dataset_name.split(".") + if len(dataset_params) != 3: + self.report.report_warning( + "invalid-dataset-pattern", + f"Found {dataset_params} of type {dataset_type}", + ) + # NOTE: this case returned `True` earlier when extracting lineage + return False + + if not self.config.database_pattern.allowed( + dataset_params[0].strip('"') + ) or not self.config.schema_pattern.allowed(dataset_params[1].strip('"')): + return False + + if dataset_type.lower() in {"table"} and not self.config.table_pattern.allowed( + dataset_params[2].strip('"') + ): + return False + + if dataset_type.lower() in { + "view", + "materialized_view", + } and not self.config.view_pattern.allowed(dataset_params[2].strip('"')): + return False + + return True + + def snowflake_identifier(self: SnowflakeCommonProtocol, identifier: str) -> str: + # to be in in sync with older connector, convert name to lowercase + if self.config.convert_urns_to_lowercase: + return identifier.lower() + return identifier + + def get_dataset_identifier( + self: SnowflakeCommonProtocol, table_name: str, schema_name: str, db_name: str + ) -> str: + return self.snowflake_identifier(f"{db_name}.{schema_name}.{table_name}") + + # Qualified Object names from snowflake audit logs have quotes for for snowflake quoted identifiers, + # For example "test-database"."test-schema".test_table + # whereas we generate urns without quotes even for quoted identifiers for backward compatibility + # and also unavailability of utility function to identify whether current table/schema/database + # name should be quoted in above method get_dataset_identifier + def get_dataset_identifier_from_qualified_name( + self: SnowflakeCommonProtocol, qualified_name: str + ) -> str: + name_parts = qualified_name.split(".") + if len(name_parts) != 3: + self.report.report_warning( + "invalid-dataset-pattern", + f"Found non-parseable {name_parts} for {qualified_name}", + ) + return self.snowflake_identifier(qualified_name.replace('"', "")) + return self.get_dataset_identifier( + name_parts[2].strip('"'), name_parts[1].strip('"'), name_parts[0].strip('"') + ) + + # Note - decide how to construct user urns. + # Historically urns were created using part before @ from user's email. + # Users without email were skipped from both user entries as well as aggregates. + # However email is not mandatory field in snowflake user, user_name is always present. + def get_user_identifier( + self: SnowflakeCommonProtocol, user_name: str, user_email: Optional[str] + ) -> str: + if user_email is not None: + return user_email.split("@")[0] + return self.snowflake_identifier(user_name) + + def wrap_aspect_as_workunit( + self: SnowflakeCommonProtocol, + entityName: str, + entityUrn: str, + aspectName: str, + aspect: _Aspect, + ) -> MetadataWorkUnit: + id = f"{aspectName}-for-{entityUrn}" + if "timestampMillis" in aspect._inner_dict: + id = f"{aspectName}-{aspect.timestampMillis}-for-{entityUrn}" # type: ignore + wu = MetadataWorkUnit( + id=id, + mcp=MetadataChangeProposalWrapper( + entityType=entityName, + entityUrn=entityUrn, + aspectName=aspectName, + aspect=aspect, + changeType=ChangeType.UPSERT, + ), + ) + self.report.report_workunit(wu) + return wu diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py new file mode 100644 index 0000000000..28f17cd2ba --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py @@ -0,0 +1,1139 @@ +import json +import logging +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Dict, Iterable, List, Optional, Tuple, Union, cast + +import pydantic +from snowflake.connector import SnowflakeConnection + +from datahub.configuration.time_window_config import get_time_bucket +from datahub.emitter.mce_builder import ( + make_data_platform_urn, + make_dataplatform_instance_urn, + make_dataset_urn, + make_dataset_urn_with_platform_instance, + make_domain_urn, + make_schema_field_urn, +) +from datahub.emitter.mcp import MetadataChangeProposalWrapper +from datahub.emitter.mcp_builder import ( + DatabaseKey, + PlatformKey, + SchemaKey, + add_dataset_to_container, + add_domain_to_entity_wu, + gen_containers, +) +from datahub.ingestion.api.common import PipelineContext, WorkUnit +from datahub.ingestion.api.decorators import ( + SupportStatus, + capability, + config_class, + platform_name, + support_status, +) +from datahub.ingestion.api.ingestion_job_state_provider import JobId +from datahub.ingestion.api.source import ( + CapabilityReport, + Source, + SourceCapability, + SourceReport, + TestableSource, + TestConnectionReport, +) +from datahub.ingestion.api.workunit import MetadataWorkUnit +from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config +from datahub.ingestion.source.snowflake.snowflake_lineage import ( + SnowflakeLineageExtractor, +) +from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report +from datahub.ingestion.source.snowflake.snowflake_schema import ( + SnowflakeColumn, + SnowflakeDatabase, + SnowflakeDataDictionary, + SnowflakeFK, + SnowflakePK, + SnowflakeQuery, + SnowflakeSchema, + SnowflakeTable, + SnowflakeView, +) +from datahub.ingestion.source.snowflake.snowflake_usage_v2 import ( + SnowflakeUsageExtractor, +) +from datahub.ingestion.source.snowflake.snowflake_utils import ( + SnowflakeCommonMixin, + SnowflakeQueryMixin, +) +from datahub.ingestion.source.sql.sql_common import SqlContainerSubTypes +from datahub.ingestion.source.state.checkpoint import Checkpoint +from datahub.ingestion.source.state.sql_common_state import ( + BaseSQLAlchemyCheckpointState, +) +from datahub.ingestion.source.state.stateful_ingestion_base import ( + StatefulIngestionSourceBase, +) +from datahub.ingestion.source.state.usage_common_state import BaseUsageCheckpointState +from datahub.metadata.com.linkedin.pegasus2avro.common import Status, SubTypes +from datahub.metadata.com.linkedin.pegasus2avro.dataset import ( + DatasetProfile, + DatasetProperties, + UpstreamLineage, + ViewProperties, +) +from datahub.metadata.com.linkedin.pegasus2avro.schema import ( + ArrayType, + BooleanType, + BytesType, + DateType, + ForeignKeyConstraint, + MySqlDDL, + NullType, + NumberType, + RecordType, + SchemaField, + SchemaFieldDataType, + SchemaMetadata, + StringType, + TimeType, +) +from datahub.metadata.schema_classes import ( + ChangeTypeClass, + DataPlatformInstanceClass, + JobStatusClass, + StatusClass, + TimeWindowSizeClass, +) +from datahub.utilities.registries.domain_registry import DomainRegistry + +logger: logging.Logger = logging.getLogger(__name__) + +# https://docs.snowflake.com/en/sql-reference/intro-summary-data-types.html +SNOWFLAKE_FIELD_TYPE_MAPPINGS = { + "DATE": DateType, + "BIGINT": NumberType, + "BINARY": BytesType, + # 'BIT': BIT, + "BOOLEAN": BooleanType, + "CHAR": NullType, + "CHARACTER": NullType, + "DATETIME": TimeType, + "DEC": NumberType, + "DECIMAL": NumberType, + "DOUBLE": NumberType, + "FIXED": NumberType, + "FLOAT": NumberType, + "INT": NumberType, + "INTEGER": NumberType, + "NUMBER": NumberType, + # 'OBJECT': ? + "REAL": NumberType, + "BYTEINT": NumberType, + "SMALLINT": NumberType, + "STRING": StringType, + "TEXT": StringType, + "TIME": TimeType, + "TIMESTAMP": TimeType, + "TIMESTAMP_TZ": TimeType, + "TIMESTAMP_LTZ": TimeType, + "TIMESTAMP_NTZ": TimeType, + "TINYINT": NumberType, + "VARBINARY": BytesType, + "VARCHAR": StringType, + "VARIANT": RecordType, + "OBJECT": NullType, + "ARRAY": ArrayType, + "GEOGRAPHY": NullType, +} + + +@platform_name("Snowflake") +@config_class(SnowflakeV2Config) +@support_status(SupportStatus.INCUBATING) +@capability(SourceCapability.PLATFORM_INSTANCE, "Enabled by default") +@capability(SourceCapability.DOMAINS, "Supported via the `domain` config field") +@capability(SourceCapability.CONTAINERS, "Enabled by default") +@capability(SourceCapability.SCHEMA_METADATA, "Enabled by default") +@capability( + SourceCapability.DATA_PROFILING, + "Optionally enabled via configuration, only table level profiling is supported", +) +@capability(SourceCapability.DESCRIPTIONS, "Enabled by default") +@capability( + SourceCapability.LINEAGE_COARSE, + "Enabled by default, can be disabled via configuration `include_table_lineage` and `include_view_lineage`", +) +@capability( + SourceCapability.USAGE_STATS, + "Enabled by default, can be disabled via configuration `include_usage_stats", +) +@capability(SourceCapability.DELETION_DETECTION, "Coming soon", supported=False) +class SnowflakeV2Source( + SnowflakeQueryMixin, + SnowflakeCommonMixin, + StatefulIngestionSourceBase, + TestableSource, +): + def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): + super().__init__(config, ctx) + self.config: SnowflakeV2Config = config + self.report: SnowflakeV2Report = SnowflakeV2Report() + self.platform: str = "snowflake" + self.logger = logger + + if self.config.domain: + self.domain_registry = DomainRegistry( + cached_domains=[k for k in self.config.domain], graph=self.ctx.graph + ) + + # For database, schema, tables, views, etc + self.data_dictionary = SnowflakeDataDictionary() + + # For lineage + self.lineage_extractor = SnowflakeLineageExtractor(config, self.report) + + # For usage stats + self.usage_extractor = SnowflakeUsageExtractor(config, self.report) + + # Currently caching using instance variables + # TODO - rewrite cache for readability or use out of the box solution + self.db_tables: Dict[str, Optional[Dict[str, List[SnowflakeTable]]]] = {} + self.db_views: Dict[str, Optional[Dict[str, List[SnowflakeView]]]] = {} + + # For column related queries and constraints, we currently query at schema level + # In future, we may consider using queries and caching at database level first + self.schema_columns: Dict[ + Tuple[str, str], Optional[Dict[str, List[SnowflakeColumn]]] + ] = {} + self.schema_pk_constraints: Dict[Tuple[str, str], Dict[str, SnowflakePK]] = {} + self.schema_fk_constraints: Dict[ + Tuple[str, str], Dict[str, List[SnowflakeFK]] + ] = {} + + @classmethod + def create(cls, config_dict: dict, ctx: PipelineContext) -> "Source": + config = SnowflakeV2Config.parse_obj(config_dict) + return cls(ctx, config) + + @staticmethod + def test_connection(config_dict: dict) -> TestConnectionReport: + test_report = TestConnectionReport() + + try: + SnowflakeV2Config.Config.extra = ( + pydantic.Extra.allow + ) # we are okay with extra fields during this stage + connection_conf = SnowflakeV2Config.parse_obj(config_dict) + + connection: SnowflakeConnection = connection_conf.get_connection() + assert connection + + test_report.basic_connectivity = CapabilityReport(capable=True) + + test_report.capability_report = SnowflakeV2Source.check_capabilities( + connection, connection_conf + ) + + except Exception as e: + logger.error(f"Failed to test connection due to {e}", exc_info=e) + if test_report.basic_connectivity is None: + test_report.basic_connectivity = CapabilityReport( + capable=False, failure_reason=f"{e}" + ) + else: + test_report.internal_failure = True + test_report.internal_failure_reason = f"{e}" + finally: + SnowflakeV2Config.Config.extra = ( + pydantic.Extra.forbid + ) # set config flexibility back to strict + return test_report + + @staticmethod + def check_capabilities( + conn: SnowflakeConnection, connection_conf: SnowflakeV2Config + ) -> Dict[Union[SourceCapability, str], CapabilityReport]: + + # Currently only overall capabilities are reported. + # Resource level variations in capabilities are not considered. + + @dataclass + class SnowflakePrivilege: + privilege: str + object_name: str + object_type: str + + def query(query): + logger.info("Query : {}".format(query)) + resp = conn.cursor().execute(query) + return resp + + _report: Dict[Union[SourceCapability, str], CapabilityReport] = dict() + privileges: List[SnowflakePrivilege] = [] + capabilities: List[SourceCapability] = [c.capability for c in SnowflakeV2Source.get_capabilities() if c.capability not in (SourceCapability.PLATFORM_INSTANCE, SourceCapability.DOMAINS, SourceCapability.DELETION_DETECTION)] # type: ignore + + cur = query("select current_role()") + current_role = [row[0] for row in cur][0] + + cur = query("select current_secondary_roles()") + secondary_roles_str = json.loads([row[0] for row in cur][0])["roles"] + secondary_roles = ( + [] if secondary_roles_str == "" else secondary_roles_str.split(",") + ) + + roles = [current_role] + secondary_roles + + # PUBLIC role is automatically granted to every role + if "PUBLIC" not in roles: + roles.append("PUBLIC") + i = 0 + + while i < len(roles): + role = roles[i] + i = i + 1 + # for some roles, quoting is necessary. for example test-role + cur = query(f'show grants to role "{role}"') + for row in cur: + privilege = SnowflakePrivilege( + privilege=row[1], object_type=row[2], object_name=row[3] + ) + privileges.append(privilege) + + if privilege.object_type in ( + "DATABASE", + "SCHEMA", + ) and privilege.privilege in ("OWNERSHIP", "USAGE"): + _report[SourceCapability.CONTAINERS] = CapabilityReport( + capable=True + ) + elif privilege.object_type in ( + "TABLE", + "VIEW", + "MATERIALIZED_VIEW", + ): + _report[SourceCapability.SCHEMA_METADATA] = CapabilityReport( + capable=True + ) + _report[SourceCapability.DESCRIPTIONS] = CapabilityReport( + capable=True + ) + + # Table level profiling is supported without SELECT access + # if privilege.privilege in ("SELECT", "OWNERSHIP"): + _report[SourceCapability.DATA_PROFILING] = CapabilityReport( + capable=True + ) + + if privilege.object_name.startswith("SNOWFLAKE.ACCOUNT_USAGE."): + # if access to "snowflake" shared database, access to all account_usage views is automatically granted + # Finer access control is not yet supported for shares + # https://community.snowflake.com/s/article/Error-Granting-individual-privileges-on-imported-database-is-not-allowed-Use-GRANT-IMPORTED-PRIVILEGES-instead + _report[SourceCapability.LINEAGE_COARSE] = CapabilityReport( + capable=True + ) + _report[SourceCapability.USAGE_STATS] = CapabilityReport( + capable=True + ) + # If all capabilities supported, no need to continue + if set(capabilities) == set(_report.keys()): + break + + # Due to this, entire role hierarchy is considered + if ( + privilege.object_type == "ROLE" + and privilege.privilege == "USAGE" + and privilege.object_name not in roles + ): + roles.append(privilege.object_name) + + cur = query("select current_warehouse()") + current_warehouse = [row[0] for row in cur][0] + + default_failure_messages = { + SourceCapability.SCHEMA_METADATA: "Either no tables exist or current role does not have permissions to access them", + SourceCapability.DESCRIPTIONS: "Either no tables exist or current role does not have permissions to access them", + SourceCapability.DATA_PROFILING: "Either no tables exist or current role does not have permissions to access them", + SourceCapability.CONTAINERS: "Current role does not have permissions to use any database", + SourceCapability.LINEAGE_COARSE: "Current role does not have permissions to snowflake account usage views", + SourceCapability.USAGE_STATS: "Current role does not have permissions to snowflake account usage views", + } + + for c in capabilities: # type:ignore + + # These capabilities do not work without active warehouse + if current_warehouse is None and c in ( + SourceCapability.SCHEMA_METADATA, + SourceCapability.DESCRIPTIONS, + SourceCapability.DATA_PROFILING, + SourceCapability.LINEAGE_COARSE, + SourceCapability.USAGE_STATS, + ): + failure_message = ( + f"Current role does not have permissions to use warehouse {connection_conf.warehouse}" + if connection_conf.warehouse is not None + else "No default warehouse set for user. Either set default warehouse for user or configure warehouse in recipe" + ) + _report[c] = CapabilityReport( + capable=False, + failure_reason=failure_message, + ) + + if c in _report.keys(): + continue + + # If some capabilities are missing, then mark them as not capable + _report[c] = CapabilityReport( + capable=False, + failure_reason=default_failure_messages[c], + ) + + return _report + + def get_workunits(self) -> Iterable[WorkUnit]: + + # TODO: Support column level profiling + + conn: SnowflakeConnection = self.config.get_connection() + self.add_config_to_report() + self.inspect_session_metadata(conn) + + databases: List[SnowflakeDatabase] = self.data_dictionary.get_databases(conn) + for snowflake_db in databases: + if not self.config.database_pattern.allowed(snowflake_db.name): + self.report.report_dropped(snowflake_db.name) + continue + + yield from self._process_database(conn, snowflake_db) + + if self.is_stateful_ingestion_configured(): + # For database, schema, table, view + removed_entity_workunits = self.gen_removed_entity_workunits() + for wu in removed_entity_workunits: + self.report.report_workunit(wu) + yield wu + + if self.config.include_usage_stats or self.config.include_operational_stats: + self.should_skip_usage_run = self._should_skip_usage_run() + if self.should_skip_usage_run: + return + # creating checkpoint for usage ingestion + self.get_current_checkpoint(self.get_usage_ingestion_job_id()) + yield from self.usage_extractor.get_workunits() + + def _process_database( + self, conn: SnowflakeConnection, snowflake_db: SnowflakeDatabase + ) -> Iterable[MetadataWorkUnit]: + db_name = snowflake_db.name + + yield from self.gen_database_containers(snowflake_db) + + # Use database and extract metadata from its information_schema + # If this query fails, it means, user does not have usage access on database + try: + self.query(conn, SnowflakeQuery.use_database(db_name)) + snowflake_db.schemas = self.data_dictionary.get_schemas_for_database( + conn, db_name + ) + except Exception as e: + self.report.report_warning( + db_name, + f"unable to get metadata information for database {db_name} due to an error -> {e}", + ) + self.report.report_dropped(db_name) + return + + for snowflake_schema in snowflake_db.schemas: + + if not self.config.schema_pattern.allowed(snowflake_schema.name): + self.report.report_dropped(f"{snowflake_schema.name}.*") + continue + + yield from self._process_schema(conn, snowflake_schema, db_name) + + def _process_schema( + self, conn: SnowflakeConnection, snowflake_schema: SnowflakeSchema, db_name: str + ) -> Iterable[MetadataWorkUnit]: + schema_name = snowflake_schema.name + yield from self.gen_schema_containers(snowflake_schema, db_name) + + if self.config.include_tables: + snowflake_schema.tables = self.get_tables_for_schema( + conn, schema_name, db_name + ) + + for table in snowflake_schema.tables: + yield from self._process_table(conn, table, schema_name, db_name) + + if self.config.include_views: + snowflake_schema.views = self.get_views_for_schema( + conn, schema_name, db_name + ) + + for view in snowflake_schema.views: + yield from self._process_view(conn, view, schema_name, db_name) + + def _process_table( + self, + conn: SnowflakeConnection, + table: SnowflakeTable, + schema_name: str, + db_name: str, + ) -> Iterable[MetadataWorkUnit]: + table_identifier = self.get_dataset_identifier(table.name, schema_name, db_name) + + self.report.report_entity_scanned(table_identifier) + + if not self.config.table_pattern.allowed(table_identifier): + self.report.report_dropped(table_identifier) + return + + table.columns = self.get_columns_for_table( + conn, table.name, schema_name, db_name + ) + table.pk = self.get_pk_constraints_for_table( + conn, table.name, schema_name, db_name + ) + table.foreign_keys = self.get_fk_constraints_for_table( + conn, table.name, schema_name, db_name + ) + dataset_name = self.get_dataset_identifier(table.name, schema_name, db_name) + + lineage_info = self.lineage_extractor._get_upstream_lineage_info(dataset_name) + + yield from self.gen_dataset_workunits(table, schema_name, db_name, lineage_info) + + def _process_view( + self, + conn: SnowflakeConnection, + view: SnowflakeView, + schema_name: str, + db_name: str, + ) -> Iterable[MetadataWorkUnit]: + view_name = self.get_dataset_identifier(view.name, schema_name, db_name) + + self.report.report_entity_scanned(view_name, "view") + + if not self.config.view_pattern.allowed(view_name): + self.report.report_dropped(view_name) + return + + view.columns = self.get_columns_for_table(conn, view.name, schema_name, db_name) + lineage_info = self.lineage_extractor._get_upstream_lineage_info(view_name) + yield from self.gen_dataset_workunits(view, schema_name, db_name, lineage_info) + + def gen_dataset_workunits( + self, + table: Union[SnowflakeTable, SnowflakeView], + schema_name: str, + db_name: str, + lineage_info: Optional[Tuple[UpstreamLineage, Dict[str, str]]], + ) -> Iterable[MetadataWorkUnit]: + dataset_name = self.get_dataset_identifier(table.name, schema_name, db_name) + dataset_urn = make_dataset_urn_with_platform_instance( + self.platform, + dataset_name, + self.config.platform_instance, + self.config.env, + ) + + if self.is_stateful_ingestion_configured(): + cur_checkpoint = self.get_current_checkpoint( + self.get_default_ingestion_job_id() + ) + if cur_checkpoint is not None: + checkpoint_state = cast( + BaseSQLAlchemyCheckpointState, cur_checkpoint.state + ) + if isinstance(table, SnowflakeTable): + checkpoint_state.add_table_urn(dataset_urn) + else: + checkpoint_state.add_view_urn(dataset_urn) + if lineage_info is not None: + upstream_lineage, upstream_column_props = lineage_info + else: + upstream_column_props = {} + upstream_lineage = None + + status = Status(removed=False) + yield self.wrap_aspect_as_workunit("dataset", dataset_urn, "status", status) + + schema_metadata = self.get_schema_metadata(table, dataset_name, dataset_urn) + yield self.wrap_aspect_as_workunit( + "dataset", dataset_urn, "schemaMetadata", schema_metadata + ) + + dataset_properties = DatasetProperties( + name=table.name, + description=table.comment, + qualifiedName=dataset_name, + customProperties={**upstream_column_props}, + ) + yield self.wrap_aspect_as_workunit( + "dataset", dataset_urn, "datasetProperties", dataset_properties + ) + + yield from self.add_table_to_schema_container( + dataset_urn, + self.snowflake_identifier(db_name), + self.snowflake_identifier(schema_name), + ) + dpi_aspect = self.get_dataplatform_instance_aspect(dataset_urn=dataset_urn) + if dpi_aspect: + yield dpi_aspect + + subTypes = SubTypes( + typeNames=["view"] if isinstance(table, SnowflakeView) else ["table"] + ) + yield self.wrap_aspect_as_workunit("dataset", dataset_urn, "subTypes", subTypes) + + yield from self._get_domain_wu( + dataset_name=dataset_name, + entity_urn=dataset_urn, + entity_type="dataset", + ) + + if upstream_lineage is not None: + # Emit the lineage work unit + yield self.wrap_aspect_as_workunit( + "dataset", dataset_urn, "upstreamLineage", upstream_lineage + ) + + if isinstance(table, SnowflakeTable) and self.config.profiling.enabled: + if self.config.profiling.allow_deny_patterns.allowed(dataset_name): + # Emit the profile work unit + dataset_profile = DatasetProfile( + timestampMillis=round(datetime.now().timestamp() * 1000), + columnCount=len(table.columns), + rowCount=table.rows_count, + ) + self.report.report_entity_profiled(dataset_name) + yield self.wrap_aspect_as_workunit( + "dataset", + dataset_urn, + "datasetProfile", + dataset_profile, + ) + + else: + self.report.report_dropped(f"Profile for {dataset_name}") + + if isinstance(table, SnowflakeView): + view = cast(SnowflakeView, table) + view_properties_aspect = ViewProperties( + materialized=False, + viewLanguage="SQL", + viewLogic=view.view_definition, + ) + yield self.wrap_aspect_as_workunit( + "dataset", + dataset_urn, + "viewProperties", + view_properties_aspect, + ) + + def get_schema_metadata( + self, + table: Union[SnowflakeTable, SnowflakeView], + dataset_name: str, + dataset_urn: str, + ) -> SchemaMetadata: + foreign_keys: Optional[List[ForeignKeyConstraint]] = None + if isinstance(table, SnowflakeTable) and len(table.foreign_keys) > 0: + foreign_keys = [] + for fk in table.foreign_keys: + foreign_dataset = make_dataset_urn( + self.platform, + self.get_dataset_identifier( + fk.referred_table, fk.referred_schema, fk.referred_database + ), + self.config.env, + ) + foreign_keys.append( + ForeignKeyConstraint( + name=fk.name, + foreignDataset=foreign_dataset, + foreignFields=[ + make_schema_field_urn( + foreign_dataset, + self.snowflake_identifier(col), + ) + for col in fk.referred_column_names + ], + sourceFields=[ + make_schema_field_urn( + dataset_urn, + self.snowflake_identifier(col), + ) + for col in fk.column_names + ], + ) + ) + + schema_metadata = SchemaMetadata( + schemaName=dataset_name, + platform=make_data_platform_urn(self.platform), + version=0, + hash="", + platformSchema=MySqlDDL(tableSchema=""), + fields=[ + SchemaField( + fieldPath=self.snowflake_identifier(col.name), + type=SchemaFieldDataType( + SNOWFLAKE_FIELD_TYPE_MAPPINGS.get(col.data_type, NullType)() + ), + # NOTE: nativeDataType will not be in sync with older connector + nativeDataType=col.data_type, + description=col.comment, + nullable=col.is_nullable, + isPartOfKey=col.name in table.pk.column_names + if isinstance(table, SnowflakeTable) and table.pk is not None + else None, + ) + for col in table.columns + ], + foreignKeys=foreign_keys, + ) + return schema_metadata + + def get_report(self) -> SourceReport: + return self.report + + def get_dataplatform_instance_aspect( + self, dataset_urn: str + ) -> Optional[MetadataWorkUnit]: + # If we are a platform instance based source, emit the instance aspect + if self.config.platform_instance: + mcp = MetadataChangeProposalWrapper( + entityType="dataset", + changeType=ChangeTypeClass.UPSERT, + entityUrn=dataset_urn, + aspectName="dataPlatformInstance", + aspect=DataPlatformInstanceClass( + platform=make_data_platform_urn(self.platform), + instance=make_dataplatform_instance_urn( + self.platform, self.config.platform_instance + ), + ), + ) + wu = MetadataWorkUnit(id=f"{dataset_urn}-dataPlatformInstance", mcp=mcp) + self.report.report_workunit(wu) + return wu + else: + return None + + def _get_domain_wu( + self, + dataset_name: str, + entity_urn: str, + entity_type: str, + ) -> Iterable[MetadataWorkUnit]: + + domain_urn = self._gen_domain_urn(dataset_name) + if domain_urn: + wus = add_domain_to_entity_wu( + entity_type=entity_type, + entity_urn=entity_urn, + domain_urn=domain_urn, + ) + for wu in wus: + self.report.report_workunit(wu) + yield wu + + def add_table_to_schema_container( + self, dataset_urn: str, db_name: str, schema: str + ) -> Iterable[MetadataWorkUnit]: + schema_container_key = self.gen_schema_key(db_name, schema) + container_workunits = add_dataset_to_container( + container_key=schema_container_key, + dataset_urn=dataset_urn, + ) + for wu in container_workunits: + self.report.report_workunit(wu) + yield wu + + def gen_schema_key(self, db_name: str, schema: str) -> PlatformKey: + return SchemaKey( + database=db_name, + schema=schema, + platform=self.platform, + instance=self.config.platform_instance + if self.config.platform_instance is not None + else self.config.env, + ) + + def gen_database_key(self, database: str) -> PlatformKey: + return DatabaseKey( + database=database, + platform=self.platform, + instance=self.config.platform_instance + if self.config.platform_instance is not None + else self.config.env, + ) + + def _gen_domain_urn(self, dataset_name: str) -> Optional[str]: + domain_urn: Optional[str] = None + + for domain, pattern in self.config.domain.items(): + if pattern.allowed(dataset_name): + domain_urn = make_domain_urn( + self.domain_registry.get_domain_urn(domain) + ) + + return domain_urn + + def gen_database_containers( + self, database: SnowflakeDatabase + ) -> Iterable[MetadataWorkUnit]: + + domain_urn = self._gen_domain_urn(database.name) + + database_container_key = self.gen_database_key( + self.snowflake_identifier(database.name) + ) + container_workunits = gen_containers( + container_key=database_container_key, + name=database.name, + description=database.comment, + sub_types=[SqlContainerSubTypes.DATABASE], + domain_urn=domain_urn, + ) + + for wu in container_workunits: + self.report.report_workunit(wu) + yield wu + + def gen_schema_containers( + self, schema: SnowflakeSchema, db_name: str + ) -> Iterable[MetadataWorkUnit]: + schema_container_key = self.gen_schema_key( + self.snowflake_identifier(db_name), + self.snowflake_identifier(schema.name), + ) + + database_container_key: Optional[PlatformKey] = None + if db_name is not None: + database_container_key = self.gen_database_key( + database=self.snowflake_identifier(db_name) + ) + + container_workunits = gen_containers( + container_key=schema_container_key, + name=schema.name, + description=schema.comment, + sub_types=[SqlContainerSubTypes.SCHEMA], + parent_container_key=database_container_key, + ) + + for wu in container_workunits: + self.report.report_workunit(wu) + yield wu + + def get_tables_for_schema( + self, conn: SnowflakeConnection, schema_name: str, db_name: str + ) -> List[SnowflakeTable]: + + if db_name not in self.db_tables.keys(): + tables = self.data_dictionary.get_tables_for_database(conn, db_name) + self.db_tables[db_name] = tables + else: + tables = self.db_tables[db_name] + + # get all tables for database failed, + # falling back to get tables for schema + if tables is None: + return self.data_dictionary.get_tables_for_schema( + conn, schema_name, db_name + ) + + # Some schema may not have any table + return tables.get(schema_name, []) + + def get_views_for_schema( + self, conn: SnowflakeConnection, schema_name: str, db_name: str + ) -> List[SnowflakeView]: + + if db_name not in self.db_views.keys(): + views = self.data_dictionary.get_views_for_database(conn, db_name) + self.db_views[db_name] = views + else: + views = self.db_views[db_name] + + # get all views for database failed, + # falling back to get views for schema + if views is None: + return self.data_dictionary.get_views_for_schema(conn, schema_name, db_name) + + # Some schema may not have any table + return views.get(schema_name, []) + + def get_columns_for_table( + self, conn: SnowflakeConnection, table_name: str, schema_name: str, db_name: str + ) -> List[SnowflakeColumn]: + + if (db_name, schema_name) not in self.schema_columns.keys(): + columns = self.data_dictionary.get_columns_for_schema( + conn, schema_name, db_name + ) + self.schema_columns[(db_name, schema_name)] = columns + else: + columns = self.schema_columns[(db_name, schema_name)] + + # get all columns for schema failed, + # falling back to get columns for table + if columns is None: + return self.data_dictionary.get_columns_for_table( + conn, table_name, schema_name, db_name + ) + + # Access to table but none of its columns - is this possible ? + return columns.get(table_name, []) + + def get_pk_constraints_for_table( + self, conn: SnowflakeConnection, table_name: str, schema_name: str, db_name: str + ) -> Optional[SnowflakePK]: + + if (db_name, schema_name) not in self.schema_pk_constraints.keys(): + constraints = self.data_dictionary.get_pk_constraints_for_schema( + conn, schema_name, db_name + ) + self.schema_pk_constraints[(db_name, schema_name)] = constraints + else: + constraints = self.schema_pk_constraints[(db_name, schema_name)] + + # Access to table but none of its constraints - is this possible ? + return constraints.get(table_name) + + def get_fk_constraints_for_table( + self, conn: SnowflakeConnection, table_name: str, schema_name: str, db_name: str + ) -> List[SnowflakeFK]: + + if (db_name, schema_name) not in self.schema_fk_constraints.keys(): + constraints = self.data_dictionary.get_fk_constraints_for_schema( + conn, schema_name, db_name + ) + self.schema_fk_constraints[(db_name, schema_name)] = constraints + else: + constraints = self.schema_fk_constraints[(db_name, schema_name)] + + # Access to table but none of its constraints - is this possible ? + return constraints.get(table_name, []) + + def add_config_to_report(self): + self.report.cleaned_account_id = self.config.get_account() + self.report.ignore_start_time_lineage = self.config.ignore_start_time_lineage + self.report.upstream_lineage_in_report = self.config.upstream_lineage_in_report + if not self.report.ignore_start_time_lineage: + self.report.lineage_start_time = self.config.start_time + self.report.lineage_end_time = self.config.end_time + self.report.check_role_grants = self.config.check_role_grants + self.report.include_usage_stats = self.config.include_usage_stats + self.report.include_operational_stats = self.config.include_operational_stats + if self.report.include_usage_stats or self.config.include_operational_stats: + self.report.start_time = self.config.start_time + self.report.end_time = self.config.end_time + + def inspect_session_metadata(self, conn: SnowflakeConnection) -> None: + try: + logger.info("Checking current version") + for db_row in self.query(conn, SnowflakeQuery.current_version()): + self.report.saas_version = db_row["CURRENT_VERSION()"] + except Exception as e: + self.report.report_failure("version", f"Error: {e}") + try: + logger.info("Checking current role") + for db_row in self.query(conn, SnowflakeQuery.current_role()): + self.report.role = db_row["CURRENT_ROLE()"] + except Exception as e: + self.report.report_failure("version", f"Error: {e}") + try: + logger.info("Checking current warehouse") + for db_row in self.query(conn, SnowflakeQuery.current_warehouse()): + self.report.default_warehouse = db_row["CURRENT_WAREHOUSE()"] + except Exception as e: + self.report.report_failure("current_warehouse", f"Error: {e}") + + def get_default_ingestion_job_id(self) -> JobId: + + # For backward compatibility, keeping job id same as sql common + return JobId("common_ingest_from_sql_source") + + def get_usage_ingestion_job_id(self) -> JobId: + """ + Default ingestion job name for snowflake_usage. + """ + return JobId("snowflake_usage_ingestion") + + # Stateful Ingestion Overrides. + def get_platform_instance_id(self) -> str: + return self.config.get_account() + + # Stateful Ingestion Overrides. + def create_checkpoint(self, job_id: JobId) -> Optional[Checkpoint]: + assert self.ctx.pipeline_name is not None + if job_id == self.get_default_ingestion_job_id(): + return Checkpoint( + job_name=job_id, + pipeline_name=self.ctx.pipeline_name, + platform_instance_id=self.get_platform_instance_id(), + run_id=self.ctx.run_id, + config=self.config, + state=BaseSQLAlchemyCheckpointState(), + ) + elif job_id == self.get_usage_ingestion_job_id(): + return Checkpoint( + job_name=job_id, + pipeline_name=self.ctx.pipeline_name, + platform_instance_id=self.get_platform_instance_id(), + run_id=self.ctx.run_id, + config=self.config, + state=BaseUsageCheckpointState( + begin_timestamp_millis=int( + self.config.start_time.timestamp() * 1000 + ), + end_timestamp_millis=int(self.config.end_time.timestamp() * 1000), + ), + ) + return None + + # Stateful Ingestion Overrides. + def is_checkpointing_enabled(self, job_id: JobId) -> bool: + if job_id == self.get_default_ingestion_job_id(): + if ( + job_id == self.get_default_ingestion_job_id() + and self.is_stateful_ingestion_configured() + and self.config.stateful_ingestion + and self.config.stateful_ingestion.remove_stale_metadata + ): + return True + elif job_id == self.get_usage_ingestion_job_id(): + assert self.config.stateful_ingestion + return self.config.stateful_ingestion.enabled + return False + + def update_job_run_summary(self): + self.update_default_job_run_summary() + if self.config.include_usage_stats or self.config.include_operational_stats: + self.update_usage_job_run_summary() + + def update_default_job_run_summary(self) -> None: + summary = self.get_job_run_summary(self.get_default_ingestion_job_id()) + if summary is not None: + summary.config = self.config.json() + summary.custom_summary = self.report.as_string() + summary.runStatus = ( + JobStatusClass.FAILED + if self.get_report().failures + else JobStatusClass.COMPLETED + ) + summary.numWarnings = len(self.report.warnings) + summary.numErrors = len(self.report.failures) + summary.numEntities = self.report.workunits_produced + + def update_usage_job_run_summary(self): + summary = self.get_job_run_summary(self.get_usage_ingestion_job_id()) + if summary is not None: + summary.runStatus = ( + JobStatusClass.SKIPPED + if self.should_skip_usage_run + else JobStatusClass.COMPLETED + ) + summary.eventGranularity = TimeWindowSizeClass( + unit=self.config.bucket_duration, multiple=1 + ) + + def close(self): + self.update_job_run_summary() + self.prepare_for_commit() + + def gen_removed_entity_workunits(self) -> Iterable[MetadataWorkUnit]: + last_checkpoint = self.get_last_checkpoint( + self.get_default_ingestion_job_id(), BaseSQLAlchemyCheckpointState + ) + cur_checkpoint = self.get_current_checkpoint( + self.get_default_ingestion_job_id() + ) + if ( + self.config.stateful_ingestion + and self.config.stateful_ingestion.remove_stale_metadata + and last_checkpoint is not None + and last_checkpoint.state is not None + and cur_checkpoint is not None + and cur_checkpoint.state is not None + ): + logger.debug("Checking for stale entity removal.") + + def soft_delete_item(urn: str, type: str) -> Iterable[MetadataWorkUnit]: + entity_type: str = "dataset" + + if type == "container": + entity_type = "container" + + logger.info(f"Soft-deleting stale entity of type {type} - {urn}.") + mcp = MetadataChangeProposalWrapper( + entityType=entity_type, + entityUrn=urn, + changeType=ChangeTypeClass.UPSERT, + aspectName="status", + aspect=StatusClass(removed=True), + ) + wu = MetadataWorkUnit(id=f"soft-delete-{type}-{urn}", mcp=mcp) + self.report.report_workunit(wu) + self.report.report_stale_entity_soft_deleted(urn) + yield wu + + last_checkpoint_state = cast( + BaseSQLAlchemyCheckpointState, last_checkpoint.state + ) + cur_checkpoint_state = cast( + BaseSQLAlchemyCheckpointState, cur_checkpoint.state + ) + + for table_urn in last_checkpoint_state.get_table_urns_not_in( + cur_checkpoint_state + ): + yield from soft_delete_item(table_urn, "table") + + for view_urn in last_checkpoint_state.get_view_urns_not_in( + cur_checkpoint_state + ): + yield from soft_delete_item(view_urn, "view") + + for container_urn in last_checkpoint_state.get_container_urns_not_in( + cur_checkpoint_state + ): + yield from soft_delete_item(container_urn, "container") + + def _should_skip_usage_run(self) -> bool: + # Check if forced rerun. + if ( + self.config.stateful_ingestion + and self.config.stateful_ingestion.ignore_old_state + ): + return False + # Determine from the last check point state + last_successful_pipeline_run_end_time_millis: Optional[int] = None + last_checkpoint = self.get_last_checkpoint( + self.get_usage_ingestion_job_id(), BaseUsageCheckpointState + ) + if last_checkpoint and last_checkpoint.state: + state = cast(BaseUsageCheckpointState, last_checkpoint.state) + last_successful_pipeline_run_end_time_millis = state.end_timestamp_millis + + if last_successful_pipeline_run_end_time_millis is not None: + last_run_bucket_start = get_time_bucket( + datetime.fromtimestamp( + last_successful_pipeline_run_end_time_millis / 1000, tz=timezone.utc + ), + self.config.bucket_duration, + ) + if self.config.start_time < last_run_bucket_start: + warn_msg = ( + f"Skippig usage run, since the last run's bucket duration start: " + f"{last_run_bucket_start}" + f" is later than the current start_time: {self.config.start_time}" + ) + logger.warning(warn_msg) + self.report.report_warning("skip-run", warn_msg) + return True + return False diff --git a/metadata-ingestion/src/datahub/utilities/sql_formatter.py b/metadata-ingestion/src/datahub/utilities/sql_formatter.py index a4542d1a10..96cb71749b 100644 --- a/metadata-ingestion/src/datahub/utilities/sql_formatter.py +++ b/metadata-ingestion/src/datahub/utilities/sql_formatter.py @@ -12,3 +12,18 @@ def format_sql_query(query: str, **options: Any) -> str: except Exception as e: logger.debug(f"Exception:{e} while formatting query '{query}'.") return query + + +def trim_query( + query: str, budget_per_query: int, query_trimmer_string: str = " ..." +) -> str: + trimmed_query = query + if len(query) > budget_per_query: + if budget_per_query - len(query_trimmer_string) > 0: + end_index = budget_per_query - len(query_trimmer_string) + trimmed_query = query[:end_index] + query_trimmer_string + else: + raise Exception( + "Budget per query is too low. Please, decrease the number of top_n_queries." + ) + return trimmed_query diff --git a/metadata-ingestion/tests/integration/bigquery-usage/test_bigquery_usage.py b/metadata-ingestion/tests/integration/bigquery-usage/test_bigquery_usage.py index 4282545a2f..60c57f0195 100644 --- a/metadata-ingestion/tests/integration/bigquery-usage/test_bigquery_usage.py +++ b/metadata-ingestion/tests/integration/bigquery-usage/test_bigquery_usage.py @@ -34,7 +34,7 @@ def test_bq_usage_config(): ) assert config.get_allow_pattern_string() == "test-regex|test-regex-1" assert config.get_deny_pattern_string() == "" - assert (config.end_time - config.start_time) == timedelta(hours=2) + assert (config.end_time - config.start_time) == timedelta(hours=1) assert config.projects == ["sample-bigquery-project-name-1234"] diff --git a/metadata-ingestion/tests/unit/test_bigquery_usage_source.py b/metadata-ingestion/tests/unit/test_bigquery_usage_source.py index 94e0cc2475..dc28751d09 100644 --- a/metadata-ingestion/tests/unit/test_bigquery_usage_source.py +++ b/metadata-ingestion/tests/unit/test_bigquery_usage_source.py @@ -106,7 +106,7 @@ protoPayload.serviceData.jobCompletedEvent.job.jobStatistics.referencedTables.ta AND timestamp >= "2021-07-18T23:45:00Z" AND -timestamp < "2021-07-21T00:15:00Z\"""" # noqa: W293 +timestamp < "2021-07-20T00:15:00Z\"""" # noqa: W293 source = BigQueryUsageSource.create(config, PipelineContext(run_id="bq-usage-test")) @@ -164,7 +164,7 @@ protoPayload.serviceData.jobCompletedEvent.job.jobStatistics.referencedTables.ta AND timestamp >= "2021-07-18T23:45:00Z" AND -timestamp < "2021-07-21T00:15:00Z\"""" # noqa: W293 +timestamp < "2021-07-20T00:15:00Z\"""" # noqa: W293 source = BigQueryUsageSource.create(config, PipelineContext(run_id="bq-usage-test")) filter: str = source._generate_filter(BQ_AUDIT_V1) assert filter == expected_filter