set_search_properties (#595)

* update the signature of set_search_properties
This commit is contained in:
Chi Wang 2022-06-16 16:30:50 -07:00 committed by GitHub
parent 4c044e88bd
commit 1b40b4b3a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 22 additions and 19 deletions

View File

@ -43,11 +43,14 @@ jobs:
pip install -e . pip install -e .
python -c "import flaml" python -c "import flaml"
pip install -e .[test] pip install -e .[test]
- name: If linux or mac, install ray and prophet - name: If linux or mac, install ray
if: matrix.os == 'macOS-latest' || matrix.os == 'ubuntu-latest'
run: |
pip install -e .[ray]
- name: If linux or mac, install prophet on python < 3.9
if: (matrix.os == 'macOS-latest' || matrix.os == 'ubuntu-latest') && matrix.python-version != '3.9' && matrix.python-version != '3.10' if: (matrix.os == 'macOS-latest' || matrix.os == 'ubuntu-latest') && matrix.python-version != '3.9' && matrix.python-version != '3.10'
run: | run: |
pip install -e .[ray,forecast] pip install -e .[forecast]
pip install 'tensorboardX<=2.2'
- name: Lint with flake8 - name: Lint with flake8
run: | run: |
# stop the build if there are Python syntax errors or undefined names # stop the build if there are Python syntax errors or undefined names
@ -55,17 +58,17 @@ jobs:
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest - name: Test with pytest
if: ${{ (matrix.python-version != '3.7' || matrix.os == 'macos-latest') && matrix.python-version != '3.10' }} if: (matrix.python-version != '3.7' || matrix.os == 'macos-latest') && matrix.python-version != '3.10'
run: | run: |
pytest test pytest test
- name: Coverage - name: Coverage
if: ${{ (matrix.python-version == '3.7') && matrix.os != 'macos-latest' || matrix.python-version == '3.10' }} if: (matrix.python-version == '3.7') && matrix.os != 'macos-latest' || matrix.python-version == '3.10'
run: | run: |
pip install coverage pip install coverage
coverage run -a -m pytest test coverage run -a -m pytest test
coverage xml coverage xml
- name: Upload coverage to Codecov - name: Upload coverage to Codecov
if: ${{ (matrix.python-version == '3.7') && matrix.os != 'macos-latest' || matrix.python-version == '3.10'}} if: (matrix.python-version == '3.7') && matrix.os != 'macos-latest' || matrix.python-version == '3.10'
uses: codecov/codecov-action@v1 uses: codecov/codecov-action@v1
with: with:
file: ./coverage.xml file: ./coverage.xml

View File

@ -2992,9 +2992,7 @@ class AutoML(BaseEstimator):
search_state.search_alg.searcher.set_search_properties( search_state.search_alg.searcher.set_search_properties(
metric=None, metric=None,
mode=None, mode=None,
setting={ metric_target=self._state.best_loss,
"metric_target": self._state.best_loss,
},
) )
start_run_time = time.time() start_run_time = time.time()
analysis = tune.run( analysis = tune.run(

View File

@ -235,7 +235,7 @@ class BlendSearch(Searcher):
metric: Optional[str] = None, metric: Optional[str] = None,
mode: Optional[str] = None, mode: Optional[str] = None,
config: Optional[Dict] = None, config: Optional[Dict] = None,
setting: Optional[Dict] = None, **spec,
) -> bool: ) -> bool:
metric_changed = mode_changed = False metric_changed = mode_changed = False
if metric and self._metric != metric: if metric and self._metric != metric:
@ -272,21 +272,21 @@ class BlendSearch(Searcher):
) )
self._gs.space = self._ls.space self._gs.space = self._ls.space
self._init_search() self._init_search()
if setting: if spec:
# CFO doesn't need these settings # CFO doesn't need these settings
if "time_budget_s" in setting: if "time_budget_s" in spec:
self._time_budget_s = setting["time_budget_s"] # budget from now self._time_budget_s = spec["time_budget_s"] # budget from now
now = time.time() now = time.time()
self._time_used += now - self._start_time self._time_used += now - self._start_time
self._start_time = now self._start_time = now
self._set_deadline() self._set_deadline()
if self._input_cost_attr == "auto": if self._input_cost_attr == "auto":
self.cost_attr = TIME_TOTAL_S self.cost_attr = TIME_TOTAL_S
if "metric_target" in setting: if "metric_target" in spec:
self._metric_target = setting.get("metric_target") self._metric_target = spec.get("metric_target")
if "num_samples" in setting: if "num_samples" in spec:
self._num_samples = ( self._num_samples = (
setting["num_samples"] spec["num_samples"]
+ len(self._result) + len(self._result)
+ len(self._trial_proposed_by) + len(self._trial_proposed_by)
) )

View File

@ -425,7 +425,7 @@ def run(
setting["time_budget_s"] = time_budget_s setting["time_budget_s"] = time_budget_s
if num_samples > 0: if num_samples > 0:
setting["num_samples"] = num_samples setting["num_samples"] = num_samples
searcher.set_search_properties(metric, mode, config, setting) searcher.set_search_properties(metric, mode, config, **setting)
else: else:
searcher.set_search_properties(metric, mode, config) searcher.set_search_properties(metric, mode, config)
if scheduler in ("asha", "asynchyperband", "async_hyperband"): if scheduler in ("asha", "asynchyperband", "async_hyperband"):

View File

@ -203,7 +203,9 @@ def test_searcher():
points_to_evaluate=[{"a": 1, "b": 0.01}], points_to_evaluate=[{"a": 1, "b": 0.01}],
) )
searcher.set_search_properties( searcher.set_search_properties(
metric="m2", config=config, setting={"time_budget_s": 0} metric="m2",
config=config,
time_budget_s=0,
) )
c = searcher.suggest("t1") c = searcher.suggest("t1")
print("t1", c) print("t1", c)