Fixes #22112: Snowflake schema tags inheritance (#22979)

* Add schema-level tags and tag inheritance support for snowflake

* Add tests for schema tag inheritance

* Lint fixes
This commit is contained in:
Mohit Tilala 2025-08-20 09:52:44 +05:30 committed by GitHub
parent 3bd3158bee
commit 26fedbaf0e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 186 additions and 8 deletions

View File

@ -46,3 +46,6 @@ DEFAULT_STREAM_COLUMNS = [
"primary_key": False,
},
]
SNOWFLAKE_TAG_DESCRIPTION = "SNOWFLAKE TAG VALUE"
SNOWFLAKE_CLASSIFICATION_DESCRIPTION = "SNOWFLAKE TAG NAME"

View File

@ -54,6 +54,7 @@ from metadata.generated.schema.type.basic import (
SourceUrl,
)
from metadata.generated.schema.type.entityReferenceList import EntityReferenceList
from metadata.generated.schema.type.tagLabel import TagLabel
from metadata.ingestion.api.delete import delete_entity_by_name
from metadata.ingestion.api.models import Either
from metadata.ingestion.api.steps import InvalidSourceException
@ -73,6 +74,8 @@ from metadata.ingestion.source.database.incremental_metadata_extraction import (
from metadata.ingestion.source.database.multi_db_source import MultiDBSource
from metadata.ingestion.source.database.snowflake.constants import (
DEFAULT_STREAM_COLUMNS,
SNOWFLAKE_CLASSIFICATION_DESCRIPTION,
SNOWFLAKE_TAG_DESCRIPTION,
)
from metadata.ingestion.source.database.snowflake.models import (
STORED_PROC_LANGUAGE_MAP,
@ -81,7 +84,8 @@ from metadata.ingestion.source.database.snowflake.models import (
from metadata.ingestion.source.database.snowflake.queries import (
SNOWFLAKE_DESC_FUNCTION,
SNOWFLAKE_DESC_STORED_PROCEDURE,
SNOWFLAKE_FETCH_ALL_TAGS,
SNOWFLAKE_FETCH_SCHEMA_TAGS,
SNOWFLAKE_FETCH_TABLE_TAGS,
SNOWFLAKE_GET_CLUSTER_KEY,
SNOWFLAKE_GET_CURRENT_ACCOUNT,
SNOWFLAKE_GET_DATABASE_COMMENTS,
@ -123,7 +127,7 @@ from metadata.utils.sqlalchemy_utils import (
get_all_table_ddls,
get_all_view_definitions,
)
from metadata.utils.tag_utils import get_ometa_tag_and_classification
from metadata.utils.tag_utils import get_ometa_tag_and_classification, get_tag_label
class MAP(StructuredType):
@ -203,6 +207,7 @@ class SnowflakeSource(
self.schema_desc_map = {}
self.database_desc_map = {}
self.external_location_map = {}
self.schema_tags_map = {}
self._account: Optional[str] = None
self._org_name: Optional[str] = None
@ -299,6 +304,32 @@ class SnowflakeSource(
for row in results
}
def set_schema_tags_map(self, database_name: str) -> None:
"""Fetch and store all schema-level tags for the current database"""
self.schema_tags_map.clear()
if not self.source_config.includeTags:
return
try:
results = self.engine.execute(
SNOWFLAKE_FETCH_SCHEMA_TAGS.format(
database_name=database_name,
account_usage=self.service_connection.accountUsageSchema,
)
).all()
for row in results:
schema_name = row.SCHEMA_NAME
if schema_name not in self.schema_tags_map:
self.schema_tags_map[schema_name] = []
self.schema_tags_map[schema_name].append(
{"tag_name": row.TAG_NAME, "tag_value": row.TAG_VALUE}
)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(f"Failed to fetch schema tags: {exc}")
def get_schema_description(self, schema_name: str) -> Optional[str]:
"""
Method to fetch the schema description
@ -329,6 +360,7 @@ class SnowflakeSource(
self.set_schema_description_map()
self.set_database_description_map()
self.set_external_location_map(configured_db)
self.set_schema_tags_map(configured_db)
yield configured_db
else:
for new_database in self.get_database_names_raw():
@ -357,6 +389,7 @@ class SnowflakeSource(
self.set_schema_description_map()
self.set_database_description_map()
self.set_external_location_map(new_database)
self.set_schema_tags_map(new_database)
yield new_database
except Exception as exc:
logger.debug(traceback.format_exc())
@ -451,11 +484,14 @@ class SnowflakeSource(
def yield_tag(
self, schema_name: str
) -> Iterable[Either[OMetaTagAndClassification]]:
"""
Yield tags for tables/columns and schemas.
"""
if self.source_config.includeTags:
result = []
try:
result = self.connection.execute(
SNOWFLAKE_FETCH_ALL_TAGS.format(
SNOWFLAKE_FETCH_TABLE_TAGS.format(
database_name=self.context.get().database,
schema_name=schema_name,
account_usage=self.service_connection.accountUsageSchema,
@ -469,7 +505,7 @@ class SnowflakeSource(
f"Error fetching tags {exc}. Trying with quoted names"
)
result = self.connection.execute(
SNOWFLAKE_FETCH_ALL_TAGS.format(
SNOWFLAKE_FETCH_TABLE_TAGS.format(
database_name=f'"{self.context.get().database}"',
schema_name=f'"{self.context.get().database_schema}"',
account_usage=self.service_connection.accountUsageSchema,
@ -495,12 +531,32 @@ class SnowflakeSource(
),
tags=[row[1]],
classification_name=row[0],
tag_description="SNOWFLAKE TAG VALUE",
classification_description="SNOWFLAKE TAG NAME",
tag_description=SNOWFLAKE_TAG_DESCRIPTION,
classification_description=SNOWFLAKE_CLASSIFICATION_DESCRIPTION,
metadata=self.metadata,
system_tags=True,
)
# Yield schema-level tags
if schema_name in self.schema_tags_map:
schema_fqn = fqn.build(
self.metadata,
entity_type=DatabaseSchema,
service_name=self.context.get().database_service,
database_name=self.context.get().database,
schema_name=schema_name,
)
for tag_info in self.schema_tags_map[schema_name]:
yield from get_ometa_tag_and_classification(
tag_fqn=FullyQualifiedEntityName(schema_fqn),
tags=[tag_info["tag_value"]],
classification_name=tag_info["tag_name"],
tag_description=SNOWFLAKE_TAG_DESCRIPTION,
classification_description=SNOWFLAKE_CLASSIFICATION_DESCRIPTION,
metadata=self.metadata,
system_tags=True,
)
def _get_table_names_and_types(
self, schema_name: str, table_type: TableType = TableType.Regular
) -> List[TableNameAndType]:
@ -945,3 +1001,49 @@ class SnowflakeSource(
logger.debug(
f"Processing ownership is not supported for {self.service_connection.type.name}"
)
def get_schema_tag_labels(self, schema_name: str) -> Optional[List[TagLabel]]:
"""
Return tags for schema entity including Snowflake schema-level tags.
"""
schema_tags = []
if schema_name in self.schema_tags_map:
for tag_info in self.schema_tags_map[schema_name]:
tag_label = get_tag_label(
metadata=self.metadata,
tag_name=tag_info["tag_value"],
classification_name=tag_info["tag_name"],
)
if tag_label:
schema_tags.append(tag_label)
# Include parent tags from context
parent_tags = super().get_schema_tag_labels(schema_name) or []
for tag in parent_tags:
if tag not in schema_tags:
schema_tags.append(tag)
return schema_tags if schema_tags else None
def get_tag_labels(self, table_name: str) -> Optional[List[TagLabel]]:
"""
Override to include schema-level tags inherited by tables.
This method combines:
1. Tags directly assigned to the table (from parent implementation)
2. Tags inherited from the schema level
"""
table_tags = super().get_tag_labels(table_name) or []
schema_name = self.context.get().database_schema
if schema_name and schema_name in self.schema_tags_map:
for tag_info in self.schema_tags_map[schema_name]:
tag_label = get_tag_label(
metadata=self.metadata,
tag_name=tag_info["tag_value"],
classification_name=tag_info["tag_name"],
)
if tag_label and tag_label not in table_tags:
table_tags.append(tag_label)
return table_tags if table_tags else None

View File

@ -37,7 +37,7 @@ SNOWFLAKE_SQL_STATEMENT = textwrap.dedent(
SNOWFLAKE_SESSION_TAG_QUERY = 'ALTER SESSION SET QUERY_TAG="{query_tag}"'
SNOWFLAKE_FETCH_ALL_TAGS = textwrap.dedent(
SNOWFLAKE_FETCH_TABLE_TAGS = textwrap.dedent(
"""
select TAG_NAME, TAG_VALUE, OBJECT_DATABASE, OBJECT_SCHEMA, OBJECT_NAME, COLUMN_NAME
from {account_usage}.tag_references
@ -46,6 +46,18 @@ SNOWFLAKE_FETCH_ALL_TAGS = textwrap.dedent(
"""
)
SNOWFLAKE_FETCH_SCHEMA_TAGS = textwrap.dedent(
"""
select TAG_NAME, TAG_VALUE, OBJECT_NAME as SCHEMA_NAME
from {account_usage}.tag_references
where OBJECT_DATABASE = '{database_name}'
and OBJECT_SCHEMA IS NULL
and OBJECT_NAME IS NOT NULL
and COLUMN_NAME IS NULL
and DOMAIN = 'SCHEMA'
"""
)
SNOWFLAKE_GET_EXTERNAL_TABLE_NAMES = """
select TABLE_NAME, NULL from information_schema.tables
where TABLE_SCHEMA = '{schema}' AND TABLE_TYPE = 'EXTERNAL TABLE'

View File

@ -14,7 +14,7 @@ snowflake unit tests
"""
# pylint: disable=line-too-long
from unittest import TestCase
from unittest.mock import PropertyMock, patch
from unittest.mock import Mock, PropertyMock, patch
import sqlalchemy.types as sqltypes
@ -25,6 +25,12 @@ from metadata.generated.schema.entity.services.ingestionPipelines.ingestionPipel
from metadata.generated.schema.metadataIngestion.workflow import (
OpenMetadataWorkflowConfig,
)
from metadata.generated.schema.type.tagLabel import (
LabelType,
State,
TagLabel,
TagSource,
)
from metadata.ingestion.source.database.snowflake.metadata import MAP, SnowflakeSource
from metadata.ingestion.source.database.snowflake.models import SnowflakeStoredProcedure
@ -305,3 +311,58 @@ class SnowflakeUnitTest(TestCase):
self.assertEqual(map_type.key_type, key_type)
self.assertEqual(map_type.value_type, sqltypes.VARCHAR) # default
self.assertFalse(map_type.not_null) # default
@patch(
"metadata.ingestion.source.database.database_service.DatabaseServiceSource.get_tag_labels"
)
@patch("metadata.ingestion.source.database.snowflake.metadata.get_tag_label")
def test_schema_tag_inheritance(
self, mock_get_tag_label, mock_parent_get_tag_labels
):
"""Test schema tag inheritance"""
for source in self.sources.values():
# Verify tags are fetched and stored
mock_schema_tags = [
Mock(
SCHEMA_NAME="TEST_SCHEMA", TAG_NAME="SCHEMA_TAG", TAG_VALUE="VALUE"
),
]
mock_execute = Mock()
mock_execute.all.return_value = mock_schema_tags
source.engine.execute = Mock(return_value=mock_execute)
source.set_schema_tags_map("TEST_DATABASE")
self.assertEqual(len(source.schema_tags_map["TEST_SCHEMA"]), 1)
self.assertEqual(
source.schema_tags_map["TEST_SCHEMA"][0],
{"tag_name": "SCHEMA_TAG", "tag_value": "VALUE"},
)
# Verify schema tag labels
mock_get_tag_label.return_value = TagLabel(
tagFQN="SnowflakeTag.SCHEMA_TAG",
labelType=LabelType.Automated,
state=State.Suggested,
source=TagSource.Classification,
)
schema_labels = source.get_schema_tag_labels(schema_name="TEST_SCHEMA")
self.assertIsNotNone(schema_labels)
self.assertEqual(len(schema_labels), 1)
# Verify tag inheritance
source.context.get().__dict__["database_schema"] = "TEST_SCHEMA"
mock_parent_get_tag_labels.return_value = [
TagLabel(
tagFQN="SnowflakeTag.TABLE_TAG",
labelType=LabelType.Automated,
state=State.Suggested,
source=TagSource.Classification,
)
]
table_labels = source.get_tag_labels(table_name="TEST_TABLE")
self.assertEqual(len(table_labels), 2)
tag_fqns = [tag.tagFQN.root for tag in table_labels]
self.assertIn("SnowflakeTag.SCHEMA_TAG", tag_fqns)
self.assertIn("SnowflakeTag.TABLE_TAG", tag_fqns)