mirror of
https://github.com/datahub-project/datahub.git
synced 2025-07-04 15:50:14 +00:00
401 lines
13 KiB
Python
401 lines
13 KiB
Python
![]() |
import argparse
|
||
|
from datetime import datetime
|
||
|
|
||
|
from dh_ai_client import DatahubAIClient
|
||
|
|
||
|
from datahub.emitter.mcp_builder import (
|
||
|
ContainerKey,
|
||
|
)
|
||
|
from datahub.ingestion.source.common.subtypes import MLAssetSubTypes
|
||
|
from datahub.metadata.com.linkedin.pegasus2avro.dataprocess import RunResultType
|
||
|
from datahub.metadata.schema_classes import (
|
||
|
AuditStampClass,
|
||
|
DataProcessInstancePropertiesClass,
|
||
|
MLHyperParamClass,
|
||
|
MLMetricClass,
|
||
|
MLTrainingRunPropertiesClass,
|
||
|
)
|
||
|
from datahub.metadata.urns import (
|
||
|
CorpUserUrn,
|
||
|
DataProcessInstanceUrn,
|
||
|
GlossaryTermUrn,
|
||
|
TagUrn,
|
||
|
)
|
||
|
from datahub.sdk.container import Container
|
||
|
from datahub.sdk.dataset import Dataset
|
||
|
from datahub.sdk.mlmodel import MLModel
|
||
|
from datahub.sdk.mlmodelgroup import MLModelGroup
|
||
|
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument("--token", required=False, help="DataHub access token")
|
||
|
parser.add_argument(
|
||
|
"--server_url",
|
||
|
required=False,
|
||
|
default="http://localhost:8080",
|
||
|
help="DataHub server URL (defaults to http://localhost:8080)",
|
||
|
)
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
# Initialize client
|
||
|
client = DatahubAIClient(token=args.token, server_url=args.server_url)
|
||
|
|
||
|
# Use a unique prefix for all IDs to avoid conflicts
|
||
|
prefix = "test"
|
||
|
|
||
|
# Define all entity IDs upfront
|
||
|
# Basic entity IDs
|
||
|
basic_model_group_id = f"{prefix}_basic_group"
|
||
|
basic_model_id = f"{prefix}_basic_model"
|
||
|
basic_experiment_id = f"{prefix}_basic_experiment"
|
||
|
basic_run_id = f"{prefix}_basic_run"
|
||
|
basic_dataset_id = f"{prefix}_basic_dataset"
|
||
|
|
||
|
# Advanced entity IDs
|
||
|
advanced_model_group_id = f"{prefix}_airline_forecast_models_group"
|
||
|
advanced_model_id = f"{prefix}_arima_model"
|
||
|
advanced_experiment_id = f"{prefix}_airline_forecast_experiment"
|
||
|
advanced_run_id = f"{prefix}_simple_training_run"
|
||
|
advanced_input_dataset_id = f"{prefix}_iris_input"
|
||
|
advanced_output_dataset_id = f"{prefix}_iris_output"
|
||
|
|
||
|
# Display names with prefix
|
||
|
basic_model_group_name = f"{prefix} Basic Group"
|
||
|
basic_model_name = f"{prefix} Basic Model"
|
||
|
basic_experiment_name = f"{prefix} Basic Experiment"
|
||
|
basic_run_name = f"{prefix} Basic Run"
|
||
|
basic_dataset_name = f"{prefix} Basic Dataset"
|
||
|
|
||
|
advanced_model_group_name = f"{prefix} Airline Forecast Models Group"
|
||
|
advanced_model_name = f"{prefix} ARIMA Model"
|
||
|
advanced_experiment_name = f"{prefix} Airline Forecast Experiment"
|
||
|
advanced_run_name = f"{prefix} Simple Training Run"
|
||
|
advanced_input_dataset_name = f"{prefix} Iris Training Input Data"
|
||
|
advanced_output_dataset_name = f"{prefix} Iris Model Output Data"
|
||
|
|
||
|
|
||
|
def create_basic_model_group():
|
||
|
"""Create a basic model group."""
|
||
|
print("Creating basic model group...")
|
||
|
basic_model_group = MLModelGroup(
|
||
|
id=basic_model_group_id,
|
||
|
platform="mlflow",
|
||
|
name=basic_model_group_name,
|
||
|
)
|
||
|
client._emit_mcps(basic_model_group.as_mcps())
|
||
|
return basic_model_group
|
||
|
|
||
|
|
||
|
def create_advanced_model_group():
|
||
|
"""Create an advanced model group."""
|
||
|
print("Creating advanced model group...")
|
||
|
advanced_model_group = MLModelGroup(
|
||
|
id=advanced_model_group_id,
|
||
|
platform="mlflow",
|
||
|
name=advanced_model_group_name,
|
||
|
description="Group of models for airline passenger forecasting",
|
||
|
created=datetime.now(),
|
||
|
last_modified=datetime.now(),
|
||
|
owners=[CorpUserUrn("urn:li:corpuser:datahub")],
|
||
|
external_url="https://www.linkedin.com/in/datahub",
|
||
|
tags=["urn:li:tag:forecasting", "urn:li:tag:arima"],
|
||
|
terms=["urn:li:glossaryTerm:forecasting"],
|
||
|
custom_properties={"team": "forecasting"},
|
||
|
)
|
||
|
client._emit_mcps(advanced_model_group.as_mcps())
|
||
|
return advanced_model_group
|
||
|
|
||
|
|
||
|
def create_basic_model():
|
||
|
"""Create a basic model."""
|
||
|
print("Creating basic model...")
|
||
|
basic_model = MLModel(
|
||
|
id=basic_model_id,
|
||
|
platform="mlflow",
|
||
|
name=basic_model_name,
|
||
|
)
|
||
|
client._emit_mcps(basic_model.as_mcps())
|
||
|
return basic_model
|
||
|
|
||
|
|
||
|
def create_advanced_model():
|
||
|
"""Create an advanced model."""
|
||
|
print("Creating advanced model...")
|
||
|
advanced_model = MLModel(
|
||
|
id=advanced_model_id,
|
||
|
platform="mlflow",
|
||
|
name=advanced_model_name,
|
||
|
description="ARIMA model for airline passenger forecasting",
|
||
|
created=datetime.now(),
|
||
|
last_modified=datetime.now(),
|
||
|
owners=[CorpUserUrn("urn:li:corpuser:datahub")],
|
||
|
external_url="https://www.linkedin.com/in/datahub",
|
||
|
tags=["urn:li:tag:forecasting", "urn:li:tag:arima"],
|
||
|
terms=["urn:li:glossaryTerm:forecasting"],
|
||
|
custom_properties={"team": "forecasting"},
|
||
|
version="1",
|
||
|
aliases=["champion"],
|
||
|
hyper_params={"learning_rate": "0.01"},
|
||
|
training_metrics={"accuracy": "0.9"},
|
||
|
)
|
||
|
client._emit_mcps(advanced_model.as_mcps())
|
||
|
return advanced_model
|
||
|
|
||
|
|
||
|
def create_basic_experiment():
|
||
|
"""Create a basic experiment."""
|
||
|
print("Creating basic experiment...")
|
||
|
basic_experiment = Container(
|
||
|
container_key=ContainerKey(platform="mlflow", name=basic_experiment_id),
|
||
|
display_name=basic_experiment_name,
|
||
|
)
|
||
|
client._emit_mcps(basic_experiment.as_mcps())
|
||
|
return basic_experiment
|
||
|
|
||
|
|
||
|
def create_advanced_experiment():
|
||
|
"""Create an advanced experiment."""
|
||
|
print("Creating advanced experiment...")
|
||
|
advanced_experiment = Container(
|
||
|
container_key=ContainerKey(platform="mlflow", name=advanced_experiment_id),
|
||
|
display_name=advanced_experiment_name,
|
||
|
description="Experiment to forecast airline passenger numbers",
|
||
|
extra_properties={"team": "forecasting"},
|
||
|
created=datetime(2025, 4, 9, 22, 30),
|
||
|
last_modified=datetime(2025, 4, 9, 22, 30),
|
||
|
subtype=MLAssetSubTypes.MLFLOW_EXPERIMENT,
|
||
|
)
|
||
|
client._emit_mcps(advanced_experiment.as_mcps())
|
||
|
return advanced_experiment
|
||
|
|
||
|
|
||
|
def create_basic_training_run():
|
||
|
"""Create a basic training run."""
|
||
|
print("Creating basic training run...")
|
||
|
basic_run_urn = client.create_training_run(
|
||
|
run_id=basic_run_id,
|
||
|
run_name=basic_run_name,
|
||
|
)
|
||
|
return basic_run_urn
|
||
|
|
||
|
|
||
|
def create_advanced_training_run():
|
||
|
"""Create an advanced training run."""
|
||
|
print("Creating advanced training run...")
|
||
|
advanced_run_urn = client.create_training_run(
|
||
|
run_id=advanced_run_id,
|
||
|
properties=DataProcessInstancePropertiesClass(
|
||
|
name=advanced_run_name,
|
||
|
created=AuditStampClass(
|
||
|
time=1628580000000, actor="urn:li:corpuser:datahub"
|
||
|
),
|
||
|
customProperties={"team": "forecasting"},
|
||
|
),
|
||
|
training_run_properties=MLTrainingRunPropertiesClass(
|
||
|
id=advanced_run_id,
|
||
|
outputUrls=["s3://my-bucket/output"],
|
||
|
trainingMetrics=[MLMetricClass(name="accuracy", value="0.9")],
|
||
|
hyperParams=[MLHyperParamClass(name="learning_rate", value="0.01")],
|
||
|
externalUrl="https:localhost:5000",
|
||
|
),
|
||
|
run_result=RunResultType.FAILURE,
|
||
|
start_timestamp=1628580000000,
|
||
|
end_timestamp=1628580001000,
|
||
|
)
|
||
|
return advanced_run_urn
|
||
|
|
||
|
|
||
|
def create_basic_dataset():
|
||
|
"""Create a basic dataset."""
|
||
|
print("Creating basic dataset...")
|
||
|
basic_input_dataset = Dataset(
|
||
|
platform="snowflake",
|
||
|
name=basic_dataset_id,
|
||
|
display_name=basic_dataset_name,
|
||
|
)
|
||
|
client._emit_mcps(basic_input_dataset.as_mcps())
|
||
|
return basic_input_dataset
|
||
|
|
||
|
|
||
|
def create_advanced_datasets():
|
||
|
"""Create advanced datasets."""
|
||
|
print("Creating advanced datasets...")
|
||
|
advanced_input_dataset = Dataset(
|
||
|
platform="snowflake",
|
||
|
name=advanced_input_dataset_id,
|
||
|
description="Raw Iris dataset used for training ML models",
|
||
|
schema=[("id", "number"), ("name", "string"), ("species", "string")],
|
||
|
display_name=advanced_input_dataset_name,
|
||
|
tags=["urn:li:tag:ml_data", "urn:li:tag:iris"],
|
||
|
terms=["urn:li:glossaryTerm:raw_data"],
|
||
|
owners=[CorpUserUrn("urn:li:corpuser:datahub")],
|
||
|
custom_properties={
|
||
|
"data_source": "UCI Repository",
|
||
|
"records": "150",
|
||
|
"features": "4",
|
||
|
},
|
||
|
)
|
||
|
client._emit_mcps(advanced_input_dataset.as_mcps())
|
||
|
|
||
|
advanced_output_dataset = Dataset(
|
||
|
platform="snowflake",
|
||
|
name=advanced_output_dataset_id,
|
||
|
description="Processed Iris dataset with model predictions",
|
||
|
schema=[("id", "number"), ("name", "string"), ("species", "string")],
|
||
|
display_name=advanced_output_dataset_name,
|
||
|
tags=["urn:li:tag:ml_data", "urn:li:tag:predictions"],
|
||
|
terms=["urn:li:glossaryTerm:model_output"],
|
||
|
owners=[CorpUserUrn("urn:li:corpuser:datahub")],
|
||
|
custom_properties={
|
||
|
"model_version": "1.0",
|
||
|
"records": "150",
|
||
|
"accuracy": "0.95",
|
||
|
},
|
||
|
)
|
||
|
client._emit_mcps(advanced_output_dataset.as_mcps())
|
||
|
return advanced_input_dataset, advanced_output_dataset
|
||
|
|
||
|
|
||
|
# Split relationship functions into individual top-level functions
|
||
|
def add_model_to_model_group(model, model_group):
|
||
|
"""Add model to model group relationship."""
|
||
|
print("Adding model to model group...")
|
||
|
model.set_model_group(model_group.urn)
|
||
|
client._emit_mcps(model.as_mcps())
|
||
|
|
||
|
|
||
|
def add_run_to_experiment(run_urn, experiment):
|
||
|
"""Add run to experiment relationship."""
|
||
|
print("Adding run to experiment...")
|
||
|
client.add_run_to_experiment(run_urn=run_urn, experiment_urn=str(experiment.urn))
|
||
|
|
||
|
|
||
|
def add_run_to_model(model, run_id):
|
||
|
"""Add run to model relationship."""
|
||
|
print("Adding run to model...")
|
||
|
model.add_training_job(DataProcessInstanceUrn(run_id))
|
||
|
client._emit_mcps(model.as_mcps())
|
||
|
|
||
|
|
||
|
def add_run_to_model_group(model_group, run_id):
|
||
|
"""Add run to model group relationship."""
|
||
|
print("Adding run to model group...")
|
||
|
model_group.add_training_job(DataProcessInstanceUrn(run_id))
|
||
|
client._emit_mcps(model_group.as_mcps())
|
||
|
|
||
|
|
||
|
def add_input_dataset_to_run(run_urn, input_dataset):
|
||
|
"""Add input dataset to run relationship."""
|
||
|
print("Adding input dataset to run...")
|
||
|
client.add_input_datasets_to_run(
|
||
|
run_urn=run_urn, dataset_urns=[str(input_dataset.urn)]
|
||
|
)
|
||
|
|
||
|
|
||
|
def add_output_dataset_to_run(run_urn, output_dataset):
|
||
|
"""Add output dataset to run relationship."""
|
||
|
print("Adding output dataset to run...")
|
||
|
client.add_output_datasets_to_run(
|
||
|
run_urn=run_urn, dataset_urns=[str(output_dataset.urn)]
|
||
|
)
|
||
|
|
||
|
|
||
|
def update_model_properties(model):
|
||
|
"""Update model properties."""
|
||
|
print("Updating model properties...")
|
||
|
|
||
|
# Update model version
|
||
|
model.set_version("2")
|
||
|
|
||
|
# Add tags and terms
|
||
|
model.add_tag(TagUrn("marketing"))
|
||
|
model.add_term(GlossaryTermUrn("marketing"))
|
||
|
|
||
|
# Add version alias
|
||
|
model.add_version_alias("challenger")
|
||
|
|
||
|
# Save the changes
|
||
|
client._emit_mcps(model.as_mcps())
|
||
|
|
||
|
|
||
|
def update_model_group_properties(model_group):
|
||
|
"""Update model group properties."""
|
||
|
print("Updating model group properties...")
|
||
|
|
||
|
# Update description
|
||
|
model_group.set_description("Updated description for airline forecast models")
|
||
|
|
||
|
# Add tags and terms
|
||
|
model_group.add_tag(TagUrn("production"))
|
||
|
model_group.add_term(GlossaryTermUrn("time-series"))
|
||
|
|
||
|
# Update custom properties
|
||
|
model_group.set_custom_properties(
|
||
|
{"team": "forecasting", "business_unit": "operations", "status": "active"}
|
||
|
)
|
||
|
|
||
|
# Save the changes
|
||
|
client._emit_mcps(model_group.as_mcps())
|
||
|
|
||
|
|
||
|
def update_experiment_properties():
|
||
|
"""Update experiment properties."""
|
||
|
print("Updating experiment properties...")
|
||
|
|
||
|
# Create a container object for the existing experiment
|
||
|
existing_experiment = Container(
|
||
|
container_key=ContainerKey(platform="mlflow", name=advanced_experiment_id),
|
||
|
display_name=advanced_experiment_name,
|
||
|
)
|
||
|
|
||
|
# Update properties
|
||
|
existing_experiment.set_description(
|
||
|
"Updated experiment for forecasting passenger numbers"
|
||
|
)
|
||
|
existing_experiment.add_tag(TagUrn("time-series"))
|
||
|
existing_experiment.add_term(GlossaryTermUrn("forecasting"))
|
||
|
existing_experiment.set_custom_properties(
|
||
|
{"team": "forecasting", "priority": "high", "status": "active"}
|
||
|
)
|
||
|
|
||
|
# Save the changes
|
||
|
client._emit_mcps(existing_experiment.as_mcps())
|
||
|
|
||
|
|
||
|
def main():
|
||
|
# Parse arguments
|
||
|
print("Creating AI assets...")
|
||
|
|
||
|
# Comment in/out the functions you want to run
|
||
|
# Create basic entities
|
||
|
create_basic_model_group()
|
||
|
create_basic_model()
|
||
|
create_basic_experiment()
|
||
|
create_basic_training_run()
|
||
|
create_basic_dataset()
|
||
|
|
||
|
# Create advanced entities
|
||
|
advanced_model_group = create_advanced_model_group()
|
||
|
advanced_model = create_advanced_model()
|
||
|
advanced_experiment = create_advanced_experiment()
|
||
|
advanced_run_urn = create_advanced_training_run()
|
||
|
advanced_input_dataset, advanced_output_dataset = create_advanced_datasets()
|
||
|
|
||
|
# # Create relationships - each can be commented out independently
|
||
|
add_model_to_model_group(advanced_model, advanced_model_group)
|
||
|
add_run_to_experiment(advanced_run_urn, advanced_experiment)
|
||
|
add_run_to_model(advanced_model, advanced_run_id)
|
||
|
add_run_to_model_group(advanced_model_group, advanced_run_id)
|
||
|
add_input_dataset_to_run(advanced_run_urn, advanced_input_dataset)
|
||
|
add_output_dataset_to_run(advanced_run_urn, advanced_output_dataset)
|
||
|
|
||
|
# # Update properties - each can be commented out independently
|
||
|
update_model_properties(advanced_model)
|
||
|
update_model_group_properties(advanced_model_group)
|
||
|
update_experiment_properties()
|
||
|
|
||
|
print("All done! AI entities created successfully.")
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|