diff --git a/ingestion/src/metadata/ingestion/source/database/snowflake/constants.py b/ingestion/src/metadata/ingestion/source/database/snowflake/constants.py index 0b2391fb847..0c6739f88ee 100644 --- a/ingestion/src/metadata/ingestion/source/database/snowflake/constants.py +++ b/ingestion/src/metadata/ingestion/source/database/snowflake/constants.py @@ -46,3 +46,6 @@ DEFAULT_STREAM_COLUMNS = [ "primary_key": False, }, ] + +SNOWFLAKE_TAG_DESCRIPTION = "SNOWFLAKE TAG VALUE" +SNOWFLAKE_CLASSIFICATION_DESCRIPTION = "SNOWFLAKE TAG NAME" diff --git a/ingestion/src/metadata/ingestion/source/database/snowflake/metadata.py b/ingestion/src/metadata/ingestion/source/database/snowflake/metadata.py index 57b9eea8bce..4568ce34200 100644 --- a/ingestion/src/metadata/ingestion/source/database/snowflake/metadata.py +++ b/ingestion/src/metadata/ingestion/source/database/snowflake/metadata.py @@ -54,6 +54,7 @@ from metadata.generated.schema.type.basic import ( SourceUrl, ) from metadata.generated.schema.type.entityReferenceList import EntityReferenceList +from metadata.generated.schema.type.tagLabel import TagLabel from metadata.ingestion.api.delete import delete_entity_by_name from metadata.ingestion.api.models import Either from metadata.ingestion.api.steps import InvalidSourceException @@ -73,6 +74,8 @@ from metadata.ingestion.source.database.incremental_metadata_extraction import ( from metadata.ingestion.source.database.multi_db_source import MultiDBSource from metadata.ingestion.source.database.snowflake.constants import ( DEFAULT_STREAM_COLUMNS, + SNOWFLAKE_CLASSIFICATION_DESCRIPTION, + SNOWFLAKE_TAG_DESCRIPTION, ) from metadata.ingestion.source.database.snowflake.models import ( STORED_PROC_LANGUAGE_MAP, @@ -81,7 +84,8 @@ from metadata.ingestion.source.database.snowflake.models import ( from metadata.ingestion.source.database.snowflake.queries import ( SNOWFLAKE_DESC_FUNCTION, SNOWFLAKE_DESC_STORED_PROCEDURE, - SNOWFLAKE_FETCH_ALL_TAGS, + SNOWFLAKE_FETCH_SCHEMA_TAGS, + SNOWFLAKE_FETCH_TABLE_TAGS, SNOWFLAKE_GET_CLUSTER_KEY, SNOWFLAKE_GET_CURRENT_ACCOUNT, SNOWFLAKE_GET_DATABASE_COMMENTS, @@ -123,7 +127,7 @@ from metadata.utils.sqlalchemy_utils import ( get_all_table_ddls, get_all_view_definitions, ) -from metadata.utils.tag_utils import get_ometa_tag_and_classification +from metadata.utils.tag_utils import get_ometa_tag_and_classification, get_tag_label class MAP(StructuredType): @@ -203,6 +207,7 @@ class SnowflakeSource( self.schema_desc_map = {} self.database_desc_map = {} self.external_location_map = {} + self.schema_tags_map = {} self._account: Optional[str] = None self._org_name: Optional[str] = None @@ -299,6 +304,32 @@ class SnowflakeSource( for row in results } + def set_schema_tags_map(self, database_name: str) -> None: + """Fetch and store all schema-level tags for the current database""" + self.schema_tags_map.clear() + if not self.source_config.includeTags: + return + + try: + results = self.engine.execute( + SNOWFLAKE_FETCH_SCHEMA_TAGS.format( + database_name=database_name, + account_usage=self.service_connection.accountUsageSchema, + ) + ).all() + + for row in results: + schema_name = row.SCHEMA_NAME + if schema_name not in self.schema_tags_map: + self.schema_tags_map[schema_name] = [] + self.schema_tags_map[schema_name].append( + {"tag_name": row.TAG_NAME, "tag_value": row.TAG_VALUE} + ) + + except Exception as exc: + logger.debug(traceback.format_exc()) + logger.warning(f"Failed to fetch schema tags: {exc}") + def get_schema_description(self, schema_name: str) -> Optional[str]: """ Method to fetch the schema description @@ -329,6 +360,7 @@ class SnowflakeSource( self.set_schema_description_map() self.set_database_description_map() self.set_external_location_map(configured_db) + self.set_schema_tags_map(configured_db) yield configured_db else: for new_database in self.get_database_names_raw(): @@ -357,6 +389,7 @@ class SnowflakeSource( self.set_schema_description_map() self.set_database_description_map() self.set_external_location_map(new_database) + self.set_schema_tags_map(new_database) yield new_database except Exception as exc: logger.debug(traceback.format_exc()) @@ -451,11 +484,14 @@ class SnowflakeSource( def yield_tag( self, schema_name: str ) -> Iterable[Either[OMetaTagAndClassification]]: + """ + Yield tags for tables/columns and schemas. + """ if self.source_config.includeTags: result = [] try: result = self.connection.execute( - SNOWFLAKE_FETCH_ALL_TAGS.format( + SNOWFLAKE_FETCH_TABLE_TAGS.format( database_name=self.context.get().database, schema_name=schema_name, account_usage=self.service_connection.accountUsageSchema, @@ -469,7 +505,7 @@ class SnowflakeSource( f"Error fetching tags {exc}. Trying with quoted names" ) result = self.connection.execute( - SNOWFLAKE_FETCH_ALL_TAGS.format( + SNOWFLAKE_FETCH_TABLE_TAGS.format( database_name=f'"{self.context.get().database}"', schema_name=f'"{self.context.get().database_schema}"', account_usage=self.service_connection.accountUsageSchema, @@ -495,12 +531,32 @@ class SnowflakeSource( ), tags=[row[1]], classification_name=row[0], - tag_description="SNOWFLAKE TAG VALUE", - classification_description="SNOWFLAKE TAG NAME", + tag_description=SNOWFLAKE_TAG_DESCRIPTION, + classification_description=SNOWFLAKE_CLASSIFICATION_DESCRIPTION, metadata=self.metadata, system_tags=True, ) + # Yield schema-level tags + if schema_name in self.schema_tags_map: + schema_fqn = fqn.build( + self.metadata, + entity_type=DatabaseSchema, + service_name=self.context.get().database_service, + database_name=self.context.get().database, + schema_name=schema_name, + ) + for tag_info in self.schema_tags_map[schema_name]: + yield from get_ometa_tag_and_classification( + tag_fqn=FullyQualifiedEntityName(schema_fqn), + tags=[tag_info["tag_value"]], + classification_name=tag_info["tag_name"], + tag_description=SNOWFLAKE_TAG_DESCRIPTION, + classification_description=SNOWFLAKE_CLASSIFICATION_DESCRIPTION, + metadata=self.metadata, + system_tags=True, + ) + def _get_table_names_and_types( self, schema_name: str, table_type: TableType = TableType.Regular ) -> List[TableNameAndType]: @@ -945,3 +1001,49 @@ class SnowflakeSource( logger.debug( f"Processing ownership is not supported for {self.service_connection.type.name}" ) + + def get_schema_tag_labels(self, schema_name: str) -> Optional[List[TagLabel]]: + """ + Return tags for schema entity including Snowflake schema-level tags. + """ + schema_tags = [] + + if schema_name in self.schema_tags_map: + for tag_info in self.schema_tags_map[schema_name]: + tag_label = get_tag_label( + metadata=self.metadata, + tag_name=tag_info["tag_value"], + classification_name=tag_info["tag_name"], + ) + if tag_label: + schema_tags.append(tag_label) + + # Include parent tags from context + parent_tags = super().get_schema_tag_labels(schema_name) or [] + for tag in parent_tags: + if tag not in schema_tags: + schema_tags.append(tag) + + return schema_tags if schema_tags else None + + def get_tag_labels(self, table_name: str) -> Optional[List[TagLabel]]: + """ + Override to include schema-level tags inherited by tables. + This method combines: + 1. Tags directly assigned to the table (from parent implementation) + 2. Tags inherited from the schema level + """ + table_tags = super().get_tag_labels(table_name) or [] + + schema_name = self.context.get().database_schema + if schema_name and schema_name in self.schema_tags_map: + for tag_info in self.schema_tags_map[schema_name]: + tag_label = get_tag_label( + metadata=self.metadata, + tag_name=tag_info["tag_value"], + classification_name=tag_info["tag_name"], + ) + if tag_label and tag_label not in table_tags: + table_tags.append(tag_label) + + return table_tags if table_tags else None diff --git a/ingestion/src/metadata/ingestion/source/database/snowflake/queries.py b/ingestion/src/metadata/ingestion/source/database/snowflake/queries.py index 18d4078567a..3eeac74aef5 100644 --- a/ingestion/src/metadata/ingestion/source/database/snowflake/queries.py +++ b/ingestion/src/metadata/ingestion/source/database/snowflake/queries.py @@ -37,7 +37,7 @@ SNOWFLAKE_SQL_STATEMENT = textwrap.dedent( SNOWFLAKE_SESSION_TAG_QUERY = 'ALTER SESSION SET QUERY_TAG="{query_tag}"' -SNOWFLAKE_FETCH_ALL_TAGS = textwrap.dedent( +SNOWFLAKE_FETCH_TABLE_TAGS = textwrap.dedent( """ select TAG_NAME, TAG_VALUE, OBJECT_DATABASE, OBJECT_SCHEMA, OBJECT_NAME, COLUMN_NAME from {account_usage}.tag_references @@ -46,6 +46,18 @@ SNOWFLAKE_FETCH_ALL_TAGS = textwrap.dedent( """ ) +SNOWFLAKE_FETCH_SCHEMA_TAGS = textwrap.dedent( + """ + select TAG_NAME, TAG_VALUE, OBJECT_NAME as SCHEMA_NAME + from {account_usage}.tag_references + where OBJECT_DATABASE = '{database_name}' + and OBJECT_SCHEMA IS NULL + and OBJECT_NAME IS NOT NULL + and COLUMN_NAME IS NULL + and DOMAIN = 'SCHEMA' +""" +) + SNOWFLAKE_GET_EXTERNAL_TABLE_NAMES = """ select TABLE_NAME, NULL from information_schema.tables where TABLE_SCHEMA = '{schema}' AND TABLE_TYPE = 'EXTERNAL TABLE' diff --git a/ingestion/tests/unit/topology/database/test_snowflake.py b/ingestion/tests/unit/topology/database/test_snowflake.py index 70d32f8193c..dad44cfbb92 100644 --- a/ingestion/tests/unit/topology/database/test_snowflake.py +++ b/ingestion/tests/unit/topology/database/test_snowflake.py @@ -14,7 +14,7 @@ snowflake unit tests """ # pylint: disable=line-too-long from unittest import TestCase -from unittest.mock import PropertyMock, patch +from unittest.mock import Mock, PropertyMock, patch import sqlalchemy.types as sqltypes @@ -25,6 +25,12 @@ from metadata.generated.schema.entity.services.ingestionPipelines.ingestionPipel from metadata.generated.schema.metadataIngestion.workflow import ( OpenMetadataWorkflowConfig, ) +from metadata.generated.schema.type.tagLabel import ( + LabelType, + State, + TagLabel, + TagSource, +) from metadata.ingestion.source.database.snowflake.metadata import MAP, SnowflakeSource from metadata.ingestion.source.database.snowflake.models import SnowflakeStoredProcedure @@ -305,3 +311,58 @@ class SnowflakeUnitTest(TestCase): self.assertEqual(map_type.key_type, key_type) self.assertEqual(map_type.value_type, sqltypes.VARCHAR) # default self.assertFalse(map_type.not_null) # default + + @patch( + "metadata.ingestion.source.database.database_service.DatabaseServiceSource.get_tag_labels" + ) + @patch("metadata.ingestion.source.database.snowflake.metadata.get_tag_label") + def test_schema_tag_inheritance( + self, mock_get_tag_label, mock_parent_get_tag_labels + ): + """Test schema tag inheritance""" + for source in self.sources.values(): + # Verify tags are fetched and stored + mock_schema_tags = [ + Mock( + SCHEMA_NAME="TEST_SCHEMA", TAG_NAME="SCHEMA_TAG", TAG_VALUE="VALUE" + ), + ] + mock_execute = Mock() + mock_execute.all.return_value = mock_schema_tags + source.engine.execute = Mock(return_value=mock_execute) + + source.set_schema_tags_map("TEST_DATABASE") + self.assertEqual(len(source.schema_tags_map["TEST_SCHEMA"]), 1) + self.assertEqual( + source.schema_tags_map["TEST_SCHEMA"][0], + {"tag_name": "SCHEMA_TAG", "tag_value": "VALUE"}, + ) + + # Verify schema tag labels + mock_get_tag_label.return_value = TagLabel( + tagFQN="SnowflakeTag.SCHEMA_TAG", + labelType=LabelType.Automated, + state=State.Suggested, + source=TagSource.Classification, + ) + + schema_labels = source.get_schema_tag_labels(schema_name="TEST_SCHEMA") + self.assertIsNotNone(schema_labels) + self.assertEqual(len(schema_labels), 1) + + # Verify tag inheritance + source.context.get().__dict__["database_schema"] = "TEST_SCHEMA" + mock_parent_get_tag_labels.return_value = [ + TagLabel( + tagFQN="SnowflakeTag.TABLE_TAG", + labelType=LabelType.Automated, + state=State.Suggested, + source=TagSource.Classification, + ) + ] + + table_labels = source.get_tag_labels(table_name="TEST_TABLE") + self.assertEqual(len(table_labels), 2) + tag_fqns = [tag.tagFQN.root for tag in table_labels] + self.assertIn("SnowflakeTag.SCHEMA_TAG", tag_fqns) + self.assertIn("SnowflakeTag.TABLE_TAG", tag_fqns)