From 1b40b4b3a6602db1df2d315d4b9e12fa0c69bb65 Mon Sep 17 00:00:00 2001 From: Chi Wang Date: Thu, 16 Jun 2022 16:30:50 -0700 Subject: [PATCH] set_search_properties (#595) * update the signature of set_search_properties --- .github/workflows/python-package.yml | 15 +++++++++------ flaml/automl.py | 4 +--- flaml/searcher/blendsearch.py | 16 ++++++++-------- flaml/tune/tune.py | 2 +- test/tune/test_searcher.py | 4 +++- 5 files changed, 22 insertions(+), 19 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 32f82d5c0..19f3a3f40 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -43,11 +43,14 @@ jobs: pip install -e . python -c "import flaml" pip install -e .[test] - - name: If linux or mac, install ray and prophet + - name: If linux or mac, install ray + if: matrix.os == 'macOS-latest' || matrix.os == 'ubuntu-latest' + run: | + pip install -e .[ray] + - name: If linux or mac, install prophet on python < 3.9 if: (matrix.os == 'macOS-latest' || matrix.os == 'ubuntu-latest') && matrix.python-version != '3.9' && matrix.python-version != '3.10' run: | - pip install -e .[ray,forecast] - pip install 'tensorboardX<=2.2' + pip install -e .[forecast] - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names @@ -55,17 +58,17 @@ jobs: # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Test with pytest - if: ${{ (matrix.python-version != '3.7' || matrix.os == 'macos-latest') && matrix.python-version != '3.10' }} + if: (matrix.python-version != '3.7' || matrix.os == 'macos-latest') && matrix.python-version != '3.10' run: | pytest test - name: Coverage - if: ${{ (matrix.python-version == '3.7') && matrix.os != 'macos-latest' || matrix.python-version == '3.10' }} + if: (matrix.python-version == '3.7') && matrix.os != 'macos-latest' || matrix.python-version == '3.10' run: | pip install coverage coverage run -a -m pytest test coverage xml - name: Upload coverage to Codecov - if: ${{ (matrix.python-version == '3.7') && matrix.os != 'macos-latest' || matrix.python-version == '3.10'}} + if: (matrix.python-version == '3.7') && matrix.os != 'macos-latest' || matrix.python-version == '3.10' uses: codecov/codecov-action@v1 with: file: ./coverage.xml diff --git a/flaml/automl.py b/flaml/automl.py index eecb59456..65b1d29bb 100644 --- a/flaml/automl.py +++ b/flaml/automl.py @@ -2992,9 +2992,7 @@ class AutoML(BaseEstimator): search_state.search_alg.searcher.set_search_properties( metric=None, mode=None, - setting={ - "metric_target": self._state.best_loss, - }, + metric_target=self._state.best_loss, ) start_run_time = time.time() analysis = tune.run( diff --git a/flaml/searcher/blendsearch.py b/flaml/searcher/blendsearch.py index 0da3d5f58..ae681b62f 100644 --- a/flaml/searcher/blendsearch.py +++ b/flaml/searcher/blendsearch.py @@ -235,7 +235,7 @@ class BlendSearch(Searcher): metric: Optional[str] = None, mode: Optional[str] = None, config: Optional[Dict] = None, - setting: Optional[Dict] = None, + **spec, ) -> bool: metric_changed = mode_changed = False if metric and self._metric != metric: @@ -272,21 +272,21 @@ class BlendSearch(Searcher): ) self._gs.space = self._ls.space self._init_search() - if setting: + if spec: # CFO doesn't need these settings - if "time_budget_s" in setting: - self._time_budget_s = setting["time_budget_s"] # budget from now + if "time_budget_s" in spec: + self._time_budget_s = spec["time_budget_s"] # budget from now now = time.time() self._time_used += now - self._start_time self._start_time = now self._set_deadline() if self._input_cost_attr == "auto": self.cost_attr = TIME_TOTAL_S - if "metric_target" in setting: - self._metric_target = setting.get("metric_target") - if "num_samples" in setting: + if "metric_target" in spec: + self._metric_target = spec.get("metric_target") + if "num_samples" in spec: self._num_samples = ( - setting["num_samples"] + spec["num_samples"] + len(self._result) + len(self._trial_proposed_by) ) diff --git a/flaml/tune/tune.py b/flaml/tune/tune.py index 89a4fe380..264a5eba6 100644 --- a/flaml/tune/tune.py +++ b/flaml/tune/tune.py @@ -425,7 +425,7 @@ def run( setting["time_budget_s"] = time_budget_s if num_samples > 0: setting["num_samples"] = num_samples - searcher.set_search_properties(metric, mode, config, setting) + searcher.set_search_properties(metric, mode, config, **setting) else: searcher.set_search_properties(metric, mode, config) if scheduler in ("asha", "asynchyperband", "async_hyperband"): diff --git a/test/tune/test_searcher.py b/test/tune/test_searcher.py index fd0dd7b3d..d3002c3b2 100644 --- a/test/tune/test_searcher.py +++ b/test/tune/test_searcher.py @@ -203,7 +203,9 @@ def test_searcher(): points_to_evaluate=[{"a": 1, "b": 0.01}], ) searcher.set_search_properties( - metric="m2", config=config, setting={"time_budget_s": 0} + metric="m2", + config=config, + time_budget_s=0, ) c = searcher.suggest("t1") print("t1", c)