feat(ingest): allow extracting snowflake tags (#6500)

This commit is contained in:
Fredrik Sannholm 2023-01-04 23:05:23 +02:00 committed by GitHub
parent 6bc85502ba
commit e0aa812621
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 4933 additions and 1391 deletions

View File

@ -16,7 +16,7 @@ grant usage on DATABASE "<your-database>" to role datahub_role;
grant usage on all schemas in database "<your-database>" to role datahub_role;
grant usage on future schemas in database "<your-database>" to role datahub_role;
// If you are NOT using Snowflake Profiling or Classification feature: Grant references privileges to your tables and views
// If you are NOT using Snowflake Profiling or Classification feature: Grant references privileges to your tables and views
grant references on all tables in database "<your-database>" to role datahub_role;
grant references on future tables in database "<your-database>" to role datahub_role;
grant references on all external tables in database "<your-database>" to role datahub_role;
@ -30,10 +30,10 @@ grant select on future tables in database "<your-database>" to role datahub_role
grant select on all external tables in database "<your-database>" to role datahub_role;
grant select on future external tables in database "<your-database>" to role datahub_role;
// Create a new DataHub user and assign the DataHub role to it
// Create a new DataHub user and assign the DataHub role to it
create user datahub_user display_name = 'DataHub' password='' default_role = datahub_role default_warehouse = '<your-warehouse>';
// Grant the datahub_role to the new DataHub user.
// Grant the datahub_role to the new DataHub user.
grant role datahub_role to user datahub_user;
```
@ -50,7 +50,7 @@ grant usage on schema "<your-database>"."<your-schema>" to role datahub_role;
This represents the bare minimum privileges required to extract databases, schemas, views, tables from Snowflake.
If you plan to enable extraction of table lineage, via the `include_table_lineage` config flag or extraction of usage statistics, via the `include_usage_stats` config, you'll also need to grant access to the [Account Usage](https://docs.snowflake.com/en/sql-reference/account-usage.html) system tables, using which the DataHub source extracts information. This can be done by granting access to the `snowflake` database.
If you plan to enable extraction of table lineage, via the `include_table_lineage` config flag, extraction of usage statistics, via the `include_usage_stats` config, or extraction of tags (without lineage), via the `extract_tags` config, you'll also need to grant access to the [Account Usage](https://docs.snowflake.com/en/sql-reference/account-usage.html) system tables, using which the DataHub source extracts information. This can be done by granting access to the `snowflake` database.
```sql
grant imported privileges on database snowflake to role datahub_role;

View File

@ -36,6 +36,9 @@ class SnowflakeObjectDomain(str, Enum):
EXTERNAL_TABLE = "external table"
VIEW = "view"
MATERIALIZED_VIEW = "materialized view"
DATABASE = "database"
SCHEMA = "schema"
COLUMN = "column"
GENERIC_PERMISSION_ERROR_KEY = "permission-error"

View File

@ -1,4 +1,5 @@
import logging
from enum import Enum
from typing import Dict, Optional, cast
from pydantic import Field, SecretStr, root_validator, validator
@ -19,6 +20,12 @@ from datahub.ingestion.source_config.usage.snowflake_usage import SnowflakeUsage
logger = logging.Logger(__name__)
class TagOption(str, Enum):
with_lineage = "with_lineage"
without_lineage = "without_lineage"
skip = "skip"
class SnowflakeV2Config(
SnowflakeConfig,
SnowflakeUsageConfig,
@ -53,6 +60,14 @@ class SnowflakeV2Config(
default=None, description="Not supported"
)
extract_tags: TagOption = Field(
default=TagOption.skip,
description="""Optional. Allowed values are `without_lineage`, `with_lineage`, and `skip` (default).
`without_lineage` only extracts tags that have been applied directly to the given entity.
`with_lineage` extracts both directly applied and propagated tags, but will be significantly slower.
See the [Snowflake documentation](https://docs.snowflake.com/en/user-guide/object-tagging.html#tag-lineage) for information about tag lineage/propagation. """,
)
classification: Optional[ClassificationConfig] = Field(
default=None,
description="For details, refer [Classification](../../../../metadata-ingestion/docs/dev_guides/classification.md).",
@ -76,6 +91,11 @@ class SnowflakeV2Config(
)
return v
tag_pattern: AllowDenyPattern = Field(
default=AllowDenyPattern.allow_all(),
description="List of regex patterns for tags to include in ingestion. Only used if `extract_tags` is enabled.",
)
@root_validator(pre=False)
def validate_unsupported_configs(cls, values: Dict) -> Dict:
value = values.get("provision_role")

View File

@ -105,6 +105,52 @@ class SnowflakeQuery:
and table_type in ('BASE TABLE', 'EXTERNAL TABLE')
order by table_schema, table_name"""
@staticmethod
def get_all_tags_on_object_with_propagation(
db_name: str, quoted_identifier: str, domain: str
) -> str:
# https://docs.snowflake.com/en/sql-reference/functions/tag_references.html
return f"""
SELECT tag_database as "TAG_DATABASE",
tag_schema AS "TAG_SCHEMA",
tag_name AS "TAG_NAME",
tag_value AS "TAG_VALUE"
FROM table("{db_name}".information_schema.tag_references('{quoted_identifier}', '{domain}'));
"""
@staticmethod
def get_all_tags_in_database_without_propagation(db_name: str) -> str:
# https://docs.snowflake.com/en/sql-reference/account-usage/tag_references.html
return f"""
SELECT tag_database as "TAG_DATABASE",
tag_schema AS "TAG_SCHEMA",
tag_name AS "TAG_NAME",
tag_value AS "TAG_VALUE",
object_database as "OBJECT_DATABASE",
object_schema AS "OBJECT_SCHEMA",
object_name AS "OBJECT_NAME",
column_name AS "COLUMN_NAME",
domain as "DOMAIN"
FROM snowflake.account_usage.tag_references
WHERE (object_database = '{db_name}' OR object_name = '{db_name}')
AND domain in ('DATABASE', 'SCHEMA', 'TABLE', 'COLUMN')
AND object_deleted IS NULL;
"""
@staticmethod
def get_tags_on_columns_with_propagation(
db_name: str, quoted_table_identifier: str
) -> str:
# https://docs.snowflake.com/en/sql-reference/functions/tag_references_all_columns.html
return f"""
SELECT tag_database as "TAG_DATABASE",
tag_schema AS "TAG_SCHEMA",
tag_name AS "TAG_NAME",
tag_value AS "TAG_VALUE",
column_name AS "COLUMN_NAME"
FROM table("{db_name}".information_schema.tag_references_all_columns('{quoted_table_identifier}', 'table'));
"""
# View definition is retrived in information_schema query only if role is owner of view. Hence this query is not used.
# https://community.snowflake.com/s/article/Is-it-possible-to-see-the-view-definition-in-information-schema-views-from-a-non-owner-role
@staticmethod

View File

@ -1,4 +1,4 @@
from typing import Optional
from typing import MutableSet, Optional
from datahub.ingestion.source.snowflake.constants import SnowflakeEdition
from datahub.ingestion.source.sql.sql_generic_profiler import ProfilingSqlReport
@ -12,6 +12,7 @@ class SnowflakeV2Report(SnowflakeReport, SnowflakeUsageReport, ProfilingSqlRepor
schemas_scanned: int = 0
databases_scanned: int = 0
tags_scanned: int = 0
include_usage_stats: bool = False
include_operational_stats: bool = False
@ -31,8 +32,16 @@ class SnowflakeV2Report(SnowflakeReport, SnowflakeUsageReport, ProfilingSqlRepor
num_get_views_for_schema_queries: int = 0
num_get_columns_for_table_queries: int = 0
# these will be non-zero if the user choses to enable the extract_tags = "with_lineage" option, which requires
# individual queries per object (database, schema, table) and an extra query per table to get the tags on the columns.
num_get_tags_for_object_queries: int = 0
num_get_tags_on_columns_for_table_queries: int = 0
rows_zero_objects_modified: int = 0
_processed_tags: MutableSet[str] = set()
_scanned_tags: MutableSet[str] = set()
edition: Optional[SnowflakeEdition] = None
def report_entity_scanned(self, name: str, ent_type: str = "table") -> None:
@ -47,5 +56,21 @@ class SnowflakeV2Report(SnowflakeReport, SnowflakeUsageReport, ProfilingSqlRepor
self.schemas_scanned += 1
elif ent_type == "database":
self.databases_scanned += 1
elif ent_type == "tag":
# the same tag can be assigned to multiple objects, so we need
# some extra logic account for each tag only once.
if self._is_tag_scanned(name):
return
self._scanned_tags.add(name)
self.tags_scanned += 1
else:
raise KeyError(f"Unknown entity {ent_type}.")
def is_tag_processed(self, tag_name: str) -> bool:
return tag_name in self._processed_tags
def _is_tag_scanned(self, tag_name: str) -> bool:
return tag_name in self._scanned_tags
def report_tag_processed(self, tag_name: str) -> None:
self._processed_tags.add(tag_name)

View File

@ -1,4 +1,5 @@
import logging
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict, List, Optional
@ -6,6 +7,7 @@ from typing import Dict, List, Optional
import pandas as pd
from snowflake.connector import SnowflakeConnection
from datahub.ingestion.source.snowflake.constants import SnowflakeObjectDomain
from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery
from datahub.ingestion.source.snowflake.snowflake_utils import SnowflakeQueryMixin
from datahub.ingestion.source.sql.sql_generic import BaseColumn, BaseTable, BaseView
@ -29,6 +31,20 @@ class SnowflakeFK:
referred_column_names: List[str]
@dataclass
class SnowflakeTag:
database: str
schema: str
name: str
value: str
def identifier(self) -> str:
return f"{self._id_prefix_as_str()}:{self.value}"
def _id_prefix_as_str(self) -> str:
return f"{self.database}.{self.schema}.{self.name}"
@dataclass(frozen=True, eq=True)
class SnowflakeColumn(BaseColumn):
character_maximum_length: Optional[int]
@ -61,12 +77,16 @@ class SnowflakeTable(BaseTable):
pk: Optional[SnowflakePK] = None
columns: List[SnowflakeColumn] = field(default_factory=list)
foreign_keys: List[SnowflakeFK] = field(default_factory=list)
tags: Optional[List[SnowflakeTag]] = None
column_tags: Dict[str, List[SnowflakeTag]] = field(default_factory=dict)
sample_data: Optional[pd.DataFrame] = None
@dataclass
class SnowflakeView(BaseView):
columns: List[SnowflakeColumn] = field(default_factory=list)
tags: Optional[List[SnowflakeTag]] = None
column_tags: Dict[str, List[SnowflakeTag]] = field(default_factory=dict)
@dataclass
@ -77,6 +97,7 @@ class SnowflakeSchema:
comment: Optional[str]
tables: List[SnowflakeTable] = field(default_factory=list)
views: List[SnowflakeView] = field(default_factory=list)
tags: Optional[List[SnowflakeTag]] = None
@dataclass
@ -86,6 +107,69 @@ class SnowflakeDatabase:
comment: Optional[str]
last_altered: Optional[datetime] = None
schemas: List[SnowflakeSchema] = field(default_factory=list)
tags: Optional[List[SnowflakeTag]] = None
class _SnowflakeTagCache:
def __init__(self) -> None:
# self._database_tags[<database_name>] = list of tags applied to database
self._database_tags: Dict[str, List[SnowflakeTag]] = defaultdict(list)
# self._schema_tags[<database_name>][<schema_name>] = list of tags applied to schema
self._schema_tags: Dict[str, Dict[str, List[SnowflakeTag]]] = defaultdict(
lambda: defaultdict(list)
)
# self._table_tags[<database_name>][<schema_name>][<table_name>] = list of tags applied to table
self._table_tags: Dict[
str, Dict[str, Dict[str, List[SnowflakeTag]]]
] = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
# self._column_tags[<database_name>][<schema_name>][<table_name>][<column_name>] = list of tags applied to column
self._column_tags: Dict[
str, Dict[str, Dict[str, Dict[str, List[SnowflakeTag]]]]
] = defaultdict(
lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
)
def add_database_tag(self, db_name: str, tag: SnowflakeTag) -> None:
self._database_tags[db_name].append(tag)
def get_database_tags(self, db_name: str) -> List[SnowflakeTag]:
return self._database_tags[db_name]
def add_schema_tag(self, schema_name: str, db_name: str, tag: SnowflakeTag) -> None:
self._schema_tags[db_name][schema_name].append(tag)
def get_schema_tags(self, schema_name: str, db_name: str) -> List[SnowflakeTag]:
return self._schema_tags.get(db_name, {}).get(schema_name, [])
def add_table_tag(
self, table_name: str, schema_name: str, db_name: str, tag: SnowflakeTag
) -> None:
self._table_tags[db_name][schema_name][table_name].append(tag)
def get_table_tags(
self, table_name: str, schema_name: str, db_name: str
) -> List[SnowflakeTag]:
return self._table_tags[db_name][schema_name][table_name]
def add_column_tag(
self,
column_name: str,
table_name: str,
schema_name: str,
db_name: str,
tag: SnowflakeTag,
) -> None:
self._column_tags[db_name][schema_name][table_name][column_name].append(tag)
def get_column_tags_for_table(
self, table_name: str, schema_name: str, db_name: str
) -> Dict[str, List[SnowflakeTag]]:
return (
self._column_tags.get(db_name, {}).get(schema_name, {}).get(table_name, {})
)
class SnowflakeDataDictionary(SnowflakeQueryMixin):
@ -358,3 +442,101 @@ class SnowflakeDataDictionary(SnowflakeQueryMixin):
constraints[row["fk_table_name"]].append(fk_constraints_map[row["fk_name"]])
return constraints
def get_tags_for_database_without_propagation(
self,
db_name: str,
) -> _SnowflakeTagCache:
cur = self.query(
SnowflakeQuery.get_all_tags_in_database_without_propagation(db_name)
)
tags = _SnowflakeTagCache()
for tag in cur:
snowflake_tag = SnowflakeTag(
database=tag["TAG_DATABASE"],
schema=tag["TAG_SCHEMA"],
name=tag["TAG_NAME"],
value=tag["TAG_VALUE"],
)
# This is the name of the object, unless the object is a column, in which
# case the name is in the `COLUMN_NAME` field.
object_name = tag["OBJECT_NAME"]
# This will be null if the object is a database or schema
object_schema = tag["OBJECT_SCHEMA"]
# This will be null if the object is a database
object_database = tag["OBJECT_DATABASE"]
domain = tag["DOMAIN"].lower()
if domain == SnowflakeObjectDomain.DATABASE:
tags.add_database_tag(object_name, snowflake_tag)
elif domain == SnowflakeObjectDomain.SCHEMA:
tags.add_schema_tag(object_name, object_database, snowflake_tag)
elif domain == SnowflakeObjectDomain.TABLE: # including views
tags.add_table_tag(
object_name, object_schema, object_database, snowflake_tag
)
elif domain == SnowflakeObjectDomain.COLUMN:
column_name = tag["COLUMN_NAME"]
tags.add_column_tag(
column_name,
object_name,
object_schema,
object_database,
snowflake_tag,
)
else:
# This should never happen.
self.logger.error(f"Encountered an unexpected domain: {domain}")
continue
return tags
def get_tags_for_object_with_propagation(
self,
domain: str,
quoted_identifier: str,
db_name: str,
) -> List[SnowflakeTag]:
tags: List[SnowflakeTag] = []
cur = self.query(
SnowflakeQuery.get_all_tags_on_object_with_propagation(
db_name, quoted_identifier, domain
),
)
for tag in cur:
tags.append(
SnowflakeTag(
database=tag["TAG_DATABASE"],
schema=tag["TAG_SCHEMA"],
name=tag["TAG_NAME"],
value=tag["TAG_VALUE"],
)
)
return tags
def get_tags_on_columns_for_table(
self, quoted_table_name: str, db_name: str
) -> Dict[str, List[SnowflakeTag]]:
tags: Dict[str, List[SnowflakeTag]] = defaultdict(list)
cur = self.query(
SnowflakeQuery.get_tags_on_columns_with_propagation(
db_name, quoted_table_name
),
)
for tag in cur:
column_name = tag["COLUMN_NAME"]
snowflake_tag = SnowflakeTag(
database=tag["TAG_DATABASE"],
schema=tag["TAG_SCHEMA"],
name=tag["TAG_NAME"],
value=tag["TAG_VALUE"],
)
tags[column_name].append(snowflake_tag)
return tags

View File

@ -0,0 +1,172 @@
import logging
from typing import Dict, List, Optional
from datahub.ingestion.source.snowflake.constants import SnowflakeObjectDomain
from datahub.ingestion.source.snowflake.snowflake_config import (
SnowflakeV2Config,
TagOption,
)
from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report
from datahub.ingestion.source.snowflake.snowflake_schema import (
SnowflakeDataDictionary,
SnowflakeTag,
_SnowflakeTagCache,
)
from datahub.ingestion.source.snowflake.snowflake_utils import SnowflakeCommonMixin
logger: logging.Logger = logging.getLogger(__name__)
class SnowflakeTagExtractor(SnowflakeCommonMixin):
def __init__(
self,
config: SnowflakeV2Config,
data_dictionary: SnowflakeDataDictionary,
report: SnowflakeV2Report,
) -> None:
self.config = config
self.data_dictionary = data_dictionary
self.report = report
self.logger = logger
self.tag_cache: Dict[str, _SnowflakeTagCache] = {}
def _get_tags_on_object_without_propagation(
self,
domain: str,
db_name: str,
schema_name: Optional[str],
table_name: Optional[str],
) -> List[SnowflakeTag]:
if db_name not in self.tag_cache:
self.tag_cache[
db_name
] = self.data_dictionary.get_tags_for_database_without_propagation(db_name)
if domain == SnowflakeObjectDomain.DATABASE:
return self.tag_cache[db_name].get_database_tags(db_name)
elif domain == SnowflakeObjectDomain.SCHEMA:
assert schema_name is not None
tags = self.tag_cache[db_name].get_schema_tags(schema_name, db_name)
elif (
domain == SnowflakeObjectDomain.TABLE
): # Views belong to this domain as well.
assert schema_name is not None
assert table_name is not None
tags = self.tag_cache[db_name].get_table_tags(
table_name, schema_name, db_name
)
else:
raise ValueError(f"Unknown domain {domain}")
return tags
def _get_tags_on_object_with_propagation(
self,
domain: str,
db_name: str,
schema_name: Optional[str],
table_name: Optional[str],
) -> List[SnowflakeTag]:
identifier = ""
if domain == SnowflakeObjectDomain.DATABASE:
identifier = self.get_quoted_identifier_for_database(db_name)
elif domain == SnowflakeObjectDomain.SCHEMA:
assert schema_name is not None
identifier = self.get_quoted_identifier_for_schema(db_name, schema_name)
elif (
domain == SnowflakeObjectDomain.TABLE
): # Views belong to this domain as well.
assert schema_name is not None
assert table_name is not None
identifier = self.get_quoted_identifier_for_table(
db_name, schema_name, table_name
)
else:
raise ValueError(f"Unknown domain {domain}")
assert identifier
self.report.num_get_tags_for_object_queries += 1
tags = self.data_dictionary.get_tags_for_object_with_propagation(
domain=domain, quoted_identifier=identifier, db_name=db_name
)
return tags
def get_tags_on_object(
self,
domain: str,
db_name: str,
schema_name: Optional[str] = None,
table_name: Optional[str] = None,
) -> List[SnowflakeTag]:
if self.config.extract_tags == TagOption.without_lineage:
tags = self._get_tags_on_object_without_propagation(
domain=domain,
db_name=db_name,
schema_name=schema_name,
table_name=table_name,
)
elif self.config.extract_tags == TagOption.with_lineage:
tags = self._get_tags_on_object_with_propagation(
domain=domain,
db_name=db_name,
schema_name=schema_name,
table_name=table_name,
)
else:
tags = []
allowed_tags = self._filter_tags(tags)
return allowed_tags if allowed_tags else []
def get_column_tags_for_table(
self,
table_name: str,
schema_name: str,
db_name: str,
) -> Dict[str, List[SnowflakeTag]]:
temp_column_tags: Dict[str, List[SnowflakeTag]] = {}
if self.config.extract_tags == TagOption.without_lineage:
if db_name not in self.tag_cache:
self.tag_cache[
db_name
] = self.data_dictionary.get_tags_for_database_without_propagation(
db_name
)
temp_column_tags = self.tag_cache[db_name].get_column_tags_for_table(
table_name, schema_name, db_name
)
elif self.config.extract_tags == TagOption.with_lineage:
self.report.num_get_tags_on_columns_for_table_queries += 1
temp_column_tags = self.data_dictionary.get_tags_on_columns_for_table(
quoted_table_name=self.get_quoted_identifier_for_table(
db_name, schema_name, table_name
),
db_name=db_name,
)
column_tags: Dict[str, List[SnowflakeTag]] = {}
for column_name in temp_column_tags:
tags = temp_column_tags[column_name]
allowed_tags = self._filter_tags(tags)
if allowed_tags:
column_tags[column_name] = allowed_tags
return column_tags
def _filter_tags(
self, tags: Optional[List[SnowflakeTag]]
) -> Optional[List[SnowflakeTag]]:
if tags is None:
return tags
allowed_tags = []
for tag in tags:
tag_identifier = tag.identifier()
self.report.report_entity_scanned(tag_identifier, "tag")
if not self.config.tag_pattern.allowed(tag_identifier):
self.report.report_dropped(tag_identifier)
allowed_tags.append(tag)
return allowed_tags

View File

@ -158,6 +158,18 @@ class SnowflakeCommonMixin:
return identifier.lower()
return identifier
@staticmethod
def get_quoted_identifier_for_database(db_name):
return f'"{db_name}"'
@staticmethod
def get_quoted_identifier_for_schema(db_name, schema_name):
return f'"{db_name}"."{schema_name}"'
@staticmethod
def get_quoted_identifier_for_table(db_name, schema_name, table_name):
return f'"{db_name}"."{schema_name}"."{table_name}"'
def get_dataset_identifier(
self: SnowflakeCommonProtocol, table_name: str, schema_name: str, db_name: str
) -> str:

View File

@ -15,6 +15,7 @@ from datahub.emitter.mce_builder import (
make_dataset_urn_with_platform_instance,
make_domain_urn,
make_schema_field_urn,
make_tag_urn,
)
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.emitter.mcp_builder import (
@ -49,7 +50,10 @@ from datahub.ingestion.source.snowflake.constants import (
SnowflakeEdition,
SnowflakeObjectDomain,
)
from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config
from datahub.ingestion.source.snowflake.snowflake_config import (
SnowflakeV2Config,
TagOption,
)
from datahub.ingestion.source.snowflake.snowflake_lineage import (
SnowflakeLineageExtractor,
)
@ -64,8 +68,10 @@ from datahub.ingestion.source.snowflake.snowflake_schema import (
SnowflakeQuery,
SnowflakeSchema,
SnowflakeTable,
SnowflakeTag,
SnowflakeView,
)
from datahub.ingestion.source.snowflake.snowflake_tag import SnowflakeTagExtractor
from datahub.ingestion.source.snowflake.snowflake_usage_v2 import (
SnowflakeUsageExtractor,
)
@ -90,8 +96,10 @@ from datahub.ingestion.source.state.stateful_ingestion_base import (
StatefulIngestionSourceBase,
)
from datahub.metadata.com.linkedin.pegasus2avro.common import (
GlobalTags,
Status,
SubTypes,
TagAssociation,
TimeStamp,
)
from datahub.metadata.com.linkedin.pegasus2avro.dataset import (
@ -114,6 +122,7 @@ from datahub.metadata.com.linkedin.pegasus2avro.schema import (
StringType,
TimeType,
)
from datahub.metadata.com.linkedin.pegasus2avro.tag import TagProperties
from datahub.metadata.schema_classes import ChangeTypeClass, DataPlatformInstanceClass
from datahub.utilities.registries.domain_registry import DomainRegistry
from datahub.utilities.time import datetime_to_ts_millis
@ -188,6 +197,11 @@ SNOWFLAKE_FIELD_TYPE_MAPPINGS = {
"Optionally enabled via `stateful_ingestion.remove_stale_metadata`",
supported=True,
)
@capability(
SourceCapability.TAGS,
"Optionally enabled via `extract_tags`",
supported=True,
)
class SnowflakeV2Source(
ClassificationMixin,
SnowflakeQueryMixin,
@ -235,6 +249,10 @@ class SnowflakeV2Source(
# For usage stats
self.usage_extractor = SnowflakeUsageExtractor(config, self.report)
self.tag_extractor = SnowflakeTagExtractor(
config, self.data_dictionary, self.report
)
self.profiling_state_handler: Optional[ProfilingHandler] = None
if self.config.store_last_profiling_timestamps:
self.profiling_state_handler = ProfilingHandler(
@ -358,6 +376,7 @@ class SnowflakeV2Source(
_report[SourceCapability.CONTAINERS] = CapabilityReport(
capable=True
)
_report[SourceCapability.TAGS] = CapabilityReport(capable=True)
elif privilege.object_type in (
"TABLE",
"VIEW",
@ -391,6 +410,8 @@ class SnowflakeV2Source(
_report[SourceCapability.USAGE_STATS] = CapabilityReport(
capable=True
)
_report[SourceCapability.TAGS] = CapabilityReport(capable=True)
# If all capabilities supported, no need to continue
if set(capabilities) == set(_report.keys()):
break
@ -414,6 +435,7 @@ class SnowflakeV2Source(
SourceCapability.LINEAGE_COARSE: "Current role does not have permissions to snowflake account usage views",
SourceCapability.LINEAGE_FINE: "Current role does not have permissions to snowflake account usage views",
SourceCapability.USAGE_STATS: "Current role does not have permissions to snowflake account usage views",
SourceCapability.TAGS: "Either no tags have been applied to objects, or the current role does not have permission to access the objects or to snowflake account usage views ",
}
for c in capabilities: # type:ignore
@ -425,6 +447,7 @@ class SnowflakeV2Source(
SourceCapability.LINEAGE_COARSE,
SourceCapability.LINEAGE_FINE,
SourceCapability.USAGE_STATS,
SourceCapability.TAGS,
):
failure_message = (
f"Current role {current_role} does not have permissions to use warehouse {connection_conf.warehouse}. Please check the grants associated with this role."
@ -471,6 +494,7 @@ class SnowflakeV2Source(
for snowflake_db in databases:
try:
yield from self._process_database(snowflake_db)
except SnowflakePermissionError as e:
# FIXME - This may break satetful ingestion if new tables than previous run are emitted above
# and stateful ingestion is enabled
@ -627,11 +651,20 @@ class SnowflakeV2Source(
)
return
if self.config.extract_tags != TagOption.skip:
snowflake_db.tags = self.tag_extractor.get_tags_on_object(
domain="database", db_name=db_name
)
if self.config.include_technical_schema:
yield from self.gen_database_containers(snowflake_db)
self.fetch_schemas_for_database(snowflake_db, db_name)
if self.config.include_technical_schema and snowflake_db.tags:
for tag in snowflake_db.tags:
yield from self._process_tag(tag)
for snowflake_schema in snowflake_db.schemas:
yield from self._process_schema(snowflake_schema, db_name)
@ -675,6 +708,12 @@ class SnowflakeV2Source(
return
schema_name = snowflake_schema.name
if self.config.extract_tags != TagOption.skip:
snowflake_schema.tags = self.tag_extractor.get_tags_on_object(
schema_name=schema_name, db_name=db_name, domain="schema"
)
if self.config.include_technical_schema:
yield from self.gen_schema_containers(snowflake_schema, db_name)
@ -692,6 +731,10 @@ class SnowflakeV2Source(
for view in snowflake_schema.views:
yield from self._process_view(view, schema_name, db_name)
if self.config.include_technical_schema and snowflake_schema.tags:
for tag in snowflake_schema.tags:
yield from self._process_tag(tag)
if not snowflake_schema.views and not snowflake_schema.tables:
self.report_warning(
"No tables/views found in schema. If tables exist, please grant REFERENCES or SELECT permissions on them.",
@ -762,6 +805,22 @@ class SnowflakeV2Source(
table, schema_name, db_name, dataset_name
)
if self.config.extract_tags != TagOption.skip:
table.tags = self.tag_extractor.get_tags_on_object(
table_name=table.name,
schema_name=schema_name,
db_name=db_name,
domain="table",
)
if self.config.include_technical_schema:
if table.tags:
for tag in table.tags:
yield from self._process_tag(tag)
for column_name in table.column_tags:
for tag in table.column_tags[column_name]:
yield from self._process_tag(tag)
yield from self.gen_dataset_workunits(table, schema_name, db_name)
def fetch_sample_data_for_classification(
@ -817,6 +876,10 @@ class SnowflakeV2Source(
def fetch_columns_for_table(self, table, schema_name, db_name, table_identifier):
try:
table.columns = self.get_columns_for_table(table.name, schema_name, db_name)
if self.config.extract_tags != TagOption.skip:
table.column_tags = self.tag_extractor.get_column_tags_for_table(
table.name, schema_name, db_name
)
except Exception as e:
logger.debug(
f"Failed to get columns for table {table_identifier} due to error {e}",
@ -840,6 +903,10 @@ class SnowflakeV2Source(
try:
view.columns = self.get_columns_for_table(view.name, schema_name, db_name)
if self.config.extract_tags != TagOption.skip:
view.column_tags = self.tag_extractor.get_column_tags_for_table(
view.name, schema_name, db_name
)
except Exception as e:
logger.debug(
f"Failed to get columns for view {view_name} due to error {e}",
@ -847,8 +914,34 @@ class SnowflakeV2Source(
)
self.report_warning("Failed to get columns for view", view_name)
if self.config.extract_tags != TagOption.skip:
view.tags = self.tag_extractor.get_tags_on_object(
table_name=view.name,
schema_name=schema_name,
db_name=db_name,
domain="table",
)
if self.config.include_technical_schema:
if view.tags:
for tag in view.tags:
yield from self._process_tag(tag)
for column_name in view.column_tags:
for tag in view.column_tags[column_name]:
yield from self._process_tag(tag)
yield from self.gen_dataset_workunits(view, schema_name, db_name)
def _process_tag(self, tag: SnowflakeTag) -> Iterable[MetadataWorkUnit]:
tag_identifier = tag.identifier()
if self.report.is_tag_processed(tag_identifier):
return
self.report.report_tag_processed(tag_identifier)
yield from self.gen_tag_workunits(tag)
def gen_dataset_workunits(
self,
table: Union[SnowflakeTable, SnowflakeView],
@ -908,6 +1001,15 @@ class SnowflakeV2Source(
entity_type="dataset",
)
if table.tags:
tag_associations = [
TagAssociation(tag=make_tag_urn(tag.identifier())) for tag in table.tags
]
global_tags = GlobalTags(tag_associations)
yield self.wrap_aspect_as_workunit(
"dataset", dataset_urn, "globalTags", global_tags
)
if (
isinstance(table, SnowflakeView)
and cast(SnowflakeView, table).view_definition is not None
@ -951,6 +1053,21 @@ class SnowflakeV2Source(
else None,
)
def gen_tag_workunits(self, tag: SnowflakeTag) -> Iterable[MetadataWorkUnit]:
tag_key = tag.identifier()
tag_urn = make_tag_urn(self.snowflake_identifier(tag_key))
tag_properties_aspect = TagProperties(
name=tag_key,
description=f"Represents the Snowflake tag `{tag._id_prefix_as_str()}` with value `{tag.value}`.",
)
self.stale_entity_removal_handler.add_entity_to_state("tag", tag_urn)
yield self.wrap_aspect_as_workunit(
"tag", tag_urn, "tagProperties", tag_properties_aspect
)
def get_schema_metadata(
self,
table: Union[SnowflakeTable, SnowflakeView],
@ -980,6 +1097,18 @@ class SnowflakeV2Source(
isPartOfKey=col.name in table.pk.column_names
if isinstance(table, SnowflakeTable) and table.pk is not None
else None,
globalTags=GlobalTags(
[
TagAssociation(
make_tag_urn(
self.snowflake_identifier(tag.identifier())
)
)
for tag in table.column_tags[col.name]
]
)
if col.name in table.column_tags
else None,
)
for col in table.columns
],
@ -1168,6 +1297,9 @@ class SnowflakeV2Source(
else int(database.created.timestamp() * 1000)
if database.created is not None
else None,
tags=[self.snowflake_identifier(tag.identifier()) for tag in database.tags]
if database.tags
else None,
)
self.stale_entity_removal_handler.add_entity_to_state(
@ -1215,6 +1347,9 @@ class SnowflakeV2Source(
else int(schema.created.timestamp() * 1000)
if schema.created is not None
else None,
tags=[self.snowflake_identifier(tag.identifier()) for tag in schema.tags]
if schema.tags
else None,
)
for wu in container_workunits:

View File

@ -352,5 +352,60 @@ def default_query_results(query): # noqa: C901
]:
return []
elif (
query
== snowflake_query.SnowflakeQuery.get_all_tags_in_database_without_propagation(
"TEST_DB"
)
):
return [
*[
{
"TAG_DATABASE": "TEST_DB",
"TAG_SCHEMA": "TEST_SCHEMA",
"TAG_NAME": f"my_tag_{ix}",
"TAG_VALUE": f"my_value_{ix}",
"OBJECT_DATABASE": "TEST_DB",
"OBJECT_SCHEMA": "TEST_SCHEMA",
"OBJECT_NAME": "VIEW_2",
"COLUMN_NAME": None,
"DOMAIN": "TABLE",
}
for ix in range(3)
],
{
"TAG_DATABASE": "TEST_DB",
"TAG_SCHEMA": "TEST_SCHEMA",
"TAG_NAME": "security",
"TAG_VALUE": "pii",
"OBJECT_DATABASE": "TEST_DB",
"OBJECT_SCHEMA": "TEST_SCHEMA",
"OBJECT_NAME": "VIEW_1",
"COLUMN_NAME": "COL_1",
"DOMAIN": "COLUMN",
},
{
"TAG_DATABASE": "OTHER_DB",
"TAG_SCHEMA": "OTHER_SCHEMA",
"TAG_NAME": "my_other_tag",
"TAG_VALUE": "other",
"OBJECT_DATABASE": "TEST_DB",
"OBJECT_SCHEMA": None,
"OBJECT_NAME": "TEST_SCHEMA",
"COLUMN_NAME": None,
"DOMAIN": "SCHEMA",
},
{
"TAG_DATABASE": "OTHER_DB",
"TAG_SCHEMA": "OTHER_SCHEMA",
"TAG_NAME": "my_other_tag",
"TAG_VALUE": "other",
"OBJECT_DATABASE": None,
"OBJECT_SCHEMA": None,
"OBJECT_NAME": "TEST_DB",
"COLUMN_NAME": None,
"DOMAIN": "DATABASE",
},
]
# Unreachable code
raise Exception(f"Unknown query {query}")

View File

@ -19,7 +19,10 @@ from datahub.ingestion.glossary.datahub_classifier import (
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.run.pipeline_config import PipelineConfig, SourceConfig
from datahub.ingestion.source.ge_profiling_config import GEProfilingConfig
from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config
from datahub.ingestion.source.snowflake.snowflake_config import (
SnowflakeV2Config,
TagOption,
)
from tests.integration.snowflake.common import FROZEN_TIME, default_query_results
from tests.test_helpers import mce_helpers
@ -109,6 +112,7 @@ def test_snowflake_basic(pytestconfig, tmp_path, mock_time, mock_datahub_graph):
profile_table_size_limit=None,
profile_table_level_only=True,
),
extract_tags=TagOption.without_lineage,
),
),
sink=DynamicTypedConfig(