2022-02-11 20:14:10 -08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								from ray_on_aml.core import Ray_On_AML
							 | 
						
					
						
							
								
									
										
										
										
											2021-12-23 13:37:07 -08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								import lightgbm as lgb
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								import numpy as np
							 | 
						
					
						
							
								
									
										
										
										
											2021-12-25 16:13:39 -08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								from sklearn.datasets import load_breast_cancer
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from sklearn.metrics import accuracy_score
							 | 
						
					
						
							
								
									
										
										
										
											2021-12-23 13:37:07 -08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from sklearn.model_selection import train_test_split
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from flaml import tune
							 | 
						
					
						
							
								
									
										
										
										
											2022-12-06 20:46:08 +00:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								from flaml.automl.model import LGBMEstimator
							 | 
						
					
						
							
								
									
										
										
										
											2021-12-23 13:37:07 -08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								def train_breast_cancer(config):
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    params = LGBMEstimator(**config).params
							 | 
						
					
						
							
								
									
										
										
										
											2022-01-30 19:36:41 -08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    X_train = ray.get(X_train_ref)
							 | 
						
					
						
							
								
									
										
										
										
											2021-12-25 16:13:39 -08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    train_set = lgb.Dataset(X_train, label=y_train)
							 | 
						
					
						
							
								
									
										
										
										
											2021-12-23 23:05:14 -08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    gbm = lgb.train(params, train_set)
							 | 
						
					
						
							
								
									
										
										
										
											2021-12-25 16:13:39 -08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    preds = gbm.predict(X_test)
							 | 
						
					
						
							
								
									
										
										
										
											2021-12-23 13:37:07 -08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    pred_labels = np.rint(preds)
							 | 
						
					
						
							
								
									
										
										
										
											2021-12-25 16:13:39 -08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    tune.report(mean_accuracy=accuracy_score(y_test, pred_labels), done=True)
							 | 
						
					
						
							
								
									
										
										
										
											2021-12-23 13:37:07 -08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								if __name__ == "__main__":
							 | 
						
					
						
							
								
									
										
										
										
											2022-02-11 20:14:10 -08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    ray_on_aml = Ray_On_AML()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    ray = ray_on_aml.getRay()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    if ray:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        X, y = load_breast_cancer(return_X_y=True)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        X_train_ref = ray.put(X_train)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        flaml_lgbm_search_space = LGBMEstimator.search_space(X_train.shape)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        config_search_space = {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            hp: space["domain"] for hp, space in flaml_lgbm_search_space.items()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        low_cost_partial_config = {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            hp: space["low_cost_init_value"]
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            for hp, space in flaml_lgbm_search_space.items()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            if "low_cost_init_value" in space
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        }
							 | 
						
					
						
							
								
									
										
										
										
											2021-12-23 13:37:07 -08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2022-02-11 20:14:10 -08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        analysis = tune.run(
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            train_breast_cancer,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            metric="mean_accuracy",
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            mode="max",
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            config=config_search_space,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            num_samples=-1,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            time_budget_s=60,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            use_ray=True,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        )
							 | 
						
					
						
							
								
									
										
										
										
											2021-12-23 13:37:07 -08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2022-02-11 20:14:10 -08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        # print("Best hyperparameters found were: ", analysis.best_config)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        print("The best trial's result: ", analysis.best_trial.last_result)
							 |