feat: RelationshipChangeEvent model + attribution action graph + kafka msk iam (all from SaaS) (#14938)

This commit is contained in:
Sergio Gómez Villamor 2025-10-07 11:10:05 +02:00 committed by GitHub
parent 5d007f04c4
commit 335290dfec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 192 additions and 37 deletions

View File

@ -14,6 +14,7 @@
import json import json
import logging import logging
import time
import urllib.parse import urllib.parse
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@ -22,6 +23,7 @@ from datahub.configuration.common import OperationalError
from datahub.ingestion.graph.client import DataHubGraph from datahub.ingestion.graph.client import DataHubGraph
from datahub.metadata.schema_classes import ( from datahub.metadata.schema_classes import (
GlossaryTermAssociationClass, GlossaryTermAssociationClass,
MetadataAttributionClass,
TagAssociationClass, TagAssociationClass,
) )
from datahub.specific.dataset import DatasetPatchBuilder from datahub.specific.dataset import DatasetPatchBuilder
@ -250,20 +252,57 @@ query listIngestionSources($input: ListIngestionSourcesInput!, $execution_start:
return target_urn in entities return target_urn in entities
return False return False
def _create_attribution_from_context(
self, context: Optional[Dict]
) -> Optional[MetadataAttributionClass]:
"""Create MetadataAttributionClass from context if action source is present."""
if not context:
return None
# Extract action source from context if present
action_source = context.get("propagation_source") or context.get("source")
if not action_source:
return None
return MetadataAttributionClass(
source=action_source,
time=int(time.time() * 1000.0),
actor=context.get("actor", "urn:li:corpuser:__datahub_system"),
sourceDetail=context,
)
def add_tags_to_dataset( def add_tags_to_dataset(
self, self,
entity_urn: str, entity_urn: str,
dataset_tags: List[str], dataset_tags: List[str],
field_tags: Optional[Dict] = None, field_tags: Optional[Dict] = None,
context: Optional[Dict] = None, context: Optional[Dict] = None,
action_urn: Optional[str] = None,
) -> None: ) -> None:
if field_tags is None: if field_tags is None:
field_tags = {} field_tags = {}
# Create attribution - prefer action_urn parameter, fallback to context
attribution = None
if action_urn:
attribution = MetadataAttributionClass(
source=action_urn,
time=int(time.time() * 1000.0),
actor=context.get("actor", "urn:li:corpuser:__datahub_system")
if context
else "urn:li:corpuser:__datahub_system",
sourceDetail=context if context else {},
)
else:
attribution = self._create_attribution_from_context(context)
dataset = DatasetPatchBuilder(entity_urn) dataset = DatasetPatchBuilder(entity_urn)
for t in dataset_tags: for t in dataset_tags:
dataset.add_tag( dataset.add_tag(
tag=TagAssociationClass( tag=TagAssociationClass(
tag=t, context=json.dumps(context) if context else None tag=t,
context=json.dumps(context) if context else None,
attribution=attribution,
) )
) )
@ -272,7 +311,9 @@ query listIngestionSources($input: ListIngestionSourcesInput!, $execution_start:
for tag in tags: for tag in tags:
field_builder.add_tag( field_builder.add_tag(
tag=TagAssociationClass( tag=TagAssociationClass(
tag=tag, context=json.dumps(context) if context else None tag=tag,
context=json.dumps(context) if context else None,
attribution=attribution,
) )
) )

View File

@ -18,6 +18,7 @@ from datahub.ingestion.api.registry import PluginRegistry
from datahub.metadata.schema_classes import ( from datahub.metadata.schema_classes import (
EntityChangeEventClass, EntityChangeEventClass,
MetadataChangeLogClass, MetadataChangeLogClass,
RelationshipChangeEventClass,
) )
from datahub_actions.event.event import Event from datahub_actions.event.event import Event
@ -80,10 +81,35 @@ class EntityChangeEvent(EntityChangeEventClass, Event):
json_obj["parameters"] = self._inner_dict["__parameters_json"] json_obj["parameters"] = self._inner_dict["__parameters_json"]
return json.dumps(json_obj) return json.dumps(json_obj)
@property
def safe_parameters(self) -> dict:
return self.parameters or self.get("__parameters_json") or {} # type: ignore
class RelationshipChangeEvent(RelationshipChangeEventClass, Event):
@classmethod
def from_class(
cls, clazz: RelationshipChangeEventClass
) -> "RelationshipChangeEvent":
instance = cls._construct({})
instance._restore_defaults()
# Shallow map inner dictionaries.
instance._inner_dict = clazz._inner_dict
return instance
@classmethod
def from_json(cls, json_str: str) -> "Event":
json_obj = json.loads(json_str)
return cls.from_class(cls.from_obj(json_obj))
def as_json(self) -> str:
return json.dumps(self.to_obj())
# Standard Event Types for easy reference. # Standard Event Types for easy reference.
ENTITY_CHANGE_EVENT_V1_TYPE = "EntityChangeEvent_v1" ENTITY_CHANGE_EVENT_V1_TYPE = "EntityChangeEvent_v1"
METADATA_CHANGE_LOG_EVENT_V1_TYPE = "MetadataChangeLogEvent_v1" METADATA_CHANGE_LOG_EVENT_V1_TYPE = "MetadataChangeLogEvent_v1"
RELATIONSHIP_CHANGE_EVENT_V1_TYPE = "RelationshipChangeEvent_v1"
# Lightweight Event Registry # Lightweight Event Registry
event_registry = PluginRegistry[Event]() event_registry = PluginRegistry[Event]()
@ -91,3 +117,4 @@ event_registry = PluginRegistry[Event]()
# Register standard event library. Each type can be considered a separate "stream" / "topic" # Register standard event library. Each type can be considered a separate "stream" / "topic"
event_registry.register(METADATA_CHANGE_LOG_EVENT_V1_TYPE, MetadataChangeLogEvent) event_registry.register(METADATA_CHANGE_LOG_EVENT_V1_TYPE, MetadataChangeLogEvent)
event_registry.register(ENTITY_CHANGE_EVENT_V1_TYPE, EntityChangeEvent) event_registry.register(ENTITY_CHANGE_EVENT_V1_TYPE, EntityChangeEvent)
event_registry.register(RELATIONSHIP_CHANGE_EVENT_V1_TYPE, RelationshipChangeEvent)

View File

@ -2,3 +2,4 @@ PLATFORM_EVENT_TOPIC_NAME = "PlatformEvent_v1"
METADATA_CHANGE_LOG_VERSIONED_TOPIC_NAME = "MetadataChangeLog_Versioned_v1" METADATA_CHANGE_LOG_VERSIONED_TOPIC_NAME = "MetadataChangeLog_Versioned_v1"
METADATA_CHANGE_LOG_TIMESERIES_TOPIC_NAME = "MetadataChangeLog_Timeseries_v1" METADATA_CHANGE_LOG_TIMESERIES_TOPIC_NAME = "MetadataChangeLog_Timeseries_v1"
ENTITY_CHANGE_EVENT_NAME = "entityChangeEvent" ENTITY_CHANGE_EVENT_NAME = "entityChangeEvent"
RELATIONSHIP_CHANGE_EVENT_NAME = "relationshipChangeEvent"

View File

@ -14,8 +14,10 @@ from datahub_actions.event.event_envelope import EventEnvelope
from datahub_actions.event.event_registry import ( from datahub_actions.event.event_registry import (
ENTITY_CHANGE_EVENT_V1_TYPE, ENTITY_CHANGE_EVENT_V1_TYPE,
METADATA_CHANGE_LOG_EVENT_V1_TYPE, METADATA_CHANGE_LOG_EVENT_V1_TYPE,
RELATIONSHIP_CHANGE_EVENT_V1_TYPE,
EntityChangeEvent, EntityChangeEvent,
MetadataChangeLogEvent, MetadataChangeLogEvent,
RelationshipChangeEvent,
) )
# May or may not need these. # May or may not need these.
@ -25,6 +27,7 @@ from datahub_actions.plugin.source.acryl.constants import (
METADATA_CHANGE_LOG_TIMESERIES_TOPIC_NAME, METADATA_CHANGE_LOG_TIMESERIES_TOPIC_NAME,
METADATA_CHANGE_LOG_VERSIONED_TOPIC_NAME, METADATA_CHANGE_LOG_VERSIONED_TOPIC_NAME,
PLATFORM_EVENT_TOPIC_NAME, PLATFORM_EVENT_TOPIC_NAME,
RELATIONSHIP_CHANGE_EVENT_NAME,
) )
from datahub_actions.plugin.source.acryl.datahub_cloud_events_ack_manager import ( from datahub_actions.plugin.source.acryl.datahub_cloud_events_ack_manager import (
AckManager, AckManager,
@ -261,8 +264,11 @@ class DataHubEventSource(EventSource):
post_json_transform(value["payload"]) post_json_transform(value["payload"])
) )
if ENTITY_CHANGE_EVENT_NAME == value["name"]: if ENTITY_CHANGE_EVENT_NAME == value["name"]:
event = build_entity_change_event(payload) ece = build_entity_change_event(payload)
yield EventEnvelope(ENTITY_CHANGE_EVENT_V1_TYPE, event, {}) yield EventEnvelope(ENTITY_CHANGE_EVENT_V1_TYPE, ece, {})
elif RELATIONSHIP_CHANGE_EVENT_NAME == value["name"]:
rce = RelationshipChangeEvent.from_json(payload.get("value"))
yield EventEnvelope(RELATIONSHIP_CHANGE_EVENT_V1_TYPE, rce, {})
@staticmethod @staticmethod
def handle_mcl(msg: ExternalEvent) -> Iterable[EventEnvelope]: def handle_mcl(msg: ExternalEvent) -> Iterable[EventEnvelope]:

View File

@ -33,8 +33,10 @@ from datahub_actions.event.event_envelope import EventEnvelope
from datahub_actions.event.event_registry import ( from datahub_actions.event.event_registry import (
ENTITY_CHANGE_EVENT_V1_TYPE, ENTITY_CHANGE_EVENT_V1_TYPE,
METADATA_CHANGE_LOG_EVENT_V1_TYPE, METADATA_CHANGE_LOG_EVENT_V1_TYPE,
RELATIONSHIP_CHANGE_EVENT_V1_TYPE,
EntityChangeEvent, EntityChangeEvent,
MetadataChangeLogEvent, MetadataChangeLogEvent,
RelationshipChangeEvent,
) )
# May or may not need these. # May or may not need these.
@ -46,6 +48,7 @@ logger = logging.getLogger(__name__)
ENTITY_CHANGE_EVENT_NAME = "entityChangeEvent" ENTITY_CHANGE_EVENT_NAME = "entityChangeEvent"
RELATIONSHIP_CHANGE_EVENT_NAME = "relationshipChangeEvent"
DEFAULT_TOPIC_ROUTES = { DEFAULT_TOPIC_ROUTES = {
"mcl": "MetadataChangeLog_Versioned_v1", "mcl": "MetadataChangeLog_Versioned_v1",
"mcl_timeseries": "MetadataChangeLog_Timeseries_v1", "mcl_timeseries": "MetadataChangeLog_Timeseries_v1",
@ -216,9 +219,13 @@ class KafkaEventSource(EventSource):
post_json_transform(value["payload"]) post_json_transform(value["payload"])
) )
if ENTITY_CHANGE_EVENT_NAME == value["name"]: if ENTITY_CHANGE_EVENT_NAME == value["name"]:
event = build_entity_change_event(payload) ece = build_entity_change_event(payload)
kafka_meta = build_kafka_meta(msg) kafka_meta = build_kafka_meta(msg)
yield EventEnvelope(ENTITY_CHANGE_EVENT_V1_TYPE, event, kafka_meta) yield EventEnvelope(ENTITY_CHANGE_EVENT_V1_TYPE, ece, kafka_meta)
elif RELATIONSHIP_CHANGE_EVENT_NAME == value["name"]:
rce = RelationshipChangeEvent.from_json(payload.get("value"))
kafka_meta = build_kafka_meta(msg)
yield EventEnvelope(RELATIONSHIP_CHANGE_EVENT_V1_TYPE, rce, kafka_meta)
def close(self) -> None: def close(self) -> None:
if self.consumer: if self.consumer:

View File

@ -1,13 +1,14 @@
"""Module for AWS MSK IAM authentication.""" """Module for AWS MSK IAM authentication."""
import logging import logging
import os
from aws_msk_iam_sasl_signer_python.msk_iam_sasl_signer import MSKAuthTokenProvider from aws_msk_iam_sasl_signer import MSKAuthTokenProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def oauth_cb(oauth_config): def oauth_cb(oauth_config: dict) -> tuple[str, float]:
""" """
OAuth callback function for AWS MSK IAM authentication. OAuth callback function for AWS MSK IAM authentication.
@ -15,12 +16,17 @@ def oauth_cb(oauth_config):
for authentication with AWS MSK using IAM. for authentication with AWS MSK using IAM.
Returns: Returns:
tuple: (auth_token, expiry_time_seconds) tuple[str, float]: (auth_token, expiry_time_seconds)
""" """
try: try:
auth_token, expiry_ms = MSKAuthTokenProvider.generate_auth_token() region = (
os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION") or "us-east-1"
)
auth_token, expiry_ms = MSKAuthTokenProvider.generate_auth_token(region=region)
# Convert expiry from milliseconds to seconds as required by Kafka client # Convert expiry from milliseconds to seconds as required by Kafka client
return auth_token, expiry_ms / 1000 return auth_token, float(expiry_ms) / 1000
except Exception as e: except Exception as e:
logger.error(f"Error generating AWS MSK IAM authentication token: {e}") logger.error(
f"Error generating AWS MSK IAM authentication token: {e}", exc_info=True
)
raise raise

View File

@ -7,46 +7,36 @@ from typing import Any, cast
import pytest import pytest
MODULE_UNDER_TEST = "datahub_actions.utils.kafka_msk_iam" MODULE_UNDER_TEST = "datahub_actions.utils.kafka_msk_iam"
VENDOR_MODULE = "aws_msk_iam_sasl_signer_python.msk_iam_sasl_signer" VENDOR_MODULE = "aws_msk_iam_sasl_signer"
def ensure_fake_vendor(monkeypatch): def ensure_fake_vendor(monkeypatch: Any) -> Any:
""" """
Ensure a fake MSKAuthTokenProvider is available at import path Ensure a fake MSKAuthTokenProvider is available at import path
aws_msk_iam_sasl_signer_python.msk_iam_sasl_signer for environments aws_msk_iam_sasl_signer for environments where the vendor package is not installed.
where the vendor package is not installed.
Returns the fake module so tests can monkeypatch its behavior. Returns the fake module so tests can monkeypatch its behavior.
""" """
# If already present (package installed), just return the real module # If already present (package installed), just return the real module
if VENDOR_MODULE in sys.modules: if VENDOR_MODULE in sys.modules:
return sys.modules[VENDOR_MODULE] return sys.modules[VENDOR_MODULE]
# Build parent package structure: aws_msk_iam_sasl_signer_python.msk_iam_sasl_signer # Create a minimal fake module matching the direct import path
parent_name = "aws_msk_iam_sasl_signer_python"
if parent_name not in sys.modules:
parent: Any = types.ModuleType(parent_name)
parent.__path__ = [] # mark as package
monkeypatch.setitem(sys.modules, parent_name, parent)
else:
parent = cast(Any, sys.modules[parent_name])
fake_mod: Any = types.ModuleType(VENDOR_MODULE) fake_mod: Any = types.ModuleType(VENDOR_MODULE)
class MSKAuthTokenProvider: class MSKAuthTokenProvider:
@staticmethod @staticmethod
def generate_auth_token(): # will be monkeypatched per test def generate_auth_token(
region: str | None = None,
) -> None: # will be monkeypatched per test
raise NotImplementedError raise NotImplementedError
fake_mod.MSKAuthTokenProvider = MSKAuthTokenProvider fake_mod.MSKAuthTokenProvider = MSKAuthTokenProvider
monkeypatch.setitem(sys.modules, VENDOR_MODULE, fake_mod) monkeypatch.setitem(sys.modules, VENDOR_MODULE, fake_mod)
# Also ensure attribute exists on parent to allow from ... import ...
parent.msk_iam_sasl_signer = fake_mod
return fake_mod return fake_mod
def import_sut(monkeypatch): def import_sut(monkeypatch: Any) -> Any:
"""Import or reload the module under test after ensuring the vendor symbol exists.""" """Import or reload the module under test after ensuring the vendor symbol exists."""
ensure_fake_vendor(monkeypatch) ensure_fake_vendor(monkeypatch)
if MODULE_UNDER_TEST in sys.modules: if MODULE_UNDER_TEST in sys.modules:
@ -54,13 +44,13 @@ def import_sut(monkeypatch):
return importlib.import_module(MODULE_UNDER_TEST) return importlib.import_module(MODULE_UNDER_TEST)
def test_oauth_cb_success_converts_ms_to_seconds(monkeypatch): def test_oauth_cb_success_converts_ms_to_seconds(monkeypatch: Any) -> None:
sut = import_sut(monkeypatch) sut = import_sut(monkeypatch)
# Monkeypatch the provider to return a known token and expiry in ms # Monkeypatch the provider to return a known token and expiry in ms
provider = cast(Any, sut).MSKAuthTokenProvider provider = cast(Any, sut).MSKAuthTokenProvider
def fake_generate(): def fake_generate(region: str | None = None) -> tuple[str, int]:
return "my-token", 12_345 # ms return "my-token", 12_345 # ms
monkeypatch.setattr(provider, "generate_auth_token", staticmethod(fake_generate)) monkeypatch.setattr(provider, "generate_auth_token", staticmethod(fake_generate))
@ -71,10 +61,10 @@ def test_oauth_cb_success_converts_ms_to_seconds(monkeypatch):
assert expiry_seconds == 12.345 # ms to seconds via division assert expiry_seconds == 12.345 # ms to seconds via division
def test_oauth_cb_raises_and_logs_on_error(monkeypatch, caplog): def test_oauth_cb_raises_and_logs_on_error(monkeypatch: Any, caplog: Any) -> None:
sut = import_sut(monkeypatch) sut = import_sut(monkeypatch)
def boom(): def boom(region: str | None = None) -> None:
raise RuntimeError("signer blew up") raise RuntimeError("signer blew up")
provider = cast(Any, sut).MSKAuthTokenProvider provider = cast(Any, sut).MSKAuthTokenProvider
@ -93,14 +83,14 @@ def test_oauth_cb_raises_and_logs_on_error(monkeypatch, caplog):
) )
def test_oauth_cb_returns_tuple_types(monkeypatch): def test_oauth_cb_returns_tuple_types(monkeypatch: Any) -> None:
sut = import_sut(monkeypatch) sut = import_sut(monkeypatch)
provider = cast(Any, sut).MSKAuthTokenProvider provider = cast(Any, sut).MSKAuthTokenProvider
monkeypatch.setattr( monkeypatch.setattr(
provider, provider,
"generate_auth_token", "generate_auth_token",
staticmethod(lambda: ("tkn", 1_000)), # 1000 ms staticmethod(lambda region=None: ("tkn", 1_000)), # 1000 ms
) )
result = sut.oauth_cb(None) result = sut.oauth_cb(None)

View File

@ -45,6 +45,7 @@ def load_schemas(schemas_path: str) -> Dict[str, dict]:
"mxe/MetadataChangeLog.avsc", "mxe/MetadataChangeLog.avsc",
"mxe/PlatformEvent.avsc", "mxe/PlatformEvent.avsc",
"platform/event/v1/EntityChangeEvent.avsc", "platform/event/v1/EntityChangeEvent.avsc",
"platform/event/v1/RelationshipChangeEvent.avsc",
"metadata/query/filter/Filter.avsc", # temporarily added to test reserved keywords support "metadata/query/filter/Filter.avsc", # temporarily added to test reserved keywords support
} }

View File

@ -37,7 +37,7 @@ from datahub.ingestion.source.sql.sql_config import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
oracledb.version = "8.3.0" oracledb.version = "8.3.0" # type: ignore[assignment]
sys.modules["cx_Oracle"] = oracledb sys.modules["cx_Oracle"] = oracledb
extra_oracle_types = { extra_oracle_types = {

View File

@ -0,0 +1,69 @@
namespace com.linkedin.platform.event.v1
import com.linkedin.avro2pegasus.events.KafkaAuditHeader
import com.linkedin.common.AuditStamp
import com.linkedin.common.Urn
/**
* Kafka event for proposing a relationship change between two entities.
* For example, when dataset1 establishes a new downstream relationship with dataset2.
*/
@Event = {
"name": "relationshipChangeEvent"
}
record RelationshipChangeEvent {
/**
* Kafka audit header containing metadata about the message itself.
* Includes information like message ID, timestamp, and server details.
*/
auditHeader: optional KafkaAuditHeader
/**
* The URN (Uniform Resource Name) of the source entity in the relationship.
* In a downstream relationship example, this would be the URN of the upstream dataset.
*/
sourceUrn: Urn
/**
* The URN of the destination entity in the relationship.
* In a downstream relationship example, this would be the URN of the downstream dataset.
*/
destinationUrn: Urn
/**
* The operation being performed on this relationship.
* Typically includes operations like ADD, REMOVE, or RESTATE.
*/
operation: RelationshipChangeOperation
/**
* The type/category of relationship being established or modified.
* Examples: "DownstreamOf", "Contains", "OwnedBy", "DerivedFrom", etc.
*/
relationshipType: string
/**
* The system or service responsible for managing the lifecycle of this relationship.
* This helps identify which component has authority over the relationship.
*/
lifecycleOwner: optional string
/**
* Information about how or through what means this relationship was established.
* Could indicate a specific pipeline, process, or tool that discovered/created the relationship.
*/
via: optional string
/**
* Additional custom properties associated with this relationship.
* Allows for flexible extension without changing the schema.
*/
properties: optional map[string, string]
/**
* Stores information about who made this change and when.
* Contains the actor (user or system) that performed the action and the timestamp.
*/
auditStamp: AuditStamp
}

View File

@ -0,0 +1,7 @@
namespace com.linkedin.platform.event.v1
enum RelationshipChangeOperation {
ADD
REMOVE
RESTATE
}