{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Pytorch model tuning example on CIFAR10\n", "This notebook uses flaml to tune a pytorch model on CIFAR10. It is modified based on [this example](https://docs.ray.io/en/master/tune/examples/cifar10_pytorch.html).\n", "\n", "**Requirements.** This notebook requires:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "%pip install torchvision flaml[blendsearch,ray]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Network Specification" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "from torch.utils.data import random_split\n", "import torchvision\n", "import torchvision.transforms as transforms\n", "\n", "\n", "class Net(nn.Module):\n", "\n", " def __init__(self, l1=120, l2=84):\n", " super(Net, self).__init__()\n", " self.conv1 = nn.Conv2d(3, 6, 5)\n", " self.pool = nn.MaxPool2d(2, 2)\n", " self.conv2 = nn.Conv2d(6, 16, 5)\n", " self.fc1 = nn.Linear(16 * 5 * 5, l1)\n", " self.fc2 = nn.Linear(l1, l2)\n", " self.fc3 = nn.Linear(l2, 10)\n", "\n", " def forward(self, x):\n", " x = self.pool(F.relu(self.conv1(x)))\n", " x = self.pool(F.relu(self.conv2(x)))\n", " x = x.view(-1, 16 * 5 * 5)\n", " x = F.relu(self.fc1(x))\n", " x = F.relu(self.fc2(x))\n", " x = self.fc3(x)\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def load_data(data_dir=\"data\"):\n", " transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n", " ])\n", "\n", " trainset = torchvision.datasets.CIFAR10(\n", " root=data_dir, train=True, download=True, transform=transform)\n", "\n", " testset = torchvision.datasets.CIFAR10(\n", " root=data_dir, train=False, download=True, transform=transform)\n", "\n", " return trainset, testset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from ray import tune\n", "\n", "def train_cifar(config, checkpoint_dir=None, data_dir=None):\n", " if \"l1\" not in config:\n", " logger.warning(config)\n", " net = Net(2**config[\"l1\"], 2**config[\"l2\"])\n", "\n", " device = \"cpu\"\n", " if torch.cuda.is_available():\n", " device = \"cuda:0\"\n", " if torch.cuda.device_count() > 1:\n", " net = nn.DataParallel(net)\n", " net.to(device)\n", "\n", " criterion = nn.CrossEntropyLoss()\n", " optimizer = optim.SGD(net.parameters(), lr=config[\"lr\"], momentum=0.9)\n", "\n", " # The `checkpoint_dir` parameter gets passed by Ray Tune when a checkpoint\n", " # should be restored.\n", " if checkpoint_dir:\n", " checkpoint = os.path.join(checkpoint_dir, \"checkpoint\")\n", " model_state, optimizer_state = torch.load(checkpoint)\n", " net.load_state_dict(model_state)\n", " optimizer.load_state_dict(optimizer_state)\n", "\n", " trainset, testset = load_data(data_dir)\n", "\n", " test_abs = int(len(trainset) * 0.8)\n", " train_subset, val_subset = random_split(\n", " trainset, [test_abs, len(trainset) - test_abs])\n", "\n", " trainloader = torch.utils.data.DataLoader(\n", " train_subset,\n", " batch_size=int(2**config[\"batch_size\"]),\n", " shuffle=True,\n", " num_workers=4)\n", " valloader = torch.utils.data.DataLoader(\n", " val_subset,\n", " batch_size=int(2**config[\"batch_size\"]),\n", " shuffle=True,\n", " num_workers=4)\n", "\n", " for epoch in range(int(round(config[\"num_epochs\"]))): # loop over the dataset multiple times\n", " running_loss = 0.0\n", " epoch_steps = 0\n", " for i, data in enumerate(trainloader, 0):\n", " # get the inputs; data is a list of [inputs, labels]\n", " inputs, labels = data\n", " inputs, labels = inputs.to(device), labels.to(device)\n", "\n", " # zero the parameter gradients\n", " optimizer.zero_grad()\n", "\n", " # forward + backward + optimize\n", " outputs = net(inputs)\n", " loss = criterion(outputs, labels)\n", " loss.backward()\n", " optimizer.step()\n", "\n", " # print statistics\n", " running_loss += loss.item()\n", " epoch_steps += 1\n", " if i % 2000 == 1999: # print every 2000 mini-batches\n", " print(\"[%d, %5d] loss: %.3f\" % (epoch + 1, i + 1,\n", " running_loss / epoch_steps))\n", " running_loss = 0.0\n", "\n", " # Validation loss\n", " val_loss = 0.0\n", " val_steps = 0\n", " total = 0\n", " correct = 0\n", " for i, data in enumerate(valloader, 0):\n", " with torch.no_grad():\n", " inputs, labels = data\n", " inputs, labels = inputs.to(device), labels.to(device)\n", "\n", " outputs = net(inputs)\n", " _, predicted = torch.max(outputs.data, 1)\n", " total += labels.size(0)\n", " correct += (predicted == labels).sum().item()\n", "\n", " loss = criterion(outputs, labels)\n", " val_loss += loss.cpu().numpy()\n", " val_steps += 1\n", "\n", " # Here we save a checkpoint. It is automatically registered with\n", " # Ray Tune and will potentially be passed as the `checkpoint_dir`\n", " # parameter in future iterations.\n", " with tune.checkpoint_dir(step=epoch) as checkpoint_dir:\n", " path = os.path.join(checkpoint_dir, \"checkpoint\")\n", " torch.save(\n", " (net.state_dict(), optimizer.state_dict()), path)\n", "\n", " tune.report(loss=(val_loss / val_steps), accuracy=correct / total)\n", " print(\"Finished Training\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test Accuracy" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def _test_accuracy(net, device=\"cpu\"):\n", " trainset, testset = load_data()\n", "\n", " testloader = torch.utils.data.DataLoader(\n", " testset, batch_size=4, shuffle=False, num_workers=2)\n", "\n", " correct = 0\n", " total = 0\n", " with torch.no_grad():\n", " for data in testloader:\n", " images, labels = data\n", " images, labels = images.to(device), labels.to(device)\n", " outputs = net(images)\n", " _, predicted = torch.max(outputs.data, 1)\n", " total += labels.size(0)\n", " correct += (predicted == labels).sum().item()\n", "\n", " return correct / total" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Hyperparameter Optimization" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import flaml\n", "import os\n", "\n", "data_dir = os.path.abspath(\"data\")\n", "load_data(data_dir) # Download data for all trials before starting the run" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Search space" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "max_num_epoch = 100\n", "config = {\n", " \"l1\": tune.randint(2, 9), # log transformed with base 2\n", " \"l2\": tune.randint(2, 9), # log transformed with base 2\n", " \"lr\": tune.loguniform(1e-4, 1e-1),\n", " \"num_epochs\": tune.loguniform(1, max_num_epoch),\n", " \"batch_size\": tune.randint(1, 5) # log transformed with base 2\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "time_budget_s = 3600 # time budget in seconds\n", "gpus_per_trial = 0.5 # number of gpus for each trial; 0.5 means two training jobs can share one gpu\n", "num_samples = 500 # maximal number of trials\n", "np.random.seed(7654321)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Launch the tuning" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import time\n", "start_time = time.time()\n", "result = flaml.tune.run(\n", " tune.with_parameters(train_cifar, data_dir=data_dir),\n", " config=config,\n", " metric=\"loss\",\n", " mode=\"min\",\n", " low_cost_partial_config={\"num_epochs\": 1},\n", " max_resource=max_num_epoch,\n", " min_resource=1,\n", " scheduler=\"asha\", # need to use tune.report to report intermediate results in train_cifar \n", " resources_per_trial={\"cpu\": 1, \"gpu\": gpus_per_trial},\n", " local_dir='logs/',\n", " num_samples=num_samples,\n", " time_budget_s=time_budget_s,\n", " use_ray=True)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(f\"#trials={len(result.trials)}\")\n", "print(f\"time={time.time()-start_time}\")\n", "best_trial = result.get_best_trial(\"loss\", \"min\", \"all\")\n", "print(\"Best trial config: {}\".format(best_trial.config))\n", "print(\"Best trial final validation loss: {}\".format(\n", " best_trial.metric_analysis[\"loss\"][\"min\"]))\n", "print(\"Best trial final validation accuracy: {}\".format(\n", " best_trial.metric_analysis[\"accuracy\"][\"max\"]))\n", "\n", "best_trained_model = Net(2**best_trial.config[\"l1\"],\n", " 2**best_trial.config[\"l2\"])\n", "device = \"cpu\"\n", "if torch.cuda.is_available():\n", " device = \"cuda:0\"\n", " if gpus_per_trial > 1:\n", " best_trained_model = nn.DataParallel(best_trained_model)\n", "best_trained_model.to(device)\n", "\n", "checkpoint_value = (\n", " getattr(best_trial.checkpoint, \"dir_or_data\", None)\n", " or best_trial.checkpoint.value\n", ")\n", "checkpoint_path = os.path.join(checkpoint_value, \"checkpoint\")\n", "\n", "model_state, optimizer_state = torch.load(checkpoint_path)\n", "best_trained_model.load_state_dict(model_state)\n", "\n", "test_acc = _test_accuracy(best_trained_model, device)\n", "print(\"Best trial test set accuracy: {}\".format(test_acc))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.11.0 64-bit", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.0" }, "metadata": { "interpreter": { "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" } }, "vscode": { "interpreter": { "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49" } } }, "nbformat": 4, "nbformat_minor": 4 }