diff --git a/flaml/tune/tune.py b/flaml/tune/tune.py index 7b5b2a62f..07e29d566 100644 --- a/flaml/tune/tune.py +++ b/flaml/tune/tune.py @@ -492,9 +492,7 @@ def run( SearchAlgorithm = BlendSearch logger.info( - "Using search algorithm {}.".format( - SearchAlgorithm.__class__.__name__ - ) + "Using search algorithm {}.".format(SearchAlgorithm.__name__) ) except ImportError: SearchAlgorithm = CFO @@ -504,9 +502,7 @@ def run( metric = metric or DEFAULT_METRIC else: SearchAlgorithm = CFO - logger.info( - "Using search algorithm {}.".format(SearchAlgorithm.__class__.__name__) - ) + logger.info("Using search algorithm {}.".format(SearchAlgorithm.__name__)) metric = lexico_objectives["metrics"][0] or DEFAULT_METRIC search_alg = SearchAlgorithm( metric=metric, @@ -675,14 +671,14 @@ def run( num_trials = 0 if time_budget_s is None: time_budget_s = np.inf - fail = 0 - ub = ( + num_failures = 0 + upperbound_num_failures = ( len(evaluated_rewards) if evaluated_rewards else 0 ) + max_failure while ( time.time() - time_start < time_budget_s and (num_samples < 0 or num_trials < num_samples) - and fail < ub + and num_failures < upperbound_num_failures ): while len(_runner.running_trials) < n_concurrent_trials: # suggest trials for spark @@ -690,9 +686,9 @@ def run( if trial_next: num_trials += 1 else: - fail += 1 # break with ub consecutive failures - logger.debug(f"consecutive failures is {fail}") - if fail >= ub: + num_failures += 1 # break with upperbound_num_failures consecutive failures + logger.debug(f"consecutive failures is {num_failures}") + if num_failures >= upperbound_num_failures: break trials_to_run = _runner.running_trials if not trials_to_run: @@ -730,7 +726,7 @@ def run( ) report(_metric=result) _runner.stop_trial(trial_to_run) - fail = 0 + num_failures = 0 analysis = ExperimentAnalysis( _runner.get_trials(), metric=metric, @@ -766,12 +762,14 @@ def run( num_trials = 0 if time_budget_s is None: time_budget_s = np.inf - fail = 0 - ub = (len(evaluated_rewards) if evaluated_rewards else 0) + max_failure + num_failures = 0 + upperbound_num_failures = ( + len(evaluated_rewards) if evaluated_rewards else 0 + ) + max_failure while ( time.time() - time_start < time_budget_s and (num_samples < 0 or num_trials < num_samples) - and fail < ub + and num_failures < upperbound_num_failures ): trial_to_run = _runner.step() if trial_to_run: @@ -789,10 +787,11 @@ def run( else: report(_metric=result) _runner.stop_trial(trial_to_run) - fail = 0 + num_failures = 0 else: - fail += 1 # break with ub consecutive failures - if fail == ub: + # break with upperbound_num_failures consecutive failures + num_failures += 1 + if num_failures == upperbound_num_failures: logger.warning( f"fail to sample a trial for {max_failure} times in a row, stopping." )