14 KiB
MLModel Entity
Overview
The MLModel entity represents a machine learning model in DataHub. ML models are trained on data and deployed to production environments, with comprehensive metadata including training metrics, hyperparameters, model groups, training jobs, downstream jobs, and deployments.
URN Structure
MLModel URNs follow this pattern:
urn:li:mlModel:(urn:li:dataPlatform:{platform},{model_name},{environment})
Components:
platform: The ML platform (e.g., tensorflow, pytorch, sklearn, sagemaker)model_name: Unique identifier for the modelenvironment: Fabric type (PROD, DEV, STAGING, TEST, etc.)
Examples:
urn:li:mlModel:(urn:li:dataPlatform:tensorflow,user_churn_predictor,PROD)
urn:li:mlModel:(urn:li:dataPlatform:pytorch,recommendation_model_v2,STAGING)
urn:li:mlModel:(urn:li:dataPlatform:sklearn,fraud_detector,PROD)
ML-Specific Concepts
Training Metrics
Metrics collected during model training that measure performance:
- Classification: accuracy, precision, recall, f1_score, auc_roc, auc_pr
- Regression: mse, mae, rmse, r2_score
- Loss metrics: log_loss, cross_entropy
- Custom metrics: training_time, validation_accuracy
Hyperparameters
Configuration parameters used during model training:
- Learning configuration: learning_rate, batch_size, epochs
- Architecture: hidden_layers, hidden_units, dropout_rate
- Optimization: optimizer, learning_rate_decay, momentum
- Regularization: l1_regularization, l2_regularization
Model Groups
Collections of related models (e.g., different versions of the same model family). A model can belong to one group, enabling version tracking and A/B testing scenarios.
Training Jobs
Data processing jobs or pipelines that produced this model. Creates lineage from training data to model.
Downstream Jobs
Jobs that consume or use this model for inference, scoring, or predictions. Creates lineage from model to downstream applications.
Deployments
Production environments where the model is deployed (e.g., SageMaker endpoints, Kubernetes services, REST APIs).
Creating an ML Model
Basic Example
MLModel model = MLModel.builder()
.platform("tensorflow")
.name("user_churn_predictor")
.env("PROD")
.displayName("User Churn Prediction Model")
.description("XGBoost model predicting user churn probability")
.build();
// Add training metrics
model.addTrainingMetric("accuracy", "0.94")
.addTrainingMetric("f1_score", "0.92")
.addTrainingMetric("auc_roc", "0.96");
// Add hyperparameters
model.addHyperParam("learning_rate", "0.01")
.addHyperParam("max_depth", "6")
.addHyperParam("n_estimators", "100");
// Add standard metadata
model.addTag("production")
.addOwner("urn:li:corpuser:ml_team", OwnershipType.TECHNICAL_OWNER)
.setDomain("urn:li:domain:MachineLearning");
// Save to DataHub
client.entities().upsert(model);
Builder Options
MLModel model = MLModel.builder()
.platform("pytorch") // Required: ML platform
.name("recommendation_model") // Required: Model identifier
.env("PROD") // Optional: Default is PROD
.displayName("Product Recommender") // Optional: Human-readable name
.description("Collaborative filtering model") // Optional
.externalUrl("https://mlflow.company.com/models/123") // Optional
.build();
ML-Specific Operations
Training Metrics
// Add individual metrics
model.addTrainingMetric("accuracy", "0.947")
.addTrainingMetric("precision", "0.934")
.addTrainingMetric("recall", "0.921");
// Set all metrics at once
MLMetric metric1 = new MLMetric();
metric1.setName("f1_score");
metric1.setValue("0.927");
MLMetric metric2 = new MLMetric();
metric2.setName("auc_roc");
metric2.setValue("0.965");
model.setTrainingMetrics(List.of(metric1, metric2));
// Get metrics
List<MLMetric> metrics = model.getTrainingMetrics();
Hyperparameters
// Add individual hyperparameters
model.addHyperParam("learning_rate", "0.001")
.addHyperParam("batch_size", "64")
.addHyperParam("epochs", "100");
// Set all hyperparameters at once
MLHyperParam param1 = new MLHyperParam();
param1.setName("dropout_rate");
param1.setValue("0.3");
MLHyperParam param2 = new MLHyperParam();
param2.setName("optimizer");
param2.setValue("adam");
model.setHyperParams(List.of(param1, param2));
// Get hyperparameters
List<MLHyperParam> params = model.getHyperParams();
Model Groups
// Set model group (creates relationship)
model.setModelGroup("urn:li:mlModelGroup:(urn:li:dataPlatform:tensorflow,churn_models,PROD)");
// Get model group
String group = model.getModelGroup();
Training Jobs (Lineage)
// Add training jobs
model.addTrainingJob("urn:li:dataJob:(urn:li:dataFlow:(airflow,ml_training_dag,prod),train_model)")
.addTrainingJob("urn:li:dataProcessInstance:(urn:li:dataFlow:(airflow,ml_training_dag,prod),2025-10-15T08:00:00Z)");
// Remove training job
model.removeTrainingJob("urn:li:dataJob:(urn:li:dataFlow:(airflow,ml_training_dag,prod),train_model)");
// Get training jobs
List<String> jobs = model.getTrainingJobs();
Downstream Jobs (Lineage)
// Add downstream jobs
model.addDownstreamJob("urn:li:dataJob:(urn:li:dataFlow:(airflow,scoring_dag,prod),score_customers)")
.addDownstreamJob("urn:li:dataJob:(urn:li:dataFlow:(airflow,inference_dag,prod),predict)");
// Remove downstream job
model.removeDownstreamJob("urn:li:dataJob:(urn:li:dataFlow:(airflow,scoring_dag,prod),score_customers)");
// Get downstream jobs
List<String> jobs = model.getDownstreamJobs();
Deployments
// Add deployments
model.addDeployment("urn:li:mlModelDeployment:(urn:li:dataPlatform:sagemaker,model-staging)")
.addDeployment("urn:li:mlModelDeployment:(urn:li:dataPlatform:sagemaker,model-production)");
// Remove deployment
model.removeDeployment("urn:li:mlModelDeployment:(urn:li:dataPlatform:sagemaker,model-staging)");
// Get deployments
List<String> deployments = model.getDeployments();
Standard Property Operations
Display Name and Description
// Set display name
model.setDisplayName("Customer Lifetime Value Model");
// Set description
model.setDescription("Deep learning model predicting CLV based on purchase history");
// Set external URL
model.setExternalUrl("https://mlflow.company.com/experiments/42/runs/abc123");
// Get properties
String name = model.getDisplayName();
String desc = model.getDescription();
String url = model.getExternalUrl();
Custom Properties
// Add individual properties
model.addCustomProperty("framework", "TensorFlow 2.14")
.addCustomProperty("model_version", "2.1.0")
.addCustomProperty("training_date", "2025-10-15");
// Set all properties at once
Map<String, String> props = new HashMap<>();
props.put("deployment_date", "2025-10-20");
props.put("inference_latency_ms", "15");
model.setCustomProperties(props);
// Get properties
Map<String, String> customProps = model.getCustomProperties();
Standard Metadata Operations
Tags
// Add tags (with or without urn:li:tag: prefix)
model.addTag("production")
.addTag("urn:li:tag:ml-model")
.addTag("deep-learning");
// Remove tag
model.removeTag("production");
Owners
// Add owners with different types
model.addOwner("urn:li:corpuser:ml_platform_team", OwnershipType.TECHNICAL_OWNER)
.addOwner("urn:li:corpuser:data_science_team", OwnershipType.DATA_STEWARD);
// Remove owner
model.removeOwner("urn:li:corpuser:ml_platform_team");
Glossary Terms
// Add glossary terms
model.addTerm("urn:li:glossaryTerm:MachineLearning.Model")
.addTerm("urn:li:glossaryTerm:CustomerAnalytics.Prediction");
// Remove term
model.removeTerm("urn:li:glossaryTerm:MachineLearning.Model");
Domain
// Set domain
model.setDomain("urn:li:domain:MachineLearning");
// Remove specific domain
model.removeDomain("urn:li:domain:MachineLearning");
// Or clear all domains
model.clearDomains();
Common Patterns
Complete ML Model Workflow
// 1. Create model with basic metadata
MLModel model = MLModel.builder()
.platform("tensorflow")
.name("customer_ltv_predictor")
.env("PROD")
.displayName("Customer Lifetime Value Prediction Model")
.description("Deep learning model predicting customer lifetime value")
.externalUrl("https://mlflow.company.com/experiments/42")
.build();
// 2. Add comprehensive training metrics
model.addTrainingMetric("accuracy", "0.947")
.addTrainingMetric("precision", "0.934")
.addTrainingMetric("recall", "0.921")
.addTrainingMetric("f1_score", "0.927")
.addTrainingMetric("auc_roc", "0.965")
.addTrainingMetric("training_time_minutes", "142.5");
// 3. Add comprehensive hyperparameters
model.addHyperParam("learning_rate", "0.001")
.addHyperParam("batch_size", "64")
.addHyperParam("epochs", "100")
.addHyperParam("optimizer", "adam")
.addHyperParam("dropout_rate", "0.3")
.addHyperParam("hidden_layers", "3");
// 4. Set model group for version tracking
model.setModelGroup("urn:li:mlModelGroup:(urn:li:dataPlatform:tensorflow,ltv_models,PROD)");
// 5. Add training lineage
model.addTrainingJob("urn:li:dataJob:(urn:li:dataFlow:(airflow,ml_training_dag,prod),train_ltv)")
.addTrainingJob("urn:li:dataProcessInstance:(urn:li:dataFlow:(airflow,ml_training_dag,prod),2025-10-15T08:00:00Z)");
// 6. Add downstream lineage
model.addDownstreamJob("urn:li:dataJob:(urn:li:dataFlow:(airflow,customer_scoring,prod),score)")
.addDownstreamJob("urn:li:dataJob:(urn:li:dataFlow:(airflow,campaign_targeting,prod),target)");
// 7. Add deployment information
model.addDeployment("urn:li:mlModelDeployment:(urn:li:dataPlatform:sagemaker,ltv-staging)")
.addDeployment("urn:li:mlModelDeployment:(urn:li:dataPlatform:sagemaker,ltv-production)");
// 8. Add organizational metadata
model.addTag("production")
.addTag("deep-learning")
.addTag("business-critical")
.addOwner("urn:li:corpuser:ml_platform", OwnershipType.TECHNICAL_OWNER)
.addOwner("urn:li:corpuser:data_science", OwnershipType.DATA_STEWARD)
.addTerm("urn:li:glossaryTerm:MachineLearning.Model")
.setDomain("urn:li:domain:MachineLearning");
// 9. Add custom properties
model.addCustomProperty("framework", "TensorFlow 2.14")
.addCustomProperty("model_version", "2.1.0")
.addCustomProperty("training_date", "2025-10-15")
.addCustomProperty("deployment_date", "2025-10-20")
.addCustomProperty("inference_latency_ms", "15");
// 10. Save to DataHub
client.entities().upsert(model);
Model Training to Deployment Flow
// Step 1: Create model after training
MLModel model = MLModel.builder()
.platform("pytorch")
.name("fraud_detector_v2")
.env("DEV")
.build();
// Step 2: Add training results
model.addTrainingMetric("accuracy", "0.97")
.addTrainingMetric("precision", "0.95")
.addHyperParam("learning_rate", "0.001")
.addHyperParam("batch_size", "128")
.setModelGroup("urn:li:mlModelGroup:(urn:li:dataPlatform:pytorch,fraud_models,DEV)");
client.entities().upsert(model);
// Step 3: Promote to staging
MLModel stagingModel = MLModel.builder()
.platform("pytorch")
.name("fraud_detector_v2")
.env("STAGING")
.build();
stagingModel.setModelGroup("urn:li:mlModelGroup:(urn:li:dataPlatform:pytorch,fraud_models,STAGING)")
.addDeployment("urn:li:mlModelDeployment:(urn:li:dataPlatform:sagemaker,fraud-staging)");
client.entities().upsert(stagingModel);
// Step 4: Deploy to production
MLModel prodModel = MLModel.builder()
.platform("pytorch")
.name("fraud_detector_v2")
.env("PROD")
.build();
prodModel.setModelGroup("urn:li:mlModelGroup:(urn:li:dataPlatform:pytorch,fraud_models,PROD)")
.addDeployment("urn:li:mlModelDeployment:(urn:li:dataPlatform:sagemaker,fraud-production)")
.addTag("production")
.addOwner("urn:li:corpuser:fraud_ml_team", OwnershipType.TECHNICAL_OWNER)
.setDomain("urn:li:domain:FraudPrevention");
client.entities().upsert(prodModel);
A/B Testing Scenario
// Model A (current champion)
MLModel modelA = MLModel.builder()
.platform("tensorflow")
.name("recommendation_model_a")
.env("PROD")
.displayName("Recommendation Model A (Champion)")
.build();
modelA.addTrainingMetric("accuracy", "0.92")
.setModelGroup("urn:li:mlModelGroup:(urn:li:dataPlatform:tensorflow,recommendation_models,PROD)")
.addDeployment("urn:li:mlModelDeployment:(urn:li:dataPlatform:sagemaker,recommend-prod-80pct)")
.addCustomProperty("traffic_percentage", "80");
// Model B (challenger)
MLModel modelB = MLModel.builder()
.platform("tensorflow")
.name("recommendation_model_b")
.env("PROD")
.displayName("Recommendation Model B (Challenger)")
.build();
modelB.addTrainingMetric("accuracy", "0.94")
.setModelGroup("urn:li:mlModelGroup:(urn:li:dataPlatform:tensorflow,recommendation_models,PROD)")
.addDeployment("urn:li:mlModelDeployment:(urn:li:dataPlatform:sagemaker,recommend-prod-20pct)")
.addCustomProperty("traffic_percentage", "20")
.addCustomProperty("experiment_id", "ab_test_2025_10");
client.entities().upsert(modelA);
client.entities().upsert(modelB);
Best Practices
-
Use descriptive names: Model names should clearly indicate purpose (e.g.,
user_churn_predictor_v2,fraud_detection_xgboost) -
Track comprehensive metrics: Include both training and validation metrics for transparency
-
Document hyperparameters: Record all hyperparameters used for reproducibility
-
Maintain lineage: Always link training jobs and downstream consumers
-
Use model groups: Group related models together for easier versioning
-
Tag appropriately: Use tags like
production,experimental,deprecated -
Set ownership: Assign technical owners (ML engineers) and data stewards
-
Add deployment info: Track where models are deployed for operational monitoring
-
Use custom properties: Store framework versions, training dates, performance benchmarks
-
Link to external systems: Use
externalUrlto link to MLflow, SageMaker, or other ML platforms
See Also
- Dataset Entity - For training data lineage
- DataJob Entity - For training job metadata
- SDK V2 Overview - General SDK concepts