2025-04-15 16:12:38 +09:00

447 lines
16 KiB
Python

import logging
import time
from typing import Any, Dict, List, Optional, Union
import datahub.metadata.schema_classes as models
from datahub.api.entities.dataset.dataset import Dataset
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph
from datahub.metadata.com.linkedin.pegasus2avro.dataprocess import (
DataProcessInstanceInput,
DataProcessInstanceOutput,
)
from datahub.metadata.schema_classes import (
ChangeTypeClass,
DataProcessInstanceRunResultClass,
DataProcessRunStatusClass,
EdgeClass,
)
from datahub.metadata.urns import (
ContainerUrn,
DataPlatformUrn,
MlModelGroupUrn,
MlModelUrn,
VersionSetUrn,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class DatahubAIClient:
"""Client for creating and managing MLflow metadata in DataHub."""
def __init__(
self,
token: Optional[str] = None,
server_url: str = "http://localhost:8080",
platform: str = "mlflow",
) -> None:
"""Initialize the DataHub AI client.
Args:
token: DataHub access token
server_url: DataHub server URL (defaults to http://localhost:8080)
platform: Platform name (defaults to mlflow)
"""
self.token = token
self.server_url = server_url
self.platform = platform
self.graph = DataHubGraph(
DatahubClientConfig(
server=server_url,
token=token,
extra_headers={"Authorization": f"Bearer {token}"},
)
)
def _create_timestamp(
self, timestamp: Optional[int] = None
) -> models.TimeStampClass:
"""Helper to create timestamp with current time if not provided"""
return models.TimeStampClass(
time=timestamp or int(time.time() * 1000), actor="urn:li:corpuser:datahub"
)
def _emit_mcps(self, mcps: Union[MetadataChangeProposalWrapper, List[Any]]) -> None:
"""Helper to emit MCPs with proper connection handling"""
if not isinstance(mcps, list):
mcps = [mcps]
with self.graph:
for mcp in mcps:
self.graph.emit(mcp)
def _get_aspect(
self, entity_urn: str, aspect_type: Any, default_constructor: Any = None
) -> Any:
"""Helper to safely get an aspect with fallback"""
try:
return self.graph.get_aspect(entity_urn=entity_urn, aspect_type=aspect_type)
except Exception as e:
logger.warning(f"Could not fetch aspect for {entity_urn}: {e}")
return default_constructor() if default_constructor else None
def _create_properties_class(
self, props_class: Any, props_dict: Optional[Dict[str, Any]] = None
) -> Any:
"""Helper to create properties class with provided values"""
if props_dict is None:
props_dict = {}
filtered_props = {k: v for k, v in props_dict.items() if v is not None}
if hasattr(props_class, "created"):
filtered_props.setdefault("created", self._create_timestamp())
if hasattr(props_class, "lastModified"):
filtered_props.setdefault("lastModified", self._create_timestamp())
return props_class(**filtered_props)
def _update_list_property(
self, existing_list: Optional[List[str]], new_item: str
) -> List[str]:
"""Helper to update a list property while maintaining uniqueness"""
items = set(existing_list if existing_list else [])
items.add(new_item)
return list(items)
def _create_mcp(
self,
entity_urn: str,
aspect: Any,
entity_type: Optional[str] = None,
aspect_name: Optional[str] = None,
change_type: str = ChangeTypeClass.UPSERT,
) -> MetadataChangeProposalWrapper:
"""Helper to create an MCP with standard parameters"""
mcp_args = {"entityUrn": entity_urn, "aspect": aspect}
if entity_type:
mcp_args["entityType"] = entity_type
if aspect_name:
mcp_args["aspectName"] = aspect_name
mcp_args["changeType"] = change_type
return MetadataChangeProposalWrapper(**mcp_args)
def _update_entity_properties(
self,
entity_urn: str,
aspect_type: Any,
updates: Dict[str, Any],
entity_type: str,
skip_properties: Optional[List[str]] = None,
) -> None:
"""Helper to update entity properties while preserving existing ones"""
existing_props = self._get_aspect(entity_urn, aspect_type, aspect_type)
skip_list = [] if skip_properties is None else skip_properties
props = self._copy_existing_properties(existing_props, skip_list) or {}
for key, value in updates.items():
if isinstance(value, str) and hasattr(existing_props, key):
existing_value = getattr(existing_props, key, [])
props[key] = self._update_list_property(existing_value, value)
else:
props[key] = value
updated_props = self._create_properties_class(aspect_type, props)
mcp = self._create_mcp(
entity_urn, updated_props, entity_type, f"{entity_type}Properties"
)
self._emit_mcps(mcp)
def _copy_existing_properties(
self, existing_props: Any, skip_properties: Optional[List[str]] = None
) -> Dict[str, Any]:
"""Helper to copy existing properties while skipping specified ones"""
skip_list = [] if skip_properties is None else skip_properties
internal_props = {
"ASPECT_INFO",
"ASPECT_NAME",
"ASPECT_TYPE",
"RECORD_SCHEMA",
}
skip_list.extend(internal_props)
props: Dict[str, Any] = {}
if existing_props:
for prop in dir(existing_props):
if (
prop.startswith("_")
or callable(getattr(existing_props, prop))
or prop in skip_list
):
continue
value = getattr(existing_props, prop)
if value is not None:
props[prop] = value
if hasattr(existing_props, "created"):
props.setdefault("created", self._create_timestamp())
if hasattr(existing_props, "lastModified"):
props.setdefault("lastModified", self._create_timestamp())
return props
def _create_run_event(
self,
status: str,
timestamp: int,
result: Optional[str] = None,
duration_millis: Optional[int] = None,
) -> models.DataProcessInstanceRunEventClass:
"""Helper to create run event with common parameters."""
event_args: Dict[str, Any] = {
"timestampMillis": timestamp,
"status": status,
"attempt": 1,
}
if result:
event_args["result"] = DataProcessInstanceRunResultClass(
type=result, nativeResultType=str(result)
)
if duration_millis:
event_args["durationMillis"] = duration_millis
return models.DataProcessInstanceRunEventClass(**event_args)
def create_model_group(
self,
group_id: str,
properties: Optional[models.MLModelGroupPropertiesClass] = None,
**kwargs: Any,
) -> str:
"""Create an ML model group with either property class or kwargs."""
model_group_urn = MlModelGroupUrn(platform=self.platform, name=group_id)
if properties is None:
properties = self._create_properties_class(
models.MLModelGroupPropertiesClass, kwargs
)
mcp = self._create_mcp(
str(model_group_urn), properties, "mlModelGroup", "mlModelGroupProperties"
)
self._emit_mcps(mcp)
logger.info(f"Created model group: {model_group_urn}")
return str(model_group_urn)
def create_model(
self,
model_id: str,
version: str,
alias: Optional[str] = None,
properties: Optional[models.MLModelPropertiesClass] = None,
**kwargs: Any,
) -> str:
"""Create an ML model with either property classes or kwargs."""
model_urn = MlModelUrn(platform=self.platform, name=model_id)
version_set_urn = VersionSetUrn(
id=f"mlmodel_{model_id}_versions", entity_type="mlModel"
)
# Handle model properties
if properties is None:
# If no properties provided, create from kwargs
properties = self._create_properties_class(
models.MLModelPropertiesClass, kwargs
)
# Ensure version is set in model properties
version_tag = models.VersionTagClass(versionTag=str(version))
properties.version = version_tag
# Create version properties
version_props = {
"version": version_tag,
"versionSet": str(version_set_urn),
"sortId": str(version_tag).zfill(10),
}
# Add alias if provided
if alias:
version_props["aliases"] = [models.VersionTagClass(versionTag=alias)]
version_properties = self._create_properties_class(
models.VersionPropertiesClass, version_props
)
mcps = [
self._create_mcp(
str(model_urn), properties, "mlModel", "mlModelProperties"
),
self._create_mcp(
str(model_urn), version_properties, "mlModel", "versionProperties"
),
]
self._emit_mcps(mcps)
logger.info(f"Created model: {model_urn}")
return str(model_urn)
def create_experiment(
self,
experiment_id: str,
properties: Optional[models.ContainerPropertiesClass] = None,
**kwargs: Any,
) -> str:
"""Create an ML experiment with either property class or kwargs."""
container_urn = ContainerUrn(guid=experiment_id)
platform_urn = DataPlatformUrn(platform_name=self.platform)
if properties is None:
properties = self._create_properties_class(
models.ContainerPropertiesClass, kwargs
)
container_subtype = models.SubTypesClass(typeNames=["ML Experiment"])
browse_path = models.BrowsePathsV2Class(path=[])
platform_instance = models.DataPlatformInstanceClass(platform=str(platform_urn))
mcps = MetadataChangeProposalWrapper.construct_many(
entityUrn=str(container_urn),
aspects=[container_subtype, properties, browse_path, platform_instance],
)
self._emit_mcps(mcps)
logger.info(f"Created experiment: {container_urn}")
return str(container_urn)
def create_training_run(
self,
run_id: str,
properties: Optional[models.DataProcessInstancePropertiesClass] = None,
training_run_properties: Optional[models.MLTrainingRunPropertiesClass] = None,
run_result: Optional[str] = None,
start_timestamp: Optional[int] = None,
end_timestamp: Optional[int] = None,
**kwargs: Any,
) -> str:
"""Create a training run with properties and events."""
dpi_urn = f"urn:li:dataProcessInstance:{run_id}"
# Create basic properties and aspects
aspects: List[Any] = []
# Only add properties if they are provided
if properties is not None:
aspects.append(properties)
# Always add the subtype
aspects.append(models.SubTypesClass(typeNames=["ML Training Run"]))
# Add training run properties if provided
if training_run_properties:
aspects.append(training_run_properties)
# Handle run events
current_time = int(time.time() * 1000)
start_ts = start_timestamp or current_time
end_ts = end_timestamp or current_time
# Create events
aspects.append(
self._create_run_event(
status=DataProcessRunStatusClass.STARTED, timestamp=start_ts
)
)
if run_result:
aspects.append(
self._create_run_event(
status=DataProcessRunStatusClass.COMPLETE,
timestamp=end_ts,
result=run_result,
duration_millis=end_ts - start_ts,
)
)
# Create and emit MCPs
mcps = [self._create_mcp(dpi_urn, aspect) for aspect in aspects]
self._emit_mcps(mcps)
logger.info(f"Created training run: {dpi_urn}")
return dpi_urn
def create_dataset(self, name: str, platform: str, **kwargs: Any) -> str:
"""Create a dataset with flexible properties."""
dataset = Dataset(id=name, platform=platform, name=name, **kwargs)
mcps = list(dataset.generate_mcp())
self._emit_mcps(mcps)
if dataset.urn is None:
raise ValueError(f"Failed to create dataset URN for {name}")
return dataset.urn
def add_run_to_model(self, model_urn: str, run_urn: str) -> None:
"""Add a run to a model while preserving existing properties."""
self._update_entity_properties(
entity_urn=model_urn,
aspect_type=models.MLModelPropertiesClass,
updates={"trainingJobs": run_urn},
entity_type="mlModel",
skip_properties=["trainingJobs"],
)
logger.info(f"Added run {run_urn} to model {model_urn}")
def add_run_to_model_group(self, model_group_urn: str, run_urn: str) -> None:
"""Add a run to a model group while preserving existing properties."""
self._update_entity_properties(
entity_urn=model_group_urn,
aspect_type=models.MLModelGroupPropertiesClass,
updates={"trainingJobs": run_urn},
entity_type="mlModelGroup",
skip_properties=["trainingJobs"],
)
logger.info(f"Added run {run_urn} to model group {model_group_urn}")
def add_model_to_model_group(self, model_urn: str, group_urn: str) -> None:
"""Add a model to a group while preserving existing properties"""
self._update_entity_properties(
entity_urn=model_urn,
aspect_type=models.MLModelPropertiesClass,
updates={"groups": group_urn},
entity_type="mlModel",
skip_properties=["groups"],
)
logger.info(f"Added model {model_urn} to group {group_urn}")
def add_run_to_experiment(self, run_urn: str, experiment_urn: str) -> None:
"""Add a run to an experiment"""
mcp = self._create_mcp(
entity_urn=run_urn, aspect=models.ContainerClass(container=experiment_urn)
)
self._emit_mcps(mcp)
logger.info(f"Added run {run_urn} to experiment {experiment_urn}")
def add_input_datasets_to_run(self, run_urn: str, dataset_urns: List[str]) -> None:
"""Add input datasets to a run"""
mcp = self._create_mcp(
entity_urn=run_urn,
entity_type="dataProcessInstance",
aspect_name="dataProcessInstanceInput",
aspect=DataProcessInstanceInput(
inputs=[],
inputEdges=[
EdgeClass(destinationUrn=str(dataset_urn))
for dataset_urn in dataset_urns
],
),
)
self._emit_mcps(mcp)
logger.info(f"Added input datasets to run {run_urn}")
def add_output_datasets_to_run(self, run_urn: str, dataset_urns: List[str]) -> None:
"""Add output datasets to a run"""
mcp = self._create_mcp(
entity_urn=run_urn,
entity_type="dataProcessInstance",
aspect_name="dataProcessInstanceOutput",
aspect=DataProcessInstanceOutput(
outputEdges=[
EdgeClass(destinationUrn=str(dataset_urn))
for dataset_urn in dataset_urns
],
outputs=[],
),
)
self._emit_mcps(mcp)
logger.info(f"Added output datasets to run {run_urn}")