2021-06-02 22:08:24 -04:00
{
"cells": [
{
"cell_type": "markdown",
2021-12-16 17:11:33 -08:00
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
2021-06-02 22:08:24 -04:00
"source": [
2022-06-21 18:59:07 -07:00
"Copyright (c) Microsoft Corporation. All rights reserved. \n",
2021-06-02 22:08:24 -04:00
"\n",
"Licensed under the MIT License.\n",
"\n",
"# AutoVW: ChaCha for Online AutoML with Vowpal Wabbit\n",
"\n",
"\n",
"## 1. Introduction\n",
"\n",
"\n",
"In this notebook, we use one real data example (regression task) to showcase AutoVW, which is an online AutoML solution based on the following work:\n",
"\n",
2023-02-03 16:57:16 -08:00
"*ChaCha for online AutoML. Qingyun Wu, Chi Wang, John Langford, Paul Mineiro and Marco Rossi. ICML 2021.*\n",
2021-06-02 22:08:24 -04:00
"\n",
2022-06-21 18:59:07 -07:00
"AutoVW is implemented in FLAML. FLAML requires `Python>=3.7`. To run this notebook example, please install:"
2021-12-16 17:11:33 -08:00
]
2021-06-02 22:08:24 -04:00
},
{
"cell_type": "code",
"execution_count": null,
2021-12-16 17:11:33 -08:00
"metadata": {},
"outputs": [],
2021-06-02 22:08:24 -04:00
"source": [
2023-03-11 02:39:08 +00:00
"%pip install flaml[notebook,vw]==1.1.2"
2021-12-16 17:11:33 -08:00
]
2021-06-02 22:08:24 -04:00
},
{
"cell_type": "markdown",
2021-12-16 17:11:33 -08:00
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
2021-06-02 22:08:24 -04:00
"source": [
"## 2. Online regression with AutoVW\n",
"### Load data from openml and preprocess\n",
"\n",
2021-06-11 10:25:45 -07:00
"Download [NewFuelCar](https://www.openml.org/d/41506) from OpenML."
2021-12-16 17:11:33 -08:00
]
2021-06-02 22:08:24 -04:00
},
{
"cell_type": "code",
"execution_count": 1,
2021-12-16 17:11:33 -08:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(36203, 17) (36203,)\n"
]
}
],
2021-06-02 22:08:24 -04:00
"source": [
"import openml\n",
"# did = 42183\n",
"did = 41506\n",
"ds = openml.datasets.get_dataset(did)\n",
"target_attribute = ds.default_target_attribute\n",
"data = ds.get_data(target=target_attribute, dataset_format='array')\n",
"X, y = data[0], data[1]\n",
"print(X.shape, y.shape)"
2021-12-16 17:11:33 -08:00
]
2021-06-02 22:08:24 -04:00
},
{
"cell_type": "markdown",
2021-12-16 17:11:33 -08:00
"metadata": {},
2021-06-02 22:08:24 -04:00
"source": [
"Convert the openml dataset into vowpalwabbit examples:\n",
"Sequentially group features into up to 10 namespaces and convert the original data examples into vowpal wabbit format."
2021-12-16 17:11:33 -08:00
]
2021-06-02 22:08:24 -04:00
},
{
"cell_type": "code",
2021-07-05 21:17:26 -04:00
"execution_count": 3,
2021-12-16 17:11:33 -08:00
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"openml example: 8.170000076293945 [1.0000e+01 7.0000e+00 3.0000e+00 4.0000e+00 nan 6.3300e+00\n",
" 1.3600e-01 7.3300e+00 7.0100e+00 6.9800e+00 3.0000e-03 7.0000e+00\n",
" 9.7000e+00 1.2300e+01 1.0217e+03 0.0000e+00 5.8000e+01]\n",
"vw example: 8.170000076293945 |a 0:10.000000 1:7.000000|b 2:3.000000 3:4.000000|c 4:nan 5:6.330000|d 6:0.136000 7:7.330000|e 8:7.010000 9:6.980000|f 10:0.003000 11:7.000000|g 12:9.700000 13:12.300000|h 14:1021.700012 15:0.000000|i 16:58.000000\n"
]
}
],
2021-06-02 22:08:24 -04:00
"source": [
"import numpy as np\n",
"import string\n",
"NS_LIST = list(string.ascii_lowercase) + list(string.ascii_uppercase)\n",
"max_ns_num = 10 # the maximum number of namespaces\n",
"orginal_dim = X.shape[1]\n",
"max_size_per_group = int(np.ceil(orginal_dim / float(max_ns_num)))\n",
"# sequential grouping\n",
"group_indexes = []\n",
"for i in range(max_ns_num):\n",
" indexes = [ind for ind in range(i * max_size_per_group,\n",
" min((i + 1) * max_size_per_group, orginal_dim))]\n",
" if len(indexes) > 0:\n",
" group_indexes.append(indexes)\n",
"\n",
"vw_examples = []\n",
"for i in range(X.shape[0]):\n",
" ns_content = []\n",
" for zz in range(len(group_indexes)):\n",
" ns_features = ' '.join('{}:{:.6f}'.format(ind, X[i][ind]) for ind in group_indexes[zz])\n",
" ns_content.append(ns_features)\n",
" ns_line = '{} |{}'.format(str(y[i]), '|'.join('{} {}'.format(NS_LIST[j], ns_content[j]) for j in range(len(group_indexes))))\n",
" vw_examples.append(ns_line)\n",
"print('openml example:', y[0], X[0])\n",
"print('vw example:', vw_examples[0])"
2021-12-16 17:11:33 -08:00
]
2021-06-02 22:08:24 -04:00
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
2021-12-16 17:11:33 -08:00
},
"source": [
"### Set up the online learning loop\n"
]
2021-06-02 22:08:24 -04:00
},
{
"cell_type": "code",
2021-07-05 21:17:26 -04:00
"execution_count": 4,
2021-12-16 17:11:33 -08:00
"metadata": {},
"outputs": [],
2021-06-02 22:08:24 -04:00
"source": [
"from sklearn.metrics import mean_squared_error\n",
"def online_learning_loop(iter_num, vw_examples, vw_alg):\n",
" \"\"\"Implements the online learning loop.\n",
" \"\"\"\n",
" print('Online learning for', iter_num, 'steps...')\n",
" loss_list = []\n",
" for i in range(iter_num):\n",
" vw_x = vw_examples[i]\n",
" y_true = float(vw_examples[i].split('|')[0])\n",
" # predict step\n",
" y_pred = vw_alg.predict(vw_x)\n",
" # learn step\n",
" vw_alg.learn(vw_x)\n",
" # calculate one step loss\n",
" loss = mean_squared_error([y_pred], [y_true])\n",
" loss_list.append(loss)\n",
" return loss_list\n",
"\n",
"max_iter_num = 10000 # or len(vw_examples)"
2021-12-16 17:11:33 -08:00
]
2021-06-02 22:08:24 -04:00
},
{
"cell_type": "markdown",
2021-12-16 17:11:33 -08:00
"metadata": {},
2021-06-02 22:08:24 -04:00
"source": [
"### Vanilla Vowpal Wabbit (VW)\n",
"Create and run a vanilla vowpal wabbit learner."
2021-12-16 17:11:33 -08:00
]
2021-06-02 22:08:24 -04:00
},
{
"cell_type": "code",
2021-07-05 21:17:26 -04:00
"execution_count": 5,
2021-12-16 17:11:33 -08:00
"metadata": {
"tags": []
},
2021-06-02 22:08:24 -04:00
"outputs": [
{
"name": "stdout",
2021-12-16 17:11:33 -08:00
"output_type": "stream",
2021-07-05 21:17:26 -04:00
"text": [
"Online learning for 10000 steps...\n",
"Final progressive validation loss of vanilla vw: 15.18087237487917\n"
]
2021-06-02 22:08:24 -04:00
}
],
2021-12-16 17:11:33 -08:00
"source": [
"from vowpalwabbit import pyvw\n",
"''' create a vanilla vw instance '''\n",
"vanilla_vw = pyvw.vw('--quiet')\n",
"\n",
"# online learning with vanilla VW\n",
"loss_list_vanilla = online_learning_loop(max_iter_num, vw_examples, vanilla_vw)\n",
"print('Final progressive validation loss of vanilla vw:', sum(loss_list_vanilla)/len(loss_list_vanilla))"
]
2021-06-02 22:08:24 -04:00
},
{
"cell_type": "markdown",
2021-12-16 17:11:33 -08:00
"metadata": {},
2021-06-02 22:08:24 -04:00
"source": [
"### AutoVW which tunes namespace interactions \n",
"Create and run an AutoVW instance which tunes namespace interactions. Each AutoVW instance allows ```max_live_model_num``` of VW models (each associated with its own hyperaparameter configurations that are tuned online) to run concurrently in each step of the online learning loop."
2021-12-16 17:11:33 -08:00
]
2021-06-02 22:08:24 -04:00
},
{
"cell_type": "code",
2021-07-05 21:17:26 -04:00
"execution_count": 6,
2021-12-16 17:11:33 -08:00
"metadata": {
"slideshow": {
"slide_type": "slide"
},
"tags": []
},
2021-06-02 22:08:24 -04:00
"outputs": [
{
"name": "stderr",
2021-12-16 17:11:33 -08:00
"output_type": "stream",
2021-07-05 21:17:26 -04:00
"text": [
"Seed namespaces (singletons and interactions): ['g', 'a', 'h', 'b', 'c', 'i', 'd', 'e', 'f']\n",
"Created challengers from champion ||\n",
"New challenger size 37, ['|ah|', '|eg|', '|gi|', '|ag|', '|de|', '|ei|', '|eh|', '|fg|', '|cf|', '|hi|', '|bf|', '|cd|', '|ai|', '|ef|', '|cg|', '|ch|', '|ad|', '|bc|', '|gh|', '|bh|', '|ci|', '|fh|', '|bg|', '|be|', '|bd|', '|fi|', '|bi|', '|df|', '|ac|', '|ae|', '|dg|', '|af|', '|di|', '|ce|', '|dh|', '|ab|', '||']\n",
"Online learning for 10000 steps...\n",
"Seed namespaces (singletons and interactions): ['ce', 'g', 'a', 'h', 'b', 'c', 'i', 'd', 'e', 'f']\n",
"Created challengers from champion |ce|\n",
"New challenger size 43, ['|be_ce|', '|bce_ce|', '|ce_ei|', '|ce_ceg|', '|ce_fh|', '|ce_gh|', '|ce_cef|', '|cd_ce|', '|ce_cg|', '|cde_ce|', '|ce_cf|', '|bd_ce|', '|ae_ce|', '|ce_gi|', '|ce_ci|', '|ab_ce|', '|ce_fg|', '|ce_di|', '|bi_ce|', '|ce_de|', '|ce_eg|', '|ce_dg|', '|ce_hi|', '|ai_ce|', '|ag_ce|', '|ac_ce|', '|bh_ce|', '|ce_ch|', '|ce|', '|ace_ce|', '|ah_ce|', '|af_ce|', '|bc_ce|', '|ce_dh|', '|ce_ef|', '|ad_ce|', '|ce_df|', '|ce_cei|', '|ce_eh|', '|bg_ce|', '|ce_ceh|', '|bf_ce|', '|ce_fi|']\n",
"Final progressive validation loss of autovw: 8.718817421944529\n"
]
2021-06-02 22:08:24 -04:00
}
],
2021-12-16 17:11:33 -08:00
"source": [
"''' import AutoVW class from flaml package '''\n",
"from flaml import AutoVW\n",
"\n",
"'''create an AutoVW instance for tuning namespace interactions'''\n",
"# configure both hyperparamters to tune, e.g., 'interactions', and fixed arguments about the online learner,\n",
"# e.g., 'quiet' in the search_space argument.\n",
"autovw_ni = AutoVW(max_live_model_num=5, search_space={'interactions': AutoVW.AUTOMATIC, 'quiet': ''})\n",
"\n",
"# online learning with AutoVW\n",
"loss_list_autovw_ni = online_learning_loop(max_iter_num, vw_examples, autovw_ni)\n",
"print('Final progressive validation loss of autovw:', sum(loss_list_autovw_ni)/len(loss_list_autovw_ni))"
]
2021-06-02 22:08:24 -04:00
},
{
"cell_type": "markdown",
2021-12-16 17:11:33 -08:00
"metadata": {},
2021-06-02 22:08:24 -04:00
"source": [
"### Online performance comparison between vanilla VW and AutoVW"
2021-12-16 17:11:33 -08:00
]
2021-06-02 22:08:24 -04:00
},
{
"cell_type": "code",
2021-07-05 21:17:26 -04:00
"execution_count": 7,
2021-12-16 17:11:33 -08:00
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfQAAAFzCAYAAADIY/vqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAgAElEQVR4nOzdeZxcVZ3//9ep6qV637N2VrKRlSyEVdkEghACwgABGQWFrwvCoN8ZcXRUGOcn+kUGcRlBGWHUSWBQgYAMi2yGBEgChKxkXztL7/tadX5/nNud7qzV6a6+Vd3v5+NRj6q6tX36UuRd59xzzzHWWkRERCSxBfwuQERERHpOgS4iItIPKNBFRET6AQW6iIhIP6BAFxER6QcU6CIiIv1Akt8F9ERhYaEdPXq032WIiIj0iVWrVpVZa4uO9lhCB/ro0aNZuXKl32WIiIj0CWPMzmM9pi53ERGRfkCBLiIi0g8o0EVERPqBhD6GLiIi/mhtbWXPnj00NTX5XUq/FAqFKC4uJjk5OerXKNBFRKTb9uzZQ1ZWFqNHj8YY43c5/Yq1lvLycvbs2cOYMWOifp263EVEpNuampooKChQmMeAMYaCgoJu934o0EVE5KQozGPnZPatAl1ERBLOBRdcwEsvvdRl20MPPcSXv/zlbr3Pc889x/333w/A97//fR544AEAPv/5z/P0008f83VPPPEECxcu7LKtrKyMoqIinn32Wa666qqO7T/84Q8ZN25cx/0lS5Zw5ZVXdqvOaCjQRUQk4SxcuJDFixd32bZ48eIjQvZErrzySu65555uf/7VV1/NK6+8QkNDQ8e2p59+mvnz53P22WfzzjvvdGxfvnw52dnZHDx4EIBly5Zx9tlnd/szT0SBLiIiCefaa6/lhRdeoKWlBYAdO3ZQUlLCokWLmDNnDlOmTOF73/tex/NHjx7N9773PWbNmsW0adPYuHEjAI8//jh33HHHcT/rvvvu4/TTT2fq1KncfvvtWGvJzs7mvPPOY8mSJR3Pa/9BUVRURHZ2Nlu2bAFg7969XHPNNSxbtgxwgX7OOef06v4AjXIXEZEeunfJOtaX1PTqe04els335k855uP5+fnMnTuXF198kQULFrB48WKuu+46/vmf/5n8/HzC4TAXXXQRH330EdOnTwegsLCQ999/n1/+8pc88MAD/OY3v4mqljvuuIPvfve7ANx88808//zzzJ8/n4ULF/KHP/yB66+/npKSEjZt2sSFF14IwDnnnMOyZcsIh8OMHz+eM888k5deeokrrriC1atXc/rpp/dwDx1JLXTP1jXvsO7tF/wuQ0REotS52729dfzUU08xa9YsZs6cybp161i/fn3H8z/zmc8AMHv2bHbs2BH157z++uucccYZTJs2jddee41169YBcPnll/P2229TU1PDU089xTXXXEMwGATg7LPPZtmyZSxbtoyzzjqLuXPn8u677/LBBx8wadIkQqFQL+2FQ9RC91S+8gDDaj+Ccy73uxQRkYRyvJZ0LC1YsIC7776b999/n4aGBvLz83nggQdYsWIFeXl5fP7zn+9y6ldqaioAwWCQtra2qD6jqamJr3zlK6xcuZIRI0bw/e9/v+M909LSmDdvHn/+859ZvHgxDz74YMfrzjnnHH72s58RDoe57bbbyMrKoqmpiTfeeCMmx89BLfQO1gQwNuJ3GSIiEqXMzEwuuOACbr31VhYuXEhNTQ0ZGRnk5ORw4MABXnzxxR5/Rnt4FxYWUldXd8TI94ULF/Lggw9y4MABzjrrrI7tp556KiUlJSxdupSZM2cCcNppp/GrX/0qJsfPQYHewZogBut3GSIi0g0LFy5k9erVLFy4kBkzZjBz5kwmTZrEjTfe2CvBmZuby2233cbUqVO59NJLjzj2ffHFF1NSUsL111/f5dxxYwxnnHEGBQUFHdO3nnXWWWzbti1mLXRjbeKG2Jw5c2xvrYf+3k9vZHTlcgZ9f3uvvJ+ISH+2YcMGTj31VL/L6NeOto+NMaustXOO9ny10NuZAAHU5S4iIolJge5Rl7uIiCQyBXo7EyCgQBcRkQSlQO9gCBD2uwgREZGTokD32IC63EVEJHEp0NuZAIEEHvEvIiIDmwK9g9EodxGRBPPMM89gjOlYbOV4HnrooS6rox3NLbfcwiOPPHLEZ1x22WXcfffdPPTQQx3bL730Ur74xS923P/GN77RZbY4cEuypqend6y0Bm5CnKPd7ikFuscGggp0EZEEs2jRIs4991wWLVp0wudGE+jHW5a1fcEVgEgkQllZWce87nDsZVELCwv5yU9+Es2f0yMK9HYa5S4iklDq6upYunQpjz32WEcIv/HGG1xxxRUdz7njjjt4/PHHefjhhykpKeGCCy7gggsuANyPgWnTpjF16lS++c1vAnDRRRexceNG9u3bB0B9fT2vvvoqV111FWeffTbLly8HYN26dUydOpWsrCwqKytpbm5mw4YNzJo164g6b731Vp588kkqKipiuj+0OEs7TSwjInJyXrwH9q/p3fccMg0uu/+4T3n22WeZN28eEyZMoKCggFWrVh3zuXfeeScPPvggr7/+OoWFhZSUlPDNb36TVatWkZeXxyWXXMIzzzzDVVddxTXXXMNTTz3FXXfdxZIlSzj//PPJzs4mOzubpKQkdu3a1bGK2t69e1m+fDk5OTlMmzaNlJQUvvvd7zJnzhyuvPJKwHWr33rrrfz0pz/l3nvv7dXd1Jla6B5jgmqhi4gkkEWLFnHDDTcAcMMNN0TV7d5uxYoVnH/++RQVFZGUlMRNN93EW2+9BRx9WdZ2hy+LetZZZ3Xcb587/r777usI83Z33nknTzzxBLW1tT36m49HLXSPNQECxmIjEUxAv3NERKJ2gpZ0LFRUVPDaa6+xZs0ajDGEw2GMMSxYsIBI5FBva+flU6N19tlns2/fPlavXs2yZcu6HFNvP46+Zs0apk6dyogRI/jJT35CdnY2t9xyyzHfMzc3lxtvvJFf/OIX3a4nWkqudsbtis5fBBERiU9PP/00N998Mzt37mTHjh3s3r2bMWPGEIlEWL9+Pc3NzVRVVfHXv/614zVZWVkdLeS5c+fy5ptvUlZWRjgcZtGiRZx33nmAWynt+uuv53Of+xyXXXYZoVCo4z3OPvtsnn/+efLz8wkGg+Tn51NVVcXy5ctPuIra17/+dR555JGo12LvLgV6u0B7oGu2OBGReLdo0SKuvvrqLtuuueYaFi9ezHXXXcfUqVO57rrrOtYiB7j99tuZN28eF1xwAUOHDuX+++/nggsuYMaMGcyePZsFCxZ0PLfzsqydTZs2jbKyMs4888wu23JycigsLATgu9/9Ls8999wRNRcWFnL11VfT3NzcK/vgcFo+1bP8iX/mrO2/oPmefaSG0nvlPUVE+istnxp7Wj71ZJkgAJGwWugiIpJ4FOgeoy53ERFJYAr0dhoUJyIiCUyB3q490NXlLiISlUQegxXvTmbfKtDbecfQrbrcRUROKBQKUV5erlCPAWst5eXlXU6Xi4YmlvHoGLqISPSKi4vZs2cPpaWlfpfSL4VCIYqLi7v1GgV6O6NAFxGJVnJyMmPGjPG7DOlEXe7tOrrcNShOREQSjwLdoy53ERFJZAp0j+kY5R6bOXZFRERiSYHusV6g24hGbIqISOJRoHtMwDuGbtXlLiIiiUeB7jGaKU5ERBKYAr1dQDPFiYhI4lKge9q73NEodxERSUAKdE9Hl7tVl7uIiCQeBbrHqMtdREQSmAK9nRZnERGRBKZA9xw6bU1d7iIikngU6B4TMIBa6CIikpgU6B5j3MJzmstdREQSkQLd0z4oDk0sIyIiCUiB3s4bFKcWuoiIJCIFuscE2xdnUQtdREQSjwLdY4wWZxERkcSlQPe0H0O3YbXQRUQk8ST5XUA7Y8xVwOVANvCYtfblvvz8QPtc7mqhi4hIAoppC90Y85/GmIPGmLWHbZ9njPnYGLPFGHMPgLX2GWvtbcCXgOtjWddRBXQMXUREElesu9wfB+Z13mDcwepfAJcBk4GFxpjJnZ7
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Created with matplotlib (https://matplotlib.org/) -->\n<svg height=\"370.91625pt\" version=\"1.1\" viewBox=\"0 0 499.078125 370.91625\" width=\"499.078125pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n <defs>\n <style type=\"text/css\">\n*{stroke-linecap:butt;stroke-linejoin:round;}\n </style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 370.91625 \nL 499.078125 370.91625 \nL 499.078125 0 \nL 0 0 \nz\n\" style=\"fill:none;\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 45.478125 333.36 \nL 491.878125 333.36 \nL 491.878125 7.2 \nL 45.478125 7.2 \nz\n\" style=\"fill:#ffffff;\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <defs>\n <path d=\"M 0 0 \nL 0 3.5 \n\" id=\"m443b15f508\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n </defs>\n <g>\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"65.362728\" xlink:href=\"#m443b15f508\" y=\"333.36\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 0 -->\n <defs>\n <path d=\"M 31.78125 66.40625 \nQ 24.171875 66.40625 20.328125 58.90625 \nQ 16.5 51.421875 16.5 36.375 \nQ 16.5 21.390625 20.328125 13.890625 \nQ 24.171875 6.390625 31.78125 6.390625 \nQ 39.453125 6.390625 43.28125 13.890625 \nQ 47.125 21.390625 47.125 36.375 \nQ 47.125 51.421875 43.28125 58.90625 \nQ 39.453125 66.40625 31.78125 66.40625 \nz\nM 31.78125 74.21875 \nQ 44.046875 74.21875 50.515625 64.515625 \nQ 56.984375 54.828125 56.984375 36.375 \nQ 56.984375 17.96875 50.515625 8.265625 \nQ 44.046875 -1.421875 31.78125 -1.421875 \nQ 19.53125 -1.421875 13.0625 8.265625 \nQ 6.59375 17.96875 6.59375 36.375 \nQ 6.59375 54.828125 13.0625 64.515625 \nQ 19.53125 74.21875 31.78125 74.21875 \nz\n\" id=\"DejaVuSans-48\"/>\n </defs>\n <g transform=\"translate(62.181478 347.958438)scale(0.1 -0.1)\">\n <use xlink:href=\"#DejaVuSans-48\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_2\">\n <g>\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"146.623878\" xlink:href=\"#m443b15f508\" y=\"333.36\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 2000 -->\n <defs>\n <path d=\"M 19.1875 8.296875 \nL 53.609375 8.296875 \nL 53.609375 0 \nL 7.328125 0 \nL 7.328125 8.296875 \nQ 12.9375 14.109375 22.625 23.890625 \nQ 32.328125 33.6875 34.8125 36.53125 \nQ 39.546875 41.84375 41.421875 45.53125 \nQ 43.3125 49.21875 43.3125 52.78125 \nQ 43.3125 58.59375 39.234375 62.25 \nQ 35.15625 65.921875 28.609375 65.921875 \nQ 23.96875 65.921875 18.8125 64.3125 \nQ 13.671875 62.703125 7.8125 59.421875 \nL 7.8125 69.390625 \nQ 13.765625 71.78125 18.9375 73 \nQ 24.125 74.21875 28.421875 74.21875 \nQ 39.75 74.21875 46.484375 68.546875 \nQ 53.21875 62.890625 53.21875 53.421875 \nQ 53.21875 48.921875 51.53125 44.890625 \nQ 49.859375 40.875 45.40625 35.40625 \nQ 44.1875 33.984375 37.640625 27.21875 \nQ 31.109375 20.453125 19.1875 8.296875 \nz\n\" id=\"DejaVuSans-50\"/>\n </defs>\n <g transform=\"translate(133.898878 347.958438)scale(0.1 -0.1)\">\n <use xlink:href=\"#DejaVuSans-50\"/>\n <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_3\">\n <g id=\"line2d_3\">\n <g>\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"227.885028\" xlink:href=\"#m443b15f508\" y=\"333.36\"/>\n </g>\n </g>\n <g id=\"text_3\">\n <!-- 4000 -->\n <defs>\n <path d=\"M 37.796875 64.3125 \nL 12.890625 25.390625 \nL 37.796875 25.390625 \nz\nM 35.203125 72.90625 \nL 47.609375 72.90625 \nL 47.609375 25.390625 \nL 58.015625 25.390625 \nL 58.015625 17.1875 \nL 47.609375 17.18
"text/plain": [
"<Figure size 576x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2021-06-02 22:08:24 -04:00
"source": [
"import matplotlib.pyplot as plt\n",
2021-12-16 17:11:33 -08:00
"def plot_progressive_loss(obj_list, alias, result_interval=1):\n",
2021-06-02 22:08:24 -04:00
" \"\"\"Show real-time progressive validation loss\n",
" \"\"\"\n",
" avg_list = [sum(obj_list[:i]) / i for i in range(1, len(obj_list))]\n",
" total_obs = len(avg_list)\n",
" warm_starting_point = 10 #0\n",
" plt.plot(range(warm_starting_point, len(avg_list)), avg_list[warm_starting_point:], label = alias)\n",
" plt.xlabel('# of data samples',)\n",
" plt.ylabel('Progressive validation loss')\n",
" plt.yscale('log')\n",
" plt.legend(loc='upper right')\n",
"plt.figure(figsize=(8, 6))\n",
"plot_progressive_loss(loss_list_vanilla, 'VanillaVW')\n",
"plot_progressive_loss(loss_list_autovw_ni, 'AutoVW:NI')\n",
"plt.show()"
2021-12-16 17:11:33 -08:00
]
2021-06-02 22:08:24 -04:00
},
{
"cell_type": "markdown",
2021-12-16 17:11:33 -08:00
"metadata": {},
2021-06-02 22:08:24 -04:00
"source": [
"### AutoVW which tunes both namespace interactions and learning rate\n",
"Create and run an AutoVW instance which tunes both namespace interactions and learning rate."
2021-12-16 17:11:33 -08:00
]
2021-06-02 22:08:24 -04:00
},
{
"cell_type": "code",
2021-07-05 21:17:26 -04:00
"execution_count": 8,
2021-12-16 17:11:33 -08:00
"metadata": {
"tags": []
},
2021-06-02 22:08:24 -04:00
"outputs": [
{
"name": "stderr",
2021-12-16 17:11:33 -08:00
"output_type": "stream",
2021-07-05 21:17:26 -04:00
"text": [
"Seed namespaces (singletons and interactions): ['g', 'a', 'h', 'b', 'c', 'i', 'd', 'e', 'f']\n",
"No low-cost partial config given to the search algorithm. For cost-frugal search, consider providing low-cost values for cost-related hps via 'low_cost_partial_config'.\n",
"Created challengers from champion ||0.5|\n",
"New challenger size 39, ['|gi|0.5|', '|af|0.5|', '|df|0.5|', '|gh|0.5|', '|ae|0.5|', '|di|0.5|', '|be|0.5|', '|ac|0.5|', '|hi|0.5|', '|de|0.5|', '|ef|0.5|', '|bc|0.5|', '|cf|0.5|', '|dg|0.5|', '|fg|0.5|', '|bh|0.5|', '|ei|0.5|', '|ce|0.5|', '|bf|0.5|', '|ah|0.5|', '|ad|0.5|', '|bg|0.5|', '|bd|0.5|', '|ab|0.5|', '|bi|0.5|', '|eg|0.5|', '|ai|0.5|', '|eh|0.5|', '|dh|0.5|', '|cd|0.5|', '|fi|0.5|', '|ci|0.5|', '|ag|0.5|', '|fh|0.5|', '|ch|0.5|', '|cg|0.5|', '||0.05358867312681484|', '||1.0|', '||0.5|']\n",
"Online learning for 10000 steps...\n",
"Seed namespaces (singletons and interactions): ['g', 'a', 'h', 'b', 'c', 'i', 'd', 'e', 'f']\n",
"No low-cost partial config given to the search algorithm. For cost-frugal search, consider providing low-cost values for cost-related hps via 'low_cost_partial_config'.\n",
"Created challengers from champion ||1.0|\n",
"New challenger size 50, ['|gi|0.5|', '|af|0.5|', '|df|0.5|', '|gh|0.5|', '|ae|0.5|', '|di|0.5|', '|be|0.5|', '|ac|0.5|', '|hi|0.5|', '|de|0.5|', '|ef|0.5|', '|bc|0.5|', '|dh|1.0|', '|ah|1.0|', '|cd|1.0|', '|bh|1.0|', '|bi|1.0|', '|ab|1.0|', '|gi|1.0|', '|bg|1.0|', '|bd|1.0|', '|eh|1.0|', '|af|1.0|', '|hi|1.0|', '|cf|1.0|', '|ei|1.0|', '|ef|1.0|', '|ai|1.0|', '|ch|1.0|', '|gh|1.0|', '|fg|1.0|', '|ad|1.0|', '|ci|1.0|', '|bc|1.0|', '|ag|1.0|', '|df|1.0|', '|dg|1.0|', '|de|1.0|', '|di|1.0|', '|cg|1.0|', '|be|1.0|', '|eg|1.0|', '|ce|1.0|', '|fi|1.0|', '|ae|1.0|', '|bf|1.0|', '|fh|1.0|', '|ac|1.0|', '||0.10717734625362937|', '||0.3273795141019504|']\n",
"Final progressive validation loss of autovw_nilr: 7.611900319489723\n"
]
2021-06-02 22:08:24 -04:00
}
],
2021-12-16 17:11:33 -08:00
"source": [
"from flaml.tune import loguniform\n",
"''' create another AutoVW instance for tuning namespace interactions and learning rate'''\n",
"# set up the search space and init config\n",
"search_space_nilr = {'interactions': AutoVW.AUTOMATIC, 'learning_rate': loguniform(lower=2e-10, upper=1.0), 'quiet': ''}\n",
"init_config_nilr = {'interactions': set(), 'learning_rate': 0.5}\n",
"# create an AutoVW instance\n",
"autovw_nilr = AutoVW(max_live_model_num=5, search_space=search_space_nilr, init_config=init_config_nilr)\n",
"\n",
"# online learning with AutoVW\n",
"loss_list_autovw_nilr = online_learning_loop(max_iter_num, vw_examples, autovw_nilr)\n",
"print('Final progressive validation loss of autovw_nilr:', sum(loss_list_autovw_nilr)/len(loss_list_autovw_nilr))\n"
]
2021-06-02 22:08:24 -04:00
},
{
"cell_type": "markdown",
2021-12-16 17:11:33 -08:00
"metadata": {},
2021-06-02 22:08:24 -04:00
"source": [
"### Online performance comparison between vanilla VW and two AutoVW instances\n",
"Compare the online progressive validation loss from the vanilla VW and two AutoVW instances."
2021-12-16 17:11:33 -08:00
]
2021-06-02 22:08:24 -04:00
},
{
"cell_type": "code",
2021-07-05 21:17:26 -04:00
"execution_count": 10,
2021-12-16 17:11:33 -08:00
"metadata": {
"tags": []
},
2021-06-02 22:08:24 -04:00
"outputs": [
{
"data": {
2021-12-16 17:11:33 -08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfQAAAFzCAYAAADIY/vqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAgAElEQVR4nOzdd3hc1Z3/8feZURn1LlmWZUvuveNOMSU2HQIBnA4hbLLLkpBkFzYbCMnubze7SViSDdmEkAQ2ydoQJxBq6MYGG3DBveIiW5at3stoNHN+f9yRbGPZHlsaj0b6vJ5HjzR37p35jnjwR+fcU4y1FhEREYlurkgXICIiIj2nQBcREekHFOgiIiL9gAJdRESkH1Cgi4iI9AMKdBERkX4gJtIF9ER2drYtKiqKdBkiIiLnxfr166ustTndPRfVgV5UVMS6desiXYaIiMh5YYwpOdVz6nIXERHpBxToIiIi/YACXUREpB+I6nvoIiISGp/PR2lpKW1tbZEuRULg8XgYMmQIsbGxIV+jQBcRGQBKS0tJSUmhqKgIY0yky5HTsNZSXV1NaWkpxcXFIV+nLncRkQGgra2NrKwshXkUMMaQlZV11r0pCnQRkQFCYR49zuW/lQJdRETCbuHChbzyyisnHHvkkUf46le/elav89xzz/GDH/wAgIceeogf/ehHAHzxi19k+fLlp7zuySefZMmSJSccq6qqIicnh7/85S/ccMMNXcf//d//nZEjR3Y9fv7557nuuuvOqs5IUKCLiEjYLVmyhGXLlp1wbNmyZSeF7Jlcd9113H///Wf9/jfeeCOvvfYaLS0tXceWL1/Otddey7x583jvvfe6jq9Zs4bU1FQqKioAWL16NfPmzTvr9zzfFOgiIhJ2N998My+++CLt7e0AHDhwgLKyMpYuXcrMmTOZMGEC3/3ud7vOLyoq4rvf/S7Tp09n0qRJ7Ny5E4AnnniCu++++7Tv9f3vf58LLriAiRMnctddd2GtJTU1lYsvvpjnn3++67zOPyhycnJITU3lo48+AuDw4cPcdNNNrF69GnACff78+b36+wgHjXIXERlgvvf8NraXNfTqa44fnMp3r51wyuczMzOZNWsWL7/8Mtdffz3Lli3jlltu4dvf/jaZmZn4/X4uu+wyNm/ezOTJkwHIzs5mw4YN/PznP+dHP/oRjz/+eEi13H333Tz44IMAfO5zn+OFF17g2muvZcmSJfzhD3/g1ltvpaysjN27d3PppZcCMH/+fFavXo3f72fUqFHMmTOHV155hWuuuYZNmzZxwQUX9PA3FH5qoQe9/v4fefq1n0a6DBGRfuv4bvfO1vHTTz/N9OnTmTZtGtu2bWP79u1d53/yk58EYMaMGRw4cCDk93nrrbeYPXs2kyZN4s0332Tbtm0AXH311bz77rs0NDTw9NNPc9NNN+F2uwGYN28eq1evZvXq1cydO5dZs2bx/vvv8+GHHzJ27Fg8Hk8v/RbCRy30oKc3PcJ+Vz23cE+kSxERCavTtaTD6frrr+fee+9lw4YNtLS0kJmZyY9+9CPWrl1LRkYGX/ziF0+YqhUfHw+A2+2mo6MjpPdoa2vjb//2b1m3bh2FhYU89NBDXa+ZkJDA4sWLeeaZZ1i2bBkPP/xw13Xz58/nv//7v/H7/Xz5y18mJSWFtrY2VqxYERX3z0Et9C4uDBYb6TJERPqt5ORkFi5cyB133MGSJUtoaGggKSmJtLQ0ysvLefnll3v8Hp3hnZ2dTVNT00kj35csWcLDDz9MeXk5c+fO7To+btw4ysrKeOedd5g2bRoAU6dO5Re/+EVU3D8HBfpxDAFN0RQRCaslS5awadMmlixZwpQpU5g2bRpjx47l05/+dK8EZ3p6Ol/+8peZOHEiixYtOune9xVXXEFZWRm33nrrCXO9jTHMnj2brKysruVW586dy759+6KmhW6sjd5W6cyZM21v7Yf+t49dxHZ3NSu+tK1XXk9EpC/ZsWMH48aNi3QZcha6+29mjFlvrZ3Z3flqoQcZXOpwFxGRqKVADzJGXe4iIhK9FOhBzqA4ERGR6KRADzK4CES6CBERkXOkQA9yGZe63EVEJGop0IMMRi10ERGJWgr0ICfQ1UQXEQmnZ599FmNM12Yrp/PII4+csDtad26//XZ++ctfnvQeV155Jffeey+PPPJI1/FFixZx5513dj3+5je/ecJqceBsyZqYmNi10xo4C+J093Nfo0APchlNWxMRCbelS5eyYMECli5desZzQwn0023L2rnhCkAgEKCqqqprXXc49bao2dnZ/PjHPw7l4/QpCvQgo5XiRETCqqmpiXfeeYdf//rXXSG8YsUKrrnmmq5z7r77bp544gl++tOfUlZWxsKFC1m4cCHg/DEwadIkJk6cyH333QfAZZddxs6dOzly5AgAzc3NvP7669xwww3MmzePNWvWALBt2zYmTpxISkoKtbW1eL1eduzYwfTp00+q84477uCpp56ipqYmrL+P3qbNWYJcxq176CIyMLx8Pxzd0ruvOWgSXPmD057yl7/8hcWLFzN69GiysrJYv379Kc+95557ePjhh3nrrbfIzs6mrKyM++67j/Xr15ORkcEnPvEJnn32WW644QZuuukmnn76ab72ta/x/PPPc8kll5CamkpqaioxMTEcPHiwaxe1w4cPs2bNGtLS0pg0aRJxcXE8+OCDzJw5k+uuuw5wutXvuOMOfvKTn/C9732vV39N4aQWepCzsIwh4PdHuhQRkX5p6dKl3HbbbQDcdtttIXW7d1q7di2XXHIJOTk5xMTE8JnPfIaVK1cC3W/L2unj26LOnTu363Hn2vHf//73u8K80z333MOTTz5JY2Njjz7z+aQWepAr+LdNwAZw4Y5wNSIiYXSGlnQ41NTU8Oabb7JlyxaMMfj9fowxXH/99QQCx/pHj98+NVTz5s3jyJEjbNq0idWrV59wT73zPvqWLVuYOHEihYWF/PjHPyY1NZXbb7/9lK+Znp7Opz/9aR599NGzridS1EIPMsFfRXtHe4QrERHpf5YvX87nPvc5SkpKOHDgAIcOHaK4uJhAIMD27dvxer3U1dXxxhtvdF2TkpLS1UKeNWsWb7/9NlVVVfj9fpYuXcrFF18MOD2st956K1/4whe48sor8Xg8Xa8xb948XnjhBTIzM3G73WRmZlJXV8eaNWvOuIvaN77xDX75y1+GvBd7pCnQg4xxfhV+vy/ClYiI9D9Lly7lxhtvPOHYTTfdxLJly7jllluYOHEit9xyS9de5AB33XUXixcvZuHCheTn5/ODH/yAhQsXMmXKFGbMmMH111/fde7x27Ieb9KkSVRVVTFnzpwTjqWlpZGdnQ3Agw8+yHPPPXdSzdnZ2dx44414vd5e+R2Em7ZPDfruk7fyZ7az6pNvkZ6S3SuvKSLSV2j71Oij7VPPkSu40b1PXe4iIhKFFOhBxjgD4fwa5S4iIlFIgR7UOcrd16F76CIiEn0U6EGdLfRAIDpGM4qIiBxPgR6ke+giIhLNFOhBndPWAlYtdBERiT4K9CCXBsWJiITdQNo+9ZJLLuHjU6tXrFhBWloaU6dOZezYsXzrW98K+fXORIEe5Aq20H1aWEZEJGz64/apTzzxBA899FDI51944YVs3LiRDz/8kBdeeIF333035GtPR4Ee5NKgOBGRsNL2qSdKSEhg6tSpHD58uFdeT5uzBHW20KNlzV4RkXP1Hx/8BztrztzlfTbGZo7lvln3nfYcbZ96otraWvbs2cNFF13UK6+nFnqQWugiIuHVn7ZPra6uZurUqUydOpUHH3yQX/ziF12Pt2w5/V7zq1atYsqUKRQUFLBo0SIGDRoU8u/hdNRCD+oc5d4R0D10EenfztSSDof+tn1qVlYWGzduBJx76AcOHAj5PvqFF17ICy+8wP79+5kzZw633HILU6dODf0Dn4Ja6EEuV2cLXaPcRUR6m7ZPPVlxcTH3338///Ef/9Err6dAD3IZp7PCr0AXEel1A3X71KuvvpohQ4YwZMgQPvWpT530/Fe+8hVWrlzJgQMHzvk
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Created with matplotlib (https://matplotlib.org/) -->\n<svg height=\"370.91625pt\" version=\"1.1\" viewBox=\"0 0 499.078125 370.91625\" width=\"499.078125pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n <defs>\n <style type=\"text/css\">\n*{stroke-linecap:butt;stroke-linejoin:round;}\n </style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 370.91625 \nL 499.078125 370.91625 \nL 499.078125 0 \nL 0 0 \nz\n\" style=\"fill:none;\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 45.478125 333.36 \nL 491.878125 333.36 \nL 491.878125 7.2 \nL 45.478125 7.2 \nz\n\" style=\"fill:#ffffff;\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <defs>\n <path d=\"M 0 0 \nL 0 3.5 \n\" id=\"mcbaf26f2ea\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n </defs>\n <g>\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"65.362728\" xlink:href=\"#mcbaf26f2ea\" y=\"333.36\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 0 -->\n <defs>\n <path d=\"M 31.78125 66.40625 \nQ 24.171875 66.40625 20.328125 58.90625 \nQ 16.5 51.421875 16.5 36.375 \nQ 16.5 21.390625 20.328125 13.890625 \nQ 24.171875 6.390625 31.78125 6.390625 \nQ 39.453125 6.390625 43.28125 13.890625 \nQ 47.125 21.390625 47.125 36.375 \nQ 47.125 51.421875 43.28125 58.90625 \nQ 39.453125 66.40625 31.78125 66.40625 \nz\nM 31.78125 74.21875 \nQ 44.046875 74.21875 50.515625 64.515625 \nQ 56.984375 54.828125 56.984375 36.375 \nQ 56.984375 17.96875 50.515625 8.265625 \nQ 44.046875 -1.421875 31.78125 -1.421875 \nQ 19.53125 -1.421875 13.0625 8.265625 \nQ 6.59375 17.96875 6.59375 36.375 \nQ 6.59375 54.828125 13.0625 64.515625 \nQ 19.53125 74.21875 31.78125 74.21875 \nz\n\" id=\"DejaVuSans-48\"/>\n </defs>\n <g transform=\"translate(62.181478 347.958438)scale(0.1 -0.1)\">\n <use xlink:href=\"#DejaVuSans-48\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_2\">\n <g>\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"146.623878\" xlink:href=\"#mcbaf26f2ea\" y=\"333.36\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 2000 -->\n <defs>\n <path d=\"M 19.1875 8.296875 \nL 53.609375 8.296875 \nL 53.609375 0 \nL 7.328125 0 \nL 7.328125 8.296875 \nQ 12.9375 14.109375 22.625 23.890625 \nQ 32.328125 33.6875 34.8125 36.53125 \nQ 39.546875 41.84375 41.421875 45.53125 \nQ 43.3125 49.21875 43.3125 52.78125 \nQ 43.3125 58.59375 39.234375 62.25 \nQ 35.15625 65.921875 28.609375 65.921875 \nQ 23.96875 65.921875 18.8125 64.3125 \nQ 13.671875 62.703125 7.8125 59.421875 \nL 7.8125 69.390625 \nQ 13.765625 71.78125 18.9375 73 \nQ 24.125 74.21875 28.421875 74.21875 \nQ 39.75 74.21875 46.484375 68.546875 \nQ 53.21875 62.890625 53.21875 53.421875 \nQ 53.21875 48.921875 51.53125 44.890625 \nQ 49.859375 40.875 45.40625 35.40625 \nQ 44.1875 33.984375 37.640625 27.21875 \nQ 31.109375 20.453125 19.1875 8.296875 \nz\n\" id=\"DejaVuSans-50\"/>\n </defs>\n <g transform=\"translate(133.898878 347.958438)scale(0.1 -0.1)\">\n <use xlink:href=\"#DejaVuSans-50\"/>\n <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_3\">\n <g id=\"line2d_3\">\n <g>\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"227.885028\" xlink:href=\"#mcbaf26f2ea\" y=\"333.36\"/>\n </g>\n </g>\n <g id=\"text_3\">\n <!-- 4000 -->\n <defs>\n <path d=\"M 37.796875 64.3125 \nL 12.890625 25.390625 \nL 37.796875 25.390625 \nz\nM 35.203125 72.90625 \nL 47.609375 72.90625 \nL 47.609375 25.390625 \nL 58.015625 25.390625 \nL 58.015625 17.1875 \nL 47.609375 17.18
2021-09-10 16:39:16 -07:00
"text/plain": [
"<Figure size 576x432 with 1 Axes>"
2021-12-16 17:11:33 -08:00
]
2021-06-02 22:08:24 -04:00
},
"metadata": {
"needs_background": "light"
2021-12-16 17:11:33 -08:00
},
"output_type": "display_data"
2021-06-02 22:08:24 -04:00
}
],
2021-12-16 17:11:33 -08:00
"source": [
"plt.figure(figsize=(8, 6))\n",
"plot_progressive_loss(loss_list_vanilla, 'VanillaVW')\n",
"plot_progressive_loss(loss_list_autovw_ni, 'AutoVW:NI')\n",
"plot_progressive_loss(loss_list_autovw_nilr, 'AutoVW:NI+LR')\n",
"plt.show()"
]
2021-06-02 22:08:24 -04:00
},
{
"cell_type": "markdown",
2021-12-16 17:11:33 -08:00
"metadata": {},
2021-06-02 22:08:24 -04:00
"source": [
"### AutoVW based on customized VW arguments\n",
"You can easily create an AutoVW instance based on customized VW arguments (For now only arguments that are compatible with supervised regression task are well supported). The customized arguments can be passed to AutoVW through init_config and search space."
2021-12-16 17:11:33 -08:00
]
2021-06-02 22:08:24 -04:00
},
{
"cell_type": "code",
2021-07-05 21:17:26 -04:00
"execution_count": 11,
2021-12-16 17:11:33 -08:00
"metadata": {
"tags": []
},
2021-06-02 22:08:24 -04:00
"outputs": [
{
"name": "stderr",
2021-12-16 17:11:33 -08:00
"output_type": "stream",
2021-07-05 21:17:26 -04:00
"text": [
"Seed namespaces (singletons and interactions): ['g', 'a', 'h', 'b', 'c', 'i', 'd', 'e', 'f']\n",
"Created challengers from champion |supervised||classic|\n",
"New challenger size 37, ['|supervised|fg|classic|', '|supervised|dh|classic|', '|supervised|ef|classic|', '|supervised|ei|classic|', '|supervised|di|classic|', '|supervised|ch|classic|', '|supervised|bh|classic|', '|supervised|cf|classic|', '|supervised|ae|classic|', '|supervised|bc|classic|', '|supervised|ci|classic|', '|supervised|eg|classic|', '|supervised|ag|classic|', '|supervised|be|classic|', '|supervised|bd|classic|', '|supervised|ce|classic|', '|supervised|af|classic|', '|supervised|ad|classic|', '|supervised|ab|classic|', '|supervised|dg|classic|', '|supervised|gh|classic|', '|supervised|bg|classic|', '|supervised|fh|classic|', '|supervised|gi|classic|', '|supervised|cg|classic|', '|supervised|cd|classic|', '|supervised|ai|classic|', '|supervised|ac|classic|', '|supervised|bi|classic|', '|supervised|eh|classic|', '|supervised|fi|classic|', '|supervised|de|classic|', '|supervised|hi|classic|', '|supervised|bf|classic|', '|supervised|df|classic|', '|supervised|ah|classic|', '|supervised||classic|']\n",
"Online learning for 10000 steps...\n",
"Seed namespaces (singletons and interactions): ['df', 'g', 'a', 'h', 'b', 'c', 'i', 'd', 'e', 'f']\n",
"Created challengers from champion |supervised|df|classic|\n",
"New challenger size 43, ['|supervised|ce_df|classic|', '|supervised|df_gi|classic|', '|supervised|df_fi|classic|', '|supervised|bd_df|classic|', '|supervised|ab_df|classic|', '|supervised|bi_df|classic|', '|supervised|df_ei|classic|', '|supervised|bh_df|classic|', '|supervised|cd_df|classic|', '|supervised|df_dfg|classic|', '|supervised|def_df|classic|', '|supervised|bdf_df|classic|', '|supervised|ag_df|classic|', '|supervised|cg_df|classic|', '|supervised|df_dg|classic|', '|supervised|af_df|classic|', '|supervised|ci_df|classic|', '|supervised|df_dh|classic|', '|supervised|ah_df|classic|', '|supervised|df|classic|', '|supervised|df_di|classic|', '|supervised|ad_df|classic|', '|supervised|df_ef|classic|', '|supervised|ae_df|classic|', '|supervised|ai_df|classic|', '|supervised|be_df|classic|', '|supervised|df_eg|classic|', '|supervised|ch_df|classic|', '|supervised|ac_df|classic|', '|supervised|df_gh|classic|', '|supervised|df_fg|classic|', '|supervised|bc_df|classic|', '|supervised|df_dfh|classic|', '|supervised|df_fh|classic|', '|supervised|df_dfi|classic|', '|supervised|de_df|classic|', '|supervised|bf_df|classic|', '|supervised|bg_df|classic|', '|supervised|df_hi|classic|', '|supervised|cdf_df|classic|', '|supervised|df_eh|classic|', '|supervised|cf_df|classic|', '|supervised|adf_df|classic|']\n",
"Average final loss of the AutoVW (tuning namespaces) based on customized vw arguments: 8.828759490602918\n"
]
2021-06-02 22:08:24 -04:00
}
],
2021-12-16 17:11:33 -08:00
"source": [
"''' create an AutoVW instance with ustomized VW arguments'''\n",
"# parse the customized VW arguments\n",
"fixed_vw_hp_config = {'alg': 'supervised', 'loss_function': 'classic', 'quiet': ''}\n",
"search_space = fixed_vw_hp_config.copy()\n",
"search_space.update({'interactions': AutoVW.AUTOMATIC,})\n",
"\n",
"autovw_custom = AutoVW(max_live_model_num=5, search_space=search_space) \n",
"loss_list_custom = online_learning_loop(max_iter_num, vw_examples, autovw_custom)\n",
"print('Average final loss of the AutoVW (tuning namespaces) based on customized vw arguments:', sum(loss_list_custom)/len(loss_list_custom))\n"
]
2021-06-02 22:08:24 -04:00
},
{
"cell_type": "code",
"execution_count": null,
2021-12-16 17:11:33 -08:00
"metadata": {},
2021-06-02 22:08:24 -04:00
"outputs": [],
2021-12-16 17:11:33 -08:00
"source": []
2021-06-02 22:08:24 -04:00
}
],
"metadata": {
2021-12-16 17:11:33 -08:00
"interpreter": {
"hash": "4502d015faca2560a557f35a41b6dd402f7fdfc08e843ae17a9c41947939f10c"
},
2021-06-02 22:08:24 -04:00
"kernelspec": {
2021-12-16 17:11:33 -08:00
"display_name": "Python 3.8.10 64-bit ('py38': conda)",
"name": "python3"
2021-06-02 22:08:24 -04:00
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2021-07-05 21:17:26 -04:00
"version": "3.8.10"
2021-06-02 22:08:24 -04:00
}
},
"nbformat": 4,
"nbformat_minor": 2
2021-12-16 17:11:33 -08:00
}