| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  | """Require: pip install flaml[test,ray]
 | 
					
						
							|  |  |  | """
 | 
					
						
							| 
									
										
										
										
											2021-08-02 19:10:26 -04:00
										 |  |  | from flaml.searcher.blendsearch import BlendSearch | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  | import time | 
					
						
							| 
									
										
										
										
											2021-05-18 15:57:42 -07:00
										 |  |  | import os | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  | from sklearn.model_selection import train_test_split | 
					
						
							|  |  |  | import sklearn.metrics | 
					
						
							|  |  |  | import sklearn.datasets | 
					
						
							| 
									
										
										
										
											2022-01-30 13:02:18 -08:00
										 |  |  | import xgboost as xgb | 
					
						
							|  |  |  | import logging | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  | try: | 
					
						
							|  |  |  |     from ray.tune.integration.xgboost import TuneReportCheckpointCallback | 
					
						
							|  |  |  | except ImportError: | 
					
						
							| 
									
										
										
										
											2021-05-18 15:57:42 -07:00
										 |  |  |     print("skip test_xgboost because ray tune cannot be imported.") | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  | logger = logging.getLogger(__name__) | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  | os.makedirs("logs", exist_ok=True) | 
					
						
							|  |  |  | logger.addHandler(logging.FileHandler("logs/tune.log")) | 
					
						
							| 
									
										
										
										
											2021-05-18 15:57:42 -07:00
										 |  |  | logger.setLevel(logging.INFO) | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def train_breast_cancer(config: dict): | 
					
						
							|  |  |  |     # This is a simple training function to be passed into Tune | 
					
						
							|  |  |  |     # Load dataset | 
					
						
							|  |  |  |     data, labels = sklearn.datasets.load_breast_cancer(return_X_y=True) | 
					
						
							|  |  |  |     # Split into train and test set | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |     train_x, test_x, train_y, test_y = train_test_split(data, labels, test_size=0.25) | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |     # Build input matrices for XGBoost | 
					
						
							|  |  |  |     train_set = xgb.DMatrix(train_x, label=train_y) | 
					
						
							|  |  |  |     test_set = xgb.DMatrix(test_x, label=test_y) | 
					
						
							|  |  |  |     # HyperOpt returns a tuple | 
					
						
							|  |  |  |     config = config.copy() | 
					
						
							|  |  |  |     config["eval_metric"] = ["logloss", "error"] | 
					
						
							|  |  |  |     config["objective"] = "binary:logistic" | 
					
						
							|  |  |  |     # Train the classifier, using the Tune callback | 
					
						
							|  |  |  |     xgb.train( | 
					
						
							|  |  |  |         config, | 
					
						
							|  |  |  |         train_set, | 
					
						
							|  |  |  |         evals=[(test_set, "eval")], | 
					
						
							|  |  |  |         verbose_eval=False, | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |         callbacks=[TuneReportCheckpointCallback(filename="model.xgb")], | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  | def _test_xgboost(method="BlendSearch"): | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |     try: | 
					
						
							|  |  |  |         import ray | 
					
						
							|  |  |  |     except ImportError: | 
					
						
							|  |  |  |         return | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |     if method == "BlendSearch": | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |         from flaml import tune | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         from ray import tune | 
					
						
							|  |  |  |     search_space = { | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |         "max_depth": tune.randint(1, 9) | 
					
						
							|  |  |  |         if method in ["BlendSearch", "BOHB", "Optuna"] | 
					
						
							|  |  |  |         else tune.randint(1, 9), | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |         "min_child_weight": tune.choice([1, 2, 3]), | 
					
						
							|  |  |  |         "subsample": tune.uniform(0.5, 1.0), | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |         "eta": tune.loguniform(1e-4, 1e-1), | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |     } | 
					
						
							|  |  |  |     max_iter = 10 | 
					
						
							| 
									
										
										
										
											2021-03-05 23:39:14 -08:00
										 |  |  |     for num_samples in [128]: | 
					
						
							| 
									
										
										
										
											2021-04-08 09:29:55 -07:00
										 |  |  |         time_budget_s = 60 | 
					
						
							| 
									
										
										
										
											2022-01-30 01:53:32 -05:00
										 |  |  |         for n_cpu in [2]: | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |             start_time = time.time() | 
					
						
							| 
									
										
										
										
											2021-05-18 15:57:42 -07:00
										 |  |  |             # ray.init(address='auto') | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |             if method == "BlendSearch": | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |                 analysis = tune.run( | 
					
						
							|  |  |  |                     train_breast_cancer, | 
					
						
							| 
									
										
										
										
											2021-04-06 11:37:52 -07:00
										 |  |  |                     config=search_space, | 
					
						
							|  |  |  |                     low_cost_partial_config={ | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |                         "max_depth": 1, | 
					
						
							|  |  |  |                     }, | 
					
						
							|  |  |  |                     cat_hp_cost={ | 
					
						
							|  |  |  |                         "min_child_weight": [6, 3, 2], | 
					
						
							|  |  |  |                     }, | 
					
						
							|  |  |  |                     metric="eval-logloss", | 
					
						
							|  |  |  |                     mode="min", | 
					
						
							|  |  |  |                     max_resource=max_iter, | 
					
						
							|  |  |  |                     min_resource=1, | 
					
						
							| 
									
										
										
										
											2021-12-04 21:52:20 -05:00
										 |  |  |                     scheduler="asha", | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |                     # You can add "gpu": 0.1 to allocate GPUs | 
					
						
							|  |  |  |                     resources_per_trial={"cpu": 1}, | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |                     local_dir="logs/", | 
					
						
							| 
									
										
										
										
											2021-04-08 09:29:55 -07:00
										 |  |  |                     num_samples=num_samples * n_cpu, | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |                     time_budget_s=time_budget_s, | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |                     use_ray=True, | 
					
						
							|  |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |             else: | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |                 if "ASHA" == method: | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |                     algo = None | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |                 elif "BOHB" == method: | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |                     from ray.tune.schedulers import HyperBandForBOHB | 
					
						
							|  |  |  |                     from ray.tune.suggest.bohb import TuneBOHB | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |                     algo = TuneBOHB(max_concurrent=n_cpu) | 
					
						
							|  |  |  |                     scheduler = HyperBandForBOHB(max_t=max_iter) | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |                 elif "Optuna" == method: | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |                     from ray.tune.suggest.optuna import OptunaSearch | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |                     algo = OptunaSearch() | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |                 elif "CFO" == method: | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |                     from flaml import CFO | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |                     algo = CFO( | 
					
						
							|  |  |  |                         low_cost_partial_config={ | 
					
						
							|  |  |  |                             "max_depth": 1, | 
					
						
							|  |  |  |                         }, | 
					
						
							|  |  |  |                         cat_hp_cost={ | 
					
						
							|  |  |  |                             "min_child_weight": [6, 3, 2], | 
					
						
							|  |  |  |                         }, | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  |                 elif "CFOCat" == method: | 
					
						
							| 
									
										
										
										
											2021-07-20 17:00:44 -07:00
										 |  |  |                     from flaml.searcher.cfo_cat import CFOCat | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |                     algo = CFOCat( | 
					
						
							|  |  |  |                         low_cost_partial_config={ | 
					
						
							|  |  |  |                             "max_depth": 1, | 
					
						
							|  |  |  |                         }, | 
					
						
							|  |  |  |                         cat_hp_cost={ | 
					
						
							|  |  |  |                             "min_child_weight": [6, 3, 2], | 
					
						
							|  |  |  |                         }, | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  |                 elif "Dragonfly" == method: | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |                     from ray.tune.suggest.dragonfly import DragonflySearch | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |                     algo = DragonflySearch() | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |                 elif "SkOpt" == method: | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |                     from ray.tune.suggest.skopt import SkOptSearch | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |                     algo = SkOptSearch() | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |                 elif "Nevergrad" == method: | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |                     from ray.tune.suggest.nevergrad import NevergradSearch | 
					
						
							|  |  |  |                     import nevergrad as ng | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |                     algo = NevergradSearch(optimizer=ng.optimizers.OnePlusOne) | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |                 elif "ZOOpt" == method: | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |                     from ray.tune.suggest.zoopt import ZOOptSearch | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-08 09:29:55 -07:00
										 |  |  |                     algo = ZOOptSearch(budget=num_samples * n_cpu) | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |                 elif "Ax" == method: | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |                     from ray.tune.suggest.ax import AxSearch | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |                     algo = AxSearch() | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |                 elif "HyperOpt" == method: | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |                     from ray.tune.suggest.hyperopt import HyperOptSearch | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |                     algo = HyperOptSearch() | 
					
						
							|  |  |  |                     scheduler = None | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |                 if method != "BOHB": | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |                     from ray.tune.schedulers import ASHAScheduler | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |                     scheduler = ASHAScheduler(max_t=max_iter, grace_period=1) | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |                 analysis = tune.run( | 
					
						
							|  |  |  |                     train_breast_cancer, | 
					
						
							|  |  |  |                     metric="eval-logloss", | 
					
						
							|  |  |  |                     mode="min", | 
					
						
							|  |  |  |                     # You can add "gpu": 0.1 to allocate GPUs | 
					
						
							|  |  |  |                     resources_per_trial={"cpu": 1}, | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |                     config=search_space, | 
					
						
							|  |  |  |                     local_dir="logs/", | 
					
						
							| 
									
										
										
										
											2021-04-08 09:29:55 -07:00
										 |  |  |                     num_samples=num_samples * n_cpu, | 
					
						
							|  |  |  |                     time_budget_s=time_budget_s, | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |                     scheduler=scheduler, | 
					
						
							|  |  |  |                     search_alg=algo, | 
					
						
							|  |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |             # # Load the best model checkpoint | 
					
						
							| 
									
										
										
										
											2021-02-13 10:43:11 -08:00
										 |  |  |             # import os | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |             # best_bst = xgb.Booster() | 
					
						
							|  |  |  |             # best_bst.load_model(os.path.join(analysis.best_checkpoint, | 
					
						
							|  |  |  |             #  "model.xgb")) | 
					
						
							| 
									
										
										
										
											2021-04-08 09:29:55 -07:00
										 |  |  |             best_trial = analysis.get_best_trial("eval-logloss", "min", "all") | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |             accuracy = 1.0 - best_trial.metric_analysis["eval-error"]["min"] | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  |             logloss = best_trial.metric_analysis["eval-logloss"]["min"] | 
					
						
							|  |  |  |             logger.info(f"method={method}") | 
					
						
							|  |  |  |             logger.info(f"n_samples={num_samples*n_cpu}") | 
					
						
							|  |  |  |             logger.info(f"time={time.time()-start_time}") | 
					
						
							|  |  |  |             logger.info(f"Best model eval loss: {logloss:.4f}") | 
					
						
							|  |  |  |             logger.info(f"Best model total accuracy: {accuracy:.4f}") | 
					
						
							|  |  |  |             logger.info(f"Best model parameters: {best_trial.config}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-02-28 12:43:43 -08:00
										 |  |  | def test_nested(): | 
					
						
							| 
									
										
										
										
											2021-07-31 16:39:31 -04:00
										 |  |  |     from flaml import tune, CFO | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-02-28 12:43:43 -08:00
										 |  |  |     search_space = { | 
					
						
							|  |  |  |         # test nested search space | 
					
						
							|  |  |  |         "cost_related": { | 
					
						
							| 
									
										
										
										
											2021-08-12 02:02:22 -04:00
										 |  |  |             "a": tune.randint(1, 9), | 
					
						
							| 
									
										
										
										
											2021-02-28 12:43:43 -08:00
										 |  |  |         }, | 
					
						
							|  |  |  |         "b": tune.uniform(0.5, 1.0), | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def simple_func(config): | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |         obj = (config["cost_related"]["a"] - 4) ** 2 + ( | 
					
						
							|  |  |  |             config["b"] - config["cost_related"]["a"] | 
					
						
							|  |  |  |         ) ** 2 | 
					
						
							| 
									
										
										
										
											2021-05-18 15:57:42 -07:00
										 |  |  |         tune.report(obj=obj) | 
					
						
							|  |  |  |         tune.report(obj=obj, ab=config["cost_related"]["a"] * config["b"]) | 
					
						
							| 
									
										
										
										
											2021-02-28 12:43:43 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-07-31 16:39:31 -04:00
										 |  |  |     analysis = tune.run( | 
					
						
							|  |  |  |         simple_func, | 
					
						
							|  |  |  |         search_alg=CFO( | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |             space=search_space, | 
					
						
							|  |  |  |             metric="obj", | 
					
						
							|  |  |  |             mode="min", | 
					
						
							|  |  |  |             low_cost_partial_config={"cost_related": {"a": 1}}, | 
					
						
							| 
									
										
										
										
											2021-07-31 16:39:31 -04:00
										 |  |  |             points_to_evaluate=[ | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |                 {"b": 0.99, "cost_related": {"a": 3}}, | 
					
						
							|  |  |  |                 {"b": 0.99, "cost_related": {"a": 2}}, | 
					
						
							|  |  |  |                 {"cost_related": {"a": 8}}, | 
					
						
							| 
									
										
										
										
											2021-07-31 16:39:31 -04:00
										 |  |  |             ], | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |             metric_constraints=[("ab", "<=", 4)], | 
					
						
							|  |  |  |         ), | 
					
						
							|  |  |  |         local_dir="logs/", | 
					
						
							| 
									
										
										
										
											2021-07-31 16:39:31 -04:00
										 |  |  |         num_samples=-1, | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |         time_budget_s=1, | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-07-31 16:39:31 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |     best_trial = analysis.get_best_trial() | 
					
						
							|  |  |  |     logger.info(f"CFO best config: {best_trial.config}") | 
					
						
							|  |  |  |     logger.info(f"CFO best result: {best_trial.last_result}") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-23 16:22:34 -07:00
										 |  |  |     bs = BlendSearch( | 
					
						
							|  |  |  |         experimental=True, | 
					
						
							|  |  |  |         space=search_space, | 
					
						
							|  |  |  |         metric="obj", | 
					
						
							|  |  |  |         mode="min", | 
					
						
							|  |  |  |         low_cost_partial_config={"cost_related": {"a": 1}}, | 
					
						
							|  |  |  |         points_to_evaluate=[ | 
					
						
							|  |  |  |             {"b": 0.99, "cost_related": {"a": 3}}, | 
					
						
							|  |  |  |             {"b": 0.99, "cost_related": {"a": 2}}, | 
					
						
							|  |  |  |             {"cost_related": {"a": 8}}, | 
					
						
							|  |  |  |         ], | 
					
						
							|  |  |  |         metric_constraints=[("ab", "<=", 4)], | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-08-02 19:10:26 -04:00
										 |  |  |     analysis = tune.run( | 
					
						
							|  |  |  |         simple_func, | 
					
						
							| 
									
										
										
										
											2022-04-23 16:22:34 -07:00
										 |  |  |         search_alg=bs, | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |         local_dir="logs/", | 
					
						
							| 
									
										
										
										
											2021-08-02 19:10:26 -04:00
										 |  |  |         num_samples=-1, | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |         time_budget_s=1, | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2022-04-23 16:22:34 -07:00
										 |  |  |     print(bs.results) | 
					
						
							| 
									
										
										
										
											2021-08-02 19:10:26 -04:00
										 |  |  |     best_trial = analysis.get_best_trial() | 
					
						
							|  |  |  |     logger.info(f"BlendSearch exp best config: {best_trial.config}") | 
					
						
							|  |  |  |     logger.info(f"BlendSearch exp best result: {best_trial.last_result}") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-04 01:42:21 -07:00
										 |  |  |     points_to_evaluate = [ | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |         {"b": 0.99, "cost_related": {"a": 3}}, | 
					
						
							|  |  |  |         {"b": 0.99, "cost_related": {"a": 2}}, | 
					
						
							| 
									
										
										
										
											2022-04-23 16:22:34 -07:00
										 |  |  |         {"cost_related": {"a": 8}}, | 
					
						
							| 
									
										
										
										
											2021-09-04 01:42:21 -07:00
										 |  |  |     ] | 
					
						
							| 
									
										
										
										
											2021-05-18 15:57:42 -07:00
										 |  |  |     analysis = tune.run( | 
					
						
							| 
									
										
										
										
											2021-02-28 12:43:43 -08:00
										 |  |  |         simple_func, | 
					
						
							| 
									
										
										
										
											2021-04-06 11:37:52 -07:00
										 |  |  |         config=search_space, | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |         low_cost_partial_config={"cost_related": {"a": 1}}, | 
					
						
							| 
									
										
										
										
											2021-09-04 01:42:21 -07:00
										 |  |  |         points_to_evaluate=points_to_evaluate, | 
					
						
							|  |  |  |         evaluated_rewards=[ | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |             (config["cost_related"]["a"] - 4) ** 2 | 
					
						
							|  |  |  |             + (config["b"] - config["cost_related"]["a"]) ** 2 | 
					
						
							| 
									
										
										
										
											2022-04-23 16:22:34 -07:00
										 |  |  |             for config in points_to_evaluate[:-1] | 
					
						
							| 
									
										
										
										
											2021-09-04 01:42:21 -07:00
										 |  |  |         ], | 
					
						
							| 
									
										
										
										
											2021-05-18 15:57:42 -07:00
										 |  |  |         metric="obj", | 
					
						
							| 
									
										
										
										
											2021-02-28 12:43:43 -08:00
										 |  |  |         mode="min", | 
					
						
							| 
									
										
										
										
											2021-05-18 15:57:42 -07:00
										 |  |  |         metric_constraints=[("ab", "<=", 4)], | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |         local_dir="logs/", | 
					
						
							| 
									
										
										
										
											2021-02-28 12:43:43 -08:00
										 |  |  |         num_samples=-1, | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |         time_budget_s=1, | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-02-28 12:43:43 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-05-18 15:57:42 -07:00
										 |  |  |     best_trial = analysis.get_best_trial() | 
					
						
							| 
									
										
										
										
											2021-07-31 16:39:31 -04:00
										 |  |  |     logger.info(f"BlendSearch best config: {best_trial.config}") | 
					
						
							|  |  |  |     logger.info(f"BlendSearch best result: {best_trial.last_result}") | 
					
						
							| 
									
										
										
										
											2021-05-18 15:57:42 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-02-28 12:43:43 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-01 21:55:38 -03:00
										 |  |  | def test_run_training_function_return_value(): | 
					
						
							|  |  |  |     from flaml import tune | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Test dict return value | 
					
						
							|  |  |  |     def evaluate_config_dict(config): | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |         metric = (round(config["x"]) - 85000) ** 2 - config["x"] / config["y"] | 
					
						
							| 
									
										
										
										
											2021-08-01 21:55:38 -03:00
										 |  |  |         return {"metric": metric} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     tune.run( | 
					
						
							|  |  |  |         evaluate_config_dict, | 
					
						
							|  |  |  |         config={ | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |             "x": tune.qloguniform(lower=1, upper=100000, q=1), | 
					
						
							|  |  |  |             "y": tune.qrandint(lower=2, upper=100000, q=2), | 
					
						
							| 
									
										
										
										
											2021-08-01 21:55:38 -03:00
										 |  |  |         }, | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |         metric="metric", | 
					
						
							|  |  |  |         mode="max", | 
					
						
							| 
									
										
										
										
											2021-08-12 02:02:22 -04:00
										 |  |  |         num_samples=100, | 
					
						
							| 
									
										
										
										
											2021-08-01 21:55:38 -03:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Test scalar return value | 
					
						
							|  |  |  |     def evaluate_config_scalar(config): | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |         metric = (round(config["x"]) - 85000) ** 2 - config["x"] / config["y"] | 
					
						
							| 
									
										
										
										
											2021-08-01 21:55:38 -03:00
										 |  |  |         return metric | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     tune.run( | 
					
						
							|  |  |  |         evaluate_config_scalar, | 
					
						
							|  |  |  |         config={ | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |             "x": tune.qloguniform(lower=1, upper=100000, q=1), | 
					
						
							|  |  |  |             "y": tune.qlograndint(lower=2, upper=100000, q=2), | 
					
						
							| 
									
										
										
										
											2021-08-01 21:55:38 -03:00
										 |  |  |         }, | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |         num_samples=100, | 
					
						
							|  |  |  |         mode="max", | 
					
						
							| 
									
										
										
										
											2021-08-01 21:55:38 -03:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-03-28 16:57:52 -07:00
										 |  |  |     # Test empty return value | 
					
						
							|  |  |  |     def evaluate_config_empty(config): | 
					
						
							|  |  |  |         return {} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     tune.run( | 
					
						
							|  |  |  |         evaluate_config_empty, | 
					
						
							|  |  |  |         config={ | 
					
						
							|  |  |  |             "x": tune.qloguniform(lower=1, upper=100000, q=1), | 
					
						
							|  |  |  |             "y": tune.qlograndint(lower=2, upper=100000, q=2), | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |         num_samples=10, | 
					
						
							|  |  |  |         mode="max", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-01 21:55:38 -03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  | def test_xgboost_bs(): | 
					
						
							|  |  |  |     _test_xgboost() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-07-20 17:00:44 -07:00
										 |  |  | def _test_xgboost_cfo(): | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |     _test_xgboost("CFO") | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-07-20 17:00:44 -07:00
										 |  |  | def test_xgboost_cfocat(): | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |     _test_xgboost("CFOCat") | 
					
						
							| 
									
										
										
										
											2021-07-20 17:00:44 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  | def _test_xgboost_dragonfly(): | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |     _test_xgboost("Dragonfly") | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _test_xgboost_skopt(): | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |     _test_xgboost("SkOpt") | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _test_xgboost_nevergrad(): | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |     _test_xgboost("Nevergrad") | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _test_xgboost_zoopt(): | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |     _test_xgboost("ZOOpt") | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _test_xgboost_ax(): | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |     _test_xgboost("Ax") | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def __test_xgboost_hyperopt(): | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |     _test_xgboost("HyperOpt") | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _test_xgboost_optuna(): | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |     _test_xgboost("Optuna") | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _test_xgboost_asha(): | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |     _test_xgboost("ASHA") | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _test_xgboost_bohb(): | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |     _test_xgboost("BOHB") | 
					
						
							| 
									
										
										
										
											2021-02-05 21:41:14 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							| 
									
										
										
										
											2021-05-18 15:57:42 -07:00
										 |  |  |     test_xgboost_bs() |