2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								{
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "cells": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2021-02-22 22:10:41 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "This notebook uses flaml to finetune a transformer model from Huggingface transformers library.\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-13 10:43:11 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "**Requirements.** This notebook has additional requirements:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 1,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "tags": []
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
									
										
										
										
											2021-02-13 10:43:11 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2022-06-24 04:45:42 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "# %pip install torch transformers datasets ipywidgets flaml[blendsearch,ray]"
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## Tokenizer"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 1,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from transformers import AutoTokenizer"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 2,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "MODEL_CHECKPOINT = \"distilbert-base-uncased\""
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 3,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT, use_fast=True)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 4,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "data": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "{'input_ids': [101, 2023, 2003, 1037, 3231, 102], 'attention_mask': [1, 1, 1, 1, 1, 1]}"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     },
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "execution_count": 4,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     "metadata": {},
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "output_type": "execute_result"
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "tokenizer(\"this is a test\")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## Data"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 5,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "TASK = \"cola\""
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 6,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import datasets"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 7,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "name": "stderr",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "Reusing dataset glue (/home/ec2-user/.cache/huggingface/datasets/glue/cola/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4)\n"
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "raw_dataset = datasets.load_dataset(\"glue\", TASK)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 8,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# define tokenization function used to process data\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "COLUMN_NAME = \"sentence\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def tokenize(examples):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    return tokenizer(examples[COLUMN_NAME], truncation=True)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 9,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "data": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "application/vnd.jupyter.widget-view+json": {
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								       "model_id": "0dcf9ca8ce024a2b832606a6a3219b17",
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								       "version_major": 2,
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								       "version_minor": 0
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     },
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "display_data"
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "data": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "application/vnd.jupyter.widget-view+json": {
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								       "model_id": "c58845729f0a4261830ad679891e7c77",
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								       "version_major": 2,
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								       "version_minor": 0
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     },
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "display_data"
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "data": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "application/vnd.jupyter.widget-view+json": {
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								       "model_id": "9716d177a40748008cc6089e3d52a1d5",
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								       "version_major": 2,
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								       "version_minor": 0
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     },
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "display_data"
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\n"
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "encoded_dataset = raw_dataset.map(tokenize, batched=True)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 10,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "data": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "{'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       " 'idx': 0,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       " 'input_ids': [101,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "  2256,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "  2814,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "  2180,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "  1005,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "  1056,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "  4965,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "  2023,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "  4106,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "  1010,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "  2292,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "  2894,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "  1996,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "  2279,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "  2028,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "  2057,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "  16599,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "  1012,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "  102],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       " 'label': 1,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       " 'sentence': \"Our friends won't buy this analysis, let alone the next one we propose.\"}"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     },
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "execution_count": 10,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     "metadata": {},
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "output_type": "execute_result"
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "encoded_dataset[\"train\"][0]"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## Model"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 11,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from transformers import AutoModelForSequenceClassification"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 12,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stderr",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classifier.weight', 'classifier.bias']\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "NUM_LABELS = 2\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model = AutoModelForSequenceClassification.from_pretrained(MODEL_CHECKPOINT, num_labels=NUM_LABELS)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 13,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "data": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "DistilBertForSequenceClassification(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "  (distilbert): DistilBertModel(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    (embeddings): Embeddings(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "      (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "      (position_embeddings): Embedding(512, 768)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "      (dropout): Dropout(p=0.1, inplace=False)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    (transformer): Transformer(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "      (layer): ModuleList(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "        (0): TransformerBlock(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          (attention): MultiHeadSelfAttention(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (dropout): Dropout(p=0.1, inplace=False)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          (ffn): FFN(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (dropout): Dropout(p=0.1, inplace=False)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "        )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "        (1): TransformerBlock(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          (attention): MultiHeadSelfAttention(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (dropout): Dropout(p=0.1, inplace=False)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          (ffn): FFN(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (dropout): Dropout(p=0.1, inplace=False)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "        )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "        (2): TransformerBlock(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          (attention): MultiHeadSelfAttention(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (dropout): Dropout(p=0.1, inplace=False)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          (ffn): FFN(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (dropout): Dropout(p=0.1, inplace=False)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "        )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "        (3): TransformerBlock(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          (attention): MultiHeadSelfAttention(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (dropout): Dropout(p=0.1, inplace=False)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          (ffn): FFN(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (dropout): Dropout(p=0.1, inplace=False)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "        )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "        (4): TransformerBlock(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          (attention): MultiHeadSelfAttention(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (dropout): Dropout(p=0.1, inplace=False)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          (ffn): FFN(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (dropout): Dropout(p=0.1, inplace=False)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "        )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "        (5): TransformerBlock(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          (attention): MultiHeadSelfAttention(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (dropout): Dropout(p=0.1, inplace=False)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          (ffn): FFN(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (dropout): Dropout(p=0.1, inplace=False)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "          (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "        )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "      )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "  )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "  (pre_classifier): Linear(in_features=768, out_features=768, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "  (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "  (dropout): Dropout(p=0.2, inplace=False)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       ")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     },
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "execution_count": 13,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     "metadata": {},
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "output_type": "execute_result"
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## Metric"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 14,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "metric = datasets.load_metric(\"glue\", TASK)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 15,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "data": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "Metric(name: \"glue\", features: {'predictions': Value(dtype='int64', id=None), 'references': Value(dtype='int64', id=None)}, usage: \"\"\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "Compute GLUE evaluation metric associated to each GLUE dataset.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "Args:\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								       "    predictions: list of predictions to score.\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								       "        Each translation should be tokenized into a list of tokens.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    references: list of lists of references for each translation.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "        Each reference should be tokenized into a list of tokens.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "Returns: depending on the GLUE subset, one or several of:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    \"accuracy\": Accuracy\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								       "    \"f1\": F1 score\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								       "    \"pearson\": Pearson Correlation\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    \"spearmanr\": Spearman Correlation\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    \"matthews_correlation\": Matthew Correlation\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								       "Examples:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    >>> glue_metric = datasets.load_metric('glue', 'sst2')  # 'sst2' or any of [\"mnli\", \"mnli_mismatched\", \"mnli_matched\", \"qnli\", \"rte\", \"wnli\", \"hans\"]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    >>> references = [0, 1]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    >>> predictions = [0, 1]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    >>> results = glue_metric.compute(predictions=predictions, references=references)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    >>> print(results)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    {'accuracy': 1.0}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    >>> glue_metric = datasets.load_metric('glue', 'mrpc')  # 'mrpc' or 'qqp'\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    >>> references = [0, 1]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    >>> predictions = [0, 1]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    >>> results = glue_metric.compute(predictions=predictions, references=references)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    >>> print(results)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    {'accuracy': 1.0, 'f1': 1.0}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    >>> glue_metric = datasets.load_metric('glue', 'stsb')\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    >>> references = [0., 1., 2., 3., 4., 5.]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    >>> predictions = [0., 1., 2., 3., 4., 5.]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    >>> results = glue_metric.compute(predictions=predictions, references=references)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    >>> print({\"pearson\": round(results[\"pearson\"], 2), \"spearmanr\": round(results[\"spearmanr\"], 2)})\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    {'pearson': 1.0, 'spearmanr': 1.0}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    >>> glue_metric = datasets.load_metric('glue', 'cola')\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    >>> references = [0, 1]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    >>> predictions = [0, 1]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    >>> results = glue_metric.compute(predictions=predictions, references=references)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    >>> print(results)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    {'matthews_correlation': 1.0}\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								       "\"\"\", stored examples: 0)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     },
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "execution_count": 15,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     "metadata": {},
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "output_type": "execute_result"
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "metric"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 16,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "import numpy as np\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "def compute_metrics(eval_pred):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    predictions, labels = eval_pred\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    predictions = np.argmax(predictions, axis=1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    return metric.compute(predictions=predictions, references=labels)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## Training (aka Finetuning)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 17,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from transformers import Trainer\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from transformers import TrainingArguments"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 18,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "args = TrainingArguments(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    output_dir='output',\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    do_eval=True,\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-04-08 09:29:55 -07:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    ")"
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 19,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "trainer = Trainer(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model=model,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    args=args,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    train_dataset=encoded_dataset[\"train\"],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    eval_dataset=encoded_dataset[\"validation\"],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    tokenizer=tokenizer,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    compute_metrics=compute_metrics,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 20,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "To disable this warning, you can either:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\t- Avoid using `tokenizers` before the fork if possible\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "To disable this warning, you can either:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\t- Avoid using `tokenizers` before the fork if possible\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "data": {
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "text/html": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    <div>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "        <style>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            /* Turns off some styling */\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            progress {\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "                /* gets rid of default border in Firefox and Opera. */\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "                border: none;\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "                background-size: auto;\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "            }\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "        </style>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "      \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "      <progress value='1591' max='3207' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "      [1591/3207 1:03:06 < 1:04:11, 0.42 it/s, Epoch 1.49/3]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    </div>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    <table border=\"1\" class=\"dataframe\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "  <thead>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    <tr style=\"text-align: left;\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "      <th>Step</th>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "      <th>Training Loss</th>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    </tr>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "  </thead>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "  <tbody>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    <tr>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "      <td>500</td>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "      <td>0.571000</td>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    </tr>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    <tr>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "      <td>1000</td>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "      <td>0.515400</td>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    </tr>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    <tr>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "      <td>1500</td>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "      <td>0.356100</td>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "    </tr>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "  </tbody>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "</table><p>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ],
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "text/plain": [
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								       "<IPython.core.display.HTML object>"
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "metadata": {},
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "output_type": "display_data"
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "trainer.train()"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## Hyperparameter Optimization\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "`flaml.tune` is a module for economical hyperparameter tuning. It frees users from manually tuning many hyperparameters for a software, such as machine learning training procedures. \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "The API is compatible with ray tune.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "### Step 1. Define training method\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "We define a function `train_distilbert(config: dict)` that accepts a hyperparameter configuration dict `config`. The specific configs will be generated by flaml's search algorithm in a given search space.\n"
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import flaml\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def train_distilbert(config: dict):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # Load CoLA dataset and apply tokenizer\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    cola_raw = datasets.load_dataset(\"glue\", TASK)\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    cola_encoded = cola_raw.map(tokenize, batched=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    train_dataset, eval_dataset = cola_encoded[\"train\"], cola_encoded[\"validation\"]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model = AutoModelForSequenceClassification.from_pretrained(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        MODEL_CHECKPOINT, num_labels=NUM_LABELS\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    metric = datasets.load_metric(\"glue\", TASK)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def compute_metrics(eval_pred):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        predictions, labels = eval_pred\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        predictions = np.argmax(predictions, axis=1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        return metric.compute(predictions=predictions, references=labels)\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    training_args = TrainingArguments(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        output_dir='.',\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        do_eval=False,\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        disable_tqdm=True,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        logging_steps=20000,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        save_total_limit=0,\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "        **config,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    trainer = Trainer(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        model,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        training_args,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        train_dataset=train_dataset,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        eval_dataset=eval_dataset,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        tokenizer=tokenizer,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        compute_metrics=compute_metrics,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # train model\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    trainer.train()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # evaluate model\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    eval_output = trainer.evaluate()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # report the metric to optimize\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    flaml.tune.report(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        loss=eval_output[\"eval_loss\"],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        matthews_correlation=eval_output[\"eval_matthews_correlation\"],\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-04-08 09:29:55 -07:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    )"
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "### Step 2. Define the search\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "We are now ready to define our search. This includes:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The `search_space` for our hyperparameters\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The metric and the mode ('max' or 'min') for optimization\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The constraints (`n_cpus`, `n_gpus`, `num_samples`, and `time_budget_s`)"
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2021-02-22 22:10:41 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "max_num_epoch = 64\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "search_space = {\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # You can mix constants with search space objects.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        \"num_train_epochs\": flaml.tune.loguniform(1, max_num_epoch),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        \"learning_rate\": flaml.tune.loguniform(1e-6, 1e-4),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        \"adam_epsilon\": flaml.tune.loguniform(1e-9, 1e-7),\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-22 22:10:41 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        \"adam_beta1\": flaml.tune.uniform(0.8, 0.99),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        \"adam_beta2\": flaml.tune.loguniform(98e-2, 9999e-4),\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-04-08 09:29:55 -07:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "}"
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# optimization objective\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "HP_METRIC, MODE = \"matthews_correlation\", \"max\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# resources\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-22 22:10:41 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "num_cpus = 4\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "num_gpus = 4\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# constraints\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "num_samples = -1    # number of trials, -1 means unlimited\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "time_budget_s = 3600    # time budget in seconds"
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "### Step 3. Launch with `flaml.tune.run`\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-13 10:43:11 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "We are now ready to launch the tuning using `flaml.tune.run`:"
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "To disable this warning, you can either:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\t- Avoid using `tokenizers` before the fork if possible\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "To disable this warning, you can either:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\t- Avoid using `tokenizers` before the fork if possible\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "To disable this warning, you can either:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\t- Avoid using `tokenizers` before the fork if possible\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "To disable this warning, you can either:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\t- Avoid using `tokenizers` before the fork if possible\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "To disable this warning, you can either:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\t- Avoid using `tokenizers` before the fork if possible\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "To disable this warning, you can either:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\t- Avoid using `tokenizers` before the fork if possible\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "To disable this warning, you can either:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\t- Avoid using `tokenizers` before the fork if possible\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "name": "stderr",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "/home/ec2-user/miniconda3/envs/myflaml/lib/python3.8/site-packages/ray/_private/services.py:238: UserWarning: Not all Ray Dashboard dependencies were found. To use the dashboard please install Ray using `pip install ray[default]`. To disable this message, set RAY_DISABLE_IMPORT_WARNING env var to '1'.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "  warnings.warn(warning_message)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "2021-12-01 23:35:54,348\tWARNING function_runner.py:558 -- Function checkpointing is disabled. This may result in unexpected behavior when using checkpointing features or certain schedulers. To enable, set the train function arguments to be `func(config, checkpoint_dir=None)`.\n"
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "Tuning started...\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "To disable this warning, you can either:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\t- Avoid using `tokenizers` before the fork if possible\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "data": {
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "text/html": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "== Status ==<br>Memory usage on this node: 4.3/7.7 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 4.0/4 CPUs, 4.0/4 GPUs, 0.0/2.34 GiB heap, 0.0/1.17 GiB objects<br>Result logdir: /home/ec2-user/FLAML/notebook/logs/train_distilbert_2021-12-01_23-35-54<br>Number of trials: 1/infinite (1 RUNNING)<br><br>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "<IPython.core.display.HTML object>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     },
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "display_data"
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "data": {
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "text/html": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "== Status ==<br>Memory usage on this node: 4.5/7.7 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 4.0/4 CPUs, 4.0/4 GPUs, 0.0/2.34 GiB heap, 0.0/1.17 GiB objects<br>Result logdir: /home/ec2-user/FLAML/notebook/logs/train_distilbert_2021-12-01_23-35-54<br>Number of trials: 2/infinite (1 PENDING, 1 RUNNING)<br><br>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "<IPython.core.display.HTML object>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     },
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "display_data"
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "data": {
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "text/html": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "== Status ==<br>Memory usage on this node: 4.6/7.7 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 4.0/4 CPUs, 4.0/4 GPUs, 0.0/2.34 GiB heap, 0.0/1.17 GiB objects<br>Result logdir: /home/ec2-user/FLAML/notebook/logs/train_distilbert_2021-12-01_23-35-54<br>Number of trials: 2/infinite (1 PENDING, 1 RUNNING)<br><br>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "<IPython.core.display.HTML object>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     },
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "display_data"
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "name": "stderr",
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "\u001b[2m\u001b[36m(pid=11344)\u001b[0m Reusing dataset glue (/home/ec2-user/.cache/huggingface/datasets/glue/cola/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4)\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "  0%|          | 0/9 [00:00<?, ?ba/s]\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      " 22%|██▏       | 2/9 [00:00<00:00, 19.41ba/s]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      " 56%|█████▌    | 5/9 [00:00<00:00, 20.98ba/s]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      " 89%|████████▉ | 8/9 [00:00<00:00, 21.75ba/s]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "100%|██████████| 9/9 [00:00<00:00, 24.49ba/s]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "100%|██████████| 2/2 [00:00<00:00, 42.79ba/s]\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "  0%|          | 0/2 [00:00<?, ?ba/s]\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "100%|██████████| 2/2 [00:00<00:00, 41.48ba/s]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\u001b[2m\u001b[36m(pid=11344)\u001b[0m Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\u001b[2m\u001b[36m(pid=11344)\u001b[0m - This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\u001b[2m\u001b[36m(pid=11344)\u001b[0m - This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\u001b[2m\u001b[36m(pid=11344)\u001b[0m Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classifier.weight', 'classifier.bias']\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\u001b[2m\u001b[36m(pid=11344)\u001b[0m You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "\u001b[2m\u001b[36m(pid=11344)\u001b[0m huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\u001b[2m\u001b[36m(pid=11344)\u001b[0m To disable this warning, you can either:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\u001b[2m\u001b[36m(pid=11344)\u001b[0m \t- Avoid using `tokenizers` before the fork if possible\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\u001b[2m\u001b[36m(pid=11344)\u001b[0m \t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\u001b[2m\u001b[36m(pid=11344)\u001b[0m huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\u001b[2m\u001b[36m(pid=11344)\u001b[0m To disable this warning, you can either:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\u001b[2m\u001b[36m(pid=11344)\u001b[0m \t- Avoid using `tokenizers` before the fork if possible\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\u001b[2m\u001b[36m(pid=11344)\u001b[0m \t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import time\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import ray\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "start_time = time.time()\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "ray.shutdown()\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "ray.init(num_cpus=num_cpus, num_gpus=num_gpus)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Tuning started...\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "analysis = flaml.tune.run(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    train_distilbert,\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    search_alg=flaml.CFO(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        space=search_space,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        metric=HP_METRIC,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        mode=MODE,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        low_cost_partial_config={\"num_train_epochs\": 1}),\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-06 17:03:43 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    # uncomment the following if scheduler = 'asha',\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    # max_resource=max_num_epoch, min_resource=1,\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    resources_per_trial={\"gpu\": num_gpus, \"cpu\": num_cpus},\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    local_dir='logs/',\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    num_samples=num_samples,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    time_budget_s=time_budget_s,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    use_ray=True,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "ray.shutdown()"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "n_trials=22\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "time=3999.769361972809\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Best model eval matthews_correlation: 0.5699\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Best model parameters: {'num_train_epochs': 15.580684188655825, 'learning_rate': 1.2851507818900338e-05, 'adam_epsilon': 8.134982521948352e-08, 'adam_beta1': 0.99, 'adam_beta2': 0.9971094424784387}\n"
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "best_trial = analysis.get_best_trial(HP_METRIC, MODE, \"all\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "metric = best_trial.metric_analysis[HP_METRIC][MODE]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(f\"n_trials={len(analysis.trials)}\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(f\"time={time.time()-start_time}\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(f\"Best model eval {HP_METRIC}: {metric:.4f}\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(f\"Best model parameters: {best_trial.config}\")\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## Next Steps\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Notice that we only reported the metric with `flaml.tune.report` at the end of full training loop. It is possible to enable reporting of intermediate performance - allowing early stopping - as follows:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Huggingface provides _Callbacks_ which can be used to insert the `flaml.tune.report` call inside the training loop\n",
							 
						 
					
						
							
								
									
										
										
										
											2021-02-28 12:43:43 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- Make sure to set `do_eval=True` in the `TrainingArguments` provided to `Trainer` and adjust the evaluation frequency accordingly"
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "metadata": {
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  "interpreter": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "hash": "1cfcceddaeccda27c3cce104660d474924e2ba82887c0e8e481b6ede3743c483"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  "kernelspec": {
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "display_name": "Python 3.8.5 64-bit",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "language": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python3"
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "language_info": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "codemirror_mode": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "name": "ipython",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "version": 3
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "file_extension": ".py",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "mimetype": "text/x-python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "nbconvert_exporter": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "pygments_lexer": "ipython3",
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "version": "3.8.12"
							 
						 
					
						
							
								
									
										
										
										
											2021-05-08 02:50:50 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "interpreter": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   }
							 
						 
					
						
							
								
									
										
										
										
											2021-02-06 16:24:38 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat": 4,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat_minor": 4
							 
						 
					
						
							
								
									
										
										
										
											2021-12-04 21:52:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								}