mirror of
				https://github.com/datahub-project/datahub.git
				synced 2025-11-04 04:39:10 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			282 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			282 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import json
 | 
						|
import unittest
 | 
						|
from typing import Any, List, Optional
 | 
						|
 | 
						|
from datahub.emitter.aspect import JSON_CONTENT_TYPE
 | 
						|
from datahub.emitter.mcp import MetadataChangeProposalWrapper
 | 
						|
from datahub.ingestion.api.common import EndOfStream, PipelineContext, RecordEnvelope
 | 
						|
from datahub.ingestion.transformer.generic_aspect_transformer import (
 | 
						|
    GenericAspectTransformer,
 | 
						|
)
 | 
						|
from datahub.metadata.schema_classes import (
 | 
						|
    DataJobSnapshotClass,
 | 
						|
    DatasetSnapshotClass,
 | 
						|
    GenericAspectClass,
 | 
						|
    MetadataChangeEventClass,
 | 
						|
    MetadataChangeProposalClass,
 | 
						|
    StatusClass,
 | 
						|
)
 | 
						|
from datahub.utilities.urns.urn import Urn
 | 
						|
 | 
						|
 | 
						|
def make_mce_dataset(
 | 
						|
    entity_urn: str = "urn:li:dataset:(urn:li:dataPlatform:bigquery,example1,PROD)",
 | 
						|
    aspects: Optional[List[Any]] = None,
 | 
						|
) -> MetadataChangeEventClass:
 | 
						|
    if aspects is None:
 | 
						|
        aspects = [StatusClass(removed=False)]
 | 
						|
    return MetadataChangeEventClass(
 | 
						|
        proposedSnapshot=DatasetSnapshotClass(
 | 
						|
            urn=entity_urn,
 | 
						|
            aspects=aspects,
 | 
						|
        ),
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def make_mce_datajob(
 | 
						|
    entity_urn: str = "urn:li:dataJob:(urn:li:dataFlow:(airflow,dag_abc,PROD),task_456)",
 | 
						|
    aspects: Optional[List[Any]] = None,
 | 
						|
) -> MetadataChangeEventClass:
 | 
						|
    if aspects is None:
 | 
						|
        aspects = [StatusClass(removed=False)]
 | 
						|
    return MetadataChangeEventClass(
 | 
						|
        proposedSnapshot=DataJobSnapshotClass(urn=entity_urn, aspects=aspects)
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def make_status_mcpw(
 | 
						|
    entity_urn: str = "urn:li:dataset:(urn:li:dataPlatform:bigquery,example1,PROD)",
 | 
						|
    aspect: Any = StatusClass(removed=False),
 | 
						|
) -> MetadataChangeProposalWrapper:
 | 
						|
    return MetadataChangeProposalWrapper(
 | 
						|
        entityUrn=entity_urn,
 | 
						|
        aspect=aspect,
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def make_mcpc(
 | 
						|
    entity_urn: str = "urn:li:dataset:(urn:li:dataPlatform:bigquery,example1,PROD)",
 | 
						|
    aspect_name: str = "status",
 | 
						|
    aspect: Any = StatusClass(removed=False),
 | 
						|
) -> MetadataChangeProposalClass:
 | 
						|
    return MetadataChangeProposalClass(
 | 
						|
        entityUrn=entity_urn,
 | 
						|
        entityType=Urn.from_string(entity_urn).entity_type,
 | 
						|
        aspectName=aspect_name,
 | 
						|
        changeType="UPSERT",
 | 
						|
        aspect=aspect,
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
class DummyGenericAspectTransformer(GenericAspectTransformer):
 | 
						|
    def __init__(self):
 | 
						|
        super().__init__()
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def create(
 | 
						|
        cls, config_dict: dict, ctx: PipelineContext
 | 
						|
    ) -> "DummyGenericAspectTransformer":
 | 
						|
        return cls()
 | 
						|
 | 
						|
    def entity_types(self) -> List[str]:
 | 
						|
        return ["dataset"]
 | 
						|
 | 
						|
    def aspect_name(self) -> str:
 | 
						|
        return "customAspect"
 | 
						|
 | 
						|
    def transform_generic_aspect(
 | 
						|
        self, entity_urn: str, aspect_name: str, aspect: Optional[GenericAspectClass]
 | 
						|
    ) -> Optional[GenericAspectClass]:
 | 
						|
        value = (
 | 
						|
            aspect.value if aspect else json.dumps({"customAspect": 10}).encode("utf-8")
 | 
						|
        )
 | 
						|
        result_aspect = GenericAspectClass(
 | 
						|
            contentType=JSON_CONTENT_TYPE,
 | 
						|
            value=value,
 | 
						|
        )
 | 
						|
        return result_aspect
 | 
						|
 | 
						|
 | 
						|
class TestDummyGenericAspectTransformer(unittest.TestCase):
 | 
						|
    def test_add_generic_aspect_when_mce_received(self):
 | 
						|
        mce_dataset = make_mce_dataset()
 | 
						|
        mce_datajob = make_mce_datajob()
 | 
						|
        inputs = [mce_dataset, mce_datajob, EndOfStream()]
 | 
						|
        outputs = list(
 | 
						|
            DummyGenericAspectTransformer().transform(
 | 
						|
                [RecordEnvelope(i, metadata={}) for i in inputs]
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
        assert len(outputs) == len(inputs) + 1
 | 
						|
        # Verify that the first entry is unchanged.
 | 
						|
        assert inputs[0] == outputs[0].record
 | 
						|
        # Verify that the second entry is unchanged.
 | 
						|
        assert inputs[1] == outputs[1].record
 | 
						|
        # Check the first entry generates generic aspect.
 | 
						|
        last_event = outputs[2].record
 | 
						|
        assert last_event.entityUrn == mce_dataset.proposedSnapshot.urn
 | 
						|
        assert isinstance(last_event, MetadataChangeProposalClass)
 | 
						|
        assert isinstance(last_event.aspect, GenericAspectClass)
 | 
						|
        assert (json.loads(last_event.aspect.value))["customAspect"] == 10
 | 
						|
        # Verify that the last entry is EndOfStream
 | 
						|
        assert inputs[2] == outputs[3].record
 | 
						|
 | 
						|
    def test_add_generic_aspect_when_mcpw_received(self):
 | 
						|
        mcpw_dataset = make_status_mcpw()
 | 
						|
        mcpw_datajob = make_status_mcpw(
 | 
						|
            entity_urn="urn:li:dataJob:(urn:li:dataFlow:(airflow,dag_abc,PROD),task_456)"
 | 
						|
        )
 | 
						|
        inputs = [mcpw_dataset, mcpw_datajob, EndOfStream()]
 | 
						|
        outputs = list(
 | 
						|
            DummyGenericAspectTransformer().transform(
 | 
						|
                [RecordEnvelope(i, metadata={}) for i in inputs]
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
        assert len(outputs) == len(inputs) + 1
 | 
						|
        # Verify that the first entry is unchanged.
 | 
						|
        assert inputs[0] == outputs[0].record
 | 
						|
        # Verify that the second entry is unchanged.
 | 
						|
        assert inputs[1] == outputs[1].record
 | 
						|
        # Check the first entry generates generic aspect.
 | 
						|
        last_event = outputs[2].record
 | 
						|
        assert last_event.entityUrn == mcpw_dataset.entityUrn
 | 
						|
        assert isinstance(last_event, MetadataChangeProposalClass)
 | 
						|
        assert isinstance(last_event.aspect, GenericAspectClass)
 | 
						|
        assert (json.loads(last_event.aspect.value))["customAspect"] == 10
 | 
						|
        # Verify that the last entry is EndOfStream
 | 
						|
        assert inputs[2] == outputs[3].record
 | 
						|
 | 
						|
    def test_add_generic_aspect_when_mcpc_received(self):
 | 
						|
        mcpc_dataset = make_mcpc()
 | 
						|
        mcpc_datajob = make_mcpc(
 | 
						|
            entity_urn="urn:li:dataJob:(urn:li:dataFlow:(airflow,dag_abc,PROD),task_456)"
 | 
						|
        )
 | 
						|
        inputs = [mcpc_dataset, mcpc_datajob, EndOfStream()]
 | 
						|
        outputs = list(
 | 
						|
            DummyGenericAspectTransformer().transform(
 | 
						|
                [RecordEnvelope(i, metadata={}) for i in inputs]
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
        assert len(outputs) == len(inputs) + 1
 | 
						|
        # Verify that the first entry is unchanged.
 | 
						|
        assert inputs[0] == outputs[0].record
 | 
						|
        # Verify that the second entry is unchanged.
 | 
						|
        assert inputs[1] == outputs[1].record
 | 
						|
        # Check the first entry generates generic aspect.
 | 
						|
        last_event = outputs[2].record
 | 
						|
        assert last_event.entityUrn == mcpc_dataset.entityUrn
 | 
						|
        assert isinstance(last_event, MetadataChangeProposalClass)
 | 
						|
        assert isinstance(last_event.aspect, GenericAspectClass)
 | 
						|
        assert (json.loads(last_event.aspect.value))["customAspect"] == 10
 | 
						|
        # Verify that the last entry is EndOfStream
 | 
						|
        assert inputs[2] == outputs[3].record
 | 
						|
 | 
						|
    def test_modify_generic_aspect_when_mcpc_received(self):
 | 
						|
        mcpc_dataset_without_custom_aspect = make_mcpc()
 | 
						|
        mcpc_dataset_with_custom_aspect = make_mcpc(
 | 
						|
            entity_urn="urn:li:dataset:(urn:li:dataPlatform:snowflake,example1,PROD)",
 | 
						|
            aspect_name="customAspect",
 | 
						|
            aspect=GenericAspectClass(
 | 
						|
                contentType=JSON_CONTENT_TYPE,
 | 
						|
                value=json.dumps({"customAspect": 5}).encode("utf-8"),
 | 
						|
            ),
 | 
						|
        )
 | 
						|
        mcpc_datajob = make_mcpc(
 | 
						|
            entity_urn="urn:li:dataJob:(urn:li:dataFlow:(airflow,dag_abc,PROD),task_456)"
 | 
						|
        )
 | 
						|
        inputs = [
 | 
						|
            mcpc_dataset_without_custom_aspect,
 | 
						|
            mcpc_dataset_with_custom_aspect,
 | 
						|
            mcpc_datajob,
 | 
						|
            EndOfStream(),
 | 
						|
        ]
 | 
						|
        outputs = list(
 | 
						|
            DummyGenericAspectTransformer().transform(
 | 
						|
                [RecordEnvelope(i, metadata={}) for i in inputs]
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
        assert len(outputs) == len(inputs) + 1
 | 
						|
        # Verify that the first entry is unchanged.
 | 
						|
        assert inputs[0] == outputs[0].record
 | 
						|
        # Check the second entry has original generic aspect.
 | 
						|
        assert outputs[1].record.entityUrn == mcpc_dataset_with_custom_aspect.entityUrn
 | 
						|
        assert isinstance(outputs[1].record, MetadataChangeProposalClass)
 | 
						|
        assert isinstance(outputs[1].record.aspect, GenericAspectClass)
 | 
						|
        assert (json.loads(outputs[1].record.aspect.value))["customAspect"] == 5
 | 
						|
        # Verify that the third entry is unchanged.
 | 
						|
        assert inputs[2] == outputs[2].record
 | 
						|
        # Check the first entry generates generic aspect.
 | 
						|
        last_event = outputs[3].record
 | 
						|
        assert last_event.entityUrn == mcpc_dataset_without_custom_aspect.entityUrn
 | 
						|
        assert isinstance(last_event, MetadataChangeProposalClass)
 | 
						|
        assert isinstance(last_event.aspect, GenericAspectClass)
 | 
						|
        assert (json.loads(last_event.aspect.value))["customAspect"] == 10
 | 
						|
        # Verify that the last entry is EndOfStream
 | 
						|
        assert inputs[3] == outputs[4].record
 | 
						|
 | 
						|
 | 
						|
class DummyRemoveGenericAspectTransformer(GenericAspectTransformer):
 | 
						|
    def __init__(self):
 | 
						|
        super().__init__()
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def create(
 | 
						|
        cls, config_dict: dict, ctx: PipelineContext
 | 
						|
    ) -> "DummyRemoveGenericAspectTransformer":
 | 
						|
        return cls()
 | 
						|
 | 
						|
    def entity_types(self) -> List[str]:
 | 
						|
        return ["dataset"]
 | 
						|
 | 
						|
    def aspect_name(self) -> str:
 | 
						|
        return "customAspect"
 | 
						|
 | 
						|
    def transform_generic_aspect(
 | 
						|
        self, entity_urn: str, aspect_name: str, aspect: Optional[GenericAspectClass]
 | 
						|
    ) -> Optional[GenericAspectClass]:
 | 
						|
        return None
 | 
						|
 | 
						|
 | 
						|
class TestDummyRemoveGenericAspectTransformer(unittest.TestCase):
 | 
						|
    def test_remove_generic_aspect_when_mcpc_received(self):
 | 
						|
        mcpc_dataset_without_custom_aspect = make_mcpc()
 | 
						|
        mcpc_dataset_with_custom_aspect = make_mcpc(
 | 
						|
            entity_urn="urn:li:dataset:(urn:li:dataPlatform:snowflake,example1,PROD)",
 | 
						|
            aspect_name="customAspect",
 | 
						|
            aspect=GenericAspectClass(
 | 
						|
                contentType=JSON_CONTENT_TYPE,
 | 
						|
                value=json.dumps({"customAspect": 5}).encode("utf-8"),
 | 
						|
            ),
 | 
						|
        )
 | 
						|
        mcpc_datajob = make_mcpc(
 | 
						|
            entity_urn="urn:li:dataJob:(urn:li:dataFlow:(airflow,dag_abc,PROD),task_456)"
 | 
						|
        )
 | 
						|
        inputs = [
 | 
						|
            mcpc_dataset_without_custom_aspect,
 | 
						|
            mcpc_dataset_with_custom_aspect,
 | 
						|
            mcpc_datajob,
 | 
						|
            EndOfStream(),
 | 
						|
        ]
 | 
						|
        outputs = list(
 | 
						|
            DummyRemoveGenericAspectTransformer().transform(
 | 
						|
                [RecordEnvelope(i, metadata={}) for i in inputs]
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
        # Check that the second entry is removed.
 | 
						|
        assert len(outputs) == len(inputs) - 1
 | 
						|
        # Verify that the first entry is unchanged.
 | 
						|
        assert inputs[0] == outputs[0].record
 | 
						|
        # Verify that the third entry is unchanged.
 | 
						|
        assert inputs[2] == outputs[1].record
 | 
						|
        # Verify that the last entry is EndOfStream
 | 
						|
        assert inputs[3] == outputs[2].record
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    unittest.main()
 |