2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								{
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "cells": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Pytorch model tuning example on CIFAR10\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "This notebook uses flaml to tune a pytorch model on CIFAR10. It is modified based on [this example](https://docs.ray.io/en/master/tune/examples/cifar10_pytorch.html).\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "**Requirements.** This notebook requires:"
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "tags": []
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "!pip install torchvision flaml[blendsearch,ray];"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## Network Specification"
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import torch\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import torch.nn as nn\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import torch.nn.functional as F\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import torch.optim as optim\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from torch.utils.data import random_split\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import torchvision\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import torchvision.transforms as transforms\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "class Net(nn.Module):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def __init__(self, l1=120, l2=84):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        super(Net, self).__init__()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.conv1 = nn.Conv2d(3, 6, 5)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.pool = nn.MaxPool2d(2, 2)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.fc1 = nn.Linear(16 * 5 * 5, l1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.fc2 = nn.Linear(l1, l2)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.fc3 = nn.Linear(l2, 10)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def forward(self, x):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        x = self.pool(F.relu(self.conv1(x)))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        x = self.pool(F.relu(self.conv2(x)))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        x = x.view(-1, 16 * 5 * 5)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        x = F.relu(self.fc1(x))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        x = F.relu(self.fc2(x))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        x = self.fc3(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        return x"
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## Data"
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def load_data(data_dir=\"data\"):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    transform = transforms.Compose([\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        transforms.ToTensor(),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    ])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    trainset = torchvision.datasets.CIFAR10(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        root=data_dir, train=True, download=True, transform=transform)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    testset = torchvision.datasets.CIFAR10(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        root=data_dir, train=False, download=True, transform=transform)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    return trainset, testset"
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## Training"
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from ray import tune\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def train_cifar(config, checkpoint_dir=None, data_dir=None):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    if \"l1\" not in config:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        logger.warning(config)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    net = Net(2**config[\"l1\"], 2**config[\"l2\"])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    device = \"cpu\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    if torch.cuda.is_available():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        device = \"cuda:0\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        if torch.cuda.device_count() > 1:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            net = nn.DataParallel(net)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    net.to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    criterion = nn.CrossEntropyLoss()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    optimizer = optim.SGD(net.parameters(), lr=config[\"lr\"], momentum=0.9)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # The `checkpoint_dir` parameter gets passed by Ray Tune when a checkpoint\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # should be restored.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    if checkpoint_dir:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        checkpoint = os.path.join(checkpoint_dir, \"checkpoint\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        model_state, optimizer_state = torch.load(checkpoint)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        net.load_state_dict(model_state)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        optimizer.load_state_dict(optimizer_state)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    trainset, testset = load_data(data_dir)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    test_abs = int(len(trainset) * 0.8)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    train_subset, val_subset = random_split(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        trainset, [test_abs, len(trainset) - test_abs])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    trainloader = torch.utils.data.DataLoader(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        train_subset,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        batch_size=int(2**config[\"batch_size\"]),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        shuffle=True,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        num_workers=4)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    valloader = torch.utils.data.DataLoader(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        val_subset,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        batch_size=int(2**config[\"batch_size\"]),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        shuffle=True,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        num_workers=4)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    for epoch in range(int(round(config[\"num_epochs\"]))):  # loop over the dataset multiple times\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        running_loss = 0.0\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        epoch_steps = 0\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        for i, data in enumerate(trainloader, 0):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # get the inputs; data is a list of [inputs, labels]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            inputs, labels = data\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            inputs, labels = inputs.to(device), labels.to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # zero the parameter gradients\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            optimizer.zero_grad()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # forward + backward + optimize\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            outputs = net(inputs)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            loss = criterion(outputs, labels)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            loss.backward()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            optimizer.step()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # print statistics\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            running_loss += loss.item()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            epoch_steps += 1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            if i % 2000 == 1999:  # print every 2000 mini-batches\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                print(\"[%d, %5d] loss: %.3f\" % (epoch + 1, i + 1,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                                                running_loss / epoch_steps))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                running_loss = 0.0\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Validation loss\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        val_loss = 0.0\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        val_steps = 0\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        total = 0\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        correct = 0\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        for i, data in enumerate(valloader, 0):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            with torch.no_grad():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                inputs, labels = data\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                inputs, labels = inputs.to(device), labels.to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                outputs = net(inputs)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                _, predicted = torch.max(outputs.data, 1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                total += labels.size(0)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                correct += (predicted == labels).sum().item()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                loss = criterion(outputs, labels)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                val_loss += loss.cpu().numpy()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                val_steps += 1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Here we save a checkpoint. It is automatically registered with\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Ray Tune and will potentially be passed as the `checkpoint_dir`\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # parameter in future iterations.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        with tune.checkpoint_dir(step=epoch) as checkpoint_dir:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            path = os.path.join(checkpoint_dir, \"checkpoint\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            torch.save(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                (net.state_dict(), optimizer.state_dict()), path)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        tune.report(loss=(val_loss / val_steps), accuracy=correct / total)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    print(\"Finished Training\")"
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## Test Accuracy"
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def _test_accuracy(net, device=\"cpu\"):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    trainset, testset = load_data()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    testloader = torch.utils.data.DataLoader(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        testset, batch_size=4, shuffle=False, num_workers=2)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    correct = 0\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    total = 0\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    with torch.no_grad():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        for data in testloader:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            images, labels = data\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            images, labels = images.to(device), labels.to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            outputs = net(images)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            _, predicted = torch.max(outputs.data, 1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            total += labels.size(0)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            correct += (predicted == labels).sum().item()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    return correct / total"
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## Hyperparameter Optimization"
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import numpy as np\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import flaml\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "import os\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "data_dir = os.path.abspath(\"data\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "load_data(data_dir)  # Download data for all trials before starting the run"
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "### Search space"
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "max_num_epoch = 100\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "config = {\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"l1\": tune.randint(2, 9),   # log transformed with base 2\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"l2\": tune.randint(2, 9),   # log transformed with base 2\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"lr\": tune.loguniform(1e-4, 1e-1),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"num_epochs\": tune.loguniform(1, max_num_epoch),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"batch_size\": tune.randint(1, 5)    # log transformed with base 2\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "}"
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "time_budget_s = 600     # time budget in seconds\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "gpus_per_trial = 0.5    # number of gpus for each trial; 0.5 means two training jobs can share one gpu\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "num_samples = 500       # maximal number of trials\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "np.random.seed(7654321)"
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "### Launch the tuning"
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import time\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "start_time = time.time()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "result = flaml.tune.run(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    tune.with_parameters(train_cifar, data_dir=data_dir),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    config=config,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    metric=\"loss\",\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    mode=\"min\",\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    low_cost_partial_config={\"num_epochs\": 1},\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    max_resource=max_num_epoch,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    min_resource=1,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    report_intermediate_result=True,  # only set to True when intermediate results are reported by tune.report\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    resources_per_trial={\"cpu\": 1, \"gpu\": gpus_per_trial},\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    local_dir='logs/',\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    num_samples=num_samples,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    time_budget_s=time_budget_s,\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    use_ray=True)\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 13,
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "#trials=44\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "time=1193.913584947586\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Best trial config: {'l1': 8, 'l2': 8, 'lr': 0.0008818671030627281, 'num_epochs': 55.9513429004283, 'batch_size': 3}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Best trial final validation loss: 1.0694482081472874\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Best trial final validation accuracy: 0.6389\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Files already downloaded and verified\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Files already downloaded and verified\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Best trial test set accuracy: 0.6294\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(f\"#trials={len(result.trials)}\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(f\"time={time.time()-start_time}\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "best_trial = result.get_best_trial(\"loss\", \"min\", \"all\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Best trial config: {}\".format(best_trial.config))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Best trial final validation loss: {}\".format(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    best_trial.metric_analysis[\"loss\"][\"min\"]))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Best trial final validation accuracy: {}\".format(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    best_trial.metric_analysis[\"accuracy\"][\"max\"]))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "best_trained_model = Net(2**best_trial.config[\"l1\"],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                         2**best_trial.config[\"l2\"])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "device = \"cpu\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "if torch.cuda.is_available():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    device = \"cuda:0\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    if gpus_per_trial > 1:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        best_trained_model = nn.DataParallel(best_trained_model)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "best_trained_model.to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "checkpoint_path = os.path.join(best_trial.checkpoint.value, \"checkpoint\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model_state, optimizer_state = torch.load(checkpoint_path)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "best_trained_model.load_state_dict(model_state)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "test_acc = _test_accuracy(best_trained_model, device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Best trial test set accuracy: {}\".format(test_acc))"
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "metadata": {
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  "interpreter": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "hash": "f7771e6a3915580179405189f5aa4eb9047494cbe4e005b29b851351b54902f6"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  "kernelspec": {
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "display_name": "Python 3.8.10 64-bit ('venv': venv)",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python3"
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "language_info": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "codemirror_mode": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "name": "ipython",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "version": 3
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "file_extension": ".py",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "mimetype": "text/x-python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "nbconvert_exporter": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "pygments_lexer": "ipython3",
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "version": "3.7.12"
							 
						 
					
						
							
								
									
										
										
										
											2021-09-10 16:39:16 -07:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "interpreter": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat": 4,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat_minor": 4
							 
						 
					
						
							
								
									
										
										
										
											2021-11-12 22:29:33 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								}