mirror of
https://github.com/microsoft/autogen.git
synced 2025-09-25 16:16:37 +00:00
362 lines
11 KiB
Python
362 lines
11 KiB
Python
"""Require: pip install torchvision ray flaml[blendsearch]
|
|
"""
|
|
import os
|
|
import time
|
|
import numpy as np
|
|
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
os.makedirs("logs", exist_ok=True)
|
|
logger.addHandler(logging.FileHandler("logs/tune_pytorch_cifar10.log"))
|
|
logger.setLevel(logging.INFO)
|
|
|
|
|
|
try:
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
from torch.utils.data import random_split
|
|
import torchvision
|
|
import torchvision.transforms as transforms
|
|
|
|
# __net_begin__
|
|
class Net(nn.Module):
|
|
def __init__(self, l1=120, l2=84):
|
|
super(Net, self).__init__()
|
|
self.conv1 = nn.Conv2d(3, 6, 5)
|
|
self.pool = nn.MaxPool2d(2, 2)
|
|
self.conv2 = nn.Conv2d(6, 16, 5)
|
|
self.fc1 = nn.Linear(16 * 5 * 5, l1)
|
|
self.fc2 = nn.Linear(l1, l2)
|
|
self.fc3 = nn.Linear(l2, 10)
|
|
|
|
def forward(self, x):
|
|
x = self.pool(F.relu(self.conv1(x)))
|
|
x = self.pool(F.relu(self.conv2(x)))
|
|
x = x.view(-1, 16 * 5 * 5)
|
|
x = F.relu(self.fc1(x))
|
|
x = F.relu(self.fc2(x))
|
|
x = self.fc3(x)
|
|
return x
|
|
|
|
# __net_end__
|
|
except ImportError:
|
|
print("skip test_pytorch because torchvision cannot be imported.")
|
|
|
|
|
|
# __load_data_begin__
|
|
def load_data(data_dir="test/data"):
|
|
transform = transforms.Compose(
|
|
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
|
)
|
|
|
|
trainset = torchvision.datasets.CIFAR10(
|
|
root=data_dir, train=True, download=True, transform=transform
|
|
)
|
|
|
|
testset = torchvision.datasets.CIFAR10(
|
|
root=data_dir, train=False, download=True, transform=transform
|
|
)
|
|
|
|
return trainset, testset
|
|
|
|
|
|
# __load_data_end__
|
|
|
|
|
|
# __train_begin__
|
|
def train_cifar(config, checkpoint_dir=None, data_dir=None):
|
|
if "l1" not in config:
|
|
logger.warning(config)
|
|
net = Net(2 ** config["l1"], 2 ** config["l2"])
|
|
|
|
device = "cpu"
|
|
if torch.cuda.is_available():
|
|
device = "cuda:0"
|
|
if torch.cuda.device_count() > 1:
|
|
net = nn.DataParallel(net)
|
|
net.to(device)
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)
|
|
|
|
# The `checkpoint_dir` parameter gets passed by Ray Tune when a checkpoint
|
|
# should be restored.
|
|
if checkpoint_dir:
|
|
checkpoint = os.path.join(checkpoint_dir, "checkpoint")
|
|
model_state, optimizer_state = torch.load(checkpoint)
|
|
net.load_state_dict(model_state)
|
|
optimizer.load_state_dict(optimizer_state)
|
|
|
|
trainset, testset = load_data(data_dir)
|
|
|
|
test_abs = int(len(trainset) * 0.8)
|
|
train_subset, val_subset = random_split(
|
|
trainset, [test_abs, len(trainset) - test_abs]
|
|
)
|
|
|
|
trainloader = torch.utils.data.DataLoader(
|
|
train_subset,
|
|
batch_size=int(2 ** config["batch_size"]),
|
|
shuffle=True,
|
|
num_workers=4,
|
|
)
|
|
valloader = torch.utils.data.DataLoader(
|
|
val_subset,
|
|
batch_size=int(2 ** config["batch_size"]),
|
|
shuffle=True,
|
|
num_workers=4,
|
|
)
|
|
|
|
from ray import tune
|
|
|
|
for epoch in range(
|
|
int(round(config["num_epochs"]))
|
|
): # loop over the dataset multiple times
|
|
running_loss = 0.0
|
|
epoch_steps = 0
|
|
for i, data in enumerate(trainloader, 0):
|
|
# get the inputs; data is a list of [inputs, labels]
|
|
inputs, labels = data
|
|
inputs, labels = inputs.to(device), labels.to(device)
|
|
|
|
# zero the parameter gradients
|
|
optimizer.zero_grad()
|
|
|
|
# forward + backward + optimize
|
|
outputs = net(inputs)
|
|
loss = criterion(outputs, labels)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
# print statistics
|
|
running_loss += loss.item()
|
|
epoch_steps += 1
|
|
if i % 2000 == 1999: # print every 2000 mini-batches
|
|
print(
|
|
"[%d, %5d] loss: %.3f"
|
|
% (epoch + 1, i + 1, running_loss / epoch_steps)
|
|
)
|
|
running_loss = 0.0
|
|
|
|
# Validation loss
|
|
val_loss = 0.0
|
|
val_steps = 0
|
|
total = 0
|
|
correct = 0
|
|
for i, data in enumerate(valloader, 0):
|
|
with torch.no_grad():
|
|
inputs, labels = data
|
|
inputs, labels = inputs.to(device), labels.to(device)
|
|
|
|
outputs = net(inputs)
|
|
_, predicted = torch.max(outputs.data, 1)
|
|
total += labels.size(0)
|
|
correct += (predicted == labels).sum().item()
|
|
|
|
loss = criterion(outputs, labels)
|
|
val_loss += loss.cpu().numpy()
|
|
val_steps += 1
|
|
|
|
# Here we save a checkpoint. It is automatically registered with
|
|
# Ray Tune and will potentially be passed as the `checkpoint_dir`
|
|
# parameter in future iterations.
|
|
with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
|
|
path = os.path.join(checkpoint_dir, "checkpoint")
|
|
torch.save((net.state_dict(), optimizer.state_dict()), path)
|
|
|
|
tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
|
|
print("Finished Training")
|
|
|
|
|
|
# __train_end__
|
|
|
|
|
|
# __test_acc_begin__
|
|
def _test_accuracy(net, device="cpu"):
|
|
trainset, testset = load_data()
|
|
|
|
testloader = torch.utils.data.DataLoader(
|
|
testset, batch_size=4, shuffle=False, num_workers=2
|
|
)
|
|
|
|
correct = 0
|
|
total = 0
|
|
with torch.no_grad():
|
|
for data in testloader:
|
|
images, labels = data
|
|
images, labels = images.to(device), labels.to(device)
|
|
outputs = net(images)
|
|
_, predicted = torch.max(outputs.data, 1)
|
|
total += labels.size(0)
|
|
correct += (predicted == labels).sum().item()
|
|
|
|
return correct / total
|
|
|
|
|
|
# __test_acc_end__
|
|
|
|
|
|
# __main_begin__
|
|
def cifar10_main(
|
|
method="BlendSearch", num_samples=10, max_num_epochs=100, gpus_per_trial=1
|
|
):
|
|
data_dir = os.path.abspath("test/data")
|
|
load_data(data_dir) # Download data for all trials before starting the run
|
|
if method == "BlendSearch":
|
|
from flaml import tune
|
|
else:
|
|
from ray import tune
|
|
if method in ["BOHB"]:
|
|
config = {
|
|
"l1": tune.randint(2, 8),
|
|
"l2": tune.randint(2, 8),
|
|
"lr": tune.loguniform(1e-4, 1e-1),
|
|
"num_epochs": tune.qloguniform(1, max_num_epochs, q=1),
|
|
"batch_size": tune.randint(1, 4),
|
|
}
|
|
else:
|
|
config = {
|
|
"l1": tune.randint(2, 9),
|
|
"l2": tune.randint(2, 9),
|
|
"lr": tune.loguniform(1e-4, 1e-1),
|
|
"num_epochs": tune.loguniform(1, max_num_epochs),
|
|
"batch_size": tune.randint(1, 5),
|
|
}
|
|
import ray
|
|
|
|
time_budget_s = 600
|
|
np.random.seed(7654321)
|
|
start_time = time.time()
|
|
if method == "BlendSearch":
|
|
result = tune.run(
|
|
ray.tune.with_parameters(train_cifar, data_dir=data_dir),
|
|
config=config,
|
|
metric="loss",
|
|
mode="min",
|
|
low_cost_partial_config={"num_epochs": 1},
|
|
max_resource=max_num_epochs,
|
|
min_resource=1,
|
|
scheduler="asha",
|
|
resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
|
|
local_dir="logs/",
|
|
num_samples=num_samples,
|
|
time_budget_s=time_budget_s,
|
|
use_ray=True,
|
|
)
|
|
else:
|
|
if "ASHA" == method:
|
|
algo = None
|
|
elif "BOHB" == method:
|
|
from ray.tune.schedulers import HyperBandForBOHB
|
|
from ray.tune.suggest.bohb import TuneBOHB
|
|
|
|
algo = TuneBOHB()
|
|
scheduler = HyperBandForBOHB(max_t=max_num_epochs)
|
|
elif "Optuna" == method:
|
|
from ray.tune.suggest.optuna import OptunaSearch
|
|
|
|
algo = OptunaSearch(seed=10)
|
|
elif "CFO" == method:
|
|
from flaml import CFO
|
|
|
|
algo = CFO(
|
|
low_cost_partial_config={
|
|
"num_epochs": 1,
|
|
}
|
|
)
|
|
elif "Nevergrad" == method:
|
|
from ray.tune.suggest.nevergrad import NevergradSearch
|
|
import nevergrad as ng
|
|
|
|
algo = NevergradSearch(optimizer=ng.optimizers.OnePlusOne)
|
|
if method != "BOHB":
|
|
from ray.tune.schedulers import ASHAScheduler
|
|
|
|
scheduler = ASHAScheduler(max_t=max_num_epochs, grace_period=1)
|
|
result = tune.run(
|
|
tune.with_parameters(train_cifar, data_dir=data_dir),
|
|
resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
|
|
config=config,
|
|
metric="loss",
|
|
mode="min",
|
|
num_samples=num_samples,
|
|
time_budget_s=time_budget_s,
|
|
scheduler=scheduler,
|
|
search_alg=algo,
|
|
)
|
|
ray.shutdown()
|
|
|
|
logger.info(f"method={method}")
|
|
logger.info(f"#trials={len(result.trials)}")
|
|
logger.info(f"time={time.time()-start_time}")
|
|
best_trial = result.get_best_trial("loss", "min", "all")
|
|
logger.info("Best trial config: {}".format(best_trial.config))
|
|
logger.info(
|
|
"Best trial final validation loss: {}".format(
|
|
best_trial.metric_analysis["loss"]["min"]
|
|
)
|
|
)
|
|
logger.info(
|
|
"Best trial final validation accuracy: {}".format(
|
|
best_trial.metric_analysis["accuracy"]["max"]
|
|
)
|
|
)
|
|
|
|
best_trained_model = Net(2 ** best_trial.config["l1"], 2 ** best_trial.config["l2"])
|
|
device = "cpu"
|
|
if torch.cuda.is_available():
|
|
device = "cuda:0"
|
|
if gpus_per_trial > 1:
|
|
best_trained_model = nn.DataParallel(best_trained_model)
|
|
best_trained_model.to(device)
|
|
|
|
checkpoint_value = (
|
|
getattr(best_trial.checkpoint, "dir_or_data", None)
|
|
or best_trial.checkpoint.value
|
|
)
|
|
checkpoint_path = os.path.join(checkpoint_value, "checkpoint")
|
|
|
|
model_state, optimizer_state = torch.load(checkpoint_path)
|
|
best_trained_model.load_state_dict(model_state)
|
|
|
|
test_acc = _test_accuracy(best_trained_model, device)
|
|
logger.info("Best trial test set accuracy: {}".format(test_acc))
|
|
|
|
|
|
# __main_end__
|
|
|
|
|
|
gpus_per_trial = 0.5 # on GPU server
|
|
num_samples = 500
|
|
|
|
|
|
def _test_cifar10_bs():
|
|
cifar10_main(num_samples=num_samples, gpus_per_trial=gpus_per_trial)
|
|
|
|
|
|
def _test_cifar10_cfo():
|
|
cifar10_main("CFO", num_samples=num_samples, gpus_per_trial=gpus_per_trial)
|
|
|
|
|
|
def _test_cifar10_optuna():
|
|
cifar10_main("Optuna", num_samples=num_samples, gpus_per_trial=gpus_per_trial)
|
|
|
|
|
|
def _test_cifar10_asha():
|
|
cifar10_main("ASHA", num_samples=num_samples, gpus_per_trial=gpus_per_trial)
|
|
|
|
|
|
def _test_cifar10_bohb():
|
|
cifar10_main("BOHB", num_samples=num_samples, gpus_per_trial=gpus_per_trial)
|
|
|
|
|
|
def _test_cifar10_nevergrad():
|
|
cifar10_main("Nevergrad", num_samples=num_samples, gpus_per_trial=gpus_per_trial)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
_test_cifar10_bs()
|