2025-04-03 10:39:47 +05:30

178 lines
6.0 KiB
Python

# Copyright 2025 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Check context operations
"""
from unittest import TestCase
from metadata.generated.schema.api.classification.createClassification import (
CreateClassificationRequest,
)
from metadata.generated.schema.api.classification.createTag import CreateTagRequest
from metadata.generated.schema.api.data.createStoredProcedure import (
CreateStoredProcedureRequest,
)
from metadata.generated.schema.api.data.createTable import CreateTableRequest
from metadata.generated.schema.entity.data.storedProcedure import (
Language,
StoredProcedure,
StoredProcedureCode,
)
from metadata.generated.schema.entity.data.table import (
Column,
ColumnName,
DataType,
Table,
)
from metadata.generated.schema.type.basic import (
EntityName,
FullyQualifiedEntityName,
Markdown,
)
from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification
from metadata.ingestion.models.topology import NodeStage, TopologyContext
from metadata.ingestion.source.database.database_service import DatabaseServiceTopology
TABLE_STAGE = NodeStage(
type_=Table,
context="table",
processor="yield_table",
consumer=["database_service", "database", "database_schema"],
use_cache=True,
)
TAGS_STAGE = NodeStage(
type_=OMetaTagAndClassification,
context="tags",
processor="yield_table_tag_details",
nullable=True,
store_all_in_context=True,
)
PROCEDURES_STAGE = NodeStage(
type_=StoredProcedure,
context="stored_procedures",
processor="yield_stored_procedure",
consumer=["database_service", "database", "database_schema"],
store_all_in_context=True,
store_fqn=True,
use_cache=True,
)
class TopologyContextTest(TestCase):
"""Validate context ops"""
# Randomly picked up to test
db_service_topology = DatabaseServiceTopology()
def test_upsert(self):
"""We can add a new key and update its value"""
context = TopologyContext.create(self.db_service_topology)
context.upsert(key="new_key", value="something")
self.assertEqual(context.new_key, "something")
context.upsert(key="new_key", value="new")
self.assertEqual(context.new_key, "new")
def test_replace(self):
"""We can append results to a new key. If it does not exist, it'll instantiate a list"""
context = TopologyContext.create(self.db_service_topology)
context.append(key="new_key", value=1)
self.assertEqual(context.new_key, [1])
context.append(key="new_key", value=2)
self.assertEqual(context.new_key, [1, 2])
def test_clear(self):
"""We can clan up the context values"""
context = TopologyContext.create(self.db_service_topology)
my_stage = NodeStage(
type_=Table,
context="new_key",
processor="random",
)
context.append(key="new_key", value=1)
self.assertEqual(context.new_key, [1])
context.clear_stage(stage=my_stage)
self.assertIsNone(context.new_key)
def test_fqn_from_stage(self):
"""We build the right fqn at each stage"""
context = TopologyContext.create(self.db_service_topology)
context.upsert(key="database_service", value="service")
context.upsert(key="database", value="database")
context.upsert(key="database_schema", value="schema")
table_fqn = context.fqn_from_stage(stage=TABLE_STAGE, entity_name="table")
self.assertEqual(table_fqn, "service.database.schema.table")
def test_update_context_value(self):
"""We can update values directly"""
context = TopologyContext.create(self.db_service_topology)
classification_and_tag = OMetaTagAndClassification(
fqn=None,
classification_request=CreateClassificationRequest(
name=EntityName("my_classification"),
description=Markdown("something"),
),
tag_request=CreateTagRequest(
name=EntityName("my_tag"),
description=Markdown("something"),
),
)
context.update_context_value(stage=TAGS_STAGE, value=classification_and_tag)
self.assertEqual(context.tags, [classification_and_tag])
def test_update_context_name(self):
"""Check context updates for EntityName and FQN"""
context = TopologyContext.create(self.db_service_topology)
context.update_context_name(
stage=TABLE_STAGE,
right=CreateTableRequest(
name=EntityName("table"),
databaseSchema=FullyQualifiedEntityName("schema"),
columns=[Column(name=ColumnName("id"), dataType=DataType.BIGINT)],
),
)
self.assertEqual(context.table, "table")
context.upsert(key="database_service", value="service")
context.upsert(key="database", value="database")
context.upsert(key="database_schema", value="schema")
context.update_context_name(
stage=PROCEDURES_STAGE,
right=CreateStoredProcedureRequest(
name=EntityName("stored_proc"),
databaseSchema=FullyQualifiedEntityName("schema"),
storedProcedureCode=StoredProcedureCode(
language=Language.SQL,
code="SELECT * FROM AWESOME",
),
),
)
self.assertEqual(
context.stored_procedures, ["service.database.schema.stored_proc"]
)