datahub/metadata-ingestion/examples/ai/dh_ai_client_sample.py

135 lines
4.6 KiB
Python

import argparse
from dh_ai_client import DatahubAIClient
import datahub.metadata.schema_classes as models
from datahub.metadata.com.linkedin.pegasus2avro.dataprocess import RunResultType
if __name__ == "__main__":
# Example usage
parser = argparse.ArgumentParser()
parser.add_argument("--token", required=False, help="DataHub access token")
parser.add_argument(
"--server_url",
required=False,
default="http://localhost:8080",
help="DataHub server URL (defaults to http://localhost:8080)",
)
args = parser.parse_args()
client = DatahubAIClient(token=args.token, server_url=args.server_url)
# Create model group
model_group_urn = client.create_model_group(
group_id="airline_forecast_models_group",
properties=models.MLModelGroupPropertiesClass(
name="Airline Forecast Models Group",
description="Group of models for airline passenger forecasting",
created=models.TimeStampClass(
time=1628580000000, actor="urn:li:corpuser:datahub"
),
),
)
# Creating a model with property classes
model_urn = client.create_model(
model_id="arima_model",
properties=models.MLModelPropertiesClass(
name="ARIMA Model",
description="ARIMA model for airline passenger forecasting",
customProperties={"team": "forecasting"},
trainingMetrics=[
models.MLMetricClass(name="accuracy", value="0.9"),
models.MLMetricClass(name="precision", value="0.8"),
],
hyperParams=[
models.MLHyperParamClass(name="learning_rate", value="0.01"),
models.MLHyperParamClass(name="batch_size", value="32"),
],
externalUrl="https:localhost:5000",
created=models.TimeStampClass(
time=1628580000000, actor="urn:li:corpuser:datahub"
),
lastModified=models.TimeStampClass(
time=1628580000000, actor="urn:li:corpuser:datahub"
),
tags=["forecasting", "arima"],
),
version="1.0",
alias="champion",
)
# Creating an experiment with property class
experiment_urn = client.create_experiment(
experiment_id="airline_forecast_experiment",
properties=models.ContainerPropertiesClass(
name="Airline Forecast Experiment",
description="Experiment to forecast airline passenger numbers",
customProperties={"team": "forecasting"},
created=models.TimeStampClass(
time=1628580000000, actor="urn:li:corpuser:datahub"
),
lastModified=models.TimeStampClass(
time=1628580000000, actor="urn:li:corpuser:datahub"
),
),
)
run_urn = client.create_training_run(
run_id="simple_training_run",
properties=models.DataProcessInstancePropertiesClass(
name="Simple Training Run",
created=models.AuditStampClass(
time=1628580000000, actor="urn:li:corpuser:datahub"
),
customProperties={"team": "forecasting"},
),
training_run_properties=models.MLTrainingRunPropertiesClass(
id="simple_training_run",
outputUrls=["s3://my-bucket/output"],
trainingMetrics=[models.MLMetricClass(name="accuracy", value="0.9")],
hyperParams=[models.MLHyperParamClass(name="learning_rate", value="0.01")],
externalUrl="https:localhost:5000",
),
run_result=RunResultType.FAILURE,
start_timestamp=1628580000000,
end_timestamp=1628580001000,
)
# Create datasets
input_dataset_urn = client.create_dataset(
platform="snowflake",
name="iris_input",
)
output_dataset_urn = client.create_dataset(
platform="snowflake",
name="iris_ouptut",
)
# Add run to experiment
client.add_run_to_experiment(run_urn=run_urn, experiment_urn=experiment_urn)
# Add model to model group
client.add_model_to_model_group(model_urn=model_urn, group_urn=model_group_urn)
# Add run to model
client.add_run_to_model(
model_urn=model_urn,
run_urn=run_urn,
)
# add run to model group
client.add_run_to_model_group(
model_group_urn=model_group_urn,
run_urn=run_urn,
)
# Add input and output datasets to run
client.add_input_datasets_to_run(
run_urn=run_urn, dataset_urns=[str(input_dataset_urn)]
)
client.add_output_datasets_to_run(
run_urn=run_urn, dataset_urns=[str(output_dataset_urn)]
)