feat(ingest): snowflake - add snowflake-beta connector (#5517)

This commit is contained in:
Mayuri Nehate 2022-08-16 09:24:02 +05:30 committed by GitHub
parent 337087cac0
commit dc08bedd6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 3073 additions and 13 deletions

View File

@ -1 +1,4 @@
To get all metadata from Snowflake you need to use two plugins `snowflake` and `snowflake-usage`. Both of them are described in this page. These will require 2 separate recipes. We understand this is not ideal and we plan to make this easier in the future.
To get all metadata from Snowflake you need to use two plugins `snowflake` and `snowflake-usage`. Both of them are described in this page. These will require 2 separate recipes.
We encourage you to try out new `snowflake-beta` plugin as alternative to running both `snowflake` and `snowflake-usage` plugins and share feedback. `snowflake-beta` is much faster than `snowflake` for extracting metadata . Please note that, `snowflake-beta` plugin currently does not support column level profiling, unlike `snowflake` plugin.

View File

@ -0,0 +1,56 @@
### Prerequisites
In order to execute this source, your Snowflake user will need to have specific privileges granted to it for reading metadata
from your warehouse.
Snowflake system admin can follow this guide to create a DataHub-specific role, assign it the required privileges, and assign it to a new DataHub user by executing the following Snowflake commands from a user with the `ACCOUNTADMIN` role or `MANAGE GRANTS` privilege.
```sql
create or replace role datahub_role;
// Grant access to a warehouse to run queries to view metadata
grant operate, usage on warehouse "<your-warehouse>" to role datahub_role;
// Grant access to view database and schema in which your tables/views exist
grant usage on DATABASE "<your-database>" to role datahub_role;
grant usage on all schemas in database "<your-database>" to role datahub_role;
grant usage on future schemas in database "<your-database>" to role datahub_role;
// If you are NOT using Snowflake Profiling feature: Grant references privileges to your tables and views
grant references on all tables in database "<your-database>" to role datahub_role;
grant references on future tables in database "<your-database>" to role datahub_role;
grant references on all external tables in database "<your-database>" to role datahub_role;
grant references on future external tables in database "<your-database>" to role datahub_role;
grant references on all views in database "<your-database>" to role datahub_role;
grant references on future views in database "<your-database>" to role datahub_role;
// If you ARE using Snowflake Profiling feature: Grant select privileges to your tables and views
grant select on all tables in database "<your-database>" to role datahub_role;
grant select on future tables in database "<your-database>" to role datahub_role;
grant select on all external tables in database "<your-database>" to role datahub_role;
grant select on future external tables in database "<your-database>" to role datahub_role;
grant select on all views in database "<your-database>" to role datahub_role;
grant select on future views in database "<your-database>" to role datahub_role;
// Create a new DataHub user and assign the DataHub role to it
create user datahub_user display_name = 'DataHub' password='' default_role = datahub_role default_warehouse = '<your-warehouse>';
// Grant the datahub_role to the new DataHub user.
grant role datahub_role to user datahub_user;
```
The details of each granted privilege can be viewed in [snowflake docs](https://docs.snowflake.com/en/user-guide/security-access-control-privileges.html). A summarization of each privilege, and why it is required for this connector:
- `operate` is required on warehouse to execute queries
- `usage` is required for us to run queries using the warehouse
- `usage` on `database` and `schema` are required because without it tables and views inside them are not accessible. If an admin does the required grants on `table` but misses the grants on `schema` or the `database` in which the table/view exists then we will not be able to get metadata for the table/view.
- If metadata is required only on some schemas then you can grant the usage privilieges only on a particular schema like
```sql
grant usage on schema "<your-database>"."<your-schema>" to role datahub_role;
```
This represents the bare minimum privileges required to extract databases, schemas, views, tables from Snowflake.
If you plan to enable extraction of table lineage, via the `include_table_lineage` config flag or extraction of usage statistics, via the `include_usage_stats` config, you'll also need to grant access to the [Account Usage](https://docs.snowflake.com/en/sql-reference/account-usage.html) system tables, using which the DataHub source extracts information. This can be done by granting access to the `snowflake` database.
```sql
grant imported privileges on database snowflake to role datahub_role;
```

View File

@ -0,0 +1,43 @@
source:
type: snowflake-beta
config:
# This option is recommended to be used for the first time to ingest all lineage
ignore_start_time_lineage: true
# This is an alternative option to specify the start_time for lineage
# if you don't want to look back since beginning
start_time: '2022-03-01T00:00:00Z'
# Coordinates
account_id: "abc48144"
warehouse: "COMPUTE_WH"
# Credentials
username: "${SNOWFLAKE_USER}"
password: "${SNOWFLAKE_PASS}"
role: "datahub_role"
# Change these as per your database names. Remove to get all databases
database_pattern:
allow:
- "^ACCOUNTING_DB$"
- "^MARKETING_DB$"
table_pattern:
allow:
# If you want to ingest only few tables with name revenue and sales
- ".*revenue"
- ".*sales"
profiling:
# Change to false to disable profiling
enabled: true
profile_table_level_only: true
profile_pattern:
allow:
- 'ACCOUNTING_DB.*.*'
- 'MARKETING_DB.*.*'
sink:
# sink configs

View File

@ -290,6 +290,7 @@ plugins: Dict[str, Set[str]] = {
| {
"more-itertools>=8.12.0",
},
"snowflake-beta": snowflake_common | usage_common,
"sqlalchemy": sql_common,
"superset": {
"requests",
@ -499,6 +500,7 @@ entry_points = {
"redshift-usage = datahub.ingestion.source.usage.redshift_usage:RedshiftUsageSource",
"snowflake = datahub.ingestion.source.sql.snowflake:SnowflakeSource",
"snowflake-usage = datahub.ingestion.source.usage.snowflake_usage:SnowflakeUsageSource",
"snowflake-beta = datahub.ingestion.source.snowflake.snowflake_v2:SnowflakeV2Source",
"superset = datahub.ingestion.source.superset:SupersetSource",
"tableau = datahub.ingestion.source.tableau:TableauSource",
"openapi = datahub.ingestion.source.openapi:OpenApiSource",

View File

@ -40,26 +40,22 @@ class BaseTimeWindowConfig(ConfigModel):
# `start_time` and `end_time` will be populated by the pre-validators.
# However, we must specify a "default" value here or pydantic will complain
# if those fields are not set by the user.
end_time: datetime = Field(default=None, description="Latest date of usage to consider. Default: Last full day in UTC (or hour, depending on `bucket_duration`)") # type: ignore
end_time: datetime = Field(default=None, description="Latest date of usage to consider. Default: Current time in UTC") # type: ignore
start_time: datetime = Field(default=None, description="Earliest date of usage to consider. Default: Last full day in UTC (or hour, depending on `bucket_duration`)") # type: ignore
@pydantic.validator("end_time", pre=True, always=True)
def default_end_time(
cls, v: Any, *, values: Dict[str, Any], **kwargs: Any
) -> datetime:
return v or get_time_bucket(
datetime.now(tz=timezone.utc)
+ get_bucket_duration_delta(values["bucket_duration"]),
values["bucket_duration"],
)
return v or datetime.now(tz=timezone.utc)
@pydantic.validator("start_time", pre=True, always=True)
def default_start_time(
cls, v: Any, *, values: Dict[str, Any], **kwargs: Any
) -> datetime:
return v or (
values["end_time"]
- get_bucket_duration_delta(values["bucket_duration"]) * 2
return v or get_time_bucket(
values["end_time"] - get_bucket_duration_delta(values["bucket_duration"]),
values["bucket_duration"],
)
@pydantic.validator("start_time", "end_time")

View File

@ -0,0 +1,79 @@
import logging
from typing import Dict, Optional, cast
from pydantic import Field, root_validator
from datahub.configuration.common import AllowDenyPattern
from datahub.ingestion.source.sql.sql_common import SQLAlchemyStatefulIngestionConfig
from datahub.ingestion.source_config.sql.snowflake import (
SnowflakeConfig,
SnowflakeProvisionRoleConfig,
)
from datahub.ingestion.source_config.usage.snowflake_usage import (
SnowflakeStatefulIngestionConfig,
SnowflakeUsageConfig,
)
class SnowflakeV2StatefulIngestionConfig(
SQLAlchemyStatefulIngestionConfig, SnowflakeStatefulIngestionConfig
):
pass
logger = logging.Logger(__name__)
class SnowflakeV2Config(SnowflakeConfig, SnowflakeUsageConfig):
convert_urns_to_lowercase: bool = Field(
default=True,
)
include_usage_stats: bool = Field(
default=True,
description="If enabled, populates the snowflake usage statistics. Requires appropriate grants given to the role.",
)
check_role_grants: bool = Field(
default=False,
description="Not supported",
)
provision_role: Optional[SnowflakeProvisionRoleConfig] = Field(
default=None, description="Not supported"
)
@root_validator(pre=False)
def validate_unsupported_configs(cls, values: Dict) -> Dict:
value = values.get("provision_role")
if value is not None and value.enabled:
raise ValueError(
"Provision role is currently not supported. Set `provision_role.enabled` to False."
)
value = values.get("profiling")
if value is not None and value.enabled and not value.profile_table_level_only:
raise ValueError(
"Only table level profiling is supported. Set `profiling.profile_table_level_only` to True.",
)
value = values.get("check_role_grants")
if value is not None and value:
raise ValueError(
"Check role grants is not supported. Set `check_role_grants` to False.",
)
value = values.get("include_read_operational_stats")
if value is not None and value:
raise ValueError(
"include_read_operational_stats is not supported. Set `include_read_operational_stats` to False.",
)
# Always exclude reporting metadata for INFORMATION_SCHEMA schema
schema_pattern = values.get("schema_pattern")
if schema_pattern is not None and schema_pattern:
logger.debug("Adding deny for INFORMATION_SCHEMA to schema_pattern.")
cast(AllowDenyPattern, schema_pattern).deny.append(r"^INFORMATION_SCHEMA$")
return values

View File

@ -0,0 +1,325 @@
import json
import logging
from collections import defaultdict
from typing import Dict, List, Optional, Set, Tuple
from snowflake.connector import SnowflakeConnection
import datahub.emitter.mce_builder as builder
from datahub.ingestion.source.aws.s3_util import make_s3_urn
from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config
from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery
from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report
from datahub.ingestion.source.snowflake.snowflake_utils import (
SnowflakeCommonMixin,
SnowflakeQueryMixin,
)
from datahub.metadata.com.linkedin.pegasus2avro.dataset import UpstreamLineage
from datahub.metadata.schema_classes import DatasetLineageTypeClass, UpstreamClass
logger: logging.Logger = logging.getLogger(__name__)
class SnowflakeLineageExtractor(SnowflakeQueryMixin, SnowflakeCommonMixin):
def __init__(self, config: SnowflakeV2Config, report: SnowflakeV2Report) -> None:
self._lineage_map: Optional[Dict[str, List[Tuple[str, str, str]]]] = None
self._external_lineage_map: Optional[Dict[str, Set[str]]] = None
self.config = config
self.platform = "snowflake"
self.report = report
self.logger = logger
# Rewrite implementation for readability, efficiency and extensibility
def _get_upstream_lineage_info(
self, dataset_name: str
) -> Optional[Tuple[UpstreamLineage, Dict[str, str]]]:
if not self.config.include_table_lineage:
return None
if self._lineage_map is None or self._external_lineage_map is None:
conn = self.config.get_connection()
if self._lineage_map is None:
self._populate_lineage(conn)
self._populate_view_lineage(conn)
if self._external_lineage_map is None:
self._populate_external_lineage(conn)
assert self._lineage_map is not None
assert self._external_lineage_map is not None
lineage = self._lineage_map[dataset_name]
external_lineage = self._external_lineage_map[dataset_name]
if not (lineage or external_lineage):
logger.debug(f"No lineage found for {dataset_name}")
return None
upstream_tables: List[UpstreamClass] = []
column_lineage: Dict[str, str] = {}
for lineage_entry in lineage:
# Update the table-lineage
upstream_table_name = lineage_entry[0]
if not self._is_dataset_pattern_allowed(upstream_table_name, "table"):
continue
upstream_table = UpstreamClass(
dataset=builder.make_dataset_urn_with_platform_instance(
self.platform,
upstream_table_name,
self.config.platform_instance,
self.config.env,
),
type=DatasetLineageTypeClass.TRANSFORMED,
)
upstream_tables.append(upstream_table)
# Update column-lineage for each down-stream column.
upstream_columns = [
self.snowflake_identifier(d["columnName"])
for d in json.loads(lineage_entry[1])
]
downstream_columns = [
self.snowflake_identifier(d["columnName"])
for d in json.loads(lineage_entry[2])
]
upstream_column_str = (
f"{upstream_table_name}({', '.join(sorted(upstream_columns))})"
)
downstream_column_str = (
f"{dataset_name}({', '.join(sorted(downstream_columns))})"
)
column_lineage_key = f"column_lineage[{upstream_table_name}]"
column_lineage_value = (
f"{{{upstream_column_str} -> {downstream_column_str}}}"
)
column_lineage[column_lineage_key] = column_lineage_value
logger.debug(f"{column_lineage_key}:{column_lineage_value}")
for external_lineage_entry in external_lineage:
# For now, populate only for S3
if external_lineage_entry.startswith("s3://"):
external_upstream_table = UpstreamClass(
dataset=make_s3_urn(external_lineage_entry, self.config.env),
type=DatasetLineageTypeClass.COPY,
)
upstream_tables.append(external_upstream_table)
if upstream_tables:
logger.debug(
f"Upstream lineage of '{dataset_name}': {[u.dataset for u in upstream_tables]}"
)
if self.config.upstream_lineage_in_report:
self.report.upstream_lineage[dataset_name] = [
u.dataset for u in upstream_tables
]
return UpstreamLineage(upstreams=upstream_tables), column_lineage
return None
def _populate_view_lineage(self, conn: SnowflakeConnection) -> None:
if not self.config.include_view_lineage:
return
self._populate_view_upstream_lineage(conn)
self._populate_view_downstream_lineage(conn)
def _populate_external_lineage(self, conn: SnowflakeConnection) -> None:
# Handles the case where a table is populated from an external location via copy.
# Eg: copy into category_english from 's3://acryl-snow-demo-olist/olist_raw_data/category_english'credentials=(aws_key_id='...' aws_secret_key='...') pattern='.*.csv';
query: str = SnowflakeQuery.external_table_lineage_history(
start_time_millis=int(self.config.start_time.timestamp() * 1000)
if not self.config.ignore_start_time_lineage
else 0,
end_time_millis=int(self.config.end_time.timestamp() * 1000),
)
num_edges: int = 0
self._external_lineage_map = defaultdict(set)
try:
for db_row in self.query(conn, query):
# key is the down-stream table name
key: str = self.get_dataset_identifier_from_qualified_name(
db_row["downstream_table_name"]
)
if not self._is_dataset_pattern_allowed(key, "table"):
continue
self._external_lineage_map[key] |= {
*json.loads(db_row["upstream_locations"])
}
logger.debug(
f"ExternalLineage[Table(Down)={key}]:External(Up)={self._external_lineage_map[key]} via access_history"
)
except Exception as e:
logger.warning(
f"Populating table external lineage from Snowflake failed."
f"Please check your premissions. Continuing...\nError was {e}."
)
# Handles the case for explicitly created external tables.
# NOTE: Snowflake does not log this information to the access_history table.
external_tables_query: str = SnowflakeQuery.show_external_tables()
try:
for db_row in self.query(conn, external_tables_query):
key = self.get_dataset_identifier(
db_row["name"], db_row["schema_name"], db_row["database_name"]
)
if not self._is_dataset_pattern_allowed(key, "table"):
continue
self._external_lineage_map[key].add(db_row["location"])
logger.debug(
f"ExternalLineage[Table(Down)={key}]:External(Up)={self._external_lineage_map[key]} via show external tables"
)
num_edges += 1
except Exception as e:
self.warn(
"external_lineage",
f"Populating external table lineage from Snowflake failed."
f"Please check your premissions. Continuing...\nError was {e}.",
)
logger.info(f"Found {num_edges} external lineage edges.")
self.report.num_external_table_edges_scanned = num_edges
def _populate_lineage(self, conn: SnowflakeConnection) -> None:
query: str = SnowflakeQuery.table_to_table_lineage_history(
start_time_millis=int(self.config.start_time.timestamp() * 1000)
if not self.config.ignore_start_time_lineage
else 0,
end_time_millis=int(self.config.end_time.timestamp() * 1000),
)
num_edges: int = 0
self._lineage_map = defaultdict(list)
try:
for db_row in self.query(conn, query):
# key is the down-stream table name
key: str = self.get_dataset_identifier_from_qualified_name(
db_row["downstream_table_name"]
)
upstream_table_name = self.get_dataset_identifier_from_qualified_name(
db_row["upstream_table_name"]
)
if not (
self._is_dataset_pattern_allowed(key, "table")
or self._is_dataset_pattern_allowed(upstream_table_name, "table")
):
continue
self._lineage_map[key].append(
# (<upstream_table_name>, <json_list_of_upstream_columns>, <json_list_of_downstream_columns>)
(
upstream_table_name,
db_row["upstream_table_columns"],
db_row["downstream_table_columns"],
)
)
num_edges += 1
logger.debug(
f"Lineage[Table(Down)={key}]:Table(Up)={self._lineage_map[key]}"
)
except Exception as e:
self.warn(
"lineage",
f"Extracting lineage from Snowflake failed."
f"Please check your premissions. Continuing...\nError was {e}.",
)
logger.info(
f"A total of {num_edges} Table->Table edges found"
f" for {len(self._lineage_map)} downstream tables.",
)
self.report.num_table_to_table_edges_scanned = num_edges
def _populate_view_upstream_lineage(self, conn: SnowflakeConnection) -> None:
# NOTE: This query captures only the upstream lineage of a view (with no column lineage).
# For more details see: https://docs.snowflake.com/en/user-guide/object-dependencies.html#object-dependencies
# and also https://docs.snowflake.com/en/sql-reference/account-usage/access_history.html#usage-notes for current limitations on capturing the lineage for views.
view_upstream_lineage_query: str = SnowflakeQuery.view_dependencies()
assert self._lineage_map is not None
num_edges: int = 0
try:
for db_row in self.query(conn, view_upstream_lineage_query):
# Process UpstreamTable/View/ExternalTable/Materialized View->View edge.
view_upstream: str = self.get_dataset_identifier_from_qualified_name(
db_row["view_upstream"]
)
view_name: str = self.get_dataset_identifier_from_qualified_name(
db_row["downstream_view"]
)
if not self._is_dataset_pattern_allowed(
dataset_name=view_name,
dataset_type=db_row["referencing_object_domain"],
):
continue
# key is the downstream view name
self._lineage_map[view_name].append(
# (<upstream_table_name>, <empty_json_list_of_upstream_table_columns>, <empty_json_list_of_downstream_view_columns>)
(view_upstream, "[]", "[]")
)
num_edges += 1
logger.debug(
f"Upstream->View: Lineage[View(Down)={view_name}]:Upstream={view_upstream}"
)
except Exception as e:
self.warn(
"view_upstream_lineage",
"Extracting the upstream view lineage from Snowflake failed."
+ f"Please check your permissions. Continuing...\nError was {e}.",
)
logger.info(f"A total of {num_edges} View upstream edges found.")
self.report.num_table_to_view_edges_scanned = num_edges
def _populate_view_downstream_lineage(self, conn: SnowflakeConnection) -> None:
# This query captures the downstream table lineage for views.
# See https://docs.snowflake.com/en/sql-reference/account-usage/access_history.html#usage-notes for current limitations on capturing the lineage for views.
# Eg: For viewA->viewB->ViewC->TableD, snowflake does not yet log intermediate view logs, resulting in only the viewA->TableD edge.
view_lineage_query: str = SnowflakeQuery.view_lineage_history(
start_time_millis=int(self.config.start_time.timestamp() * 1000)
if not self.config.ignore_start_time_lineage
else 0,
end_time_millis=int(self.config.end_time.timestamp() * 1000),
)
assert self._lineage_map is not None
self.report.num_view_to_table_edges_scanned = 0
try:
db_rows = self.query(conn, view_lineage_query)
except Exception as e:
self.warn(
"view_downstream_lineage",
f"Extracting the view lineage from Snowflake failed."
f"Please check your permissions. Continuing...\nError was {e}.",
)
else:
for db_row in db_rows:
view_name: str = self.get_dataset_identifier_from_qualified_name(
db_row["view_name"]
)
if not self._is_dataset_pattern_allowed(
view_name, db_row["view_domain"]
):
continue
downstream_table: str = self.get_dataset_identifier_from_qualified_name(
db_row["downstream_table_name"]
)
# Capture view->downstream table lineage.
self._lineage_map[downstream_table].append(
# (<upstream_view_name>, <json_list_of_upstream_view_columns>, <json_list_of_downstream_columns>)
(
view_name,
db_row["view_columns"],
db_row["downstream_table_columns"],
)
)
self.report.num_view_to_table_edges_scanned += 1
logger.debug(
f"View->Table: Lineage[Table(Down)={downstream_table}]:View(Up)={self._lineage_map[downstream_table]}"
)
logger.info(
f"Found {self.report.num_view_to_table_edges_scanned} View->Table edges."
)
def warn(self, key: str, reason: str) -> None:
self.report.report_warning(key, reason)
self.logger.warning(f"{key} => {reason}")
def error(self, key: str, reason: str) -> None:
self.report.report_failure(key, reason)
self.logger.error(f"{key} => {reason}")

View File

@ -0,0 +1,509 @@
from typing import Optional
class SnowflakeQuery:
@staticmethod
def current_version() -> str:
return "select CURRENT_VERSION()"
@staticmethod
def current_role() -> str:
return "select CURRENT_ROLE()"
@staticmethod
def current_warehouse() -> str:
return "select CURRENT_WAREHOUSE()"
@staticmethod
def current_database() -> str:
return "select CURRENT_DATABASE()"
@staticmethod
def current_schema() -> str:
return "select CURRENT_SCHEMA()"
@staticmethod
def show_databases() -> str:
return "show databases"
@staticmethod
def use_database(db_name: str) -> str:
return f'use database "{db_name}"'
@staticmethod
def schemas_for_database(db_name: Optional[str]) -> str:
db_clause = f'"{db_name}".' if db_name is not None else ""
return f"""
SELECT schema_name AS "schema_name",
created AS "created",
last_altered AS "last_altered",
comment AS "comment"
from {db_clause}information_schema.schemata
WHERE schema_name != 'INFORMATION_SCHEMA'
order by schema_name"""
@staticmethod
def tables_for_database(db_name: Optional[str]) -> str:
db_clause = f'"{db_name}".' if db_name is not None else ""
return f"""
SELECT table_catalog AS "table_catalog",
table_schema AS "table_schema",
table_name AS "table_name",
table_type AS "table_type",
created AS "created",
last_altered AS "last_altered" ,
comment AS "comment",
row_count AS "row_count",
bytes AS "bytes",
clustering_key AS "clustering_key",
auto_clustering_on AS "auto_clustering_on"
FROM {db_clause}information_schema.tables t
WHERE table_schema != 'INFORMATION_SCHEMA'
and table_type in ( 'BASE TABLE', 'EXTERNAL TABLE')
order by table_schema, table_name"""
@staticmethod
def tables_for_schema(schema_name: str, db_name: Optional[str]) -> str:
db_clause = f'"{db_name}".' if db_name is not None else ""
return f"""
SELECT table_catalog AS "table_catalog",
table_schema AS "table_schema",
table_name AS "table_name",
table_type AS "table_type",
created AS "created",
last_altered AS "last_altered" ,
comment AS "comment",
row_count AS "row_count",
bytes AS "bytes",
clustering_key AS "clustering_key",
auto_clustering_on AS "auto_clustering_on"
FROM {db_clause}information_schema.tables t
where schema_name='{schema_name}'
and table_type in ('BASE TABLE', 'EXTERNAL TABLE')
order by table_schema, table_name"""
# View definition is retrived in information_schema query only if role is owner of view. Hence this query is not used.
# https://community.snowflake.com/s/article/Is-it-possible-to-see-the-view-definition-in-information-schema-views-from-a-non-owner-role
@staticmethod
def views_for_database(db_name: Optional[str]) -> str:
db_clause = f'"{db_name}".' if db_name is not None else ""
return f"""
SELECT table_catalog AS "table_catalog",
table_schema AS "table_schema",
table_name AS "table_name",
created AS "created",
last_altered AS "last_altered",
comment AS "comment",
view_definition AS "view_definition"
FROM {db_clause}information_schema.views t
WHERE table_schema != 'INFORMATION_SCHEMA'
order by table_schema, table_name"""
# View definition is retrived in information_schema query only if role is owner of view. Hence this query is not used.
# https://community.snowflake.com/s/article/Is-it-possible-to-see-the-view-definition-in-information-schema-views-from-a-non-owner-role
@staticmethod
def views_for_schema(schema_name: str, db_name: Optional[str]) -> str:
db_clause = f'"{db_name}".' if db_name is not None else ""
return f"""
SELECT table_catalog AS "table_catalog",
table_schema AS "table_schema",
table_name AS "table_name",
created AS "created",
last_altered AS "last_altered",
comment AS "comment",
view_definition AS "view_definition"
FROM {db_clause}information_schema.views t
where schema_name='{schema_name}'
order by table_schema, table_name"""
@staticmethod
def show_views_for_database(db_name: str) -> str:
return f"""show views in database "{db_name}";"""
@staticmethod
def show_views_for_schema(schema_name: str, db_name: Optional[str]) -> str:
db_clause = f'"{db_name}".' if db_name is not None else ""
return f"""show views in schema {db_clause}"{schema_name}";"""
@staticmethod
def columns_for_schema(schema_name: str, db_name: Optional[str]) -> str:
db_clause = f'"{db_name}".' if db_name is not None else ""
return f"""
select
table_catalog AS "table_catalog",
table_schema AS "table_schema",
table_name AS "table_name",
column_name AS "column_name",
ordinal_position AS "ordinal_position",
is_nullable AS "is_nullable",
data_type AS "data_type",
comment AS "comment",
character_maximum_length AS "character_maximum_length",
numeric_precision AS "numeric_precision",
numeric_scale AS "numeric_scale",
column_default AS "column_default",
is_identity AS "is_identity"
from {db_clause}information_schema.columns
WHERE table_schema='{schema_name}'
ORDER BY ordinal_position"""
@staticmethod
def columns_for_table(
table_name: str, schema_name: str, db_name: Optional[str]
) -> str:
db_clause = f'"{db_name}".' if db_name is not None else ""
return f"""
select
table_catalog AS "table_catalog",
table_schema AS "table_schema",
table_name AS "table_name",
column_name AS "column_name",
ordinal_position AS "ordinal_position",
is_nullable AS "is_nullable",
data_type AS "data_type",
comment AS "comment",
character_maximum_length AS "character_maximum_length",
numeric_precision AS "numeric_precision",
numeric_scale AS "numeric_scale",
column_default AS "column_default",
is_identity AS "is_identity"
from {db_clause}information_schema.columns
WHERE table_schema='{schema_name}' and table_name='{table_name}'
ORDER BY ordinal_position"""
@staticmethod
def show_primary_keys_for_schema(schema_name: str, db_name: str) -> str:
return f"""
show primary keys in schema "{db_name}"."{schema_name}" """
@staticmethod
def show_foreign_keys_for_schema(schema_name: str, db_name: str) -> str:
return f"""
show imported keys in schema "{db_name}"."{schema_name}" """
@staticmethod
def operational_data_for_time_window(
start_time_millis: int, end_time_millis: int
) -> str:
return f"""
SELECT
-- access_history.query_id, -- only for debugging purposes
access_history.query_start_time AS "query_start_time",
query_history.query_text AS "query_text",
query_history.query_type AS "query_type",
query_history.rows_inserted AS "rows_inserted",
query_history.rows_updated AS "rows_updated",
query_history.rows_deleted AS "rows_deleted",
access_history.base_objects_accessed AS "base_objects_accessed",
access_history.direct_objects_accessed AS "direct_objects_accessed", -- when dealing with views, direct objects will show the view while base will show the underlying table
access_history.objects_modified AS "objects_modified",
-- query_history.execution_status, -- not really necessary, but should equal "SUCCESS"
-- query_history.warehouse_name,
access_history.user_name AS "user_name",
users.first_name AS "first_name",
users.last_name AS "last_name",
users.display_name AS "display_name",
users.email AS "email",
query_history.role_name AS "role_name"
FROM
snowflake.account_usage.access_history access_history
LEFT JOIN
snowflake.account_usage.query_history query_history
ON access_history.query_id = query_history.query_id
LEFT JOIN
snowflake.account_usage.users users
ON access_history.user_name = users.name
WHERE query_start_time >= to_timestamp_ltz({start_time_millis}, 3)
AND query_start_time < to_timestamp_ltz({end_time_millis}, 3)
AND query_history.query_type in ('INSERT', 'UPDATE', 'DELETE', 'CREATE', 'CREATE_TABLE', 'CREATE_TABLE_AS_SELECT')
ORDER BY query_start_time DESC
;"""
@staticmethod
def table_to_table_lineage_history(
start_time_millis: int, end_time_millis: int
) -> str:
return f"""
WITH table_lineage_history AS (
SELECT
r.value:"objectName"::varchar AS upstream_table_name,
r.value:"objectDomain"::varchar AS upstream_table_domain,
r.value:"columns" AS upstream_table_columns,
w.value:"objectName"::varchar AS downstream_table_name,
w.value:"objectDomain"::varchar AS downstream_table_domain,
w.value:"columns" AS downstream_table_columns,
t.query_start_time AS query_start_time
FROM
(SELECT * from snowflake.account_usage.access_history) t,
lateral flatten(input => t.DIRECT_OBJECTS_ACCESSED) r,
lateral flatten(input => t.OBJECTS_MODIFIED) w
WHERE r.value:"objectId" IS NOT NULL
AND w.value:"objectId" IS NOT NULL
AND w.value:"objectName" NOT LIKE '%.GE_TMP_%'
AND w.value:"objectName" NOT LIKE '%.GE_TEMP_%'
AND t.query_start_time >= to_timestamp_ltz({start_time_millis}, 3)
AND t.query_start_time < to_timestamp_ltz({end_time_millis}, 3))
SELECT
upstream_table_name AS "upstream_table_name",
downstream_table_name AS "downstream_table_name",
upstream_table_columns AS "upstream_table_columns",
downstream_table_columns AS "downstream_table_columns"
FROM table_lineage_history
WHERE upstream_table_domain in ('Table', 'External table') and downstream_table_domain = 'Table'
QUALIFY ROW_NUMBER() OVER (PARTITION BY downstream_table_name, upstream_table_name ORDER BY query_start_time DESC) = 1"""
@staticmethod
def view_dependencies() -> str:
return """
SELECT
concat(
referenced_database, '.', referenced_schema,
'.', referenced_object_name
) AS "view_upstream",
concat(
referencing_database, '.', referencing_schema,
'.', referencing_object_name
) AS "downstream_view",
referencing_object_domain AS "referencing_object_domain"
FROM
snowflake.account_usage.object_dependencies
WHERE
referencing_object_domain in ('VIEW', 'MATERIALIZED VIEW')
"""
@staticmethod
def view_lineage_history(start_time_millis: int, end_time_millis: int) -> str:
return f"""
WITH view_lineage_history AS (
SELECT
vu.value : "objectName"::varchar AS view_name,
vu.value : "objectDomain"::varchar AS view_domain,
vu.value : "columns" AS view_columns,
w.value : "objectName"::varchar AS downstream_table_name,
w.value : "objectDomain"::varchar AS downstream_table_domain,
w.value : "columns" AS downstream_table_columns,
t.query_start_time AS query_start_time
FROM
(
SELECT
*
FROM
snowflake.account_usage.access_history
) t,
lateral flatten(input => t.DIRECT_OBJECTS_ACCESSED) vu,
lateral flatten(input => t.OBJECTS_MODIFIED) w
WHERE
vu.value : "objectId" IS NOT NULL
AND w.value : "objectId" IS NOT NULL
AND w.value : "objectName" NOT LIKE '%.GE_TMP_%'
AND w.value : "objectName" NOT LIKE '%.GE_TEMP_%'
AND t.query_start_time >= to_timestamp_ltz({start_time_millis}, 3)
AND t.query_start_time < to_timestamp_ltz({end_time_millis}, 3)
)
SELECT
view_name AS "view_name",
view_domain AS "view_domain",
view_columns AS "view_columns",
downstream_table_name AS "downstream_table_name",
downstream_table_columns AS "downstream_table_columns"
FROM
view_lineage_history
WHERE
view_domain in ('View', 'Materialized view')
QUALIFY ROW_NUMBER() OVER (
PARTITION BY view_name,
downstream_table_name
ORDER BY
query_start_time DESC
) = 1
"""
@staticmethod
def show_external_tables() -> str:
return "show external tables in account"
@staticmethod
def external_table_lineage_history(
start_time_millis: int, end_time_millis: int
) -> str:
return f"""
WITH external_table_lineage_history AS (
SELECT
r.value:"locations" AS upstream_locations,
w.value:"objectName"::varchar AS downstream_table_name,
w.value:"objectDomain"::varchar AS downstream_table_domain,
w.value:"columns" AS downstream_table_columns,
t.query_start_time AS query_start_time
FROM
(SELECT * from snowflake.account_usage.access_history) t,
lateral flatten(input => t.BASE_OBJECTS_ACCESSED) r,
lateral flatten(input => t.OBJECTS_MODIFIED) w
WHERE r.value:"locations" IS NOT NULL
AND w.value:"objectId" IS NOT NULL
AND t.query_start_time >= to_timestamp_ltz({start_time_millis}, 3)
AND t.query_start_time < to_timestamp_ltz({end_time_millis}, 3))
SELECT
upstream_locations AS "upstream_locations",
downstream_table_name AS "downstream_table_name",
downstream_table_columns AS "downstream_table_columns"
FROM external_table_lineage_history
WHERE downstream_table_domain = 'Table'
QUALIFY ROW_NUMBER() OVER (PARTITION BY downstream_table_name ORDER BY query_start_time DESC) = 1"""
@staticmethod
def usage_per_object_per_time_bucket_for_time_window(
start_time_millis: int,
end_time_millis: int,
time_bucket_size: str,
use_base_objects: bool,
top_n_queries: int,
include_top_n_queries: bool,
) -> str:
# TODO: Do not query for top n queries if include_top_n_queries = False
# How can we make this pretty
assert time_bucket_size == "DAY" or time_bucket_size == "HOUR"
objects_column = (
"BASE_OBJECTS_ACCESSED" if use_base_objects else "DIRECT_OBJECTS_ACCESSED"
)
return f"""
WITH object_access_history AS
(
select
object.value : "objectName"::varchar AS object_name,
object.value : "objectDomain"::varchar AS object_domain,
object.value : "columns" AS object_columns,
query_start_time,
query_id,
user_name
from
(
select
query_id,
query_start_time,
user_name,
{objects_column}
from
snowflake.account_usage.access_history
WHERE
query_start_time >= to_timestamp_ltz({start_time_millis}, 3)
AND query_start_time < to_timestamp_ltz({end_time_millis}, 3)
)
t,
lateral flatten(input => t.{objects_column}) object
)
,
field_access_history AS
(
select
o.*,
col.value : "columnName"::varchar AS column_name
from
object_access_history o,
lateral flatten(input => o.object_columns) col
)
,
basic_usage_counts AS
(
SELECT
object_name,
ANY_VALUE(object_domain) AS object_domain,
DATE_TRUNC('{time_bucket_size}', CONVERT_TIMEZONE('UTC', query_start_time)) AS bucket_start_time,
count(distinct(query_id)) AS total_queries,
count( distinct(user_name) ) AS total_users
FROM
object_access_history
GROUP BY
bucket_start_time,
object_name
)
,
field_usage_counts AS
(
SELECT
object_name,
column_name,
DATE_TRUNC('{time_bucket_size}', CONVERT_TIMEZONE('UTC', query_start_time)) AS bucket_start_time,
count(distinct(query_id)) AS total_queries
FROM
field_access_history
GROUP BY
bucket_start_time,
object_name,
column_name
)
,
user_usage_counts AS
(
SELECT
object_name,
DATE_TRUNC('{time_bucket_size}', CONVERT_TIMEZONE('UTC', query_start_time)) AS bucket_start_time,
count(distinct(query_id)) AS total_queries,
user_name,
ANY_VALUE(users.email) AS user_email
FROM
object_access_history
LEFT JOIN
snowflake.account_usage.users users
ON user_name = users.name
GROUP BY
bucket_start_time,
object_name,
user_name
)
,
top_queries AS
(
SELECT
object_name,
DATE_TRUNC('{time_bucket_size}', CONVERT_TIMEZONE('UTC', query_start_time)) AS bucket_start_time,
query_history.query_text AS query_text,
count(distinct(access_history.query_id)) AS total_queries
FROM
object_access_history access_history
LEFT JOIN
snowflake.account_usage.query_history query_history
ON access_history.query_id = query_history.query_id
GROUP BY
bucket_start_time,
object_name,
query_text QUALIFY row_number() over ( partition by bucket_start_time, object_name, query_text
order by
total_queries desc ) <= {top_n_queries}
)
select
basic_usage_counts.object_name AS "object_name",
basic_usage_counts.bucket_start_time AS "bucket_start_time",
ANY_VALUE(basic_usage_counts.object_domain) AS "object_domain",
ANY_VALUE(basic_usage_counts.total_queries) AS "total_queries",
ANY_VALUE(basic_usage_counts.total_users) AS "total_users",
ARRAY_AGG( distinct top_queries.query_text) AS "top_sql_queries",
ARRAY_AGG( distinct OBJECT_CONSTRUCT( 'column_name', field_usage_counts.column_name, 'total_queries', field_usage_counts.total_queries ) ) AS "field_counts",
ARRAY_AGG( distinct OBJECT_CONSTRUCT( 'user_name', user_usage_counts.user_name, 'user_email', user_usage_counts.user_email, 'total_queries', user_usage_counts.total_queries ) ) AS "user_counts"
from
basic_usage_counts basic_usage_counts
left join
top_queries top_queries
on basic_usage_counts.bucket_start_time = top_queries.bucket_start_time
and basic_usage_counts.object_name = top_queries.object_name
left join
field_usage_counts field_usage_counts
on basic_usage_counts.bucket_start_time = field_usage_counts.bucket_start_time
and basic_usage_counts.object_name = field_usage_counts.object_name
left join
user_usage_counts user_usage_counts
on basic_usage_counts.bucket_start_time = user_usage_counts.bucket_start_time
and basic_usage_counts.object_name = user_usage_counts.object_name
where
basic_usage_counts.object_domain in
(
'Table',
'View',
'Materialized view',
'External table'
)
group by
basic_usage_counts.object_name,
basic_usage_counts.bucket_start_time
order by
basic_usage_counts.bucket_start_time
"""

View File

@ -0,0 +1,11 @@
from datahub.ingestion.source_report.sql.snowflake import SnowflakeReport
from datahub.ingestion.source_report.usage.snowflake_usage import SnowflakeUsageReport
class SnowflakeV2Report(SnowflakeReport, SnowflakeUsageReport):
include_usage_stats: bool = False
include_operational_stats: bool = False
usage_aggregation_query_secs: float = -1
rows_zero_objects_modified: int = 0

View File

@ -0,0 +1,323 @@
import logging
from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict, List, Optional
from snowflake.connector import SnowflakeConnection
from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery
from datahub.ingestion.source.snowflake.snowflake_utils import SnowflakeQueryMixin
logger: logging.Logger = logging.getLogger(__name__)
@dataclass
class SnowflakePK:
name: str
column_names: List[str]
@dataclass
class SnowflakeFK:
name: str
column_names: List[str]
referred_database: str
referred_schema: str
referred_table: str
referred_column_names: List[str]
@dataclass
class SnowflakeColumn:
name: str
ordinal_position: int
is_nullable: bool
data_type: str
comment: str
@dataclass
class SnowflakeTable:
name: str
created: datetime
last_altered: datetime
size_in_bytes: int
rows_count: int
comment: str
clustering_key: str
pk: Optional[SnowflakePK] = None
columns: List[SnowflakeColumn] = field(default_factory=list)
foreign_keys: List[SnowflakeFK] = field(default_factory=list)
@dataclass
class SnowflakeView:
name: str
created: datetime
comment: Optional[str]
view_definition: str
last_altered: Optional[datetime] = None
columns: List[SnowflakeColumn] = field(default_factory=list)
@dataclass
class SnowflakeSchema:
name: str
created: datetime
last_altered: datetime
comment: str
tables: List[SnowflakeTable] = field(default_factory=list)
views: List[SnowflakeView] = field(default_factory=list)
@dataclass
class SnowflakeDatabase:
name: str
created: datetime
comment: str
schemas: List[SnowflakeSchema] = field(default_factory=list)
class SnowflakeDataDictionary(SnowflakeQueryMixin):
def __init__(self) -> None:
self.logger = logger
def get_databases(self, conn: SnowflakeConnection) -> List[SnowflakeDatabase]:
databases: List[SnowflakeDatabase] = []
cur = self.query(
conn,
SnowflakeQuery.show_databases(),
)
for database in cur:
snowflake_db = SnowflakeDatabase(
name=database["name"],
created=database["created_on"],
comment=database["comment"],
)
databases.append(snowflake_db)
return databases
def get_schemas_for_database(
self, conn: SnowflakeConnection, db_name: str
) -> List[SnowflakeSchema]:
snowflake_schemas = []
cur = self.query(
conn,
SnowflakeQuery.schemas_for_database(db_name),
)
for schema in cur:
snowflake_schema = SnowflakeSchema(
name=schema["schema_name"],
created=schema["created"],
last_altered=schema["last_altered"],
comment=schema["comment"],
)
snowflake_schemas.append(snowflake_schema)
return snowflake_schemas
def get_tables_for_database(
self, conn: SnowflakeConnection, db_name: str
) -> Optional[Dict[str, List[SnowflakeTable]]]:
tables: Dict[str, List[SnowflakeTable]] = {}
try:
cur = self.query(
conn,
SnowflakeQuery.tables_for_database(db_name),
)
except Exception as e:
logger.debug(e)
# Error - Information schema query returned too much data. Please repeat query with more selective predicates.
return None
for table in cur:
if table["table_schema"] not in tables:
tables[table["table_schema"]] = []
tables[table["table_schema"]].append(
SnowflakeTable(
name=table["table_name"],
created=table["created"],
last_altered=table["last_altered"],
size_in_bytes=table["bytes"],
rows_count=table["row_count"],
comment=table["comment"],
clustering_key=table["clustering_key"],
)
)
return tables
def get_tables_for_schema(
self, conn: SnowflakeConnection, schema_name: str, db_name: str
) -> List[SnowflakeTable]:
tables: List[SnowflakeTable] = []
cur = self.query(
conn,
SnowflakeQuery.tables_for_schema(schema_name, db_name),
)
for table in cur:
tables.append(
SnowflakeTable(
name=table["table_name"],
created=table["created"],
last_altered=table["last_altered"],
size_in_bytes=table["bytes"],
rows_count=table["row_count"],
comment=table["comment"],
clustering_key=table["clustering_key"],
)
)
return tables
def get_views_for_database(
self, conn: SnowflakeConnection, db_name: str
) -> Optional[Dict[str, List[SnowflakeView]]]:
views: Dict[str, List[SnowflakeView]] = {}
try:
cur = self.query(conn, SnowflakeQuery.show_views_for_database(db_name))
except Exception as e:
logger.debug(e)
# Error - Information schema query returned too much data. Please repeat query with more selective predicates.
return None
for table in cur:
if table["schema_name"] not in views:
views[table["schema_name"]] = []
views[table["schema_name"]].append(
SnowflakeView(
name=table["name"],
created=table["created_on"],
# last_altered=table["last_altered"],
comment=table["comment"],
view_definition=table["text"],
)
)
return views
def get_views_for_schema(
self, conn: SnowflakeConnection, schema_name: str, db_name: str
) -> List[SnowflakeView]:
views: List[SnowflakeView] = []
cur = self.query(
conn, SnowflakeQuery.show_views_for_schema(schema_name, db_name)
)
for table in cur:
views.append(
SnowflakeView(
name=table["name"],
created=table["created_on"],
# last_altered=table["last_altered"],
comment=table["comment"],
view_definition=table["text"],
)
)
return views
def get_columns_for_schema(
self, conn: SnowflakeConnection, schema_name: str, db_name: str
) -> Optional[Dict[str, List[SnowflakeColumn]]]:
columns: Dict[str, List[SnowflakeColumn]] = {}
try:
cur = self.query(
conn, SnowflakeQuery.columns_for_schema(schema_name, db_name)
)
except Exception as e:
logger.debug(e)
# Error - Information schema query returned too much data.
# Please repeat query with more selective predicates.
return None
for column in cur:
if column["table_name"] not in columns:
columns[column["table_name"]] = []
columns[column["table_name"]].append(
SnowflakeColumn(
name=column["column_name"],
ordinal_position=column["ordinal_position"],
is_nullable=column["is_nullable"] == "YES",
data_type=column["data_type"],
comment=column["comment"],
)
)
return columns
def get_columns_for_table(
self, conn: SnowflakeConnection, table_name: str, schema_name: str, db_name: str
) -> List[SnowflakeColumn]:
columns: List[SnowflakeColumn] = []
cur = self.query(
conn,
SnowflakeQuery.columns_for_table(table_name, schema_name, db_name),
)
for column in cur:
columns.append(
SnowflakeColumn(
name=column["column_name"],
ordinal_position=column["ordinal_position"],
is_nullable=column["is_nullable"] == "YES",
data_type=column["data_type"],
comment=column["comment"],
)
)
return columns
def get_pk_constraints_for_schema(
self, conn: SnowflakeConnection, schema_name: str, db_name: str
) -> Dict[str, SnowflakePK]:
constraints: Dict[str, SnowflakePK] = {}
cur = self.query(
conn,
SnowflakeQuery.show_primary_keys_for_schema(schema_name, db_name),
)
for row in cur:
if row["table_name"] not in constraints:
constraints[row["table_name"]] = SnowflakePK(
name=row["constraint_name"], column_names=[]
)
constraints[row["table_name"]].column_names.append(row["column_name"])
return constraints
def get_fk_constraints_for_schema(
self, conn: SnowflakeConnection, schema_name: str, db_name: str
) -> Dict[str, List[SnowflakeFK]]:
constraints: Dict[str, List[SnowflakeFK]] = {}
fk_constraints_map: Dict[str, SnowflakeFK] = {}
cur = self.query(
conn,
SnowflakeQuery.show_foreign_keys_for_schema(schema_name, db_name),
)
for row in cur:
if row["fk_name"] not in constraints:
fk_constraints_map[row["fk_name"]] = SnowflakeFK(
name=row["fk_name"],
column_names=[],
referred_database=row["pk_database_name"],
referred_schema=row["pk_schema_name"],
referred_table=row["pk_table_name"],
referred_column_names=[],
)
if row["fk_table_name"] not in constraints:
constraints[row["fk_table_name"]] = []
fk_constraints_map[row["fk_name"]].column_names.append(
row["fk_column_name"]
)
fk_constraints_map[row["fk_name"]].referred_column_names.append(
row["pk_column_name"]
)
constraints[row["fk_table_name"]].append(fk_constraints_map[row["fk_name"]])
return constraints

View File

@ -0,0 +1,410 @@
import json
import logging
import time
from datetime import datetime, timezone
from typing import Any, Dict, Iterable, List, Optional
import pydantic
from snowflake.connector import SnowflakeConnection
from datahub.emitter.mce_builder import (
make_dataset_urn_with_platform_instance,
make_user_urn,
)
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config
from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery
from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report
from datahub.ingestion.source.snowflake.snowflake_utils import (
SnowflakeCommonMixin,
SnowflakeQueryMixin,
)
from datahub.metadata.com.linkedin.pegasus2avro.dataset import (
DatasetFieldUsageCounts,
DatasetUsageStatistics,
DatasetUserUsageCounts,
)
from datahub.metadata.com.linkedin.pegasus2avro.timeseries import TimeWindowSize
from datahub.metadata.schema_classes import (
ChangeTypeClass,
OperationClass,
OperationTypeClass,
)
from datahub.utilities.perf_timer import PerfTimer
from datahub.utilities.sql_formatter import format_sql_query, trim_query
logger: logging.Logger = logging.getLogger(__name__)
OPERATION_STATEMENT_TYPES = {
"INSERT": OperationTypeClass.INSERT,
"UPDATE": OperationTypeClass.UPDATE,
"DELETE": OperationTypeClass.DELETE,
"CREATE": OperationTypeClass.CREATE,
"CREATE_TABLE": OperationTypeClass.CREATE,
"CREATE_TABLE_AS_SELECT": OperationTypeClass.CREATE,
}
@pydantic.dataclasses.dataclass
class SnowflakeColumnReference:
columnId: int
columnName: str
class PermissiveModel(pydantic.BaseModel):
class Config:
extra = "allow"
class SnowflakeObjectAccessEntry(PermissiveModel):
columns: Optional[List[SnowflakeColumnReference]]
objectDomain: str
objectId: int
objectName: str
stageKind: Optional[str]
class SnowflakeJoinedAccessEvent(PermissiveModel):
query_start_time: datetime
query_text: str
query_type: str
rows_inserted: Optional[int]
rows_updated: Optional[int]
rows_deleted: Optional[int]
base_objects_accessed: List[SnowflakeObjectAccessEntry]
direct_objects_accessed: List[SnowflakeObjectAccessEntry]
objects_modified: List[SnowflakeObjectAccessEntry]
user_name: str
first_name: Optional[str]
last_name: Optional[str]
display_name: Optional[str]
email: Optional[str]
role_name: str
class SnowflakeUsageExtractor(SnowflakeQueryMixin, SnowflakeCommonMixin):
def __init__(self, config: SnowflakeV2Config, report: SnowflakeV2Report) -> None:
self.config: SnowflakeV2Config = config
self.report: SnowflakeV2Report = report
self.logger = logger
def get_workunits(self) -> Iterable[MetadataWorkUnit]:
conn = self.config.get_connection()
logger.info("Checking usage date ranges")
self._check_usage_date_ranges(conn)
if (
self.report.min_access_history_time is None
or self.report.max_access_history_time is None
):
return
# NOTE: In earlier `snowflake-usage` connector, users with no email were not considered in usage counts as well as in operation
# Now, we report the usage as well as operation metadata even if user email is absent
if self.config.include_usage_stats:
yield from self.get_usage_workunits(conn)
if self.config.include_operational_stats:
# Generate the operation workunits.
access_events = self._get_snowflake_history(conn)
for event in access_events:
yield from self._get_operation_aspect_work_unit(event)
def get_usage_workunits(
self, conn: SnowflakeConnection
) -> Iterable[MetadataWorkUnit]:
with PerfTimer() as timer:
logger.info("Getting aggregated usage statistics")
results = self.query(
conn,
SnowflakeQuery.usage_per_object_per_time_bucket_for_time_window(
start_time_millis=int(self.config.start_time.timestamp() * 1000),
end_time_millis=int(self.config.end_time.timestamp() * 1000),
time_bucket_size=self.config.bucket_duration,
use_base_objects=self.config.apply_view_usage_to_tables,
top_n_queries=self.config.top_n_queries,
include_top_n_queries=self.config.include_top_n_queries,
),
)
self.report.usage_aggregation_query_secs = timer.elapsed_seconds()
for row in results:
assert row["object_name"] is not None, "Null objectName not allowed"
if not self._is_dataset_pattern_allowed(
row["object_name"],
row["object_domain"],
):
continue
stats = DatasetUsageStatistics(
timestampMillis=int(row["bucket_start_time"].timestamp() * 1000),
eventGranularity=TimeWindowSize(
unit=self.config.bucket_duration, multiple=1
),
totalSqlQueries=row["total_queries"],
uniqueUserCount=row["total_users"],
topSqlQueries=self._map_top_sql_queries(
json.loads(row["top_sql_queries"])
)
if self.config.include_top_n_queries
else None,
userCounts=self._map_user_counts(json.loads(row["user_counts"])),
fieldCounts=[
DatasetFieldUsageCounts(
fieldPath=self.snowflake_identifier(field_count["column_name"]),
count=field_count["total_queries"],
)
for field_count in json.loads(row["field_counts"])
],
)
dataset_urn = make_dataset_urn_with_platform_instance(
"snowflake",
self.get_dataset_identifier_from_qualified_name(row["object_name"]),
self.config.platform_instance,
self.config.env,
)
yield self.wrap_aspect_as_workunit(
"dataset",
dataset_urn,
"datasetUsageStatistics",
stats,
)
def _map_top_sql_queries(self, top_sql_queries: Dict) -> List[str]:
total_budget_for_query_list: int = 24000
budget_per_query: int = int(
total_budget_for_query_list / self.config.top_n_queries
)
return [
trim_query(format_sql_query(query), budget_per_query)
if self.config.format_sql_queries
else query
for query in top_sql_queries
]
def _map_user_counts(self, user_counts: Dict) -> List[DatasetUserUsageCounts]:
filtered_user_counts = []
for user_count in user_counts:
user_email = user_count.get(
"user_email",
"{0}@{1}".format(
user_count["user_name"], self.config.email_domain
).lower()
if self.config.email_domain
else None,
)
if user_email is None or not self.config.user_email_pattern.allowed(
user_email
):
continue
filtered_user_counts.append(
DatasetUserUsageCounts(
user=make_user_urn(
self.get_user_identifier(user_count["user_name"], user_email)
),
count=user_count["total_queries"],
# NOTE: Generated emails may be incorrect, as email may be different than
# username@email_domain
userEmail=user_email,
)
)
return filtered_user_counts
def _get_snowflake_history(
self, conn: SnowflakeConnection
) -> Iterable[SnowflakeJoinedAccessEvent]:
logger.info("Getting access history")
with PerfTimer() as timer:
query = self._make_operations_query()
results = self.query(conn, query)
self.report.access_history_query_secs = round(timer.elapsed_seconds(), 2)
for row in results:
yield from self._process_snowflake_history_row(row)
def _make_operations_query(self) -> str:
start_time = int(self.config.start_time.timestamp() * 1000)
end_time = int(self.config.end_time.timestamp() * 1000)
return SnowflakeQuery.operational_data_for_time_window(start_time, end_time)
def _check_usage_date_ranges(self, conn: SnowflakeConnection) -> Any:
query = """
select
min(query_start_time) as "min_time",
max(query_start_time) as "max_time"
from snowflake.account_usage.access_history
"""
with PerfTimer() as timer:
try:
results = self.query(conn, query)
except Exception as e:
self.warn(
"check-usage-data",
f"Extracting the date range for usage data from Snowflake failed."
f"Please check your permissions. Continuing...\nError was {e}.",
)
else:
for db_row in results:
if (
len(db_row) < 2
or db_row["min_time"] is None
or db_row["max_time"] is None
):
self.warn(
"check-usage-data",
f"Missing data for access_history {db_row} - Check if using Enterprise edition of Snowflake",
)
continue
self.report.min_access_history_time = db_row["min_time"].astimezone(
tz=timezone.utc
)
self.report.max_access_history_time = db_row["max_time"].astimezone(
tz=timezone.utc
)
self.report.access_history_range_query_secs = round(
timer.elapsed_seconds(), 2
)
def _get_operation_aspect_work_unit(
self, event: SnowflakeJoinedAccessEvent
) -> Iterable[MetadataWorkUnit]:
if event.query_start_time and event.query_type in OPERATION_STATEMENT_TYPES:
start_time = event.query_start_time
query_type = event.query_type
user_email = event.email
user_name = event.user_name
operation_type = OPERATION_STATEMENT_TYPES[query_type]
reported_time: int = int(time.time() * 1000)
last_updated_timestamp: int = int(start_time.timestamp() * 1000)
user_urn = make_user_urn(self.get_user_identifier(user_name, user_email))
# NOTE: In earlier `snowflake-usage` connector this was base_objects_accessed, which is incorrect
for obj in event.objects_modified:
resource = obj.objectName
dataset_urn = make_dataset_urn_with_platform_instance(
"snowflake",
self.get_dataset_identifier_from_qualified_name(resource),
self.config.platform_instance,
self.config.env,
)
operation_aspect = OperationClass(
timestampMillis=reported_time,
lastUpdatedTimestamp=last_updated_timestamp,
actor=user_urn,
operationType=operation_type,
)
mcp = MetadataChangeProposalWrapper(
entityType="dataset",
aspectName="operation",
changeType=ChangeTypeClass.UPSERT,
entityUrn=dataset_urn,
aspect=operation_aspect,
)
wu = MetadataWorkUnit(
id=f"{start_time.isoformat()}-operation-aspect-{resource}",
mcp=mcp,
)
self.report.report_workunit(wu)
yield wu
def _process_snowflake_history_row(
self, row: Any
) -> Iterable[SnowflakeJoinedAccessEvent]:
self.report.rows_processed += 1
# Make some minor type conversions.
if hasattr(row, "_asdict"):
# Compat with SQLAlchemy 1.3 and 1.4
# See https://docs.sqlalchemy.org/en/14/changelog/migration_14.html#rowproxy-is-no-longer-a-proxy-is-now-called-row-and-behaves-like-an-enhanced-named-tuple.
event_dict = row._asdict()
else:
event_dict = dict(row)
# no use processing events that don't have a query text
if not event_dict["query_text"]:
self.report.rows_missing_query_text += 1
return
event_dict["base_objects_accessed"] = [
obj
for obj in json.loads(event_dict["base_objects_accessed"])
if self._is_object_valid(obj)
]
if len(event_dict["base_objects_accessed"]) == 0:
self.report.rows_zero_base_objects_accessed += 1
event_dict["direct_objects_accessed"] = [
obj
for obj in json.loads(event_dict["direct_objects_accessed"])
if self._is_object_valid(obj)
]
if len(event_dict["direct_objects_accessed"]) == 0:
self.report.rows_zero_direct_objects_accessed += 1
event_dict["objects_modified"] = [
obj
for obj in json.loads(event_dict["objects_modified"])
if self._is_object_valid(obj)
]
if len(event_dict["objects_modified"]) == 0:
self.report.rows_zero_objects_modified += 1
event_dict["query_start_time"] = (event_dict["query_start_time"]).astimezone(
tz=timezone.utc
)
if (
not event_dict["email"]
and self.config.email_domain
and event_dict["user_name"]
):
# NOTE: Generated emails may be incorrect, as email may be different than
# username@email_domain
event_dict[
"email"
] = f'{event_dict["user_name"]}@{self.config.email_domain}'.lower()
if not event_dict["email"]:
self.report.rows_missing_email += 1
try: # big hammer try block to ensure we don't fail on parsing events
event = SnowflakeJoinedAccessEvent(**event_dict)
yield event
except Exception as e:
self.report.rows_parsing_error += 1
self.warn(
"usage",
f"Failed to parse usage line {event_dict}, {e}",
)
def _is_unsupported_object_accessed(self, obj: Dict[str, Any]) -> bool:
unsupported_keys = ["locations"]
if obj.get("objectDomain") in ["Stage"]:
return True
return any([obj.get(key) is not None for key in unsupported_keys])
def _is_object_valid(self, obj: Dict[str, Any]) -> bool:
if self._is_unsupported_object_accessed(
obj
) or not self._is_dataset_pattern_allowed(
obj.get("objectName"), obj.get("objectDomain")
):
return False
return True
def warn(self, key: str, reason: str) -> None:
self.report.report_warning(key, reason)
self.logger.warning(f"{key} => {reason}")
def error(self, key: str, reason: str) -> None:
self.report.report_failure(key, reason)
self.logger.error(f"{key} => {reason}")

View File

@ -0,0 +1,149 @@
import logging
from typing import Any, Optional, Protocol
from snowflake.connector import SnowflakeConnection
from snowflake.connector.cursor import DictCursor
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config
from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report
from datahub.metadata.com.linkedin.pegasus2avro.events.metadata import ChangeType
from datahub.metadata.schema_classes import _Aspect
class SnowflakeLoggingProtocol(Protocol):
@property
def logger(self) -> logging.Logger:
...
class SnowflakeCommonProtocol(Protocol):
@property
def logger(self) -> logging.Logger:
...
@property
def config(self) -> SnowflakeV2Config:
...
@property
def report(self) -> SnowflakeV2Report:
...
def get_dataset_identifier(
self, table_name: str, schema_name: str, db_name: str
) -> str:
...
def snowflake_identifier(self, identifier: str) -> str:
...
class SnowflakeQueryMixin:
def query(
self: SnowflakeLoggingProtocol, conn: SnowflakeConnection, query: str
) -> Any:
self.logger.debug("Query : {}".format(query))
resp = conn.cursor(DictCursor).execute(query)
return resp
class SnowflakeCommonMixin:
def _is_dataset_pattern_allowed(
self: SnowflakeCommonProtocol,
dataset_name: Optional[str],
dataset_type: Optional[str],
) -> bool:
if not dataset_type or not dataset_name:
return True
dataset_params = dataset_name.split(".")
if len(dataset_params) != 3:
self.report.report_warning(
"invalid-dataset-pattern",
f"Found {dataset_params} of type {dataset_type}",
)
# NOTE: this case returned `True` earlier when extracting lineage
return False
if not self.config.database_pattern.allowed(
dataset_params[0].strip('"')
) or not self.config.schema_pattern.allowed(dataset_params[1].strip('"')):
return False
if dataset_type.lower() in {"table"} and not self.config.table_pattern.allowed(
dataset_params[2].strip('"')
):
return False
if dataset_type.lower() in {
"view",
"materialized_view",
} and not self.config.view_pattern.allowed(dataset_params[2].strip('"')):
return False
return True
def snowflake_identifier(self: SnowflakeCommonProtocol, identifier: str) -> str:
# to be in in sync with older connector, convert name to lowercase
if self.config.convert_urns_to_lowercase:
return identifier.lower()
return identifier
def get_dataset_identifier(
self: SnowflakeCommonProtocol, table_name: str, schema_name: str, db_name: str
) -> str:
return self.snowflake_identifier(f"{db_name}.{schema_name}.{table_name}")
# Qualified Object names from snowflake audit logs have quotes for for snowflake quoted identifiers,
# For example "test-database"."test-schema".test_table
# whereas we generate urns without quotes even for quoted identifiers for backward compatibility
# and also unavailability of utility function to identify whether current table/schema/database
# name should be quoted in above method get_dataset_identifier
def get_dataset_identifier_from_qualified_name(
self: SnowflakeCommonProtocol, qualified_name: str
) -> str:
name_parts = qualified_name.split(".")
if len(name_parts) != 3:
self.report.report_warning(
"invalid-dataset-pattern",
f"Found non-parseable {name_parts} for {qualified_name}",
)
return self.snowflake_identifier(qualified_name.replace('"', ""))
return self.get_dataset_identifier(
name_parts[2].strip('"'), name_parts[1].strip('"'), name_parts[0].strip('"')
)
# Note - decide how to construct user urns.
# Historically urns were created using part before @ from user's email.
# Users without email were skipped from both user entries as well as aggregates.
# However email is not mandatory field in snowflake user, user_name is always present.
def get_user_identifier(
self: SnowflakeCommonProtocol, user_name: str, user_email: Optional[str]
) -> str:
if user_email is not None:
return user_email.split("@")[0]
return self.snowflake_identifier(user_name)
def wrap_aspect_as_workunit(
self: SnowflakeCommonProtocol,
entityName: str,
entityUrn: str,
aspectName: str,
aspect: _Aspect,
) -> MetadataWorkUnit:
id = f"{aspectName}-for-{entityUrn}"
if "timestampMillis" in aspect._inner_dict:
id = f"{aspectName}-{aspect.timestampMillis}-for-{entityUrn}" # type: ignore
wu = MetadataWorkUnit(
id=id,
mcp=MetadataChangeProposalWrapper(
entityType=entityName,
entityUrn=entityUrn,
aspectName=aspectName,
aspect=aspect,
changeType=ChangeType.UPSERT,
),
)
self.report.report_workunit(wu)
return wu

File diff suppressed because it is too large Load Diff

View File

@ -12,3 +12,18 @@ def format_sql_query(query: str, **options: Any) -> str:
except Exception as e:
logger.debug(f"Exception:{e} while formatting query '{query}'.")
return query
def trim_query(
query: str, budget_per_query: int, query_trimmer_string: str = " ..."
) -> str:
trimmed_query = query
if len(query) > budget_per_query:
if budget_per_query - len(query_trimmer_string) > 0:
end_index = budget_per_query - len(query_trimmer_string)
trimmed_query = query[:end_index] + query_trimmer_string
else:
raise Exception(
"Budget per query is too low. Please, decrease the number of top_n_queries."
)
return trimmed_query

View File

@ -34,7 +34,7 @@ def test_bq_usage_config():
)
assert config.get_allow_pattern_string() == "test-regex|test-regex-1"
assert config.get_deny_pattern_string() == ""
assert (config.end_time - config.start_time) == timedelta(hours=2)
assert (config.end_time - config.start_time) == timedelta(hours=1)
assert config.projects == ["sample-bigquery-project-name-1234"]

View File

@ -106,7 +106,7 @@ protoPayload.serviceData.jobCompletedEvent.job.jobStatistics.referencedTables.ta
AND
timestamp >= "2021-07-18T23:45:00Z"
AND
timestamp < "2021-07-21T00:15:00Z\"""" # noqa: W293
timestamp < "2021-07-20T00:15:00Z\"""" # noqa: W293
source = BigQueryUsageSource.create(config, PipelineContext(run_id="bq-usage-test"))
@ -164,7 +164,7 @@ protoPayload.serviceData.jobCompletedEvent.job.jobStatistics.referencedTables.ta
AND
timestamp >= "2021-07-18T23:45:00Z"
AND
timestamp < "2021-07-21T00:15:00Z\"""" # noqa: W293
timestamp < "2021-07-20T00:15:00Z\"""" # noqa: W293
source = BigQueryUsageSource.create(config, PipelineContext(run_id="bq-usage-test"))
filter: str = source._generate_filter(BQ_AUDIT_V1)
assert filter == expected_filter