mirror of
https://github.com/datahub-project/datahub.git
synced 2025-12-25 17:08:29 +00:00
feat(ingest): allow extracting snowflake tags (#6500)
This commit is contained in:
parent
6bc85502ba
commit
e0aa812621
@ -16,7 +16,7 @@ grant usage on DATABASE "<your-database>" to role datahub_role;
|
||||
grant usage on all schemas in database "<your-database>" to role datahub_role;
|
||||
grant usage on future schemas in database "<your-database>" to role datahub_role;
|
||||
|
||||
// If you are NOT using Snowflake Profiling or Classification feature: Grant references privileges to your tables and views
|
||||
// If you are NOT using Snowflake Profiling or Classification feature: Grant references privileges to your tables and views
|
||||
grant references on all tables in database "<your-database>" to role datahub_role;
|
||||
grant references on future tables in database "<your-database>" to role datahub_role;
|
||||
grant references on all external tables in database "<your-database>" to role datahub_role;
|
||||
@ -30,10 +30,10 @@ grant select on future tables in database "<your-database>" to role datahub_role
|
||||
grant select on all external tables in database "<your-database>" to role datahub_role;
|
||||
grant select on future external tables in database "<your-database>" to role datahub_role;
|
||||
|
||||
// Create a new DataHub user and assign the DataHub role to it
|
||||
// Create a new DataHub user and assign the DataHub role to it
|
||||
create user datahub_user display_name = 'DataHub' password='' default_role = datahub_role default_warehouse = '<your-warehouse>';
|
||||
|
||||
// Grant the datahub_role to the new DataHub user.
|
||||
// Grant the datahub_role to the new DataHub user.
|
||||
grant role datahub_role to user datahub_user;
|
||||
```
|
||||
|
||||
@ -50,7 +50,7 @@ grant usage on schema "<your-database>"."<your-schema>" to role datahub_role;
|
||||
|
||||
This represents the bare minimum privileges required to extract databases, schemas, views, tables from Snowflake.
|
||||
|
||||
If you plan to enable extraction of table lineage, via the `include_table_lineage` config flag or extraction of usage statistics, via the `include_usage_stats` config, you'll also need to grant access to the [Account Usage](https://docs.snowflake.com/en/sql-reference/account-usage.html) system tables, using which the DataHub source extracts information. This can be done by granting access to the `snowflake` database.
|
||||
If you plan to enable extraction of table lineage, via the `include_table_lineage` config flag, extraction of usage statistics, via the `include_usage_stats` config, or extraction of tags (without lineage), via the `extract_tags` config, you'll also need to grant access to the [Account Usage](https://docs.snowflake.com/en/sql-reference/account-usage.html) system tables, using which the DataHub source extracts information. This can be done by granting access to the `snowflake` database.
|
||||
|
||||
```sql
|
||||
grant imported privileges on database snowflake to role datahub_role;
|
||||
|
||||
@ -36,6 +36,9 @@ class SnowflakeObjectDomain(str, Enum):
|
||||
EXTERNAL_TABLE = "external table"
|
||||
VIEW = "view"
|
||||
MATERIALIZED_VIEW = "materialized view"
|
||||
DATABASE = "database"
|
||||
SCHEMA = "schema"
|
||||
COLUMN = "column"
|
||||
|
||||
|
||||
GENERIC_PERMISSION_ERROR_KEY = "permission-error"
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional, cast
|
||||
|
||||
from pydantic import Field, SecretStr, root_validator, validator
|
||||
@ -19,6 +20,12 @@ from datahub.ingestion.source_config.usage.snowflake_usage import SnowflakeUsage
|
||||
logger = logging.Logger(__name__)
|
||||
|
||||
|
||||
class TagOption(str, Enum):
|
||||
with_lineage = "with_lineage"
|
||||
without_lineage = "without_lineage"
|
||||
skip = "skip"
|
||||
|
||||
|
||||
class SnowflakeV2Config(
|
||||
SnowflakeConfig,
|
||||
SnowflakeUsageConfig,
|
||||
@ -53,6 +60,14 @@ class SnowflakeV2Config(
|
||||
default=None, description="Not supported"
|
||||
)
|
||||
|
||||
extract_tags: TagOption = Field(
|
||||
default=TagOption.skip,
|
||||
description="""Optional. Allowed values are `without_lineage`, `with_lineage`, and `skip` (default).
|
||||
`without_lineage` only extracts tags that have been applied directly to the given entity.
|
||||
`with_lineage` extracts both directly applied and propagated tags, but will be significantly slower.
|
||||
See the [Snowflake documentation](https://docs.snowflake.com/en/user-guide/object-tagging.html#tag-lineage) for information about tag lineage/propagation. """,
|
||||
)
|
||||
|
||||
classification: Optional[ClassificationConfig] = Field(
|
||||
default=None,
|
||||
description="For details, refer [Classification](../../../../metadata-ingestion/docs/dev_guides/classification.md).",
|
||||
@ -76,6 +91,11 @@ class SnowflakeV2Config(
|
||||
)
|
||||
return v
|
||||
|
||||
tag_pattern: AllowDenyPattern = Field(
|
||||
default=AllowDenyPattern.allow_all(),
|
||||
description="List of regex patterns for tags to include in ingestion. Only used if `extract_tags` is enabled.",
|
||||
)
|
||||
|
||||
@root_validator(pre=False)
|
||||
def validate_unsupported_configs(cls, values: Dict) -> Dict:
|
||||
value = values.get("provision_role")
|
||||
|
||||
@ -105,6 +105,52 @@ class SnowflakeQuery:
|
||||
and table_type in ('BASE TABLE', 'EXTERNAL TABLE')
|
||||
order by table_schema, table_name"""
|
||||
|
||||
@staticmethod
|
||||
def get_all_tags_on_object_with_propagation(
|
||||
db_name: str, quoted_identifier: str, domain: str
|
||||
) -> str:
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/tag_references.html
|
||||
return f"""
|
||||
SELECT tag_database as "TAG_DATABASE",
|
||||
tag_schema AS "TAG_SCHEMA",
|
||||
tag_name AS "TAG_NAME",
|
||||
tag_value AS "TAG_VALUE"
|
||||
FROM table("{db_name}".information_schema.tag_references('{quoted_identifier}', '{domain}'));
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_all_tags_in_database_without_propagation(db_name: str) -> str:
|
||||
# https://docs.snowflake.com/en/sql-reference/account-usage/tag_references.html
|
||||
return f"""
|
||||
SELECT tag_database as "TAG_DATABASE",
|
||||
tag_schema AS "TAG_SCHEMA",
|
||||
tag_name AS "TAG_NAME",
|
||||
tag_value AS "TAG_VALUE",
|
||||
object_database as "OBJECT_DATABASE",
|
||||
object_schema AS "OBJECT_SCHEMA",
|
||||
object_name AS "OBJECT_NAME",
|
||||
column_name AS "COLUMN_NAME",
|
||||
domain as "DOMAIN"
|
||||
FROM snowflake.account_usage.tag_references
|
||||
WHERE (object_database = '{db_name}' OR object_name = '{db_name}')
|
||||
AND domain in ('DATABASE', 'SCHEMA', 'TABLE', 'COLUMN')
|
||||
AND object_deleted IS NULL;
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_tags_on_columns_with_propagation(
|
||||
db_name: str, quoted_table_identifier: str
|
||||
) -> str:
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/tag_references_all_columns.html
|
||||
return f"""
|
||||
SELECT tag_database as "TAG_DATABASE",
|
||||
tag_schema AS "TAG_SCHEMA",
|
||||
tag_name AS "TAG_NAME",
|
||||
tag_value AS "TAG_VALUE",
|
||||
column_name AS "COLUMN_NAME"
|
||||
FROM table("{db_name}".information_schema.tag_references_all_columns('{quoted_table_identifier}', 'table'));
|
||||
"""
|
||||
|
||||
# View definition is retrived in information_schema query only if role is owner of view. Hence this query is not used.
|
||||
# https://community.snowflake.com/s/article/Is-it-possible-to-see-the-view-definition-in-information-schema-views-from-a-non-owner-role
|
||||
@staticmethod
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import MutableSet, Optional
|
||||
|
||||
from datahub.ingestion.source.snowflake.constants import SnowflakeEdition
|
||||
from datahub.ingestion.source.sql.sql_generic_profiler import ProfilingSqlReport
|
||||
@ -12,6 +12,7 @@ class SnowflakeV2Report(SnowflakeReport, SnowflakeUsageReport, ProfilingSqlRepor
|
||||
|
||||
schemas_scanned: int = 0
|
||||
databases_scanned: int = 0
|
||||
tags_scanned: int = 0
|
||||
|
||||
include_usage_stats: bool = False
|
||||
include_operational_stats: bool = False
|
||||
@ -31,8 +32,16 @@ class SnowflakeV2Report(SnowflakeReport, SnowflakeUsageReport, ProfilingSqlRepor
|
||||
num_get_views_for_schema_queries: int = 0
|
||||
num_get_columns_for_table_queries: int = 0
|
||||
|
||||
# these will be non-zero if the user choses to enable the extract_tags = "with_lineage" option, which requires
|
||||
# individual queries per object (database, schema, table) and an extra query per table to get the tags on the columns.
|
||||
num_get_tags_for_object_queries: int = 0
|
||||
num_get_tags_on_columns_for_table_queries: int = 0
|
||||
|
||||
rows_zero_objects_modified: int = 0
|
||||
|
||||
_processed_tags: MutableSet[str] = set()
|
||||
_scanned_tags: MutableSet[str] = set()
|
||||
|
||||
edition: Optional[SnowflakeEdition] = None
|
||||
|
||||
def report_entity_scanned(self, name: str, ent_type: str = "table") -> None:
|
||||
@ -47,5 +56,21 @@ class SnowflakeV2Report(SnowflakeReport, SnowflakeUsageReport, ProfilingSqlRepor
|
||||
self.schemas_scanned += 1
|
||||
elif ent_type == "database":
|
||||
self.databases_scanned += 1
|
||||
elif ent_type == "tag":
|
||||
# the same tag can be assigned to multiple objects, so we need
|
||||
# some extra logic account for each tag only once.
|
||||
if self._is_tag_scanned(name):
|
||||
return
|
||||
self._scanned_tags.add(name)
|
||||
self.tags_scanned += 1
|
||||
else:
|
||||
raise KeyError(f"Unknown entity {ent_type}.")
|
||||
|
||||
def is_tag_processed(self, tag_name: str) -> bool:
|
||||
return tag_name in self._processed_tags
|
||||
|
||||
def _is_tag_scanned(self, tag_name: str) -> bool:
|
||||
return tag_name in self._scanned_tags
|
||||
|
||||
def report_tag_processed(self, tag_name: str) -> None:
|
||||
self._processed_tags.add(tag_name)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
@ -6,6 +7,7 @@ from typing import Dict, List, Optional
|
||||
import pandas as pd
|
||||
from snowflake.connector import SnowflakeConnection
|
||||
|
||||
from datahub.ingestion.source.snowflake.constants import SnowflakeObjectDomain
|
||||
from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery
|
||||
from datahub.ingestion.source.snowflake.snowflake_utils import SnowflakeQueryMixin
|
||||
from datahub.ingestion.source.sql.sql_generic import BaseColumn, BaseTable, BaseView
|
||||
@ -29,6 +31,20 @@ class SnowflakeFK:
|
||||
referred_column_names: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SnowflakeTag:
|
||||
database: str
|
||||
schema: str
|
||||
name: str
|
||||
value: str
|
||||
|
||||
def identifier(self) -> str:
|
||||
return f"{self._id_prefix_as_str()}:{self.value}"
|
||||
|
||||
def _id_prefix_as_str(self) -> str:
|
||||
return f"{self.database}.{self.schema}.{self.name}"
|
||||
|
||||
|
||||
@dataclass(frozen=True, eq=True)
|
||||
class SnowflakeColumn(BaseColumn):
|
||||
character_maximum_length: Optional[int]
|
||||
@ -61,12 +77,16 @@ class SnowflakeTable(BaseTable):
|
||||
pk: Optional[SnowflakePK] = None
|
||||
columns: List[SnowflakeColumn] = field(default_factory=list)
|
||||
foreign_keys: List[SnowflakeFK] = field(default_factory=list)
|
||||
tags: Optional[List[SnowflakeTag]] = None
|
||||
column_tags: Dict[str, List[SnowflakeTag]] = field(default_factory=dict)
|
||||
sample_data: Optional[pd.DataFrame] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SnowflakeView(BaseView):
|
||||
columns: List[SnowflakeColumn] = field(default_factory=list)
|
||||
tags: Optional[List[SnowflakeTag]] = None
|
||||
column_tags: Dict[str, List[SnowflakeTag]] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -77,6 +97,7 @@ class SnowflakeSchema:
|
||||
comment: Optional[str]
|
||||
tables: List[SnowflakeTable] = field(default_factory=list)
|
||||
views: List[SnowflakeView] = field(default_factory=list)
|
||||
tags: Optional[List[SnowflakeTag]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -86,6 +107,69 @@ class SnowflakeDatabase:
|
||||
comment: Optional[str]
|
||||
last_altered: Optional[datetime] = None
|
||||
schemas: List[SnowflakeSchema] = field(default_factory=list)
|
||||
tags: Optional[List[SnowflakeTag]] = None
|
||||
|
||||
|
||||
class _SnowflakeTagCache:
|
||||
def __init__(self) -> None:
|
||||
# self._database_tags[<database_name>] = list of tags applied to database
|
||||
self._database_tags: Dict[str, List[SnowflakeTag]] = defaultdict(list)
|
||||
|
||||
# self._schema_tags[<database_name>][<schema_name>] = list of tags applied to schema
|
||||
self._schema_tags: Dict[str, Dict[str, List[SnowflakeTag]]] = defaultdict(
|
||||
lambda: defaultdict(list)
|
||||
)
|
||||
|
||||
# self._table_tags[<database_name>][<schema_name>][<table_name>] = list of tags applied to table
|
||||
self._table_tags: Dict[
|
||||
str, Dict[str, Dict[str, List[SnowflakeTag]]]
|
||||
] = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
||||
|
||||
# self._column_tags[<database_name>][<schema_name>][<table_name>][<column_name>] = list of tags applied to column
|
||||
self._column_tags: Dict[
|
||||
str, Dict[str, Dict[str, Dict[str, List[SnowflakeTag]]]]
|
||||
] = defaultdict(
|
||||
lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
||||
)
|
||||
|
||||
def add_database_tag(self, db_name: str, tag: SnowflakeTag) -> None:
|
||||
self._database_tags[db_name].append(tag)
|
||||
|
||||
def get_database_tags(self, db_name: str) -> List[SnowflakeTag]:
|
||||
return self._database_tags[db_name]
|
||||
|
||||
def add_schema_tag(self, schema_name: str, db_name: str, tag: SnowflakeTag) -> None:
|
||||
self._schema_tags[db_name][schema_name].append(tag)
|
||||
|
||||
def get_schema_tags(self, schema_name: str, db_name: str) -> List[SnowflakeTag]:
|
||||
return self._schema_tags.get(db_name, {}).get(schema_name, [])
|
||||
|
||||
def add_table_tag(
|
||||
self, table_name: str, schema_name: str, db_name: str, tag: SnowflakeTag
|
||||
) -> None:
|
||||
self._table_tags[db_name][schema_name][table_name].append(tag)
|
||||
|
||||
def get_table_tags(
|
||||
self, table_name: str, schema_name: str, db_name: str
|
||||
) -> List[SnowflakeTag]:
|
||||
return self._table_tags[db_name][schema_name][table_name]
|
||||
|
||||
def add_column_tag(
|
||||
self,
|
||||
column_name: str,
|
||||
table_name: str,
|
||||
schema_name: str,
|
||||
db_name: str,
|
||||
tag: SnowflakeTag,
|
||||
) -> None:
|
||||
self._column_tags[db_name][schema_name][table_name][column_name].append(tag)
|
||||
|
||||
def get_column_tags_for_table(
|
||||
self, table_name: str, schema_name: str, db_name: str
|
||||
) -> Dict[str, List[SnowflakeTag]]:
|
||||
return (
|
||||
self._column_tags.get(db_name, {}).get(schema_name, {}).get(table_name, {})
|
||||
)
|
||||
|
||||
|
||||
class SnowflakeDataDictionary(SnowflakeQueryMixin):
|
||||
@ -358,3 +442,101 @@ class SnowflakeDataDictionary(SnowflakeQueryMixin):
|
||||
constraints[row["fk_table_name"]].append(fk_constraints_map[row["fk_name"]])
|
||||
|
||||
return constraints
|
||||
|
||||
def get_tags_for_database_without_propagation(
|
||||
self,
|
||||
db_name: str,
|
||||
) -> _SnowflakeTagCache:
|
||||
cur = self.query(
|
||||
SnowflakeQuery.get_all_tags_in_database_without_propagation(db_name)
|
||||
)
|
||||
|
||||
tags = _SnowflakeTagCache()
|
||||
|
||||
for tag in cur:
|
||||
snowflake_tag = SnowflakeTag(
|
||||
database=tag["TAG_DATABASE"],
|
||||
schema=tag["TAG_SCHEMA"],
|
||||
name=tag["TAG_NAME"],
|
||||
value=tag["TAG_VALUE"],
|
||||
)
|
||||
|
||||
# This is the name of the object, unless the object is a column, in which
|
||||
# case the name is in the `COLUMN_NAME` field.
|
||||
object_name = tag["OBJECT_NAME"]
|
||||
# This will be null if the object is a database or schema
|
||||
object_schema = tag["OBJECT_SCHEMA"]
|
||||
# This will be null if the object is a database
|
||||
object_database = tag["OBJECT_DATABASE"]
|
||||
|
||||
domain = tag["DOMAIN"].lower()
|
||||
if domain == SnowflakeObjectDomain.DATABASE:
|
||||
tags.add_database_tag(object_name, snowflake_tag)
|
||||
elif domain == SnowflakeObjectDomain.SCHEMA:
|
||||
tags.add_schema_tag(object_name, object_database, snowflake_tag)
|
||||
elif domain == SnowflakeObjectDomain.TABLE: # including views
|
||||
tags.add_table_tag(
|
||||
object_name, object_schema, object_database, snowflake_tag
|
||||
)
|
||||
elif domain == SnowflakeObjectDomain.COLUMN:
|
||||
column_name = tag["COLUMN_NAME"]
|
||||
tags.add_column_tag(
|
||||
column_name,
|
||||
object_name,
|
||||
object_schema,
|
||||
object_database,
|
||||
snowflake_tag,
|
||||
)
|
||||
else:
|
||||
# This should never happen.
|
||||
self.logger.error(f"Encountered an unexpected domain: {domain}")
|
||||
continue
|
||||
|
||||
return tags
|
||||
|
||||
def get_tags_for_object_with_propagation(
|
||||
self,
|
||||
domain: str,
|
||||
quoted_identifier: str,
|
||||
db_name: str,
|
||||
) -> List[SnowflakeTag]:
|
||||
tags: List[SnowflakeTag] = []
|
||||
|
||||
cur = self.query(
|
||||
SnowflakeQuery.get_all_tags_on_object_with_propagation(
|
||||
db_name, quoted_identifier, domain
|
||||
),
|
||||
)
|
||||
|
||||
for tag in cur:
|
||||
tags.append(
|
||||
SnowflakeTag(
|
||||
database=tag["TAG_DATABASE"],
|
||||
schema=tag["TAG_SCHEMA"],
|
||||
name=tag["TAG_NAME"],
|
||||
value=tag["TAG_VALUE"],
|
||||
)
|
||||
)
|
||||
return tags
|
||||
|
||||
def get_tags_on_columns_for_table(
|
||||
self, quoted_table_name: str, db_name: str
|
||||
) -> Dict[str, List[SnowflakeTag]]:
|
||||
tags: Dict[str, List[SnowflakeTag]] = defaultdict(list)
|
||||
cur = self.query(
|
||||
SnowflakeQuery.get_tags_on_columns_with_propagation(
|
||||
db_name, quoted_table_name
|
||||
),
|
||||
)
|
||||
|
||||
for tag in cur:
|
||||
column_name = tag["COLUMN_NAME"]
|
||||
snowflake_tag = SnowflakeTag(
|
||||
database=tag["TAG_DATABASE"],
|
||||
schema=tag["TAG_SCHEMA"],
|
||||
name=tag["TAG_NAME"],
|
||||
value=tag["TAG_VALUE"],
|
||||
)
|
||||
tags[column_name].append(snowflake_tag)
|
||||
|
||||
return tags
|
||||
|
||||
@ -0,0 +1,172 @@
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from datahub.ingestion.source.snowflake.constants import SnowflakeObjectDomain
|
||||
from datahub.ingestion.source.snowflake.snowflake_config import (
|
||||
SnowflakeV2Config,
|
||||
TagOption,
|
||||
)
|
||||
from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report
|
||||
from datahub.ingestion.source.snowflake.snowflake_schema import (
|
||||
SnowflakeDataDictionary,
|
||||
SnowflakeTag,
|
||||
_SnowflakeTagCache,
|
||||
)
|
||||
from datahub.ingestion.source.snowflake.snowflake_utils import SnowflakeCommonMixin
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SnowflakeTagExtractor(SnowflakeCommonMixin):
|
||||
def __init__(
|
||||
self,
|
||||
config: SnowflakeV2Config,
|
||||
data_dictionary: SnowflakeDataDictionary,
|
||||
report: SnowflakeV2Report,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.data_dictionary = data_dictionary
|
||||
self.report = report
|
||||
self.logger = logger
|
||||
|
||||
self.tag_cache: Dict[str, _SnowflakeTagCache] = {}
|
||||
|
||||
def _get_tags_on_object_without_propagation(
|
||||
self,
|
||||
domain: str,
|
||||
db_name: str,
|
||||
schema_name: Optional[str],
|
||||
table_name: Optional[str],
|
||||
) -> List[SnowflakeTag]:
|
||||
if db_name not in self.tag_cache:
|
||||
self.tag_cache[
|
||||
db_name
|
||||
] = self.data_dictionary.get_tags_for_database_without_propagation(db_name)
|
||||
|
||||
if domain == SnowflakeObjectDomain.DATABASE:
|
||||
return self.tag_cache[db_name].get_database_tags(db_name)
|
||||
elif domain == SnowflakeObjectDomain.SCHEMA:
|
||||
assert schema_name is not None
|
||||
tags = self.tag_cache[db_name].get_schema_tags(schema_name, db_name)
|
||||
elif (
|
||||
domain == SnowflakeObjectDomain.TABLE
|
||||
): # Views belong to this domain as well.
|
||||
assert schema_name is not None
|
||||
assert table_name is not None
|
||||
tags = self.tag_cache[db_name].get_table_tags(
|
||||
table_name, schema_name, db_name
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown domain {domain}")
|
||||
return tags
|
||||
|
||||
def _get_tags_on_object_with_propagation(
|
||||
self,
|
||||
domain: str,
|
||||
db_name: str,
|
||||
schema_name: Optional[str],
|
||||
table_name: Optional[str],
|
||||
) -> List[SnowflakeTag]:
|
||||
identifier = ""
|
||||
if domain == SnowflakeObjectDomain.DATABASE:
|
||||
identifier = self.get_quoted_identifier_for_database(db_name)
|
||||
elif domain == SnowflakeObjectDomain.SCHEMA:
|
||||
assert schema_name is not None
|
||||
identifier = self.get_quoted_identifier_for_schema(db_name, schema_name)
|
||||
elif (
|
||||
domain == SnowflakeObjectDomain.TABLE
|
||||
): # Views belong to this domain as well.
|
||||
assert schema_name is not None
|
||||
assert table_name is not None
|
||||
identifier = self.get_quoted_identifier_for_table(
|
||||
db_name, schema_name, table_name
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown domain {domain}")
|
||||
assert identifier
|
||||
|
||||
self.report.num_get_tags_for_object_queries += 1
|
||||
tags = self.data_dictionary.get_tags_for_object_with_propagation(
|
||||
domain=domain, quoted_identifier=identifier, db_name=db_name
|
||||
)
|
||||
return tags
|
||||
|
||||
def get_tags_on_object(
|
||||
self,
|
||||
domain: str,
|
||||
db_name: str,
|
||||
schema_name: Optional[str] = None,
|
||||
table_name: Optional[str] = None,
|
||||
) -> List[SnowflakeTag]:
|
||||
if self.config.extract_tags == TagOption.without_lineage:
|
||||
tags = self._get_tags_on_object_without_propagation(
|
||||
domain=domain,
|
||||
db_name=db_name,
|
||||
schema_name=schema_name,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
elif self.config.extract_tags == TagOption.with_lineage:
|
||||
tags = self._get_tags_on_object_with_propagation(
|
||||
domain=domain,
|
||||
db_name=db_name,
|
||||
schema_name=schema_name,
|
||||
table_name=table_name,
|
||||
)
|
||||
else:
|
||||
tags = []
|
||||
|
||||
allowed_tags = self._filter_tags(tags)
|
||||
|
||||
return allowed_tags if allowed_tags else []
|
||||
|
||||
def get_column_tags_for_table(
|
||||
self,
|
||||
table_name: str,
|
||||
schema_name: str,
|
||||
db_name: str,
|
||||
) -> Dict[str, List[SnowflakeTag]]:
|
||||
temp_column_tags: Dict[str, List[SnowflakeTag]] = {}
|
||||
if self.config.extract_tags == TagOption.without_lineage:
|
||||
if db_name not in self.tag_cache:
|
||||
self.tag_cache[
|
||||
db_name
|
||||
] = self.data_dictionary.get_tags_for_database_without_propagation(
|
||||
db_name
|
||||
)
|
||||
temp_column_tags = self.tag_cache[db_name].get_column_tags_for_table(
|
||||
table_name, schema_name, db_name
|
||||
)
|
||||
elif self.config.extract_tags == TagOption.with_lineage:
|
||||
self.report.num_get_tags_on_columns_for_table_queries += 1
|
||||
temp_column_tags = self.data_dictionary.get_tags_on_columns_for_table(
|
||||
quoted_table_name=self.get_quoted_identifier_for_table(
|
||||
db_name, schema_name, table_name
|
||||
),
|
||||
db_name=db_name,
|
||||
)
|
||||
|
||||
column_tags: Dict[str, List[SnowflakeTag]] = {}
|
||||
|
||||
for column_name in temp_column_tags:
|
||||
tags = temp_column_tags[column_name]
|
||||
allowed_tags = self._filter_tags(tags)
|
||||
if allowed_tags:
|
||||
column_tags[column_name] = allowed_tags
|
||||
|
||||
return column_tags
|
||||
|
||||
def _filter_tags(
|
||||
self, tags: Optional[List[SnowflakeTag]]
|
||||
) -> Optional[List[SnowflakeTag]]:
|
||||
if tags is None:
|
||||
return tags
|
||||
|
||||
allowed_tags = []
|
||||
for tag in tags:
|
||||
tag_identifier = tag.identifier()
|
||||
self.report.report_entity_scanned(tag_identifier, "tag")
|
||||
if not self.config.tag_pattern.allowed(tag_identifier):
|
||||
self.report.report_dropped(tag_identifier)
|
||||
allowed_tags.append(tag)
|
||||
return allowed_tags
|
||||
@ -158,6 +158,18 @@ class SnowflakeCommonMixin:
|
||||
return identifier.lower()
|
||||
return identifier
|
||||
|
||||
@staticmethod
|
||||
def get_quoted_identifier_for_database(db_name):
|
||||
return f'"{db_name}"'
|
||||
|
||||
@staticmethod
|
||||
def get_quoted_identifier_for_schema(db_name, schema_name):
|
||||
return f'"{db_name}"."{schema_name}"'
|
||||
|
||||
@staticmethod
|
||||
def get_quoted_identifier_for_table(db_name, schema_name, table_name):
|
||||
return f'"{db_name}"."{schema_name}"."{table_name}"'
|
||||
|
||||
def get_dataset_identifier(
|
||||
self: SnowflakeCommonProtocol, table_name: str, schema_name: str, db_name: str
|
||||
) -> str:
|
||||
|
||||
@ -15,6 +15,7 @@ from datahub.emitter.mce_builder import (
|
||||
make_dataset_urn_with_platform_instance,
|
||||
make_domain_urn,
|
||||
make_schema_field_urn,
|
||||
make_tag_urn,
|
||||
)
|
||||
from datahub.emitter.mcp import MetadataChangeProposalWrapper
|
||||
from datahub.emitter.mcp_builder import (
|
||||
@ -49,7 +50,10 @@ from datahub.ingestion.source.snowflake.constants import (
|
||||
SnowflakeEdition,
|
||||
SnowflakeObjectDomain,
|
||||
)
|
||||
from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config
|
||||
from datahub.ingestion.source.snowflake.snowflake_config import (
|
||||
SnowflakeV2Config,
|
||||
TagOption,
|
||||
)
|
||||
from datahub.ingestion.source.snowflake.snowflake_lineage import (
|
||||
SnowflakeLineageExtractor,
|
||||
)
|
||||
@ -64,8 +68,10 @@ from datahub.ingestion.source.snowflake.snowflake_schema import (
|
||||
SnowflakeQuery,
|
||||
SnowflakeSchema,
|
||||
SnowflakeTable,
|
||||
SnowflakeTag,
|
||||
SnowflakeView,
|
||||
)
|
||||
from datahub.ingestion.source.snowflake.snowflake_tag import SnowflakeTagExtractor
|
||||
from datahub.ingestion.source.snowflake.snowflake_usage_v2 import (
|
||||
SnowflakeUsageExtractor,
|
||||
)
|
||||
@ -90,8 +96,10 @@ from datahub.ingestion.source.state.stateful_ingestion_base import (
|
||||
StatefulIngestionSourceBase,
|
||||
)
|
||||
from datahub.metadata.com.linkedin.pegasus2avro.common import (
|
||||
GlobalTags,
|
||||
Status,
|
||||
SubTypes,
|
||||
TagAssociation,
|
||||
TimeStamp,
|
||||
)
|
||||
from datahub.metadata.com.linkedin.pegasus2avro.dataset import (
|
||||
@ -114,6 +122,7 @@ from datahub.metadata.com.linkedin.pegasus2avro.schema import (
|
||||
StringType,
|
||||
TimeType,
|
||||
)
|
||||
from datahub.metadata.com.linkedin.pegasus2avro.tag import TagProperties
|
||||
from datahub.metadata.schema_classes import ChangeTypeClass, DataPlatformInstanceClass
|
||||
from datahub.utilities.registries.domain_registry import DomainRegistry
|
||||
from datahub.utilities.time import datetime_to_ts_millis
|
||||
@ -188,6 +197,11 @@ SNOWFLAKE_FIELD_TYPE_MAPPINGS = {
|
||||
"Optionally enabled via `stateful_ingestion.remove_stale_metadata`",
|
||||
supported=True,
|
||||
)
|
||||
@capability(
|
||||
SourceCapability.TAGS,
|
||||
"Optionally enabled via `extract_tags`",
|
||||
supported=True,
|
||||
)
|
||||
class SnowflakeV2Source(
|
||||
ClassificationMixin,
|
||||
SnowflakeQueryMixin,
|
||||
@ -235,6 +249,10 @@ class SnowflakeV2Source(
|
||||
# For usage stats
|
||||
self.usage_extractor = SnowflakeUsageExtractor(config, self.report)
|
||||
|
||||
self.tag_extractor = SnowflakeTagExtractor(
|
||||
config, self.data_dictionary, self.report
|
||||
)
|
||||
|
||||
self.profiling_state_handler: Optional[ProfilingHandler] = None
|
||||
if self.config.store_last_profiling_timestamps:
|
||||
self.profiling_state_handler = ProfilingHandler(
|
||||
@ -358,6 +376,7 @@ class SnowflakeV2Source(
|
||||
_report[SourceCapability.CONTAINERS] = CapabilityReport(
|
||||
capable=True
|
||||
)
|
||||
_report[SourceCapability.TAGS] = CapabilityReport(capable=True)
|
||||
elif privilege.object_type in (
|
||||
"TABLE",
|
||||
"VIEW",
|
||||
@ -391,6 +410,8 @@ class SnowflakeV2Source(
|
||||
_report[SourceCapability.USAGE_STATS] = CapabilityReport(
|
||||
capable=True
|
||||
)
|
||||
_report[SourceCapability.TAGS] = CapabilityReport(capable=True)
|
||||
|
||||
# If all capabilities supported, no need to continue
|
||||
if set(capabilities) == set(_report.keys()):
|
||||
break
|
||||
@ -414,6 +435,7 @@ class SnowflakeV2Source(
|
||||
SourceCapability.LINEAGE_COARSE: "Current role does not have permissions to snowflake account usage views",
|
||||
SourceCapability.LINEAGE_FINE: "Current role does not have permissions to snowflake account usage views",
|
||||
SourceCapability.USAGE_STATS: "Current role does not have permissions to snowflake account usage views",
|
||||
SourceCapability.TAGS: "Either no tags have been applied to objects, or the current role does not have permission to access the objects or to snowflake account usage views ",
|
||||
}
|
||||
|
||||
for c in capabilities: # type:ignore
|
||||
@ -425,6 +447,7 @@ class SnowflakeV2Source(
|
||||
SourceCapability.LINEAGE_COARSE,
|
||||
SourceCapability.LINEAGE_FINE,
|
||||
SourceCapability.USAGE_STATS,
|
||||
SourceCapability.TAGS,
|
||||
):
|
||||
failure_message = (
|
||||
f"Current role {current_role} does not have permissions to use warehouse {connection_conf.warehouse}. Please check the grants associated with this role."
|
||||
@ -471,6 +494,7 @@ class SnowflakeV2Source(
|
||||
for snowflake_db in databases:
|
||||
try:
|
||||
yield from self._process_database(snowflake_db)
|
||||
|
||||
except SnowflakePermissionError as e:
|
||||
# FIXME - This may break satetful ingestion if new tables than previous run are emitted above
|
||||
# and stateful ingestion is enabled
|
||||
@ -627,11 +651,20 @@ class SnowflakeV2Source(
|
||||
)
|
||||
return
|
||||
|
||||
if self.config.extract_tags != TagOption.skip:
|
||||
snowflake_db.tags = self.tag_extractor.get_tags_on_object(
|
||||
domain="database", db_name=db_name
|
||||
)
|
||||
|
||||
if self.config.include_technical_schema:
|
||||
yield from self.gen_database_containers(snowflake_db)
|
||||
|
||||
self.fetch_schemas_for_database(snowflake_db, db_name)
|
||||
|
||||
if self.config.include_technical_schema and snowflake_db.tags:
|
||||
for tag in snowflake_db.tags:
|
||||
yield from self._process_tag(tag)
|
||||
|
||||
for snowflake_schema in snowflake_db.schemas:
|
||||
yield from self._process_schema(snowflake_schema, db_name)
|
||||
|
||||
@ -675,6 +708,12 @@ class SnowflakeV2Source(
|
||||
return
|
||||
|
||||
schema_name = snowflake_schema.name
|
||||
|
||||
if self.config.extract_tags != TagOption.skip:
|
||||
snowflake_schema.tags = self.tag_extractor.get_tags_on_object(
|
||||
schema_name=schema_name, db_name=db_name, domain="schema"
|
||||
)
|
||||
|
||||
if self.config.include_technical_schema:
|
||||
yield from self.gen_schema_containers(snowflake_schema, db_name)
|
||||
|
||||
@ -692,6 +731,10 @@ class SnowflakeV2Source(
|
||||
for view in snowflake_schema.views:
|
||||
yield from self._process_view(view, schema_name, db_name)
|
||||
|
||||
if self.config.include_technical_schema and snowflake_schema.tags:
|
||||
for tag in snowflake_schema.tags:
|
||||
yield from self._process_tag(tag)
|
||||
|
||||
if not snowflake_schema.views and not snowflake_schema.tables:
|
||||
self.report_warning(
|
||||
"No tables/views found in schema. If tables exist, please grant REFERENCES or SELECT permissions on them.",
|
||||
@ -762,6 +805,22 @@ class SnowflakeV2Source(
|
||||
table, schema_name, db_name, dataset_name
|
||||
)
|
||||
|
||||
if self.config.extract_tags != TagOption.skip:
|
||||
table.tags = self.tag_extractor.get_tags_on_object(
|
||||
table_name=table.name,
|
||||
schema_name=schema_name,
|
||||
db_name=db_name,
|
||||
domain="table",
|
||||
)
|
||||
|
||||
if self.config.include_technical_schema:
|
||||
if table.tags:
|
||||
for tag in table.tags:
|
||||
yield from self._process_tag(tag)
|
||||
for column_name in table.column_tags:
|
||||
for tag in table.column_tags[column_name]:
|
||||
yield from self._process_tag(tag)
|
||||
|
||||
yield from self.gen_dataset_workunits(table, schema_name, db_name)
|
||||
|
||||
def fetch_sample_data_for_classification(
|
||||
@ -817,6 +876,10 @@ class SnowflakeV2Source(
|
||||
def fetch_columns_for_table(self, table, schema_name, db_name, table_identifier):
|
||||
try:
|
||||
table.columns = self.get_columns_for_table(table.name, schema_name, db_name)
|
||||
if self.config.extract_tags != TagOption.skip:
|
||||
table.column_tags = self.tag_extractor.get_column_tags_for_table(
|
||||
table.name, schema_name, db_name
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Failed to get columns for table {table_identifier} due to error {e}",
|
||||
@ -840,6 +903,10 @@ class SnowflakeV2Source(
|
||||
|
||||
try:
|
||||
view.columns = self.get_columns_for_table(view.name, schema_name, db_name)
|
||||
if self.config.extract_tags != TagOption.skip:
|
||||
view.column_tags = self.tag_extractor.get_column_tags_for_table(
|
||||
view.name, schema_name, db_name
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Failed to get columns for view {view_name} due to error {e}",
|
||||
@ -847,8 +914,34 @@ class SnowflakeV2Source(
|
||||
)
|
||||
self.report_warning("Failed to get columns for view", view_name)
|
||||
|
||||
if self.config.extract_tags != TagOption.skip:
|
||||
view.tags = self.tag_extractor.get_tags_on_object(
|
||||
table_name=view.name,
|
||||
schema_name=schema_name,
|
||||
db_name=db_name,
|
||||
domain="table",
|
||||
)
|
||||
|
||||
if self.config.include_technical_schema:
|
||||
if view.tags:
|
||||
for tag in view.tags:
|
||||
yield from self._process_tag(tag)
|
||||
for column_name in view.column_tags:
|
||||
for tag in view.column_tags[column_name]:
|
||||
yield from self._process_tag(tag)
|
||||
|
||||
yield from self.gen_dataset_workunits(view, schema_name, db_name)
|
||||
|
||||
def _process_tag(self, tag: SnowflakeTag) -> Iterable[MetadataWorkUnit]:
|
||||
tag_identifier = tag.identifier()
|
||||
|
||||
if self.report.is_tag_processed(tag_identifier):
|
||||
return
|
||||
|
||||
self.report.report_tag_processed(tag_identifier)
|
||||
|
||||
yield from self.gen_tag_workunits(tag)
|
||||
|
||||
def gen_dataset_workunits(
|
||||
self,
|
||||
table: Union[SnowflakeTable, SnowflakeView],
|
||||
@ -908,6 +1001,15 @@ class SnowflakeV2Source(
|
||||
entity_type="dataset",
|
||||
)
|
||||
|
||||
if table.tags:
|
||||
tag_associations = [
|
||||
TagAssociation(tag=make_tag_urn(tag.identifier())) for tag in table.tags
|
||||
]
|
||||
global_tags = GlobalTags(tag_associations)
|
||||
yield self.wrap_aspect_as_workunit(
|
||||
"dataset", dataset_urn, "globalTags", global_tags
|
||||
)
|
||||
|
||||
if (
|
||||
isinstance(table, SnowflakeView)
|
||||
and cast(SnowflakeView, table).view_definition is not None
|
||||
@ -951,6 +1053,21 @@ class SnowflakeV2Source(
|
||||
else None,
|
||||
)
|
||||
|
||||
def gen_tag_workunits(self, tag: SnowflakeTag) -> Iterable[MetadataWorkUnit]:
|
||||
tag_key = tag.identifier()
|
||||
tag_urn = make_tag_urn(self.snowflake_identifier(tag_key))
|
||||
|
||||
tag_properties_aspect = TagProperties(
|
||||
name=tag_key,
|
||||
description=f"Represents the Snowflake tag `{tag._id_prefix_as_str()}` with value `{tag.value}`.",
|
||||
)
|
||||
|
||||
self.stale_entity_removal_handler.add_entity_to_state("tag", tag_urn)
|
||||
|
||||
yield self.wrap_aspect_as_workunit(
|
||||
"tag", tag_urn, "tagProperties", tag_properties_aspect
|
||||
)
|
||||
|
||||
def get_schema_metadata(
|
||||
self,
|
||||
table: Union[SnowflakeTable, SnowflakeView],
|
||||
@ -980,6 +1097,18 @@ class SnowflakeV2Source(
|
||||
isPartOfKey=col.name in table.pk.column_names
|
||||
if isinstance(table, SnowflakeTable) and table.pk is not None
|
||||
else None,
|
||||
globalTags=GlobalTags(
|
||||
[
|
||||
TagAssociation(
|
||||
make_tag_urn(
|
||||
self.snowflake_identifier(tag.identifier())
|
||||
)
|
||||
)
|
||||
for tag in table.column_tags[col.name]
|
||||
]
|
||||
)
|
||||
if col.name in table.column_tags
|
||||
else None,
|
||||
)
|
||||
for col in table.columns
|
||||
],
|
||||
@ -1168,6 +1297,9 @@ class SnowflakeV2Source(
|
||||
else int(database.created.timestamp() * 1000)
|
||||
if database.created is not None
|
||||
else None,
|
||||
tags=[self.snowflake_identifier(tag.identifier()) for tag in database.tags]
|
||||
if database.tags
|
||||
else None,
|
||||
)
|
||||
|
||||
self.stale_entity_removal_handler.add_entity_to_state(
|
||||
@ -1215,6 +1347,9 @@ class SnowflakeV2Source(
|
||||
else int(schema.created.timestamp() * 1000)
|
||||
if schema.created is not None
|
||||
else None,
|
||||
tags=[self.snowflake_identifier(tag.identifier()) for tag in schema.tags]
|
||||
if schema.tags
|
||||
else None,
|
||||
)
|
||||
|
||||
for wu in container_workunits:
|
||||
|
||||
@ -352,5 +352,60 @@ def default_query_results(query): # noqa: C901
|
||||
]:
|
||||
return []
|
||||
|
||||
elif (
|
||||
query
|
||||
== snowflake_query.SnowflakeQuery.get_all_tags_in_database_without_propagation(
|
||||
"TEST_DB"
|
||||
)
|
||||
):
|
||||
return [
|
||||
*[
|
||||
{
|
||||
"TAG_DATABASE": "TEST_DB",
|
||||
"TAG_SCHEMA": "TEST_SCHEMA",
|
||||
"TAG_NAME": f"my_tag_{ix}",
|
||||
"TAG_VALUE": f"my_value_{ix}",
|
||||
"OBJECT_DATABASE": "TEST_DB",
|
||||
"OBJECT_SCHEMA": "TEST_SCHEMA",
|
||||
"OBJECT_NAME": "VIEW_2",
|
||||
"COLUMN_NAME": None,
|
||||
"DOMAIN": "TABLE",
|
||||
}
|
||||
for ix in range(3)
|
||||
],
|
||||
{
|
||||
"TAG_DATABASE": "TEST_DB",
|
||||
"TAG_SCHEMA": "TEST_SCHEMA",
|
||||
"TAG_NAME": "security",
|
||||
"TAG_VALUE": "pii",
|
||||
"OBJECT_DATABASE": "TEST_DB",
|
||||
"OBJECT_SCHEMA": "TEST_SCHEMA",
|
||||
"OBJECT_NAME": "VIEW_1",
|
||||
"COLUMN_NAME": "COL_1",
|
||||
"DOMAIN": "COLUMN",
|
||||
},
|
||||
{
|
||||
"TAG_DATABASE": "OTHER_DB",
|
||||
"TAG_SCHEMA": "OTHER_SCHEMA",
|
||||
"TAG_NAME": "my_other_tag",
|
||||
"TAG_VALUE": "other",
|
||||
"OBJECT_DATABASE": "TEST_DB",
|
||||
"OBJECT_SCHEMA": None,
|
||||
"OBJECT_NAME": "TEST_SCHEMA",
|
||||
"COLUMN_NAME": None,
|
||||
"DOMAIN": "SCHEMA",
|
||||
},
|
||||
{
|
||||
"TAG_DATABASE": "OTHER_DB",
|
||||
"TAG_SCHEMA": "OTHER_SCHEMA",
|
||||
"TAG_NAME": "my_other_tag",
|
||||
"TAG_VALUE": "other",
|
||||
"OBJECT_DATABASE": None,
|
||||
"OBJECT_SCHEMA": None,
|
||||
"OBJECT_NAME": "TEST_DB",
|
||||
"COLUMN_NAME": None,
|
||||
"DOMAIN": "DATABASE",
|
||||
},
|
||||
]
|
||||
# Unreachable code
|
||||
raise Exception(f"Unknown query {query}")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -19,7 +19,10 @@ from datahub.ingestion.glossary.datahub_classifier import (
|
||||
from datahub.ingestion.run.pipeline import Pipeline
|
||||
from datahub.ingestion.run.pipeline_config import PipelineConfig, SourceConfig
|
||||
from datahub.ingestion.source.ge_profiling_config import GEProfilingConfig
|
||||
from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config
|
||||
from datahub.ingestion.source.snowflake.snowflake_config import (
|
||||
SnowflakeV2Config,
|
||||
TagOption,
|
||||
)
|
||||
from tests.integration.snowflake.common import FROZEN_TIME, default_query_results
|
||||
from tests.test_helpers import mce_helpers
|
||||
|
||||
@ -109,6 +112,7 @@ def test_snowflake_basic(pytestconfig, tmp_path, mock_time, mock_datahub_graph):
|
||||
profile_table_size_limit=None,
|
||||
profile_table_level_only=True,
|
||||
),
|
||||
extract_tags=TagOption.without_lineage,
|
||||
),
|
||||
),
|
||||
sink=DynamicTypedConfig(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user