add periods

This commit is contained in:
Chi Wang 2021-11-06 12:44:10 -07:00
parent c4d5986ee8
commit 03bc62363f
4 changed files with 9 additions and 10 deletions

View File

@ -205,7 +205,7 @@ def concat(X1, X2):
class DataTransformer: class DataTransformer:
"""Transform input training data""" """Transform input training data."""
def fit_transform(self, X, y, task): def fit_transform(self, X, y, task):
"""Fit transformer and process the input training data according to the task type. """Fit transformer and process the input training data according to the task type.

View File

@ -169,7 +169,7 @@ class AutoVW:
self._trial_runner.step(data_sample, (self._y_predict, self._best_trial)) self._trial_runner.step(data_sample, (self._y_predict, self._best_trial))
def _select_best_trial(self): def _select_best_trial(self):
"""Select a best trial from the running trials accoring to the _model_select_policy""" """Select a best trial from the running trials accoring to the _model_select_policy."""
best_score = ( best_score = (
float("+inf") if self._model_selection_mode == "min" else float("-inf") float("+inf") if self._model_selection_mode == "min" else float("-inf")
) )

View File

@ -4,7 +4,7 @@ import time
import math import math
import copy import copy
import collections import collections
from typing import Optional from typing import Optional, Union
from sklearn.metrics import mean_squared_error, mean_absolute_error from sklearn.metrics import mean_squared_error, mean_absolute_error
from flaml.tune import Trial from flaml.tune import Trial
@ -113,7 +113,7 @@ class OnlineResult:
def _update_loss_cb( def _update_loss_cb(
self, bound_of_range, data_dim, bound_name="sample_complexity_bound" self, bound_of_range, data_dim, bound_name="sample_complexity_bound"
): ):
"""Calculate bound coef""" """Calculate bound coef."""
if bound_name == "sample_complexity_bound": if bound_name == "sample_complexity_bound":
# set the coefficient in the loss bound # set the coefficient in the loss bound
if "mae" in self.result_type_name: if "mae" in self.result_type_name:
@ -313,7 +313,6 @@ class VowpalWabbitTrial(BaseOnlineTrial):
is_checked_under_current_champion (bool): indicates whether this trials has is_checked_under_current_champion (bool): indicates whether this trials has
been paused under the current champion. been paused under the current champion.
trial_id (str): id of the trial (if None, it will be generated in the constructor). trial_id (str): id of the trial (if None, it will be generated in the constructor).
""" """
try: try:
from vowpalwabbit import pyvw from vowpalwabbit import pyvw
@ -345,7 +344,7 @@ class VowpalWabbitTrial(BaseOnlineTrial):
@staticmethod @staticmethod
def _config_to_id(config): def _config_to_id(config):
"""Generate an id for the provided config""" """Generate an id for the provided config."""
# sort config keys # sort config keys
sorted_k_list = sorted(list(config.keys())) sorted_k_list = sorted(list(config.keys()))
config_id_full = "" config_id_full = ""
@ -439,7 +438,7 @@ class VowpalWabbitTrial(BaseOnlineTrial):
return loss_func([y_true], [y_pred]) return loss_func([y_true], [y_pred])
def _update_y_range(self, y): def _update_y_range(self, y):
"""Maintain running observed minimum and maximum target value""" """Maintain running observed minimum and maximum target value."""
if self._y_min_observed is None or y < self._y_min_observed: if self._y_min_observed is None or y < self._y_min_observed:
self._y_min_observed = y self._y_min_observed = y
if self._y_max_observed is None or y > self._y_max_observed: if self._y_max_observed is None or y > self._y_max_observed:
@ -447,9 +446,9 @@ class VowpalWabbitTrial(BaseOnlineTrial):
@staticmethod @staticmethod
def _get_dim_from_ns( def _get_dim_from_ns(
namespace_feature_dim: dict, namespace_interactions: [set, list] namespace_feature_dim: dict, namespace_interactions: Union[set, list]
): ):
"""Get the dimensionality of the corresponding feature of input namespace set""" """Get the dimensionality of the corresponding feature of input namespace set."""
total_dim = sum(namespace_feature_dim.values()) total_dim = sum(namespace_feature_dim.values())
if namespace_interactions: if namespace_interactions:
for f in namespace_interactions: for f in namespace_interactions:

View File

@ -89,7 +89,7 @@ class OnlineSuccessiveDoublingScheduler(OnlineScheduler):
class ChaChaScheduler(OnlineSuccessiveDoublingScheduler): class ChaChaScheduler(OnlineSuccessiveDoublingScheduler):
"""Keep the top performing learners running """Keep the top performing learners running.
Methods: Methods:
* on_trial_result(trial_runner, trial, result): * on_trial_result(trial_runner, trial, result):