mirror of
https://github.com/microsoft/autogen.git
synced 2025-07-26 18:31:36 +00:00
352 lines
11 KiB
Python
352 lines
11 KiB
Python
![]() |
import unittest
|
||
|
import os
|
||
|
import time
|
||
|
|
||
|
import logging
|
||
|
logger = logging.getLogger(__name__)
|
||
|
logger.addHandler(logging.FileHandler('test/tune_pytorch_cifar10.log'))
|
||
|
|
||
|
|
||
|
# __load_data_begin__
|
||
|
def load_data(data_dir="./data"):
|
||
|
transform = transforms.Compose([
|
||
|
transforms.ToTensor(),
|
||
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||
|
])
|
||
|
|
||
|
trainset = torchvision.datasets.CIFAR10(
|
||
|
root=data_dir, train=True, download=True, transform=transform)
|
||
|
|
||
|
testset = torchvision.datasets.CIFAR10(
|
||
|
root=data_dir, train=False, download=True, transform=transform)
|
||
|
|
||
|
return trainset, testset
|
||
|
# __load_data_end__
|
||
|
|
||
|
|
||
|
import numpy as np
|
||
|
try:
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
import torch.optim as optim
|
||
|
from torch.utils.data import random_split
|
||
|
import torchvision
|
||
|
import torchvision.transforms as transforms
|
||
|
|
||
|
|
||
|
# __net_begin__
|
||
|
class Net(nn.Module):
|
||
|
def __init__(self, l1=120, l2=84):
|
||
|
super(Net, self).__init__()
|
||
|
self.conv1 = nn.Conv2d(3, 6, 5)
|
||
|
self.pool = nn.MaxPool2d(2, 2)
|
||
|
self.conv2 = nn.Conv2d(6, 16, 5)
|
||
|
self.fc1 = nn.Linear(16 * 5 * 5, l1)
|
||
|
self.fc2 = nn.Linear(l1, l2)
|
||
|
self.fc3 = nn.Linear(l2, 10)
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = self.pool(F.relu(self.conv1(x)))
|
||
|
x = self.pool(F.relu(self.conv2(x)))
|
||
|
x = x.view(-1, 16 * 5 * 5)
|
||
|
x = F.relu(self.fc1(x))
|
||
|
x = F.relu(self.fc2(x))
|
||
|
x = self.fc3(x)
|
||
|
return x
|
||
|
# __net_end__
|
||
|
except ImportError:
|
||
|
print("skip test_pytorch because torchvision cannot be imported.")
|
||
|
|
||
|
|
||
|
# __load_data_begin__
|
||
|
def load_data(data_dir="test/data"):
|
||
|
transform = transforms.Compose([
|
||
|
transforms.ToTensor(),
|
||
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||
|
])
|
||
|
|
||
|
trainset = torchvision.datasets.CIFAR10(
|
||
|
root=data_dir, train=True, download=True, transform=transform)
|
||
|
|
||
|
testset = torchvision.datasets.CIFAR10(
|
||
|
root=data_dir, train=False, download=True, transform=transform)
|
||
|
|
||
|
return trainset, testset
|
||
|
# __load_data_end__
|
||
|
|
||
|
|
||
|
# __train_begin__
|
||
|
def train_cifar(config, checkpoint_dir=None, data_dir=None):
|
||
|
if not "l1" in config:
|
||
|
logger.warning(config)
|
||
|
net = Net(2 ** config["l1"], 2 ** config["l2"])
|
||
|
|
||
|
device = "cpu"
|
||
|
if torch.cuda.is_available():
|
||
|
device = "cuda:0"
|
||
|
if torch.cuda.device_count() > 1:
|
||
|
net = nn.DataParallel(net)
|
||
|
net.to(device)
|
||
|
|
||
|
criterion = nn.CrossEntropyLoss()
|
||
|
optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)
|
||
|
|
||
|
# The `checkpoint_dir` parameter gets passed by Ray Tune when a checkpoint
|
||
|
# should be restored.
|
||
|
if checkpoint_dir:
|
||
|
checkpoint = os.path.join(checkpoint_dir, "checkpoint")
|
||
|
model_state, optimizer_state = torch.load(checkpoint)
|
||
|
net.load_state_dict(model_state)
|
||
|
optimizer.load_state_dict(optimizer_state)
|
||
|
|
||
|
trainset, testset = load_data(data_dir)
|
||
|
|
||
|
test_abs = int(len(trainset) * 0.8)
|
||
|
train_subset, val_subset = random_split(
|
||
|
trainset, [test_abs, len(trainset) - test_abs])
|
||
|
|
||
|
trainloader = torch.utils.data.DataLoader(
|
||
|
train_subset,
|
||
|
batch_size=int(2**config["batch_size"]),
|
||
|
shuffle=True,
|
||
|
num_workers=4)
|
||
|
valloader = torch.utils.data.DataLoader(
|
||
|
val_subset,
|
||
|
batch_size=int(2**config["batch_size"]),
|
||
|
shuffle=True,
|
||
|
num_workers=4)
|
||
|
|
||
|
from ray import tune
|
||
|
|
||
|
for epoch in range(int(round(config["num_epochs"]))): # loop over the dataset multiple times
|
||
|
running_loss = 0.0
|
||
|
epoch_steps = 0
|
||
|
for i, data in enumerate(trainloader, 0):
|
||
|
# get the inputs; data is a list of [inputs, labels]
|
||
|
inputs, labels = data
|
||
|
inputs, labels = inputs.to(device), labels.to(device)
|
||
|
|
||
|
# zero the parameter gradients
|
||
|
optimizer.zero_grad()
|
||
|
|
||
|
# forward + backward + optimize
|
||
|
outputs = net(inputs)
|
||
|
loss = criterion(outputs, labels)
|
||
|
loss.backward()
|
||
|
optimizer.step()
|
||
|
|
||
|
# print statistics
|
||
|
running_loss += loss.item()
|
||
|
epoch_steps += 1
|
||
|
if i % 2000 == 1999: # print every 2000 mini-batches
|
||
|
print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1,
|
||
|
running_loss / epoch_steps))
|
||
|
running_loss = 0.0
|
||
|
|
||
|
# Validation loss
|
||
|
val_loss = 0.0
|
||
|
val_steps = 0
|
||
|
total = 0
|
||
|
correct = 0
|
||
|
for i, data in enumerate(valloader, 0):
|
||
|
with torch.no_grad():
|
||
|
inputs, labels = data
|
||
|
inputs, labels = inputs.to(device), labels.to(device)
|
||
|
|
||
|
outputs = net(inputs)
|
||
|
_, predicted = torch.max(outputs.data, 1)
|
||
|
total += labels.size(0)
|
||
|
correct += (predicted == labels).sum().item()
|
||
|
|
||
|
loss = criterion(outputs, labels)
|
||
|
val_loss += loss.cpu().numpy()
|
||
|
val_steps += 1
|
||
|
|
||
|
# Here we save a checkpoint. It is automatically registered with
|
||
|
# Ray Tune and will potentially be passed as the `checkpoint_dir`
|
||
|
# parameter in future iterations.
|
||
|
with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
|
||
|
path = os.path.join(checkpoint_dir, "checkpoint")
|
||
|
torch.save(
|
||
|
(net.state_dict(), optimizer.state_dict()), path)
|
||
|
|
||
|
tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
|
||
|
print("Finished Training")
|
||
|
# __train_end__
|
||
|
|
||
|
|
||
|
# __test_acc_begin__
|
||
|
def _test_accuracy(net, device="cpu"):
|
||
|
trainset, testset = load_data()
|
||
|
|
||
|
testloader = torch.utils.data.DataLoader(
|
||
|
testset, batch_size=4, shuffle=False, num_workers=2)
|
||
|
|
||
|
correct = 0
|
||
|
total = 0
|
||
|
with torch.no_grad():
|
||
|
for data in testloader:
|
||
|
images, labels = data
|
||
|
images, labels = images.to(device), labels.to(device)
|
||
|
outputs = net(images)
|
||
|
_, predicted = torch.max(outputs.data, 1)
|
||
|
total += labels.size(0)
|
||
|
correct += (predicted == labels).sum().item()
|
||
|
|
||
|
return correct / total
|
||
|
# __test_acc_end__
|
||
|
|
||
|
|
||
|
# __main_begin__
|
||
|
def cifar10_main(method='BlendSearch', num_samples=10, max_num_epochs=100,
|
||
|
gpus_per_trial=2):
|
||
|
data_dir = os.path.abspath("test/data")
|
||
|
load_data(data_dir) # Download data for all trials before starting the run
|
||
|
if method == 'BlendSearch':
|
||
|
from flaml import tune
|
||
|
else:
|
||
|
from ray import tune
|
||
|
if method in ['BlendSearch', 'BOHB', 'Optuna']:
|
||
|
config = {
|
||
|
"l1": tune.randint(2, 8),
|
||
|
"l2": tune.randint(2, 8),
|
||
|
"lr": tune.loguniform(1e-4, 1e-1),
|
||
|
"num_epochs": tune.qloguniform(1, max_num_epochs, q=1),
|
||
|
"batch_size": tune.randint(1, 4)#tune.choice([2, 4, 8, 16])
|
||
|
}
|
||
|
else:
|
||
|
config = {
|
||
|
"l1": tune.randint(2, 9),
|
||
|
"l2": tune.randint(2, 9),
|
||
|
"lr": tune.loguniform(1e-4, 1e-1),
|
||
|
"num_epochs": tune.qloguniform(1, max_num_epochs+1, q=1),
|
||
|
"batch_size": tune.randint(1, 5)#tune.choice([2, 4, 8, 16])
|
||
|
}
|
||
|
import ray
|
||
|
time_budget_s = 3600
|
||
|
start_time = time.time()
|
||
|
if method == 'BlendSearch':
|
||
|
result = tune.run(
|
||
|
ray.tune.with_parameters(train_cifar, data_dir=data_dir),
|
||
|
init_config={
|
||
|
"l1": 2,
|
||
|
"l2": 2,
|
||
|
"num_epochs": 1,
|
||
|
"batch_size": 4,
|
||
|
},
|
||
|
metric="loss",
|
||
|
mode="min",
|
||
|
max_resource=max_num_epochs,
|
||
|
min_resource=1,
|
||
|
report_intermediate_result=True,
|
||
|
resources_per_trial={"cpu": 2, "gpu": gpus_per_trial},
|
||
|
config=config,
|
||
|
local_dir='logs/',
|
||
|
num_samples=num_samples,
|
||
|
time_budget_s=time_budget_s,
|
||
|
use_ray=True)
|
||
|
else:
|
||
|
if 'ASHA' == method:
|
||
|
algo = None
|
||
|
elif 'BOHB' == method:
|
||
|
from ray.tune.schedulers import HyperBandForBOHB
|
||
|
from ray.tune.suggest.bohb import TuneBOHB
|
||
|
algo = TuneBOHB()
|
||
|
scheduler = HyperBandForBOHB(max_t=max_num_epochs)
|
||
|
elif 'Optuna' == method:
|
||
|
from ray.tune.suggest.optuna import OptunaSearch
|
||
|
algo = OptunaSearch()
|
||
|
elif 'CFO' == method:
|
||
|
from flaml import CFO
|
||
|
algo = CFO(points_to_evaluate=[{
|
||
|
"l1": 2,
|
||
|
"l2": 2,
|
||
|
"num_epochs": 1,
|
||
|
"batch_size": 4,
|
||
|
}])
|
||
|
elif 'Nevergrad' == method:
|
||
|
from ray.tune.suggest.nevergrad import NevergradSearch
|
||
|
import nevergrad as ng
|
||
|
algo = NevergradSearch(optimizer=ng.optimizers.OnePlusOne)
|
||
|
if method != 'BOHB':
|
||
|
from ray.tune.schedulers import ASHAScheduler
|
||
|
scheduler = ASHAScheduler(
|
||
|
max_t=max_num_epochs,
|
||
|
grace_period=1)
|
||
|
result = tune.run(
|
||
|
tune.with_parameters(train_cifar, data_dir=data_dir),
|
||
|
resources_per_trial={"cpu": 2, "gpu": gpus_per_trial},
|
||
|
config=config,
|
||
|
metric="loss",
|
||
|
mode="min",
|
||
|
num_samples=num_samples, time_budget_s=time_budget_s,
|
||
|
scheduler=scheduler, search_alg=algo
|
||
|
)
|
||
|
ray.shutdown()
|
||
|
|
||
|
logger.info(f"method={method}")
|
||
|
logger.info(f"n_samples={num_samples}")
|
||
|
logger.info(f"time={time.time()-start_time}")
|
||
|
best_trial = result.get_best_trial("loss", "min", "all")
|
||
|
logger.info("Best trial config: {}".format(best_trial.config))
|
||
|
logger.info("Best trial final validation loss: {}".format(
|
||
|
best_trial.metric_analysis["loss"]["min"]))
|
||
|
logger.info("Best trial final validation accuracy: {}".format(
|
||
|
best_trial.metric_analysis["accuracy"]["max"]))
|
||
|
|
||
|
best_trained_model = Net(2**best_trial.config["l1"],
|
||
|
2**best_trial.config["l2"])
|
||
|
device = "cpu"
|
||
|
if torch.cuda.is_available():
|
||
|
device = "cuda:0"
|
||
|
if gpus_per_trial > 1:
|
||
|
best_trained_model = nn.DataParallel(best_trained_model)
|
||
|
best_trained_model.to(device)
|
||
|
|
||
|
checkpoint_path = os.path.join(best_trial.checkpoint.value, "checkpoint")
|
||
|
|
||
|
model_state, optimizer_state = torch.load(checkpoint_path)
|
||
|
best_trained_model.load_state_dict(model_state)
|
||
|
|
||
|
test_acc = _test_accuracy(best_trained_model, device)
|
||
|
logger.info("Best trial test set accuracy: {}".format(test_acc))
|
||
|
# __main_end__
|
||
|
|
||
|
|
||
|
gpus_per_trial=0#.5
|
||
|
num_samples=500
|
||
|
|
||
|
|
||
|
def _test_cifar10_bs():
|
||
|
cifar10_main(num_samples=num_samples, gpus_per_trial=gpus_per_trial)
|
||
|
|
||
|
|
||
|
def _test_cifar10_cfo():
|
||
|
cifar10_main('CFO',
|
||
|
num_samples=num_samples, gpus_per_trial=gpus_per_trial)
|
||
|
|
||
|
|
||
|
def _test_cifar10_optuna():
|
||
|
cifar10_main('Optuna',
|
||
|
num_samples=num_samples, gpus_per_trial=gpus_per_trial)
|
||
|
|
||
|
|
||
|
def _test_cifar10_asha():
|
||
|
cifar10_main('ASHA',
|
||
|
num_samples=num_samples, gpus_per_trial=gpus_per_trial)
|
||
|
|
||
|
|
||
|
def _test_cifar10_bohb():
|
||
|
cifar10_main('BOHB',
|
||
|
num_samples=num_samples, gpus_per_trial=gpus_per_trial)
|
||
|
|
||
|
|
||
|
def _test_cifar10_nevergrad():
|
||
|
cifar10_main('Nevergrad',
|
||
|
num_samples=num_samples, gpus_per_trial=gpus_per_trial)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
unittest.main()
|