mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2026-01-06 04:26:57 +00:00
* Add schema-level tags and tag inheritance support for snowflake * Add tests for schema tag inheritance * Lint fixes
This commit is contained in:
parent
3bd3158bee
commit
26fedbaf0e
@ -46,3 +46,6 @@ DEFAULT_STREAM_COLUMNS = [
|
||||
"primary_key": False,
|
||||
},
|
||||
]
|
||||
|
||||
SNOWFLAKE_TAG_DESCRIPTION = "SNOWFLAKE TAG VALUE"
|
||||
SNOWFLAKE_CLASSIFICATION_DESCRIPTION = "SNOWFLAKE TAG NAME"
|
||||
|
||||
@ -54,6 +54,7 @@ from metadata.generated.schema.type.basic import (
|
||||
SourceUrl,
|
||||
)
|
||||
from metadata.generated.schema.type.entityReferenceList import EntityReferenceList
|
||||
from metadata.generated.schema.type.tagLabel import TagLabel
|
||||
from metadata.ingestion.api.delete import delete_entity_by_name
|
||||
from metadata.ingestion.api.models import Either
|
||||
from metadata.ingestion.api.steps import InvalidSourceException
|
||||
@ -73,6 +74,8 @@ from metadata.ingestion.source.database.incremental_metadata_extraction import (
|
||||
from metadata.ingestion.source.database.multi_db_source import MultiDBSource
|
||||
from metadata.ingestion.source.database.snowflake.constants import (
|
||||
DEFAULT_STREAM_COLUMNS,
|
||||
SNOWFLAKE_CLASSIFICATION_DESCRIPTION,
|
||||
SNOWFLAKE_TAG_DESCRIPTION,
|
||||
)
|
||||
from metadata.ingestion.source.database.snowflake.models import (
|
||||
STORED_PROC_LANGUAGE_MAP,
|
||||
@ -81,7 +84,8 @@ from metadata.ingestion.source.database.snowflake.models import (
|
||||
from metadata.ingestion.source.database.snowflake.queries import (
|
||||
SNOWFLAKE_DESC_FUNCTION,
|
||||
SNOWFLAKE_DESC_STORED_PROCEDURE,
|
||||
SNOWFLAKE_FETCH_ALL_TAGS,
|
||||
SNOWFLAKE_FETCH_SCHEMA_TAGS,
|
||||
SNOWFLAKE_FETCH_TABLE_TAGS,
|
||||
SNOWFLAKE_GET_CLUSTER_KEY,
|
||||
SNOWFLAKE_GET_CURRENT_ACCOUNT,
|
||||
SNOWFLAKE_GET_DATABASE_COMMENTS,
|
||||
@ -123,7 +127,7 @@ from metadata.utils.sqlalchemy_utils import (
|
||||
get_all_table_ddls,
|
||||
get_all_view_definitions,
|
||||
)
|
||||
from metadata.utils.tag_utils import get_ometa_tag_and_classification
|
||||
from metadata.utils.tag_utils import get_ometa_tag_and_classification, get_tag_label
|
||||
|
||||
|
||||
class MAP(StructuredType):
|
||||
@ -203,6 +207,7 @@ class SnowflakeSource(
|
||||
self.schema_desc_map = {}
|
||||
self.database_desc_map = {}
|
||||
self.external_location_map = {}
|
||||
self.schema_tags_map = {}
|
||||
|
||||
self._account: Optional[str] = None
|
||||
self._org_name: Optional[str] = None
|
||||
@ -299,6 +304,32 @@ class SnowflakeSource(
|
||||
for row in results
|
||||
}
|
||||
|
||||
def set_schema_tags_map(self, database_name: str) -> None:
|
||||
"""Fetch and store all schema-level tags for the current database"""
|
||||
self.schema_tags_map.clear()
|
||||
if not self.source_config.includeTags:
|
||||
return
|
||||
|
||||
try:
|
||||
results = self.engine.execute(
|
||||
SNOWFLAKE_FETCH_SCHEMA_TAGS.format(
|
||||
database_name=database_name,
|
||||
account_usage=self.service_connection.accountUsageSchema,
|
||||
)
|
||||
).all()
|
||||
|
||||
for row in results:
|
||||
schema_name = row.SCHEMA_NAME
|
||||
if schema_name not in self.schema_tags_map:
|
||||
self.schema_tags_map[schema_name] = []
|
||||
self.schema_tags_map[schema_name].append(
|
||||
{"tag_name": row.TAG_NAME, "tag_value": row.TAG_VALUE}
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(f"Failed to fetch schema tags: {exc}")
|
||||
|
||||
def get_schema_description(self, schema_name: str) -> Optional[str]:
|
||||
"""
|
||||
Method to fetch the schema description
|
||||
@ -329,6 +360,7 @@ class SnowflakeSource(
|
||||
self.set_schema_description_map()
|
||||
self.set_database_description_map()
|
||||
self.set_external_location_map(configured_db)
|
||||
self.set_schema_tags_map(configured_db)
|
||||
yield configured_db
|
||||
else:
|
||||
for new_database in self.get_database_names_raw():
|
||||
@ -357,6 +389,7 @@ class SnowflakeSource(
|
||||
self.set_schema_description_map()
|
||||
self.set_database_description_map()
|
||||
self.set_external_location_map(new_database)
|
||||
self.set_schema_tags_map(new_database)
|
||||
yield new_database
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
@ -451,11 +484,14 @@ class SnowflakeSource(
|
||||
def yield_tag(
|
||||
self, schema_name: str
|
||||
) -> Iterable[Either[OMetaTagAndClassification]]:
|
||||
"""
|
||||
Yield tags for tables/columns and schemas.
|
||||
"""
|
||||
if self.source_config.includeTags:
|
||||
result = []
|
||||
try:
|
||||
result = self.connection.execute(
|
||||
SNOWFLAKE_FETCH_ALL_TAGS.format(
|
||||
SNOWFLAKE_FETCH_TABLE_TAGS.format(
|
||||
database_name=self.context.get().database,
|
||||
schema_name=schema_name,
|
||||
account_usage=self.service_connection.accountUsageSchema,
|
||||
@ -469,7 +505,7 @@ class SnowflakeSource(
|
||||
f"Error fetching tags {exc}. Trying with quoted names"
|
||||
)
|
||||
result = self.connection.execute(
|
||||
SNOWFLAKE_FETCH_ALL_TAGS.format(
|
||||
SNOWFLAKE_FETCH_TABLE_TAGS.format(
|
||||
database_name=f'"{self.context.get().database}"',
|
||||
schema_name=f'"{self.context.get().database_schema}"',
|
||||
account_usage=self.service_connection.accountUsageSchema,
|
||||
@ -495,12 +531,32 @@ class SnowflakeSource(
|
||||
),
|
||||
tags=[row[1]],
|
||||
classification_name=row[0],
|
||||
tag_description="SNOWFLAKE TAG VALUE",
|
||||
classification_description="SNOWFLAKE TAG NAME",
|
||||
tag_description=SNOWFLAKE_TAG_DESCRIPTION,
|
||||
classification_description=SNOWFLAKE_CLASSIFICATION_DESCRIPTION,
|
||||
metadata=self.metadata,
|
||||
system_tags=True,
|
||||
)
|
||||
|
||||
# Yield schema-level tags
|
||||
if schema_name in self.schema_tags_map:
|
||||
schema_fqn = fqn.build(
|
||||
self.metadata,
|
||||
entity_type=DatabaseSchema,
|
||||
service_name=self.context.get().database_service,
|
||||
database_name=self.context.get().database,
|
||||
schema_name=schema_name,
|
||||
)
|
||||
for tag_info in self.schema_tags_map[schema_name]:
|
||||
yield from get_ometa_tag_and_classification(
|
||||
tag_fqn=FullyQualifiedEntityName(schema_fqn),
|
||||
tags=[tag_info["tag_value"]],
|
||||
classification_name=tag_info["tag_name"],
|
||||
tag_description=SNOWFLAKE_TAG_DESCRIPTION,
|
||||
classification_description=SNOWFLAKE_CLASSIFICATION_DESCRIPTION,
|
||||
metadata=self.metadata,
|
||||
system_tags=True,
|
||||
)
|
||||
|
||||
def _get_table_names_and_types(
|
||||
self, schema_name: str, table_type: TableType = TableType.Regular
|
||||
) -> List[TableNameAndType]:
|
||||
@ -945,3 +1001,49 @@ class SnowflakeSource(
|
||||
logger.debug(
|
||||
f"Processing ownership is not supported for {self.service_connection.type.name}"
|
||||
)
|
||||
|
||||
def get_schema_tag_labels(self, schema_name: str) -> Optional[List[TagLabel]]:
|
||||
"""
|
||||
Return tags for schema entity including Snowflake schema-level tags.
|
||||
"""
|
||||
schema_tags = []
|
||||
|
||||
if schema_name in self.schema_tags_map:
|
||||
for tag_info in self.schema_tags_map[schema_name]:
|
||||
tag_label = get_tag_label(
|
||||
metadata=self.metadata,
|
||||
tag_name=tag_info["tag_value"],
|
||||
classification_name=tag_info["tag_name"],
|
||||
)
|
||||
if tag_label:
|
||||
schema_tags.append(tag_label)
|
||||
|
||||
# Include parent tags from context
|
||||
parent_tags = super().get_schema_tag_labels(schema_name) or []
|
||||
for tag in parent_tags:
|
||||
if tag not in schema_tags:
|
||||
schema_tags.append(tag)
|
||||
|
||||
return schema_tags if schema_tags else None
|
||||
|
||||
def get_tag_labels(self, table_name: str) -> Optional[List[TagLabel]]:
|
||||
"""
|
||||
Override to include schema-level tags inherited by tables.
|
||||
This method combines:
|
||||
1. Tags directly assigned to the table (from parent implementation)
|
||||
2. Tags inherited from the schema level
|
||||
"""
|
||||
table_tags = super().get_tag_labels(table_name) or []
|
||||
|
||||
schema_name = self.context.get().database_schema
|
||||
if schema_name and schema_name in self.schema_tags_map:
|
||||
for tag_info in self.schema_tags_map[schema_name]:
|
||||
tag_label = get_tag_label(
|
||||
metadata=self.metadata,
|
||||
tag_name=tag_info["tag_value"],
|
||||
classification_name=tag_info["tag_name"],
|
||||
)
|
||||
if tag_label and tag_label not in table_tags:
|
||||
table_tags.append(tag_label)
|
||||
|
||||
return table_tags if table_tags else None
|
||||
|
||||
@ -37,7 +37,7 @@ SNOWFLAKE_SQL_STATEMENT = textwrap.dedent(
|
||||
|
||||
SNOWFLAKE_SESSION_TAG_QUERY = 'ALTER SESSION SET QUERY_TAG="{query_tag}"'
|
||||
|
||||
SNOWFLAKE_FETCH_ALL_TAGS = textwrap.dedent(
|
||||
SNOWFLAKE_FETCH_TABLE_TAGS = textwrap.dedent(
|
||||
"""
|
||||
select TAG_NAME, TAG_VALUE, OBJECT_DATABASE, OBJECT_SCHEMA, OBJECT_NAME, COLUMN_NAME
|
||||
from {account_usage}.tag_references
|
||||
@ -46,6 +46,18 @@ SNOWFLAKE_FETCH_ALL_TAGS = textwrap.dedent(
|
||||
"""
|
||||
)
|
||||
|
||||
SNOWFLAKE_FETCH_SCHEMA_TAGS = textwrap.dedent(
|
||||
"""
|
||||
select TAG_NAME, TAG_VALUE, OBJECT_NAME as SCHEMA_NAME
|
||||
from {account_usage}.tag_references
|
||||
where OBJECT_DATABASE = '{database_name}'
|
||||
and OBJECT_SCHEMA IS NULL
|
||||
and OBJECT_NAME IS NOT NULL
|
||||
and COLUMN_NAME IS NULL
|
||||
and DOMAIN = 'SCHEMA'
|
||||
"""
|
||||
)
|
||||
|
||||
SNOWFLAKE_GET_EXTERNAL_TABLE_NAMES = """
|
||||
select TABLE_NAME, NULL from information_schema.tables
|
||||
where TABLE_SCHEMA = '{schema}' AND TABLE_TYPE = 'EXTERNAL TABLE'
|
||||
|
||||
@ -14,7 +14,7 @@ snowflake unit tests
|
||||
"""
|
||||
# pylint: disable=line-too-long
|
||||
from unittest import TestCase
|
||||
from unittest.mock import PropertyMock, patch
|
||||
from unittest.mock import Mock, PropertyMock, patch
|
||||
|
||||
import sqlalchemy.types as sqltypes
|
||||
|
||||
@ -25,6 +25,12 @@ from metadata.generated.schema.entity.services.ingestionPipelines.ingestionPipel
|
||||
from metadata.generated.schema.metadataIngestion.workflow import (
|
||||
OpenMetadataWorkflowConfig,
|
||||
)
|
||||
from metadata.generated.schema.type.tagLabel import (
|
||||
LabelType,
|
||||
State,
|
||||
TagLabel,
|
||||
TagSource,
|
||||
)
|
||||
from metadata.ingestion.source.database.snowflake.metadata import MAP, SnowflakeSource
|
||||
from metadata.ingestion.source.database.snowflake.models import SnowflakeStoredProcedure
|
||||
|
||||
@ -305,3 +311,58 @@ class SnowflakeUnitTest(TestCase):
|
||||
self.assertEqual(map_type.key_type, key_type)
|
||||
self.assertEqual(map_type.value_type, sqltypes.VARCHAR) # default
|
||||
self.assertFalse(map_type.not_null) # default
|
||||
|
||||
@patch(
|
||||
"metadata.ingestion.source.database.database_service.DatabaseServiceSource.get_tag_labels"
|
||||
)
|
||||
@patch("metadata.ingestion.source.database.snowflake.metadata.get_tag_label")
|
||||
def test_schema_tag_inheritance(
|
||||
self, mock_get_tag_label, mock_parent_get_tag_labels
|
||||
):
|
||||
"""Test schema tag inheritance"""
|
||||
for source in self.sources.values():
|
||||
# Verify tags are fetched and stored
|
||||
mock_schema_tags = [
|
||||
Mock(
|
||||
SCHEMA_NAME="TEST_SCHEMA", TAG_NAME="SCHEMA_TAG", TAG_VALUE="VALUE"
|
||||
),
|
||||
]
|
||||
mock_execute = Mock()
|
||||
mock_execute.all.return_value = mock_schema_tags
|
||||
source.engine.execute = Mock(return_value=mock_execute)
|
||||
|
||||
source.set_schema_tags_map("TEST_DATABASE")
|
||||
self.assertEqual(len(source.schema_tags_map["TEST_SCHEMA"]), 1)
|
||||
self.assertEqual(
|
||||
source.schema_tags_map["TEST_SCHEMA"][0],
|
||||
{"tag_name": "SCHEMA_TAG", "tag_value": "VALUE"},
|
||||
)
|
||||
|
||||
# Verify schema tag labels
|
||||
mock_get_tag_label.return_value = TagLabel(
|
||||
tagFQN="SnowflakeTag.SCHEMA_TAG",
|
||||
labelType=LabelType.Automated,
|
||||
state=State.Suggested,
|
||||
source=TagSource.Classification,
|
||||
)
|
||||
|
||||
schema_labels = source.get_schema_tag_labels(schema_name="TEST_SCHEMA")
|
||||
self.assertIsNotNone(schema_labels)
|
||||
self.assertEqual(len(schema_labels), 1)
|
||||
|
||||
# Verify tag inheritance
|
||||
source.context.get().__dict__["database_schema"] = "TEST_SCHEMA"
|
||||
mock_parent_get_tag_labels.return_value = [
|
||||
TagLabel(
|
||||
tagFQN="SnowflakeTag.TABLE_TAG",
|
||||
labelType=LabelType.Automated,
|
||||
state=State.Suggested,
|
||||
source=TagSource.Classification,
|
||||
)
|
||||
]
|
||||
|
||||
table_labels = source.get_tag_labels(table_name="TEST_TABLE")
|
||||
self.assertEqual(len(table_labels), 2)
|
||||
tag_fqns = [tag.tagFQN.root for tag in table_labels]
|
||||
self.assertIn("SnowflakeTag.SCHEMA_TAG", tag_fqns)
|
||||
self.assertIn("SnowflakeTag.TABLE_TAG", tag_fqns)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user