autogen/test/pipeline_tuning_example/submit_train_pipeline.py
Jirka Borovec a701cd82f8
set black with 120 line length (#975)
* set black with 120 line length

* apply pre-commit

* apply black
2023-04-10 19:50:40 +00:00

126 lines
3.5 KiB
Python

from dataclasses import dataclass
from pathlib import Path
import azureml.core
from azureml.core import Workspace, Dataset, Run
from azure.ml.component import (
Component,
dsl,
)
import hydra
from hydra.core.config_store import ConfigStore
from hydra.utils import to_absolute_path
@dataclass
class AMLConfig:
subscription_id: str
resource_group: str
workspace: str
@dataclass
class TrainConfig:
exp_name: str
data_path: str
test_train_ratio: float
learning_rate: float
n_estimators: int
@dataclass
class PipelineConfig:
aml_config: AMLConfig
train_config: TrainConfig
LOCAL_DIR = Path(__file__).parent.absolute()
TARGET_DATA_DIR = "classification_data"
cs = ConfigStore.instance()
cs.store(name="config", node=PipelineConfig)
@hydra.main(config_path="configs", config_name="train_config")
def main(config: PipelineConfig):
build_and_submit_aml_pipeline(config)
def build_and_submit_aml_pipeline(config):
"""This function can be called from Python
while the main function is meant for CLI only.
When calling the main function in Python,
there is error due to the hydra.main decorator
"""
if isinstance(config, list):
with hydra.initialize(config_path="configs"):
config = hydra.compose(config_name="train_config", overrides=config)
################################################
# connect to your Azure ML workspace
################################################
if isinstance(Run.get_context(), azureml.core.run._OfflineRun):
ws = Workspace(
subscription_id=config.aml_config.subscription_id,
resource_group=config.aml_config.resource_group,
workspace_name=config.aml_config.workspace_name,
)
else:
ws = Run.get_context().experiment.workspace
################################################
# load input datasets:
################################################
datastore = ws.get_default_datastore()
Dataset.File.upload_directory(
src_dir=to_absolute_path(LOCAL_DIR / "data"),
target=(datastore, TARGET_DATA_DIR),
overwrite=True,
)
dataset = Dataset.File.from_files(path=(datastore, TARGET_DATA_DIR))
################################################
# load component functions
################################################
data_prep_component = Component.from_yaml(ws, yaml_file=LOCAL_DIR / "data_prep/data_prep.yaml")
train_component = Component.from_yaml(ws, yaml_file=LOCAL_DIR / "train/train.yaml")
################################################
# build pipeline
################################################
# TODO: update the pipeline
@dsl.pipeline(
default_compute_target="cpucluster",
)
def train_pipeline():
data_prep_job = data_prep_component(
data=dataset,
test_train_ratio=config.train_config.test_train_ratio,
)
train_component(
train_data=data_prep_job.outputs.train_data,
test_data=data_prep_job.outputs.test_data,
learning_rate=config.train_config.learning_rate,
n_estimators=config.train_config.n_estimators,
)
return
pipeline = train_pipeline()
tags = {
"n_estimators": str(config.train_config.n_estimators),
"learning_rate": str(config.train_config.learning_rate),
}
# submit the pipeline
run = pipeline.submit(tags=tags, regenerate_outputs=False)
return run
if __name__ == "__main__":
main()