mirror of
https://github.com/datahub-project/datahub.git
synced 2025-07-04 07:34:44 +00:00
447 lines
16 KiB
Python
447 lines
16 KiB
Python
import logging
|
|
import time
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
import datahub.metadata.schema_classes as models
|
|
from datahub.api.entities.dataset.dataset import Dataset
|
|
from datahub.emitter.mcp import MetadataChangeProposalWrapper
|
|
from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph
|
|
from datahub.metadata.com.linkedin.pegasus2avro.dataprocess import (
|
|
DataProcessInstanceInput,
|
|
DataProcessInstanceOutput,
|
|
)
|
|
from datahub.metadata.schema_classes import (
|
|
ChangeTypeClass,
|
|
DataProcessInstanceRunResultClass,
|
|
DataProcessRunStatusClass,
|
|
EdgeClass,
|
|
)
|
|
from datahub.metadata.urns import (
|
|
ContainerUrn,
|
|
DataPlatformUrn,
|
|
MlModelGroupUrn,
|
|
MlModelUrn,
|
|
VersionSetUrn,
|
|
)
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DatahubAIClient:
|
|
"""Client for creating and managing MLflow metadata in DataHub."""
|
|
|
|
def __init__(
|
|
self,
|
|
token: Optional[str] = None,
|
|
server_url: str = "http://localhost:8080",
|
|
platform: str = "mlflow",
|
|
) -> None:
|
|
"""Initialize the DataHub AI client.
|
|
|
|
Args:
|
|
token: DataHub access token
|
|
server_url: DataHub server URL (defaults to http://localhost:8080)
|
|
platform: Platform name (defaults to mlflow)
|
|
"""
|
|
self.token = token
|
|
self.server_url = server_url
|
|
self.platform = platform
|
|
self.graph = DataHubGraph(
|
|
DatahubClientConfig(
|
|
server=server_url,
|
|
token=token,
|
|
extra_headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
)
|
|
|
|
def _create_timestamp(
|
|
self, timestamp: Optional[int] = None
|
|
) -> models.TimeStampClass:
|
|
"""Helper to create timestamp with current time if not provided"""
|
|
return models.TimeStampClass(
|
|
time=timestamp or int(time.time() * 1000), actor="urn:li:corpuser:datahub"
|
|
)
|
|
|
|
def _emit_mcps(self, mcps: Union[MetadataChangeProposalWrapper, List[Any]]) -> None:
|
|
"""Helper to emit MCPs with proper connection handling"""
|
|
if not isinstance(mcps, list):
|
|
mcps = [mcps]
|
|
with self.graph:
|
|
for mcp in mcps:
|
|
self.graph.emit(mcp)
|
|
|
|
def _get_aspect(
|
|
self, entity_urn: str, aspect_type: Any, default_constructor: Any = None
|
|
) -> Any:
|
|
"""Helper to safely get an aspect with fallback"""
|
|
try:
|
|
return self.graph.get_aspect(entity_urn=entity_urn, aspect_type=aspect_type)
|
|
except Exception as e:
|
|
logger.warning(f"Could not fetch aspect for {entity_urn}: {e}")
|
|
return default_constructor() if default_constructor else None
|
|
|
|
def _create_properties_class(
|
|
self, props_class: Any, props_dict: Optional[Dict[str, Any]] = None
|
|
) -> Any:
|
|
"""Helper to create properties class with provided values"""
|
|
if props_dict is None:
|
|
props_dict = {}
|
|
|
|
filtered_props = {k: v for k, v in props_dict.items() if v is not None}
|
|
|
|
if hasattr(props_class, "created"):
|
|
filtered_props.setdefault("created", self._create_timestamp())
|
|
if hasattr(props_class, "lastModified"):
|
|
filtered_props.setdefault("lastModified", self._create_timestamp())
|
|
|
|
return props_class(**filtered_props)
|
|
|
|
def _update_list_property(
|
|
self, existing_list: Optional[List[str]], new_item: str
|
|
) -> List[str]:
|
|
"""Helper to update a list property while maintaining uniqueness"""
|
|
items = set(existing_list if existing_list else [])
|
|
items.add(new_item)
|
|
return list(items)
|
|
|
|
def _create_mcp(
|
|
self,
|
|
entity_urn: str,
|
|
aspect: Any,
|
|
entity_type: Optional[str] = None,
|
|
aspect_name: Optional[str] = None,
|
|
change_type: str = ChangeTypeClass.UPSERT,
|
|
) -> MetadataChangeProposalWrapper:
|
|
"""Helper to create an MCP with standard parameters"""
|
|
mcp_args = {"entityUrn": entity_urn, "aspect": aspect}
|
|
if entity_type:
|
|
mcp_args["entityType"] = entity_type
|
|
if aspect_name:
|
|
mcp_args["aspectName"] = aspect_name
|
|
mcp_args["changeType"] = change_type
|
|
return MetadataChangeProposalWrapper(**mcp_args)
|
|
|
|
def _update_entity_properties(
|
|
self,
|
|
entity_urn: str,
|
|
aspect_type: Any,
|
|
updates: Dict[str, Any],
|
|
entity_type: str,
|
|
skip_properties: Optional[List[str]] = None,
|
|
) -> None:
|
|
"""Helper to update entity properties while preserving existing ones"""
|
|
existing_props = self._get_aspect(entity_urn, aspect_type, aspect_type)
|
|
skip_list = [] if skip_properties is None else skip_properties
|
|
props = self._copy_existing_properties(existing_props, skip_list) or {}
|
|
|
|
for key, value in updates.items():
|
|
if isinstance(value, str) and hasattr(existing_props, key):
|
|
existing_value = getattr(existing_props, key, [])
|
|
props[key] = self._update_list_property(existing_value, value)
|
|
else:
|
|
props[key] = value
|
|
|
|
updated_props = self._create_properties_class(aspect_type, props)
|
|
mcp = self._create_mcp(
|
|
entity_urn, updated_props, entity_type, f"{entity_type}Properties"
|
|
)
|
|
self._emit_mcps(mcp)
|
|
|
|
def _copy_existing_properties(
|
|
self, existing_props: Any, skip_properties: Optional[List[str]] = None
|
|
) -> Dict[str, Any]:
|
|
"""Helper to copy existing properties while skipping specified ones"""
|
|
skip_list = [] if skip_properties is None else skip_properties
|
|
|
|
internal_props = {
|
|
"ASPECT_INFO",
|
|
"ASPECT_NAME",
|
|
"ASPECT_TYPE",
|
|
"RECORD_SCHEMA",
|
|
}
|
|
skip_list.extend(internal_props)
|
|
|
|
props: Dict[str, Any] = {}
|
|
if existing_props:
|
|
for prop in dir(existing_props):
|
|
if (
|
|
prop.startswith("_")
|
|
or callable(getattr(existing_props, prop))
|
|
or prop in skip_list
|
|
):
|
|
continue
|
|
|
|
value = getattr(existing_props, prop)
|
|
if value is not None:
|
|
props[prop] = value
|
|
|
|
if hasattr(existing_props, "created"):
|
|
props.setdefault("created", self._create_timestamp())
|
|
if hasattr(existing_props, "lastModified"):
|
|
props.setdefault("lastModified", self._create_timestamp())
|
|
|
|
return props
|
|
|
|
def _create_run_event(
|
|
self,
|
|
status: str,
|
|
timestamp: int,
|
|
result: Optional[str] = None,
|
|
duration_millis: Optional[int] = None,
|
|
) -> models.DataProcessInstanceRunEventClass:
|
|
"""Helper to create run event with common parameters."""
|
|
event_args: Dict[str, Any] = {
|
|
"timestampMillis": timestamp,
|
|
"status": status,
|
|
"attempt": 1,
|
|
}
|
|
|
|
if result:
|
|
event_args["result"] = DataProcessInstanceRunResultClass(
|
|
type=result, nativeResultType=str(result)
|
|
)
|
|
if duration_millis:
|
|
event_args["durationMillis"] = duration_millis
|
|
|
|
return models.DataProcessInstanceRunEventClass(**event_args)
|
|
|
|
def create_model_group(
|
|
self,
|
|
group_id: str,
|
|
properties: Optional[models.MLModelGroupPropertiesClass] = None,
|
|
**kwargs: Any,
|
|
) -> str:
|
|
"""Create an ML model group with either property class or kwargs."""
|
|
model_group_urn = MlModelGroupUrn(platform=self.platform, name=group_id)
|
|
|
|
if properties is None:
|
|
properties = self._create_properties_class(
|
|
models.MLModelGroupPropertiesClass, kwargs
|
|
)
|
|
|
|
mcp = self._create_mcp(
|
|
str(model_group_urn), properties, "mlModelGroup", "mlModelGroupProperties"
|
|
)
|
|
self._emit_mcps(mcp)
|
|
logger.info(f"Created model group: {model_group_urn}")
|
|
return str(model_group_urn)
|
|
|
|
def create_model(
|
|
self,
|
|
model_id: str,
|
|
version: str,
|
|
alias: Optional[str] = None,
|
|
properties: Optional[models.MLModelPropertiesClass] = None,
|
|
**kwargs: Any,
|
|
) -> str:
|
|
"""Create an ML model with either property classes or kwargs."""
|
|
model_urn = MlModelUrn(platform=self.platform, name=model_id)
|
|
version_set_urn = VersionSetUrn(
|
|
id=f"mlmodel_{model_id}_versions", entity_type="mlModel"
|
|
)
|
|
|
|
# Handle model properties
|
|
if properties is None:
|
|
# If no properties provided, create from kwargs
|
|
properties = self._create_properties_class(
|
|
models.MLModelPropertiesClass, kwargs
|
|
)
|
|
|
|
# Ensure version is set in model properties
|
|
version_tag = models.VersionTagClass(versionTag=str(version))
|
|
properties.version = version_tag
|
|
|
|
# Create version properties
|
|
version_props = {
|
|
"version": version_tag,
|
|
"versionSet": str(version_set_urn),
|
|
"sortId": str(version_tag).zfill(10),
|
|
}
|
|
|
|
# Add alias if provided
|
|
if alias:
|
|
version_props["aliases"] = [models.VersionTagClass(versionTag=alias)]
|
|
|
|
version_properties = self._create_properties_class(
|
|
models.VersionPropertiesClass, version_props
|
|
)
|
|
|
|
mcps = [
|
|
self._create_mcp(
|
|
str(model_urn), properties, "mlModel", "mlModelProperties"
|
|
),
|
|
self._create_mcp(
|
|
str(model_urn), version_properties, "mlModel", "versionProperties"
|
|
),
|
|
]
|
|
self._emit_mcps(mcps)
|
|
logger.info(f"Created model: {model_urn}")
|
|
return str(model_urn)
|
|
|
|
def create_experiment(
|
|
self,
|
|
experiment_id: str,
|
|
properties: Optional[models.ContainerPropertiesClass] = None,
|
|
**kwargs: Any,
|
|
) -> str:
|
|
"""Create an ML experiment with either property class or kwargs."""
|
|
container_urn = ContainerUrn(guid=experiment_id)
|
|
platform_urn = DataPlatformUrn(platform_name=self.platform)
|
|
|
|
if properties is None:
|
|
properties = self._create_properties_class(
|
|
models.ContainerPropertiesClass, kwargs
|
|
)
|
|
|
|
container_subtype = models.SubTypesClass(typeNames=["ML Experiment"])
|
|
browse_path = models.BrowsePathsV2Class(path=[])
|
|
platform_instance = models.DataPlatformInstanceClass(platform=str(platform_urn))
|
|
|
|
mcps = MetadataChangeProposalWrapper.construct_many(
|
|
entityUrn=str(container_urn),
|
|
aspects=[container_subtype, properties, browse_path, platform_instance],
|
|
)
|
|
self._emit_mcps(mcps)
|
|
logger.info(f"Created experiment: {container_urn}")
|
|
return str(container_urn)
|
|
|
|
def create_training_run(
|
|
self,
|
|
run_id: str,
|
|
properties: Optional[models.DataProcessInstancePropertiesClass] = None,
|
|
training_run_properties: Optional[models.MLTrainingRunPropertiesClass] = None,
|
|
run_result: Optional[str] = None,
|
|
start_timestamp: Optional[int] = None,
|
|
end_timestamp: Optional[int] = None,
|
|
**kwargs: Any,
|
|
) -> str:
|
|
"""Create a training run with properties and events."""
|
|
dpi_urn = f"urn:li:dataProcessInstance:{run_id}"
|
|
|
|
# Create basic properties and aspects
|
|
aspects: List[Any] = []
|
|
|
|
# Only add properties if they are provided
|
|
if properties is not None:
|
|
aspects.append(properties)
|
|
|
|
# Always add the subtype
|
|
aspects.append(models.SubTypesClass(typeNames=["ML Training Run"]))
|
|
|
|
# Add training run properties if provided
|
|
if training_run_properties:
|
|
aspects.append(training_run_properties)
|
|
|
|
# Handle run events
|
|
current_time = int(time.time() * 1000)
|
|
start_ts = start_timestamp or current_time
|
|
end_ts = end_timestamp or current_time
|
|
|
|
# Create events
|
|
aspects.append(
|
|
self._create_run_event(
|
|
status=DataProcessRunStatusClass.STARTED, timestamp=start_ts
|
|
)
|
|
)
|
|
|
|
if run_result:
|
|
aspects.append(
|
|
self._create_run_event(
|
|
status=DataProcessRunStatusClass.COMPLETE,
|
|
timestamp=end_ts,
|
|
result=run_result,
|
|
duration_millis=end_ts - start_ts,
|
|
)
|
|
)
|
|
|
|
# Create and emit MCPs
|
|
mcps = [self._create_mcp(dpi_urn, aspect) for aspect in aspects]
|
|
self._emit_mcps(mcps)
|
|
logger.info(f"Created training run: {dpi_urn}")
|
|
return dpi_urn
|
|
|
|
def create_dataset(self, name: str, platform: str, **kwargs: Any) -> str:
|
|
"""Create a dataset with flexible properties."""
|
|
dataset = Dataset(id=name, platform=platform, name=name, **kwargs)
|
|
mcps = list(dataset.generate_mcp())
|
|
self._emit_mcps(mcps)
|
|
if dataset.urn is None:
|
|
raise ValueError(f"Failed to create dataset URN for {name}")
|
|
return dataset.urn
|
|
|
|
def add_run_to_model(self, model_urn: str, run_urn: str) -> None:
|
|
"""Add a run to a model while preserving existing properties."""
|
|
self._update_entity_properties(
|
|
entity_urn=model_urn,
|
|
aspect_type=models.MLModelPropertiesClass,
|
|
updates={"trainingJobs": run_urn},
|
|
entity_type="mlModel",
|
|
skip_properties=["trainingJobs"],
|
|
)
|
|
logger.info(f"Added run {run_urn} to model {model_urn}")
|
|
|
|
def add_run_to_model_group(self, model_group_urn: str, run_urn: str) -> None:
|
|
"""Add a run to a model group while preserving existing properties."""
|
|
self._update_entity_properties(
|
|
entity_urn=model_group_urn,
|
|
aspect_type=models.MLModelGroupPropertiesClass,
|
|
updates={"trainingJobs": run_urn},
|
|
entity_type="mlModelGroup",
|
|
skip_properties=["trainingJobs"],
|
|
)
|
|
logger.info(f"Added run {run_urn} to model group {model_group_urn}")
|
|
|
|
def add_model_to_model_group(self, model_urn: str, group_urn: str) -> None:
|
|
"""Add a model to a group while preserving existing properties"""
|
|
self._update_entity_properties(
|
|
entity_urn=model_urn,
|
|
aspect_type=models.MLModelPropertiesClass,
|
|
updates={"groups": group_urn},
|
|
entity_type="mlModel",
|
|
skip_properties=["groups"],
|
|
)
|
|
logger.info(f"Added model {model_urn} to group {group_urn}")
|
|
|
|
def add_run_to_experiment(self, run_urn: str, experiment_urn: str) -> None:
|
|
"""Add a run to an experiment"""
|
|
mcp = self._create_mcp(
|
|
entity_urn=run_urn, aspect=models.ContainerClass(container=experiment_urn)
|
|
)
|
|
self._emit_mcps(mcp)
|
|
logger.info(f"Added run {run_urn} to experiment {experiment_urn}")
|
|
|
|
def add_input_datasets_to_run(self, run_urn: str, dataset_urns: List[str]) -> None:
|
|
"""Add input datasets to a run"""
|
|
mcp = self._create_mcp(
|
|
entity_urn=run_urn,
|
|
entity_type="dataProcessInstance",
|
|
aspect_name="dataProcessInstanceInput",
|
|
aspect=DataProcessInstanceInput(
|
|
inputs=[],
|
|
inputEdges=[
|
|
EdgeClass(destinationUrn=str(dataset_urn))
|
|
for dataset_urn in dataset_urns
|
|
],
|
|
),
|
|
)
|
|
self._emit_mcps(mcp)
|
|
logger.info(f"Added input datasets to run {run_urn}")
|
|
|
|
def add_output_datasets_to_run(self, run_urn: str, dataset_urns: List[str]) -> None:
|
|
"""Add output datasets to a run"""
|
|
mcp = self._create_mcp(
|
|
entity_urn=run_urn,
|
|
entity_type="dataProcessInstance",
|
|
aspect_name="dataProcessInstanceOutput",
|
|
aspect=DataProcessInstanceOutput(
|
|
outputEdges=[
|
|
EdgeClass(destinationUrn=str(dataset_urn))
|
|
for dataset_urn in dataset_urns
|
|
],
|
|
outputs=[],
|
|
),
|
|
)
|
|
self._emit_mcps(mcp)
|
|
logger.info(f"Added output datasets to run {run_urn}")
|