mirror of
https://github.com/datahub-project/datahub.git
synced 2025-08-17 13:45:54 +00:00
135 lines
4.6 KiB
Python
135 lines
4.6 KiB
Python
import argparse
|
|
|
|
from dh_ai_client import DatahubAIClient
|
|
|
|
import datahub.metadata.schema_classes as models
|
|
from datahub.metadata.com.linkedin.pegasus2avro.dataprocess import RunResultType
|
|
|
|
if __name__ == "__main__":
|
|
# Example usage
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--token", required=False, help="DataHub access token")
|
|
parser.add_argument(
|
|
"--server_url",
|
|
required=False,
|
|
default="http://localhost:8080",
|
|
help="DataHub server URL (defaults to http://localhost:8080)",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
client = DatahubAIClient(token=args.token, server_url=args.server_url)
|
|
|
|
# Create model group
|
|
model_group_urn = client.create_model_group(
|
|
group_id="airline_forecast_models_group",
|
|
properties=models.MLModelGroupPropertiesClass(
|
|
name="Airline Forecast Models Group",
|
|
description="Group of models for airline passenger forecasting",
|
|
created=models.TimeStampClass(
|
|
time=1628580000000, actor="urn:li:corpuser:datahub"
|
|
),
|
|
),
|
|
)
|
|
|
|
# Creating a model with property classes
|
|
model_urn = client.create_model(
|
|
model_id="arima_model",
|
|
properties=models.MLModelPropertiesClass(
|
|
name="ARIMA Model",
|
|
description="ARIMA model for airline passenger forecasting",
|
|
customProperties={"team": "forecasting"},
|
|
trainingMetrics=[
|
|
models.MLMetricClass(name="accuracy", value="0.9"),
|
|
models.MLMetricClass(name="precision", value="0.8"),
|
|
],
|
|
hyperParams=[
|
|
models.MLHyperParamClass(name="learning_rate", value="0.01"),
|
|
models.MLHyperParamClass(name="batch_size", value="32"),
|
|
],
|
|
externalUrl="https:localhost:5000",
|
|
created=models.TimeStampClass(
|
|
time=1628580000000, actor="urn:li:corpuser:datahub"
|
|
),
|
|
lastModified=models.TimeStampClass(
|
|
time=1628580000000, actor="urn:li:corpuser:datahub"
|
|
),
|
|
tags=["forecasting", "arima"],
|
|
),
|
|
version="1.0",
|
|
alias="champion",
|
|
)
|
|
|
|
# Creating an experiment with property class
|
|
experiment_urn = client.create_experiment(
|
|
experiment_id="airline_forecast_experiment",
|
|
properties=models.ContainerPropertiesClass(
|
|
name="Airline Forecast Experiment",
|
|
description="Experiment to forecast airline passenger numbers",
|
|
customProperties={"team": "forecasting"},
|
|
created=models.TimeStampClass(
|
|
time=1628580000000, actor="urn:li:corpuser:datahub"
|
|
),
|
|
lastModified=models.TimeStampClass(
|
|
time=1628580000000, actor="urn:li:corpuser:datahub"
|
|
),
|
|
),
|
|
)
|
|
|
|
run_urn = client.create_training_run(
|
|
run_id="simple_training_run",
|
|
properties=models.DataProcessInstancePropertiesClass(
|
|
name="Simple Training Run",
|
|
created=models.AuditStampClass(
|
|
time=1628580000000, actor="urn:li:corpuser:datahub"
|
|
),
|
|
customProperties={"team": "forecasting"},
|
|
),
|
|
training_run_properties=models.MLTrainingRunPropertiesClass(
|
|
id="simple_training_run",
|
|
outputUrls=["s3://my-bucket/output"],
|
|
trainingMetrics=[models.MLMetricClass(name="accuracy", value="0.9")],
|
|
hyperParams=[models.MLHyperParamClass(name="learning_rate", value="0.01")],
|
|
externalUrl="https:localhost:5000",
|
|
),
|
|
run_result=RunResultType.FAILURE,
|
|
start_timestamp=1628580000000,
|
|
end_timestamp=1628580001000,
|
|
)
|
|
# Create datasets
|
|
input_dataset_urn = client.create_dataset(
|
|
platform="snowflake",
|
|
name="iris_input",
|
|
)
|
|
|
|
output_dataset_urn = client.create_dataset(
|
|
platform="snowflake",
|
|
name="iris_ouptut",
|
|
)
|
|
|
|
# Add run to experiment
|
|
client.add_run_to_experiment(run_urn=run_urn, experiment_urn=experiment_urn)
|
|
|
|
# Add model to model group
|
|
client.add_model_to_model_group(model_urn=model_urn, group_urn=model_group_urn)
|
|
|
|
# Add run to model
|
|
client.add_run_to_model(
|
|
model_urn=model_urn,
|
|
run_urn=run_urn,
|
|
)
|
|
|
|
# add run to model group
|
|
client.add_run_to_model_group(
|
|
model_group_urn=model_group_urn,
|
|
run_urn=run_urn,
|
|
)
|
|
|
|
# Add input and output datasets to run
|
|
client.add_input_datasets_to_run(
|
|
run_urn=run_urn, dataset_urns=[str(input_dataset_urn)]
|
|
)
|
|
|
|
client.add_output_datasets_to_run(
|
|
run_urn=run_urn, dataset_urns=[str(output_dataset_urn)]
|
|
)
|