mirror of
				https://github.com/datahub-project/datahub.git
				synced 2025-10-31 10:49:00 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			452 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			452 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import logging
 | |
| import time
 | |
| from typing import Any, Dict, List, Optional, Union
 | |
| 
 | |
| import datahub.metadata.schema_classes as models
 | |
| from datahub.api.entities.dataset.dataset import Dataset
 | |
| from datahub.emitter.mcp import MetadataChangeProposalWrapper
 | |
| from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph
 | |
| from datahub.ingestion.source.common.subtypes import MLAssetSubTypes
 | |
| from datahub.metadata.com.linkedin.pegasus2avro.dataprocess import (
 | |
|     DataProcessInstanceInput,
 | |
|     DataProcessInstanceOutput,
 | |
| )
 | |
| from datahub.metadata.schema_classes import (
 | |
|     ChangeTypeClass,
 | |
|     DataProcessInstanceRunResultClass,
 | |
|     DataProcessRunStatusClass,
 | |
|     EdgeClass,
 | |
| )
 | |
| from datahub.metadata.urns import (
 | |
|     ContainerUrn,
 | |
|     DataPlatformUrn,
 | |
|     MlModelGroupUrn,
 | |
|     MlModelUrn,
 | |
|     VersionSetUrn,
 | |
| )
 | |
| 
 | |
| logging.basicConfig(level=logging.INFO)
 | |
| logger = logging.getLogger(__name__)
 | |
| 
 | |
| 
 | |
| class DatahubAIClient:
 | |
|     """Client for creating and managing MLflow metadata in DataHub."""
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         token: Optional[str] = None,
 | |
|         server_url: str = "http://localhost:8080",
 | |
|         platform: str = "mlflow",
 | |
|     ) -> None:
 | |
|         """Initialize the DataHub AI client.
 | |
| 
 | |
|         Args:
 | |
|             token: DataHub access token
 | |
|             server_url: DataHub server URL (defaults to http://localhost:8080)
 | |
|             platform: Platform name (defaults to mlflow)
 | |
|         """
 | |
|         self.token = token
 | |
|         self.server_url = server_url
 | |
|         self.platform = platform
 | |
|         self.graph = DataHubGraph(
 | |
|             DatahubClientConfig(
 | |
|                 server=server_url,
 | |
|                 token=token,
 | |
|                 extra_headers={"Authorization": f"Bearer {token}"},
 | |
|             )
 | |
|         )
 | |
| 
 | |
|     def _create_timestamp(
 | |
|         self, timestamp: Optional[int] = None
 | |
|     ) -> models.TimeStampClass:
 | |
|         """Helper to create timestamp with current time if not provided"""
 | |
|         return models.TimeStampClass(
 | |
|             time=timestamp or int(time.time() * 1000), actor="urn:li:corpuser:datahub"
 | |
|         )
 | |
| 
 | |
|     def _emit_mcps(self, mcps: Union[MetadataChangeProposalWrapper, List[Any]]) -> None:
 | |
|         """Helper to emit MCPs with proper connection handling"""
 | |
|         if not isinstance(mcps, list):
 | |
|             mcps = [mcps]
 | |
|         with self.graph:
 | |
|             for mcp in mcps:
 | |
|                 self.graph.emit(mcp)
 | |
| 
 | |
|     def _get_aspect(
 | |
|         self, entity_urn: str, aspect_type: Any, default_constructor: Any = None
 | |
|     ) -> Any:
 | |
|         """Helper to safely get an aspect with fallback"""
 | |
|         try:
 | |
|             return self.graph.get_aspect(entity_urn=entity_urn, aspect_type=aspect_type)
 | |
|         except Exception as e:
 | |
|             logger.warning(f"Could not fetch aspect for {entity_urn}: {e}")
 | |
|             return default_constructor() if default_constructor else None
 | |
| 
 | |
|     def _create_properties_class(
 | |
|         self, props_class: Any, props_dict: Optional[Dict[str, Any]] = None
 | |
|     ) -> Any:
 | |
|         """Helper to create properties class with provided values"""
 | |
|         if props_dict is None:
 | |
|             props_dict = {}
 | |
| 
 | |
|         filtered_props = {k: v for k, v in props_dict.items() if v is not None}
 | |
| 
 | |
|         if hasattr(props_class, "created"):
 | |
|             filtered_props.setdefault("created", self._create_timestamp())
 | |
|         if hasattr(props_class, "lastModified"):
 | |
|             filtered_props.setdefault("lastModified", self._create_timestamp())
 | |
| 
 | |
|         return props_class(**filtered_props)
 | |
| 
 | |
|     def _update_list_property(
 | |
|         self, existing_list: Optional[List[str]], new_item: str
 | |
|     ) -> List[str]:
 | |
|         """Helper to update a list property while maintaining uniqueness"""
 | |
|         items = set(existing_list if existing_list else [])
 | |
|         items.add(new_item)
 | |
|         return list(items)
 | |
| 
 | |
|     def _create_mcp(
 | |
|         self,
 | |
|         entity_urn: str,
 | |
|         aspect: Any,
 | |
|         entity_type: Optional[str] = None,
 | |
|         aspect_name: Optional[str] = None,
 | |
|         change_type: str = ChangeTypeClass.UPSERT,
 | |
|     ) -> MetadataChangeProposalWrapper:
 | |
|         """Helper to create an MCP with standard parameters"""
 | |
|         mcp_args = {"entityUrn": entity_urn, "aspect": aspect}
 | |
|         if entity_type:
 | |
|             mcp_args["entityType"] = entity_type
 | |
|         if aspect_name:
 | |
|             mcp_args["aspectName"] = aspect_name
 | |
|         mcp_args["changeType"] = change_type
 | |
|         return MetadataChangeProposalWrapper(**mcp_args)
 | |
| 
 | |
|     def _update_entity_properties(
 | |
|         self,
 | |
|         entity_urn: str,
 | |
|         aspect_type: Any,
 | |
|         updates: Dict[str, Any],
 | |
|         entity_type: str,
 | |
|         skip_properties: Optional[List[str]] = None,
 | |
|     ) -> None:
 | |
|         """Helper to update entity properties while preserving existing ones"""
 | |
|         existing_props = self._get_aspect(entity_urn, aspect_type, aspect_type)
 | |
|         skip_list = [] if skip_properties is None else skip_properties
 | |
|         props = self._copy_existing_properties(existing_props, skip_list) or {}
 | |
| 
 | |
|         for key, value in updates.items():
 | |
|             if isinstance(value, str) and hasattr(existing_props, key):
 | |
|                 existing_value = getattr(existing_props, key, [])
 | |
|                 props[key] = self._update_list_property(existing_value, value)
 | |
|             else:
 | |
|                 props[key] = value
 | |
| 
 | |
|         updated_props = self._create_properties_class(aspect_type, props)
 | |
|         mcp = self._create_mcp(
 | |
|             entity_urn, updated_props, entity_type, f"{entity_type}Properties"
 | |
|         )
 | |
|         self._emit_mcps(mcp)
 | |
| 
 | |
|     def _copy_existing_properties(
 | |
|         self, existing_props: Any, skip_properties: Optional[List[str]] = None
 | |
|     ) -> Dict[str, Any]:
 | |
|         """Helper to copy existing properties while skipping specified ones"""
 | |
|         skip_list = [] if skip_properties is None else skip_properties
 | |
| 
 | |
|         internal_props = {
 | |
|             "ASPECT_INFO",
 | |
|             "ASPECT_NAME",
 | |
|             "ASPECT_TYPE",
 | |
|             "RECORD_SCHEMA",
 | |
|         }
 | |
|         skip_list.extend(internal_props)
 | |
| 
 | |
|         props: Dict[str, Any] = {}
 | |
|         if existing_props:
 | |
|             for prop in dir(existing_props):
 | |
|                 if (
 | |
|                     prop.startswith("_")
 | |
|                     or callable(getattr(existing_props, prop))
 | |
|                     or prop in skip_list
 | |
|                 ):
 | |
|                     continue
 | |
| 
 | |
|                 value = getattr(existing_props, prop)
 | |
|                 if value is not None:
 | |
|                     props[prop] = value
 | |
| 
 | |
|             if hasattr(existing_props, "created"):
 | |
|                 props.setdefault("created", self._create_timestamp())
 | |
|             if hasattr(existing_props, "lastModified"):
 | |
|                 props.setdefault("lastModified", self._create_timestamp())
 | |
| 
 | |
|         return props
 | |
| 
 | |
|     def _create_run_event(
 | |
|         self,
 | |
|         status: str,
 | |
|         timestamp: int,
 | |
|         result: Optional[str] = None,
 | |
|         duration_millis: Optional[int] = None,
 | |
|     ) -> models.DataProcessInstanceRunEventClass:
 | |
|         """Helper to create run event with common parameters."""
 | |
|         event_args: Dict[str, Any] = {
 | |
|             "timestampMillis": timestamp,
 | |
|             "status": status,
 | |
|             "attempt": 1,
 | |
|         }
 | |
| 
 | |
|         if result:
 | |
|             event_args["result"] = DataProcessInstanceRunResultClass(
 | |
|                 type=result, nativeResultType=str(result)
 | |
|             )
 | |
|         if duration_millis:
 | |
|             event_args["durationMillis"] = duration_millis
 | |
| 
 | |
|         return models.DataProcessInstanceRunEventClass(**event_args)
 | |
| 
 | |
|     def create_model_group(
 | |
|         self,
 | |
|         group_id: str,
 | |
|         properties: Optional[models.MLModelGroupPropertiesClass] = None,
 | |
|         **kwargs: Any,
 | |
|     ) -> str:
 | |
|         """Create an ML model group with either property class or kwargs."""
 | |
|         model_group_urn = MlModelGroupUrn(platform=self.platform, name=group_id)
 | |
| 
 | |
|         if properties is None:
 | |
|             properties = self._create_properties_class(
 | |
|                 models.MLModelGroupPropertiesClass, kwargs
 | |
|             )
 | |
| 
 | |
|         mcp = self._create_mcp(
 | |
|             str(model_group_urn), properties, "mlModelGroup", "mlModelGroupProperties"
 | |
|         )
 | |
|         self._emit_mcps(mcp)
 | |
|         logger.info(f"Created model group: {model_group_urn}")
 | |
|         return str(model_group_urn)
 | |
| 
 | |
|     def create_model(
 | |
|         self,
 | |
|         model_id: str,
 | |
|         version: str,
 | |
|         alias: Optional[str] = None,
 | |
|         properties: Optional[models.MLModelPropertiesClass] = None,
 | |
|         **kwargs: Any,
 | |
|     ) -> str:
 | |
|         """Create an ML model with either property classes or kwargs."""
 | |
|         model_urn = MlModelUrn(platform=self.platform, name=model_id)
 | |
|         version_set_urn = VersionSetUrn(
 | |
|             id=f"mlmodel_{model_id}_versions", entity_type="mlModel"
 | |
|         )
 | |
| 
 | |
|         # Handle model properties
 | |
|         if properties is None:
 | |
|             # If no properties provided, create from kwargs
 | |
|             properties = self._create_properties_class(
 | |
|                 models.MLModelPropertiesClass, kwargs
 | |
|             )
 | |
| 
 | |
|         # Ensure version is set in model properties
 | |
|         version_tag = models.VersionTagClass(versionTag=str(version))
 | |
|         properties.version = version_tag
 | |
| 
 | |
|         # Create version properties
 | |
|         version_props = {
 | |
|             "version": version_tag,
 | |
|             "versionSet": str(version_set_urn),
 | |
|             "sortId": str(version_tag).zfill(10),
 | |
|         }
 | |
| 
 | |
|         # Add alias if provided
 | |
|         if alias:
 | |
|             version_props["aliases"] = [models.VersionTagClass(versionTag=alias)]
 | |
| 
 | |
|         version_properties = self._create_properties_class(
 | |
|             models.VersionPropertiesClass, version_props
 | |
|         )
 | |
| 
 | |
|         mcps = [
 | |
|             self._create_mcp(
 | |
|                 str(model_urn), properties, "mlModel", "mlModelProperties"
 | |
|             ),
 | |
|             self._create_mcp(
 | |
|                 str(model_urn), version_properties, "mlModel", "versionProperties"
 | |
|             ),
 | |
|         ]
 | |
|         self._emit_mcps(mcps)
 | |
|         logger.info(f"Created model: {model_urn}")
 | |
|         return str(model_urn)
 | |
| 
 | |
|     def create_experiment(
 | |
|         self,
 | |
|         experiment_id: str,
 | |
|         properties: Optional[models.ContainerPropertiesClass] = None,
 | |
|         **kwargs: Any,
 | |
|     ) -> str:
 | |
|         """Create an ML experiment with either property class or kwargs."""
 | |
|         container_urn = ContainerUrn(guid=experiment_id)
 | |
|         platform_urn = DataPlatformUrn(platform_name=self.platform)
 | |
| 
 | |
|         if properties is None:
 | |
|             properties = self._create_properties_class(
 | |
|                 models.ContainerPropertiesClass, kwargs
 | |
|             )
 | |
| 
 | |
|         container_subtype = models.SubTypesClass(
 | |
|             typeNames=[MLAssetSubTypes.MLFLOW_EXPERIMENT]
 | |
|         )
 | |
|         browse_path = models.BrowsePathsV2Class(path=[])
 | |
|         platform_instance = models.DataPlatformInstanceClass(platform=str(platform_urn))
 | |
| 
 | |
|         mcps = MetadataChangeProposalWrapper.construct_many(
 | |
|             entityUrn=str(container_urn),
 | |
|             aspects=[container_subtype, properties, browse_path, platform_instance],
 | |
|         )
 | |
|         self._emit_mcps(mcps)
 | |
|         logger.info(f"Created experiment: {container_urn}")
 | |
|         return str(container_urn)
 | |
| 
 | |
|     def create_training_run(
 | |
|         self,
 | |
|         run_id: str,
 | |
|         properties: Optional[models.DataProcessInstancePropertiesClass] = None,
 | |
|         training_run_properties: Optional[models.MLTrainingRunPropertiesClass] = None,
 | |
|         run_result: Optional[str] = None,
 | |
|         start_timestamp: Optional[int] = None,
 | |
|         end_timestamp: Optional[int] = None,
 | |
|         **kwargs: Any,
 | |
|     ) -> str:
 | |
|         """Create a training run with properties and events."""
 | |
|         dpi_urn = f"urn:li:dataProcessInstance:{run_id}"
 | |
| 
 | |
|         # Create basic properties and aspects
 | |
|         aspects: List[Any] = []
 | |
| 
 | |
|         # Only add properties if they are provided
 | |
|         if properties is not None:
 | |
|             aspects.append(properties)
 | |
| 
 | |
|         # Always add the subtype
 | |
|         aspects.append(
 | |
|             models.SubTypesClass(typeNames=[MLAssetSubTypes.MLFLOW_TRAINING_RUN])
 | |
|         )
 | |
| 
 | |
|         # Add training run properties if provided
 | |
|         if training_run_properties:
 | |
|             aspects.append(training_run_properties)
 | |
| 
 | |
|         # Handle run events
 | |
|         current_time = int(time.time() * 1000)
 | |
|         start_ts = start_timestamp or current_time
 | |
|         end_ts = end_timestamp or current_time
 | |
| 
 | |
|         # Create events
 | |
|         aspects.append(
 | |
|             self._create_run_event(
 | |
|                 status=DataProcessRunStatusClass.STARTED, timestamp=start_ts
 | |
|             )
 | |
|         )
 | |
| 
 | |
|         if run_result:
 | |
|             aspects.append(
 | |
|                 self._create_run_event(
 | |
|                     status=DataProcessRunStatusClass.COMPLETE,
 | |
|                     timestamp=end_ts,
 | |
|                     result=run_result,
 | |
|                     duration_millis=end_ts - start_ts,
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|         # Create and emit MCPs
 | |
|         mcps = [self._create_mcp(dpi_urn, aspect) for aspect in aspects]
 | |
|         self._emit_mcps(mcps)
 | |
|         logger.info(f"Created training run: {dpi_urn}")
 | |
|         return dpi_urn
 | |
| 
 | |
|     def create_dataset(self, name: str, platform: str, **kwargs: Any) -> str:
 | |
|         """Create a dataset with flexible properties."""
 | |
|         dataset = Dataset(id=name, platform=platform, name=name, **kwargs)
 | |
|         mcps = list(dataset.generate_mcp())
 | |
|         self._emit_mcps(mcps)
 | |
|         if dataset.urn is None:
 | |
|             raise ValueError(f"Failed to create dataset URN for {name}")
 | |
|         return dataset.urn
 | |
| 
 | |
|     def add_run_to_model(self, model_urn: str, run_urn: str) -> None:
 | |
|         """Add a run to a model while preserving existing properties."""
 | |
|         self._update_entity_properties(
 | |
|             entity_urn=model_urn,
 | |
|             aspect_type=models.MLModelPropertiesClass,
 | |
|             updates={"trainingJobs": run_urn},
 | |
|             entity_type="mlModel",
 | |
|             skip_properties=["trainingJobs"],
 | |
|         )
 | |
|         logger.info(f"Added run {run_urn} to model {model_urn}")
 | |
| 
 | |
|     def add_run_to_model_group(self, model_group_urn: str, run_urn: str) -> None:
 | |
|         """Add a run to a model group while preserving existing properties."""
 | |
|         self._update_entity_properties(
 | |
|             entity_urn=model_group_urn,
 | |
|             aspect_type=models.MLModelGroupPropertiesClass,
 | |
|             updates={"trainingJobs": run_urn},
 | |
|             entity_type="mlModelGroup",
 | |
|             skip_properties=["trainingJobs"],
 | |
|         )
 | |
|         logger.info(f"Added run {run_urn} to model group {model_group_urn}")
 | |
| 
 | |
|     def add_model_to_model_group(self, model_urn: str, group_urn: str) -> None:
 | |
|         """Add a model to a group while preserving existing properties"""
 | |
|         self._update_entity_properties(
 | |
|             entity_urn=model_urn,
 | |
|             aspect_type=models.MLModelPropertiesClass,
 | |
|             updates={"groups": group_urn},
 | |
|             entity_type="mlModel",
 | |
|             skip_properties=["groups"],
 | |
|         )
 | |
|         logger.info(f"Added model {model_urn} to group {group_urn}")
 | |
| 
 | |
|     def add_run_to_experiment(self, run_urn: str, experiment_urn: str) -> None:
 | |
|         """Add a run to an experiment"""
 | |
|         mcp = self._create_mcp(
 | |
|             entity_urn=run_urn, aspect=models.ContainerClass(container=experiment_urn)
 | |
|         )
 | |
|         self._emit_mcps(mcp)
 | |
|         logger.info(f"Added run {run_urn} to experiment {experiment_urn}")
 | |
| 
 | |
|     def add_input_datasets_to_run(self, run_urn: str, dataset_urns: List[str]) -> None:
 | |
|         """Add input datasets to a run"""
 | |
|         mcp = self._create_mcp(
 | |
|             entity_urn=run_urn,
 | |
|             entity_type="dataProcessInstance",
 | |
|             aspect_name="dataProcessInstanceInput",
 | |
|             aspect=DataProcessInstanceInput(
 | |
|                 inputs=[],
 | |
|                 inputEdges=[
 | |
|                     EdgeClass(destinationUrn=str(dataset_urn))
 | |
|                     for dataset_urn in dataset_urns
 | |
|                 ],
 | |
|             ),
 | |
|         )
 | |
|         self._emit_mcps(mcp)
 | |
|         logger.info(f"Added input datasets to run {run_urn}")
 | |
| 
 | |
|     def add_output_datasets_to_run(self, run_urn: str, dataset_urns: List[str]) -> None:
 | |
|         """Add output datasets to a run"""
 | |
|         mcp = self._create_mcp(
 | |
|             entity_urn=run_urn,
 | |
|             entity_type="dataProcessInstance",
 | |
|             aspect_name="dataProcessInstanceOutput",
 | |
|             aspect=DataProcessInstanceOutput(
 | |
|                 outputEdges=[
 | |
|                     EdgeClass(destinationUrn=str(dataset_urn))
 | |
|                     for dataset_urn in dataset_urns
 | |
|                 ],
 | |
|                 outputs=[],
 | |
|             ),
 | |
|         )
 | |
|         self._emit_mcps(mcp)
 | |
|         logger.info(f"Added output datasets to run {run_urn}")
 | 
