mirror of
https://github.com/microsoft/autogen.git
synced 2025-09-13 18:25:42 +00:00

* add pipeline tuner component and dependencies. * clean code. * do not need force rerun. * replace the resources. * update metrics retrieving. * Update test/pipeline_tuning_example/requirements.txt * Update test/pipeline_tuning_example/train/env.yaml * Update test/pipeline_tuning_example/tuner/env.yaml * Update test/pipeline_tuning_example/tuner/tuner_func.py * Update test/pipeline_tuning_example/data_prep/env.yaml * fix issues found by lint with flake8. * add documentation * add data. * do not need AML resource for local run. * AML -> AzureML * clean code. * Update website/docs/Examples/Tune-AzureML pipeline.md * rename and add pip install. * update figure name. * align docs with code. * remove extra line.
73 lines
2.1 KiB
Python
73 lines
2.1 KiB
Python
import logging
|
|
from azureml.core import Workspace
|
|
from azure.ml.component import (
|
|
Component,
|
|
dsl,
|
|
)
|
|
import argparse
|
|
from pathlib import Path
|
|
|
|
LOCAL_DIR = Path(__file__).parent.absolute()
|
|
|
|
|
|
def remote_run():
|
|
################################################
|
|
# connect to your Azure ML workspace
|
|
################################################
|
|
ws = Workspace(subscription_id=args.subscription_id,
|
|
resource_group=args.resource_group,
|
|
workspace_name=args.workspace)
|
|
|
|
################################################
|
|
# load component functions
|
|
################################################
|
|
|
|
pipeline_tuning_func = Component.from_yaml(ws, yaml_file=LOCAL_DIR
|
|
/ "tuner/component_spec.yaml")
|
|
|
|
################################################
|
|
# build pipeline
|
|
################################################
|
|
@dsl.pipeline(
|
|
name="pipeline_tuning",
|
|
default_compute_target="cpucluster",
|
|
)
|
|
def sample_pipeline():
|
|
pipeline_tuning_func()
|
|
|
|
pipeline = sample_pipeline()
|
|
|
|
run = pipeline.submit(regenerate_outputs=False)
|
|
return run
|
|
|
|
|
|
def local_run():
|
|
logger.info("Run tuner locally.")
|
|
from tuner import tuner_func
|
|
tuner_func.tune_pipeline(concurrent_run=2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# parser argument
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_mutually_exclusive_group(required=False)
|
|
parser.add_argument(
|
|
"--subscription_id", type=str, help="your_subscription_id", required=False,
|
|
)
|
|
parser.add_argument(
|
|
"--resource_group", type=str, help="your_resource_group", required=False)
|
|
parser.add_argument(
|
|
"--workspace", type=str, help="your_workspace", required=False)
|
|
|
|
parser.add_argument('--remote', dest='remote', action='store_true')
|
|
parser.add_argument('--local', dest='remote', action='store_false')
|
|
parser.set_defaults(remote=True)
|
|
args = parser.parse_args()
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
if args.remote:
|
|
remote_run()
|
|
else:
|
|
local_run()
|