datahub/metadata-ingestion/examples/ai/dh_sdk_client_sample.py

106 lines
3.2 KiB
Python
Raw Normal View History

from datetime import datetime
from datahub.emitter.mcp_builder import (
ContainerKey,
)
from datahub.ingestion.source.common.subtypes import MLAssetSubTypes
from datahub.metadata.urns import (
CorpUserUrn,
GlossaryTermUrn,
TagUrn,
)
from datahub.sdk.container import Container
from datahub.sdk.dataset import Dataset
from datahub.sdk.main_client import DataHubClient
from datahub.sdk.mlmodel import MLModel
from datahub.sdk.mlmodelgroup import MLModelGroup
run_id = "simple_training_run"
run_name = "Simple Training Run"
experiment_id = "airline_forecast_experiment"
experiment_name = "Airline Forecast Experiment"
model_id = "arima_model"
model_name = "ARIMA Model"
model_group_id = "airline_forecast_models_group"
model_group_name = "Airline Forecast Models Group"
if __name__ == "__main__":
client = DataHubClient.from_env()
# Create model group
model_group = MLModelGroup(
id=model_group_id,
platform="mlflow",
name=model_group_name,
description="Group of models for airline passenger forecasting",
created=datetime.now(),
last_modified=datetime.now(),
owners=[CorpUserUrn("urn:li:corpuser:datahub")],
external_url="https://www.linkedin.com/in/datahub",
tags=["urn:li:tag:forecasting", "urn:li:tag:arima"],
terms=["urn:li:glossaryTerm:forecasting"],
custom_properties={"team": "forecasting"},
)
# Create model
model = MLModel(
id=model_id,
platform="mlflow",
name=model_name,
description="ARIMA model for airline passenger forecasting",
created=datetime.now(),
last_modified=datetime.now(),
owners=[CorpUserUrn("urn:li:corpuser:datahub")],
external_url="https://www.linkedin.com/in/datahub",
tags=["urn:li:tag:forecasting", "urn:li:tag:arima"],
terms=["urn:li:glossaryTerm:forecasting"],
custom_properties={"team": "forecasting"},
version="1",
aliases=["champion"],
# group=str(model_group.urn),
hyper_params={"learning_rate": "0.01"},
training_metrics={"accuracy": "0.9"},
)
# create experiment
experiment = Container(
container_key=ContainerKey(
platform="mlflow",
name=experiment_id,
),
display_name=experiment_name,
description="Experiment to forecast airline passenger numbers",
extra_properties={"team": "forecasting"},
created=datetime(2025, 4, 9, 22, 30),
last_modified=datetime(2025, 4, 9, 22, 30),
subtype=MLAssetSubTypes.MLFLOW_EXPERIMENT,
)
client.entities.upsert(experiment)
# Create datasets
input_dataset = Dataset(
platform="snowflake",
name="iris_input",
)
client.entities.upsert(input_dataset)
output_dataset = Dataset(
platform="snowflake",
name="iris_output",
)
client.entities.upsert(output_dataset)
model.set_model_group(model_group.urn)
model.add_version_alias("challenger")
model.add_term(GlossaryTermUrn("marketing"))
model.add_tag(TagUrn("marketing"))
model.set_version("2")
client.entities.upsert(model)
client.entities.upsert(model_group)