mirror of
https://github.com/microsoft/autogen.git
synced 2025-09-16 19:53:57 +00:00
158 lines
5.5 KiB
Python
158 lines
5.5 KiB
Python
![]() |
'''
|
||
|
Copyright 2020 The Ray Authors.
|
||
|
|
||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
you may not use this file except in compliance with the License.
|
||
|
You may obtain a copy of the License at
|
||
|
|
||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
||
|
Unless required by applicable law or agreed to in writing, software
|
||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
See the License for the specific language governing permissions and
|
||
|
limitations under the License.
|
||
|
|
||
|
This source file is adapted here because ray does not fully support Windows.
|
||
|
|
||
|
Copyright (c) Microsoft Corporation.
|
||
|
'''
|
||
|
from typing import Dict, Optional
|
||
|
|
||
|
from flaml.tune import trial_runner
|
||
|
from flaml.tune.result import DEFAULT_METRIC
|
||
|
from flaml.tune.trial import Trial
|
||
|
|
||
|
|
||
|
class TrialScheduler:
|
||
|
"""Interface for implementing a Trial Scheduler class."""
|
||
|
|
||
|
CONTINUE = "CONTINUE" #: Status for continuing trial execution
|
||
|
PAUSE = "PAUSE" #: Status for pausing trial execution
|
||
|
STOP = "STOP" #: Status for stopping trial execution
|
||
|
|
||
|
_metric = None
|
||
|
|
||
|
@property
|
||
|
def metric(self):
|
||
|
return self._metric
|
||
|
|
||
|
def set_search_properties(self, metric: Optional[str],
|
||
|
mode: Optional[str]) -> bool:
|
||
|
"""Pass search properties to scheduler.
|
||
|
This method acts as an alternative to instantiating schedulers
|
||
|
that react to metrics with their own `metric` and `mode` parameters.
|
||
|
Args:
|
||
|
metric (str): Metric to optimize
|
||
|
mode (str): One of ["min", "max"]. Direction to optimize.
|
||
|
"""
|
||
|
if self._metric and metric:
|
||
|
return False
|
||
|
if metric:
|
||
|
self._metric = metric
|
||
|
|
||
|
if self._metric is None:
|
||
|
# Per default, use anonymous metric
|
||
|
self._metric = DEFAULT_METRIC
|
||
|
|
||
|
return True
|
||
|
|
||
|
def on_trial_add(self, trial_runner: "trial_runner.TrialRunner",
|
||
|
trial: Trial):
|
||
|
"""Called when a new trial is added to the trial runner."""
|
||
|
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def on_trial_error(self, trial_runner: "trial_runner.TrialRunner",
|
||
|
trial: Trial):
|
||
|
"""Notification for the error of trial.
|
||
|
This will only be called when the trial is in the RUNNING state."""
|
||
|
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def on_trial_result(self, trial_runner: "trial_runner.TrialRunner",
|
||
|
trial: Trial, result: Dict) -> str:
|
||
|
"""Called on each intermediate result returned by a trial.
|
||
|
At this point, the trial scheduler can make a decision by returning
|
||
|
one of CONTINUE, PAUSE, and STOP. This will only be called when the
|
||
|
trial is in the RUNNING state."""
|
||
|
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def on_trial_complete(self, trial_runner: "trial_runner.TrialRunner",
|
||
|
trial: Trial, result: Dict):
|
||
|
"""Notification for the completion of trial.
|
||
|
This will only be called when the trial is in the RUNNING state and
|
||
|
either completes naturally or by manual termination."""
|
||
|
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def on_trial_remove(self, trial_runner: "trial_runner.TrialRunner",
|
||
|
trial: Trial):
|
||
|
"""Called to remove trial.
|
||
|
This is called when the trial is in PAUSED or PENDING state. Otherwise,
|
||
|
call `on_trial_complete`."""
|
||
|
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def choose_trial_to_run(
|
||
|
self, trial_runner: "trial_runner.TrialRunner") -> Optional[Trial]:
|
||
|
"""Called to choose a new trial to run.
|
||
|
This should return one of the trials in trial_runner that is in
|
||
|
the PENDING or PAUSED state. This function must be idempotent.
|
||
|
If no trial is ready, return None."""
|
||
|
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def debug_string(self) -> str:
|
||
|
"""Returns a human readable message for printing to the console."""
|
||
|
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def save(self, checkpoint_path: str):
|
||
|
"""Save trial scheduler to a checkpoint"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def restore(self, checkpoint_path: str):
|
||
|
"""Restore trial scheduler from checkpoint."""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
|
||
|
class FIFOScheduler(TrialScheduler):
|
||
|
"""Simple scheduler that just runs trials in submission order."""
|
||
|
|
||
|
def on_trial_add(self, trial_runner: "trial_runner.TrialRunner",
|
||
|
trial: Trial):
|
||
|
pass
|
||
|
|
||
|
def on_trial_error(self, trial_runner: "trial_runner.TrialRunner",
|
||
|
trial: Trial):
|
||
|
pass
|
||
|
|
||
|
def on_trial_result(self, trial_runner: "trial_runner.TrialRunner",
|
||
|
trial: Trial, result: Dict) -> str:
|
||
|
return TrialScheduler.CONTINUE
|
||
|
|
||
|
def on_trial_complete(self, trial_runner: "trial_runner.TrialRunner",
|
||
|
trial: Trial, result: Dict):
|
||
|
pass
|
||
|
|
||
|
def on_trial_remove(self, trial_runner: "trial_runner.TrialRunner",
|
||
|
trial: Trial):
|
||
|
pass
|
||
|
|
||
|
def choose_trial_to_run(
|
||
|
self, trial_runner: "trial_runner.TrialRunner") -> Optional[Trial]:
|
||
|
for trial in trial_runner.get_trials():
|
||
|
if (trial.status == Trial.PENDING
|
||
|
and trial_runner.has_resources_for_trial(trial)):
|
||
|
return trial
|
||
|
for trial in trial_runner.get_trials():
|
||
|
if (trial.status == Trial.PAUSED
|
||
|
and trial_runner.has_resources_for_trial(trial)):
|
||
|
return trial
|
||
|
return None
|
||
|
|
||
|
def debug_string(self) -> str:
|
||
|
return "Using FIFO scheduling algorithm."
|