434 lines
17 KiB
Python
Raw Normal View History

Add ChaCha (#92) * pickle the AutoML object * get best model per estimator * test deberta * stateless API * pickle the AutoML object * get best model per estimator * test deberta * stateless API * prevent divide by zero * test roberta * BlendSearchTuner * sync * version number * update gitignore * delta time * reindex columns when dropping int-indexed columns * add seed * add seed in Args * merge * init upload of ChaCha * remove redundancy * add back catboost * improve AutoVW API * set min_resource_lease in VWOnlineTrial * docstr * rename * docstr * add docstr * improve API and documentation * fix name * docstr * naming * remove max_resource in scheduler * add TODO in flow2 * remove redundancy in rearcher * add input type * adapt code from ray.tune * move files * naming * documentation * fix import error * fix format issues * remove cb in worse than test * improve _generate_all_comb * remove ray tune * naming * VowpalWabbitTrial * import error * import error * merge test code * scheduler import * fix import * remove * import, minor bug and version * Float or Categorical * fix default * add test_autovw.py * add vowpalwabbit and openml * lint * reorg * lint * indent * add autovw notebook * update notebook * update log msg and autovw notebook * update autovw notebook * update autovw notebook * add available strings for model_select_policy * string for metric * Update vw format in flaml/onlineml/trial.py Co-authored-by: olgavrou <olgavrou@gmail.com> * make init_config optional * add _setup_trial_runner and update notebook * space Co-authored-by: Chi Wang (MSR) <chiw@microsoft.com> Co-authored-by: Chi Wang <wang.chi@microsoft.com> Co-authored-by: Qingyun Wu <qiw@microsoft.com> Co-authored-by: olgavrou <olgavrou@gmail.com>
2021-06-02 22:08:24 -04:00
import numpy as np
import logging
import time
import math
import copy
import collections
from typing import Optional
Add ChaCha (#92) * pickle the AutoML object * get best model per estimator * test deberta * stateless API * pickle the AutoML object * get best model per estimator * test deberta * stateless API * prevent divide by zero * test roberta * BlendSearchTuner * sync * version number * update gitignore * delta time * reindex columns when dropping int-indexed columns * add seed * add seed in Args * merge * init upload of ChaCha * remove redundancy * add back catboost * improve AutoVW API * set min_resource_lease in VWOnlineTrial * docstr * rename * docstr * add docstr * improve API and documentation * fix name * docstr * naming * remove max_resource in scheduler * add TODO in flow2 * remove redundancy in rearcher * add input type * adapt code from ray.tune * move files * naming * documentation * fix import error * fix format issues * remove cb in worse than test * improve _generate_all_comb * remove ray tune * naming * VowpalWabbitTrial * import error * import error * merge test code * scheduler import * fix import * remove * import, minor bug and version * Float or Categorical * fix default * add test_autovw.py * add vowpalwabbit and openml * lint * reorg * lint * indent * add autovw notebook * update notebook * update log msg and autovw notebook * update autovw notebook * update autovw notebook * add available strings for model_select_policy * string for metric * Update vw format in flaml/onlineml/trial.py Co-authored-by: olgavrou <olgavrou@gmail.com> * make init_config optional * add _setup_trial_runner and update notebook * space Co-authored-by: Chi Wang (MSR) <chiw@microsoft.com> Co-authored-by: Chi Wang <wang.chi@microsoft.com> Co-authored-by: Qingyun Wu <qiw@microsoft.com> Co-authored-by: olgavrou <olgavrou@gmail.com>
2021-06-02 22:08:24 -04:00
from sklearn.metrics import mean_squared_error, mean_absolute_error
from flaml.tune import Trial
Add ChaCha (#92) * pickle the AutoML object * get best model per estimator * test deberta * stateless API * pickle the AutoML object * get best model per estimator * test deberta * stateless API * prevent divide by zero * test roberta * BlendSearchTuner * sync * version number * update gitignore * delta time * reindex columns when dropping int-indexed columns * add seed * add seed in Args * merge * init upload of ChaCha * remove redundancy * add back catboost * improve AutoVW API * set min_resource_lease in VWOnlineTrial * docstr * rename * docstr * add docstr * improve API and documentation * fix name * docstr * naming * remove max_resource in scheduler * add TODO in flow2 * remove redundancy in rearcher * add input type * adapt code from ray.tune * move files * naming * documentation * fix import error * fix format issues * remove cb in worse than test * improve _generate_all_comb * remove ray tune * naming * VowpalWabbitTrial * import error * import error * merge test code * scheduler import * fix import * remove * import, minor bug and version * Float or Categorical * fix default * add test_autovw.py * add vowpalwabbit and openml * lint * reorg * lint * indent * add autovw notebook * update notebook * update log msg and autovw notebook * update autovw notebook * update autovw notebook * add available strings for model_select_policy * string for metric * Update vw format in flaml/onlineml/trial.py Co-authored-by: olgavrou <olgavrou@gmail.com> * make init_config optional * add _setup_trial_runner and update notebook * space Co-authored-by: Chi Wang (MSR) <chiw@microsoft.com> Co-authored-by: Chi Wang <wang.chi@microsoft.com> Co-authored-by: Qingyun Wu <qiw@microsoft.com> Co-authored-by: olgavrou <olgavrou@gmail.com>
2021-06-02 22:08:24 -04:00
logger = logging.getLogger(__name__)
def get_ns_feature_dim_from_vw_example(vw_example) -> dict:
"""Get a dictionary of feature dimensionality for each namespace singleton
NOTE:
Assumption: assume the vw_example takes one of the following format
depending on whether the example includes the feature names
format 1: 'y |ns1 feature1:feature_value1 feature2:feature_value2 |ns2
ns2 feature3:feature_value3 feature4:feature_value4'
format 2: 'y | ns1 feature_value1 feature_value2 |
ns2 feature_value3 feature_value4'
The output of both cases are {'ns1': 2, 'ns2': 2}
For more information about the input formate of vw example, please refer to
https://github.com/VowpalWabbit/vowpal_wabbit/wiki/Input-format
"""
ns_feature_dim = {}
data = vw_example.split('|')
for i in range(1, len(data)):
if ':' in data[i]:
ns_w_feature = data[i].split(' ')
ns = ns_w_feature[0]
feature = ns_w_feature[1:]
feature_dim = len(feature)
else:
data_split = data[i].split(' ')
ns = data_split[0]
feature_dim = len(data_split) - 1
if len(data_split[-1]) == 0:
feature_dim -= 1
ns_feature_dim[ns] = feature_dim
logger.debug('name space feature dimension %s', ns_feature_dim)
return ns_feature_dim
class OnlineResult:
"""Class for managing the result statistics of a trial
Attributes:
observation_count: the total number of observations
resource_used: the sum of loss
Methods:
update_result(new_loss, new_resource_used, data_dimension)
Update result
get_score(score_name)
Get the score according to the input score_name
"""
prob_delta = 0.1
LOSS_MIN = 0.0
LOSS_MAX = np.inf
CB_COEF = 0.05 # 0.001 for mse
def __init__(self, result_type_name: str, cb_coef: Optional[float] = None,
init_loss: Optional[float] = 0.0, init_cb: Optional[float] = 100.0,
mode: Optional[str] = 'min', sliding_window_size: Optional[int] = 100):
"""
Args:
result_type_name (str): The name of the result type
"""
self._result_type_name = result_type_name # for example 'mse' or 'mae'
self._mode = mode
self._init_loss = init_loss
# statistics needed for alg
self.observation_count = 0
self.resource_used = 0.0
self._loss_avg = 0.0
self._loss_cb = init_cb # a large number (TODO: this can be changed)
self._cb_coef = cb_coef if cb_coef is not None else self.CB_COEF
# optional statistics
self._sliding_window_size = sliding_window_size
self._loss_queue = collections.deque(maxlen=self._sliding_window_size)
def update_result(self, new_loss, new_resource_used, data_dimension,
bound_of_range=1.0, new_observation_count=1.0):
"""Update result statistics
"""
self.resource_used += new_resource_used
# keep the running average instead of sum of loss to avoid over overflow
self._loss_avg = self._loss_avg * (self.observation_count / (self.observation_count + new_observation_count)
) + new_loss / (self.observation_count + new_observation_count)
self.observation_count += new_observation_count
self._loss_cb = self._update_loss_cb(bound_of_range, data_dimension)
self._loss_queue.append(new_loss)
def _update_loss_cb(self, bound_of_range, data_dim,
bound_name='sample_complexity_bound'):
"""Calculate bound coef
"""
if bound_name == 'sample_complexity_bound':
# set the coefficient in the loss bound
if 'mae' in self.result_type_name:
coef = self._cb_coef * bound_of_range
else:
coef = 0.001 * bound_of_range
comp_F = math.sqrt(data_dim)
n = self.observation_count
return coef * comp_F * math.sqrt((np.log10(n / OnlineResult.prob_delta)) / n)
else:
raise NotImplementedError
@property
def result_type_name(self):
return self._result_type_name
@property
def loss_avg(self):
return self._loss_avg if \
self.observation_count != 0 else self._init_loss
@property
def loss_cb(self):
return self._loss_cb
@property
def loss_lcb(self):
return max(self._loss_avg - self._loss_cb, OnlineResult.LOSS_MIN)
@property
def loss_ucb(self):
return min(self._loss_avg + self._loss_cb, OnlineResult.LOSS_MAX)
@property
def loss_avg_recent(self):
return sum(self._loss_queue) / len(self._loss_queue) \
if len(self._loss_queue) != 0 else self._init_loss
def get_score(self, score_name, cb_ratio=1):
if 'lcb' in score_name:
return max(self._loss_avg - cb_ratio * self._loss_cb, OnlineResult.LOSS_MIN)
elif 'ucb' in score_name:
return min(self._loss_avg + cb_ratio * self._loss_cb, OnlineResult.LOSS_MAX)
elif 'avg' in score_name:
return self._loss_avg
else:
raise NotImplementedError
class BaseOnlineTrial(Trial):
"""Class for online trial.
Attributes
config: the config for this trial
trial_id: the trial_id of this trial
min_resource_lease (float): the minimum resource realse
status: the status of this trial
start_time: the start time of this trial
custom_trial_name: a custom name for this trial
Methods:
set_resource_lease(resource)
set_status(status)
set_checked_under_current_champion(checked_under_current_champion)
"""
def __init__(self,
config: dict,
min_resource_lease: float,
is_champion: Optional[bool] = False,
is_checked_under_current_champion: Optional[bool] = True,
custom_trial_name: Optional[str] = 'mae',
trial_id: Optional[str] = None,
):
"""
Args:
config: the config dict
min_resource_lease: the minimum resource realse
is_champion: a bool variable
is_checked_under_current_champion: a bool variable
custom_trial_name: custom trial name
trial_id: the trial id
"""
# ****basic variables
self.config = config
self.trial_id = trial_id
self.status = Trial.PENDING
self.start_time = time.time()
self.custom_trial_name = custom_trial_name
# ***resource budget related variable
self._min_resource_lease = min_resource_lease
self._resource_lease = copy.copy(self._min_resource_lease)
# ***champion related variables
self._is_champion = is_champion
# self._is_checked_under_current_champion_ is supposed to be always 1 when the trial is first created
self._is_checked_under_current_champion = is_checked_under_current_champion
@property
def is_champion(self):
return self._is_champion
@property
def is_checked_under_current_champion(self):
return self._is_checked_under_current_champion
@property
def resource_lease(self):
return self._resource_lease
def set_checked_under_current_champion(self, checked_under_current_champion: bool):
"""TODO: add documentation why this is needed. This is needed because sometimes
we want to know whether a trial has been paused since a new champion is promoted.
We want to try to pause those running trials (even though they are not yet achieve
the next scheduling check point according to resource used and resource lease),
because a better trial is likely to be in the new challengers generated by the new
champion, so we want to try them as soon as possible.
If we wait until we reach the next scheduling point, we may waste a lot of resource
(depending on what is the current resource lease) on the old trials (note that new
trials is not possible to be scheduled to run until there is a slot openning).
Intuitively speaking, we want to squize an opening slot as soon as possible once
a new champion is promoted, such that we are able to try newly generated challengers.
"""
self._is_checked_under_current_champion = checked_under_current_champion
def set_resource_lease(self, resource: float):
self._resource_lease = resource
def set_status(self, status):
"""Sets the status of the trial and record the start time
"""
self.status = status
if status == Trial.RUNNING:
if self.start_time is None:
self.start_time = time.time()
class VowpalWabbitTrial(BaseOnlineTrial):
"""Implement BaseOnlineTrial for Vowpal Wabbit
Attributes:
model: the online model
result: the anytime result for the online model
trainable_class: the model class (set as pyvw.vw for VowpalWabbitTrial)
config: the config for this trial
trial_id: the trial_id of this trial
min_resource_lease (float): the minimum resource realse
status: the status of this trial
start_time: the start time of this trial
custom_trial_name: a custom name for this trial
Methods:
set_resource_lease(resource)
set_status(status)
set_checked_under_current_champion(checked_under_current_champion)
NOTE:
About result:
1. training related results (need to be updated in the trainable class)
2. result about resources lease (need to be updated externally)
About namespaces in vw:
- Wiki in vw:
https://github.com/VowpalWabbit/vowpal_wabbit/wiki/Namespaces
- Namespace vs features:
https://stackoverflow.com/questions/28586225/in-vowpal-wabbit-what-is-the-difference-between-a-namespace-and-feature
"""
from vowpalwabbit import pyvw
Add ChaCha (#92) * pickle the AutoML object * get best model per estimator * test deberta * stateless API * pickle the AutoML object * get best model per estimator * test deberta * stateless API * prevent divide by zero * test roberta * BlendSearchTuner * sync * version number * update gitignore * delta time * reindex columns when dropping int-indexed columns * add seed * add seed in Args * merge * init upload of ChaCha * remove redundancy * add back catboost * improve AutoVW API * set min_resource_lease in VWOnlineTrial * docstr * rename * docstr * add docstr * improve API and documentation * fix name * docstr * naming * remove max_resource in scheduler * add TODO in flow2 * remove redundancy in rearcher * add input type * adapt code from ray.tune * move files * naming * documentation * fix import error * fix format issues * remove cb in worse than test * improve _generate_all_comb * remove ray tune * naming * VowpalWabbitTrial * import error * import error * merge test code * scheduler import * fix import * remove * import, minor bug and version * Float or Categorical * fix default * add test_autovw.py * add vowpalwabbit and openml * lint * reorg * lint * indent * add autovw notebook * update notebook * update log msg and autovw notebook * update autovw notebook * update autovw notebook * add available strings for model_select_policy * string for metric * Update vw format in flaml/onlineml/trial.py Co-authored-by: olgavrou <olgavrou@gmail.com> * make init_config optional * add _setup_trial_runner and update notebook * space Co-authored-by: Chi Wang (MSR) <chiw@microsoft.com> Co-authored-by: Chi Wang <wang.chi@microsoft.com> Co-authored-by: Qingyun Wu <qiw@microsoft.com> Co-authored-by: olgavrou <olgavrou@gmail.com>
2021-06-02 22:08:24 -04:00
MODEL_CLASS = pyvw.vw
cost_unit = 1.0
interactions_config_key = 'interactions'
MIN_RES_CONST = 5
def __init__(self,
config: dict,
min_resource_lease: float,
metric: str = 'mae',
is_champion: Optional[bool] = False,
is_checked_under_current_champion: Optional[bool] = True,
custom_trial_name: Optional[str] = 'vw_mae_clipped',
trial_id: Optional[str] = None,
cb_coef: Optional[float] = None,
):
"""Constructor
Args:
config (dict): the config of the trial (note that the config is a set
because the hyperparameters are )
min_resource_lease (float): the minimum resource lease
metric (str): the loss metric
is_champion (bool): indicates whether the trial is the current champion or not
is_checked_under_current_champion (bool): indicates whether this trials has
been paused under the current champion
trial_id (str): id of the trial (if None, it will be generated in the constructor)
"""
# attributes
self.trial_id = self._config_to_id(config) if trial_id is None else trial_id
logger.info('Create trial with trial_id: %s', self.trial_id)
super().__init__(config, min_resource_lease, is_champion, is_checked_under_current_champion,
custom_trial_name, self.trial_id)
self.model = None # model is None until the config is scheduled to run
self.result = None
self.trainable_class = self.MODEL_CLASS
# variables that are needed during online training
self._metric = metric
self._y_min_observed = None
self._y_max_observed = None
# application dependent variables
self._dim = None
self._cb_coef = cb_coef
@staticmethod
def _config_to_id(config):
"""Generate an id for the provided config
"""
# sort config keys
sorted_k_list = sorted(list(config.keys()))
config_id_full = ''
for key in sorted_k_list:
v = config[key]
config_id = '|'
if isinstance(v, set):
value_list = sorted(v)
config_id += '_'.join([str(k) for k in value_list])
else:
config_id += str(v)
config_id_full = config_id_full + config_id
return config_id_full
def _initialize_vw_model(self, vw_example):
"""Initialize a vw model using the trainable_class
"""
self._vw_config = self.config.copy()
ns_interactions = self.config.get(VowpalWabbitTrial.interactions_config_key, None)
# ensure the feature interaction config is a list (required by VW)
if ns_interactions is not None:
self._vw_config[VowpalWabbitTrial.interactions_config_key] \
= list(ns_interactions)
# get the dimensionality of the feature according to the namespace configuration
namespace_feature_dim = get_ns_feature_dim_from_vw_example(vw_example)
self._dim = self._get_dim_from_ns(namespace_feature_dim, ns_interactions)
# construct an instance of vw model using the input config and fixed config
self.model = self.trainable_class(**self._vw_config)
self.result = OnlineResult(self._metric,
cb_coef=self._cb_coef,
init_loss=0.0, init_cb=100.0,)
def train_eval_model_online(self, data_sample, y_pred):
"""Train and eval model online
"""
# extract info needed the first time we see the data
if self._resource_lease == 'auto' or self._resource_lease is None:
assert self._dim is not None
self._resource_lease = self._dim * self.MIN_RES_CONST
y = self._get_y_from_vw_example(data_sample)
self._update_y_range(y)
if self.model is None:
# initialize self.model and self.result
self._initialize_vw_model(data_sample)
# do one step of learning
self.model.learn(data_sample)
# update training related results accordingly
new_loss = self._get_loss(y, y_pred, self._metric,
self._y_min_observed, self._y_max_observed)
# udpate sample size, sum of loss, and cost
data_sample_size = 1
bound_of_range = self._y_max_observed - self._y_min_observed
if bound_of_range == 0:
bound_of_range = 1.0
self.result.update_result(new_loss,
VowpalWabbitTrial.cost_unit * data_sample_size,
self._dim, bound_of_range)
def predict(self, x):
"""Predict using the model
"""
if self.model is None:
# initialize self.model and self.result
self._initialize_vw_model(x)
return self.model.predict(x)
def _get_loss(self, y_true, y_pred, loss_func_name, y_min_observed, y_max_observed):
"""Get instantaneous loss from y_true and y_pred, and loss_func_name
For mae_clip, we clip y_pred in the observed range of y
"""
if 'mse' in loss_func_name or 'squared' in loss_func_name:
loss_func = mean_squared_error
elif 'mae' in loss_func_name or 'absolute' in loss_func_name:
loss_func = mean_absolute_error
if y_min_observed is not None and y_max_observed is not None and \
'clip' in loss_func_name:
# clip y_pred in the observed range of y
y_pred = min(y_max_observed, max(y_pred, y_min_observed))
else:
raise NotImplementedError
return loss_func([y_true], [y_pred])
def _update_y_range(self, y):
"""Maintain running observed minimum and maximum target value
"""
if self._y_min_observed is None or y < self._y_min_observed:
self._y_min_observed = y
if self._y_max_observed is None or y > self._y_max_observed:
self._y_max_observed = y
@staticmethod
def _get_dim_from_ns(namespace_feature_dim: dict, namespace_interactions: [set, list]):
"""Get the dimensionality of the corresponding feature of input namespace set
"""
total_dim = sum(namespace_feature_dim.values())
if namespace_interactions:
for f in namespace_interactions:
ns_dim = 1.0
for c in f:
ns_dim *= namespace_feature_dim[c]
total_dim += ns_dim
return total_dim
def clean_up_model(self):
self.model = None
self.result = None
@staticmethod
def _get_y_from_vw_example(vw_example):
"""Get y from a vw_example. this works for regression datasets.
"""
return float(vw_example.split('|')[0])