mirror of
				https://github.com/microsoft/autogen.git
				synced 2025-10-31 01:40:58 +00:00 
			
		
		
		
	fix a bug when using ray & update ray on aml (#455)
* fix a bug when using ray & update ray on aml When using with_parameters(), the config argument must be the first argument in the trainable function. * make training function runnable standalone
This commit is contained in:
		
							parent
							
								
									b4d312412a
								
							
						
					
					
						commit
						9e88f22167
					
				| @ -213,41 +213,42 @@ class AutoMLState: | |||||||
|                 groups = self.groups_all |                 groups = self.groups_all | ||||||
|         return sampled_X_train, sampled_y_train, sampled_weight, groups |         return sampled_X_train, sampled_y_train, sampled_weight, groups | ||||||
| 
 | 
 | ||||||
|     def _compute_with_config_base(self, estimator, config_w_resource): |     @staticmethod | ||||||
|  |     def _compute_with_config_base(config_w_resource, state, estimator): | ||||||
|         if "FLAML_sample_size" in config_w_resource: |         if "FLAML_sample_size" in config_w_resource: | ||||||
|             sample_size = int(config_w_resource["FLAML_sample_size"]) |             sample_size = int(config_w_resource["FLAML_sample_size"]) | ||||||
|         else: |         else: | ||||||
|             sample_size = self.data_size[0] |             sample_size = state.data_size[0] | ||||||
|         ( |         ( | ||||||
|             sampled_X_train, |             sampled_X_train, | ||||||
|             sampled_y_train, |             sampled_y_train, | ||||||
|             sampled_weight, |             sampled_weight, | ||||||
|             groups, |             groups, | ||||||
|         ) = self._prepare_sample_train_data(sample_size) |         ) = state._prepare_sample_train_data(sample_size) | ||||||
|         if sampled_weight is not None: |         if sampled_weight is not None: | ||||||
|             weight = self.fit_kwargs["sample_weight"] |             weight = state.fit_kwargs["sample_weight"] | ||||||
|             self.fit_kwargs["sample_weight"] = sampled_weight |             state.fit_kwargs["sample_weight"] = sampled_weight | ||||||
|         else: |         else: | ||||||
|             weight = None |             weight = None | ||||||
|         if groups is not None: |         if groups is not None: | ||||||
|             self.fit_kwargs["groups"] = groups |             state.fit_kwargs["groups"] = groups | ||||||
|         config = config_w_resource.copy() |         config = config_w_resource.copy() | ||||||
|         if "FLAML_sample_size" in config: |         if "FLAML_sample_size" in config: | ||||||
|             del config["FLAML_sample_size"] |             del config["FLAML_sample_size"] | ||||||
|         budget = ( |         budget = ( | ||||||
|             None |             None | ||||||
|             if self.time_budget is None |             if state.time_budget is None | ||||||
|             else self.time_budget - self.time_from_start |             else state.time_budget - state.time_from_start | ||||||
|             if sample_size == self.data_size[0] |             if sample_size == state.data_size[0] | ||||||
|             else (self.time_budget - self.time_from_start) |             else (state.time_budget - state.time_from_start) | ||||||
|             / 2 |             / 2 | ||||||
|             * sample_size |             * sample_size | ||||||
|             / self.data_size[0] |             / state.data_size[0] | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         if _is_nlp_task(self.task): |         if _is_nlp_task(state.task): | ||||||
|             self.fit_kwargs["X_val"] = self.X_val |             state.fit_kwargs["X_val"] = state.X_val | ||||||
|             self.fit_kwargs["y_val"] = self.y_val |             state.fit_kwargs["y_val"] = state.y_val | ||||||
| 
 | 
 | ||||||
|         ( |         ( | ||||||
|             trained_estimator, |             trained_estimator, | ||||||
| @ -258,41 +259,42 @@ class AutoMLState: | |||||||
|         ) = compute_estimator( |         ) = compute_estimator( | ||||||
|             sampled_X_train, |             sampled_X_train, | ||||||
|             sampled_y_train, |             sampled_y_train, | ||||||
|             self.X_val, |             state.X_val, | ||||||
|             self.y_val, |             state.y_val, | ||||||
|             self.weight_val, |             state.weight_val, | ||||||
|             self.groups_val, |             state.groups_val, | ||||||
|             self.train_time_limit |             state.train_time_limit | ||||||
|             if budget is None |             if budget is None | ||||||
|             else min(budget, self.train_time_limit), |             else min(budget, state.train_time_limit), | ||||||
|             self.kf, |             state.kf, | ||||||
|             config, |             config, | ||||||
|             self.task, |             state.task, | ||||||
|             estimator, |             estimator, | ||||||
|             self.eval_method, |             state.eval_method, | ||||||
|             self.metric, |             state.metric, | ||||||
|             self.best_loss, |             state.best_loss, | ||||||
|             self.n_jobs, |             state.n_jobs, | ||||||
|             self.learner_classes.get(estimator), |             state.learner_classes.get(estimator), | ||||||
|             self.log_training_metric, |             state.log_training_metric, | ||||||
|             self.fit_kwargs, |             state.fit_kwargs, | ||||||
|         ) |         ) | ||||||
|         if self.retrain_final and not self.model_history: |         if state.retrain_final and not state.model_history: | ||||||
|             trained_estimator.cleanup() |             trained_estimator.cleanup() | ||||||
| 
 | 
 | ||||||
|         if _is_nlp_task(self.task): |         if _is_nlp_task(state.task): | ||||||
|             del self.fit_kwargs["X_val"] |             del state.fit_kwargs["X_val"] | ||||||
|             del self.fit_kwargs["y_val"] |             del state.fit_kwargs["y_val"] | ||||||
| 
 | 
 | ||||||
|         result = { |         result = { | ||||||
|             "pred_time": pred_time, |             "pred_time": pred_time, | ||||||
|             "wall_clock_time": time.time() - self._start_time_flag, |             "wall_clock_time": time.time() - state._start_time_flag, | ||||||
|             "metric_for_logging": metric_for_logging, |             "metric_for_logging": metric_for_logging, | ||||||
|             "val_loss": val_loss, |             "val_loss": val_loss, | ||||||
|             "trained_estimator": trained_estimator, |             "trained_estimator": trained_estimator, | ||||||
|         } |         } | ||||||
|         if sampled_weight is not None: |         if sampled_weight is not None: | ||||||
|             self.fit_kwargs["sample_weight"] = weight |             state.fit_kwargs["sample_weight"] = weight | ||||||
|  |         tune.report(**result) | ||||||
|         return result |         return result | ||||||
| 
 | 
 | ||||||
|     def _train_with_config( |     def _train_with_config( | ||||||
| @ -1672,12 +1674,14 @@ class AutoML(BaseEstimator): | |||||||
| 
 | 
 | ||||||
|                     search_state.training_function = with_parameters( |                     search_state.training_function = with_parameters( | ||||||
|                         AutoMLState._compute_with_config_base, |                         AutoMLState._compute_with_config_base, | ||||||
|                         self=self._state, |                         state=self._state, | ||||||
|                         estimator=estimator, |                         estimator=estimator, | ||||||
|                     ) |                     ) | ||||||
|                 else: |                 else: | ||||||
|                     search_state.training_function = partial( |                     search_state.training_function = partial( | ||||||
|                         AutoMLState._compute_with_config_base, self._state, estimator |                         AutoMLState._compute_with_config_base, | ||||||
|  |                         state=self._state, | ||||||
|  |                         estimator=estimator, | ||||||
|                     ) |                     ) | ||||||
|         states = self._search_states |         states = self._search_states | ||||||
|         mem_res = self._mem_thres |         mem_res = self._mem_thres | ||||||
| @ -2461,7 +2465,9 @@ class AutoML(BaseEstimator): | |||||||
|             ) |             ) | ||||||
|             if not search_state.search_alg: |             if not search_state.search_alg: | ||||||
|                 search_state.training_function = partial( |                 search_state.training_function = partial( | ||||||
|                     AutoMLState._compute_with_config_base, self._state, estimator |                     AutoMLState._compute_with_config_base, | ||||||
|  |                     state=self._state, | ||||||
|  |                     estimator=estimator, | ||||||
|                 ) |                 ) | ||||||
|                 search_space = search_state.search_space |                 search_space = search_state.search_space | ||||||
|                 if self._sample: |                 if self._sample: | ||||||
|  | |||||||
| @ -125,3 +125,7 @@ class SequentialTrialRunner(BaseTrialRunner): | |||||||
|             trial = None |             trial = None | ||||||
|         self.running_trial = trial |         self.running_trial = trial | ||||||
|         return trial |         return trial | ||||||
|  | 
 | ||||||
|  |     def stop_trial(self, trial): | ||||||
|  |         super().stop_trial(trial) | ||||||
|  |         self.running_trial = None | ||||||
|  | |||||||
| @ -90,7 +90,9 @@ def report(_metric=None, **kwargs): | |||||||
|         result = kwargs |         result = kwargs | ||||||
|         if _metric: |         if _metric: | ||||||
|             result[DEFAULT_METRIC] = _metric |             result[DEFAULT_METRIC] = _metric | ||||||
|         trial = _runner.running_trial |         trial = getattr(_runner, "running_trial", None) | ||||||
|  |         if not trial: | ||||||
|  |             return None | ||||||
|         if _running_trial == trial: |         if _running_trial == trial: | ||||||
|             _training_iteration += 1 |             _training_iteration += 1 | ||||||
|         else: |         else: | ||||||
| @ -102,11 +104,11 @@ def report(_metric=None, **kwargs): | |||||||
|             del result["config"][INCUMBENT_RESULT] |             del result["config"][INCUMBENT_RESULT] | ||||||
|         for key, value in trial.config.items(): |         for key, value in trial.config.items(): | ||||||
|             result["config/" + key] = value |             result["config/" + key] = value | ||||||
|         _runner.process_trial_result(_runner.running_trial, result) |         _runner.process_trial_result(trial, result) | ||||||
|         result["time_total_s"] = trial.last_update_time - trial.start_time |         result["time_total_s"] = trial.last_update_time - trial.start_time | ||||||
|         if _verbose > 2: |         if _verbose > 2: | ||||||
|             logger.info(f"result: {result}") |             logger.info(f"result: {result}") | ||||||
|         if _runner.running_trial.is_finished(): |         if trial.is_finished(): | ||||||
|             return None |             return None | ||||||
|         else: |         else: | ||||||
|             return True |             return True | ||||||
|  | |||||||
| @ -1 +1 @@ | |||||||
| __version__ = "0.9.6" | __version__ = "0.9.7" | ||||||
|  | |||||||
| @ -1,8 +1,8 @@ | |||||||
| FROM python:3.7 | FROM mcr.microsoft.com/azureml/openmpi3.1.2-ubuntu18.04 | ||||||
| 
 | 
 | ||||||
| RUN pip install azureml-core | RUN pip install azureml-core | ||||||
| RUN pip install flaml[blendsearch,ray] | RUN pip install flaml[blendsearch,ray] | ||||||
| 
 | RUN pip install ray-on-aml | ||||||
| 
 | 
 | ||||||
| EXPOSE 8265 | EXPOSE 8265 | ||||||
| EXPOSE 6379 | EXPOSE 6379 | ||||||
|  | |||||||
| @ -1,15 +1,17 @@ | |||||||
|  | from ray_on_aml.core import Ray_On_AML | ||||||
| from flaml import AutoML | from flaml import AutoML | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _test_ray_classification(): | def _test_ray_classification(): | ||||||
|     from sklearn.datasets import make_classification |     from sklearn.datasets import make_classification | ||||||
|     import ray |  | ||||||
| 
 | 
 | ||||||
|     ray.init(address="auto") |  | ||||||
|     X, y = make_classification(1000, 10) |     X, y = make_classification(1000, 10) | ||||||
|     automl = AutoML() |     automl = AutoML() | ||||||
|     automl.fit(X, y, time_budget=10, task="classification", n_concurrent_trials=2) |     automl.fit(X, y, time_budget=10, task="classification", n_concurrent_trials=2) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|  |     ray_on_aml = Ray_On_AML() | ||||||
|  |     ray = ray_on_aml.getRay() | ||||||
|  |     if ray: | ||||||
|         _test_ray_classification() |         _test_ray_classification() | ||||||
|  | |||||||
| @ -1,4 +1,4 @@ | |||||||
| import ray | from ray_on_aml.core import Ray_On_AML | ||||||
| import lightgbm as lgb | import lightgbm as lgb | ||||||
| import numpy as np | import numpy as np | ||||||
| from sklearn.datasets import load_breast_cancer | from sklearn.datasets import load_breast_cancer | ||||||
| @ -7,11 +7,6 @@ from sklearn.model_selection import train_test_split | |||||||
| from flaml import tune | from flaml import tune | ||||||
| from flaml.model import LGBMEstimator | from flaml.model import LGBMEstimator | ||||||
| 
 | 
 | ||||||
| X, y = load_breast_cancer(return_X_y=True) |  | ||||||
| X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25) |  | ||||||
| ray.init(address="auto") |  | ||||||
| X_train_ref = ray.put(X_train) |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| def train_breast_cancer(config): | def train_breast_cancer(config): | ||||||
|     params = LGBMEstimator(**config).params |     params = LGBMEstimator(**config).params | ||||||
| @ -24,6 +19,12 @@ def train_breast_cancer(config): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|  |     ray_on_aml = Ray_On_AML() | ||||||
|  |     ray = ray_on_aml.getRay() | ||||||
|  |     if ray: | ||||||
|  |         X, y = load_breast_cancer(return_X_y=True) | ||||||
|  |         X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25) | ||||||
|  |         X_train_ref = ray.put(X_train) | ||||||
|         flaml_lgbm_search_space = LGBMEstimator.search_space(X_train.shape) |         flaml_lgbm_search_space = LGBMEstimator.search_space(X_train.shape) | ||||||
|         config_search_space = { |         config_search_space = { | ||||||
|             hp: space["domain"] for hp, space in flaml_lgbm_search_space.items() |             hp: space["domain"] for hp, space in flaml_lgbm_search_space.items() | ||||||
|  | |||||||
| @ -1,5 +1,6 @@ | |||||||
| import time | import time | ||||||
| from azureml.core import Workspace, Experiment, ScriptRunConfig, Environment | from azureml.core import Workspace, Experiment, ScriptRunConfig, Environment | ||||||
|  | from azureml.core.runconfig import RunConfiguration, DockerConfiguration | ||||||
| 
 | 
 | ||||||
| ws = Workspace.from_config() | ws = Workspace.from_config() | ||||||
| ray_environment_name = "aml-ray-cpu" | ray_environment_name = "aml-ray-cpu" | ||||||
| @ -18,20 +19,21 @@ while ray_cpu_build_details.status not in ["Succeeded", "Failed"]: | |||||||
|     ) |     ) | ||||||
|     time.sleep(10) |     time.sleep(10) | ||||||
| 
 | 
 | ||||||
|  | command = ["python distribute_automl.py"] | ||||||
| env = Environment.get(workspace=ws, name=ray_environment_name) | env = Environment.get(workspace=ws, name=ray_environment_name) | ||||||
| compute_target = ws.compute_targets["cpucluster"] | compute_target = ws.compute_targets["cpucluster"] | ||||||
| command = ["python automl.py"] | aml_run_config = RunConfiguration(communicator="OpenMpi") | ||||||
|  | aml_run_config.target = compute_target | ||||||
|  | aml_run_config.docker = DockerConfiguration(use_docker=True) | ||||||
|  | aml_run_config.environment = env | ||||||
|  | aml_run_config.node_count = 2 | ||||||
| config = ScriptRunConfig( | config = ScriptRunConfig( | ||||||
|     source_directory="ray/", |     source_directory="ray/", | ||||||
|     command=command, |     command=command, | ||||||
|     compute_target=compute_target, |     run_config=aml_run_config, | ||||||
|     environment=env, |  | ||||||
| ) | ) | ||||||
| config.run_config.node_count = 2 |  | ||||||
| config.run_config.environment_variables["_AZUREML_CR_START_RAY"] = "true" |  | ||||||
| config.run_config.environment_variables["AZUREML_COMPUTE_USE_COMMON_RUNTIME"] = "true" |  | ||||||
| 
 | 
 | ||||||
| exp = Experiment(ws, "test-ray") | exp = Experiment(ws, "distribute-automl") | ||||||
| run = exp.submit(config) | run = exp.submit(config) | ||||||
| print(run.get_portal_url())  # link to ml.azure.com | print(run.get_portal_url())  # link to ml.azure.com | ||||||
| run.wait_for_completion(show_output=True) | run.wait_for_completion(show_output=True) | ||||||
|  | |||||||
| @ -1,5 +1,6 @@ | |||||||
| import time | import time | ||||||
| from azureml.core import Workspace, Experiment, ScriptRunConfig, Environment | from azureml.core import Workspace, Experiment, ScriptRunConfig, Environment | ||||||
|  | from azureml.core.runconfig import RunConfiguration, DockerConfiguration | ||||||
| 
 | 
 | ||||||
| ws = Workspace.from_config() | ws = Workspace.from_config() | ||||||
| ray_environment_name = "aml-ray-cpu" | ray_environment_name = "aml-ray-cpu" | ||||||
| @ -18,20 +19,21 @@ while ray_cpu_build_details.status not in ["Succeeded", "Failed"]: | |||||||
|     ) |     ) | ||||||
|     time.sleep(10) |     time.sleep(10) | ||||||
| 
 | 
 | ||||||
|  | command = ["python distribute_tune.py"] | ||||||
| env = Environment.get(workspace=ws, name=ray_environment_name) | env = Environment.get(workspace=ws, name=ray_environment_name) | ||||||
| compute_target = ws.compute_targets["cpucluster"] | compute_target = ws.compute_targets["cpucluster"] | ||||||
| command = ["python distribute_tune.py"] | aml_run_config = RunConfiguration(communicator="OpenMpi") | ||||||
|  | aml_run_config.target = compute_target | ||||||
|  | aml_run_config.docker = DockerConfiguration(use_docker=True) | ||||||
|  | aml_run_config.environment = env | ||||||
|  | aml_run_config.node_count = 2 | ||||||
| config = ScriptRunConfig( | config = ScriptRunConfig( | ||||||
|     source_directory="ray/", |     source_directory="ray/", | ||||||
|     command=command, |     command=command, | ||||||
|     compute_target=compute_target, |     run_config=aml_run_config, | ||||||
|     environment=env, |  | ||||||
| ) | ) | ||||||
| config.run_config.node_count = 2 |  | ||||||
| config.run_config.environment_variables["_AZUREML_CR_START_RAY"] = "true" |  | ||||||
| config.run_config.environment_variables["AZUREML_COMPUTE_USE_COMMON_RUNTIME"] = "true" |  | ||||||
| 
 | 
 | ||||||
| exp = Experiment(ws, "test-ray") | exp = Experiment(ws, "distribute-tune") | ||||||
| run = exp.submit(config) | run = exp.submit(config) | ||||||
| print(run.get_portal_url())  # link to ml.azure.com | print(run.get_portal_url())  # link to ml.azure.com | ||||||
| run.wait_for_completion(show_output=True) | run.wait_for_completion(show_output=True) | ||||||
|  | |||||||
| @ -128,40 +128,35 @@ If the computer target "cpucluster" already exists, it will not be recreated. | |||||||
| 
 | 
 | ||||||
| #### Run distributed AutoML job | #### Run distributed AutoML job | ||||||
| 
 | 
 | ||||||
| Assuming you have an automl script like [ray/distribute_automl.py](https://github.com/microsoft/FLAML/blob/main/test/ray/distribute_automl.py). It uses `ray.init(address="auto")` to initialize the cluster, and uses `n_concurrent_trials=k` to inform `AutoML.fit()` to perform k concurrent trials in parallel. | Assuming you have an automl script like [ray/distribute_automl.py](https://github.com/microsoft/FLAML/blob/main/test/ray/distribute_automl.py). It uses `n_concurrent_trials=k` to inform `AutoML.fit()` to perform k concurrent trials in parallel. | ||||||
| 
 | 
 | ||||||
| Submit an AzureML job as the following: | Submit an AzureML job as the following: | ||||||
| 
 | 
 | ||||||
| ```python | ```python | ||||||
| from azureml.core import Workspace, Experiment, ScriptRunConfig, Environment | from azureml.core import Workspace, Experiment, ScriptRunConfig, Environment | ||||||
|  | from azureml.core.runconfig import RunConfiguration, DockerConfiguration | ||||||
| 
 | 
 | ||||||
| command = ["python distribute_automl.py"] | command = ["python distribute_automl.py"] | ||||||
| ray_environment_name = 'aml-ray-cpu' | ray_environment_name = "aml-ray-cpu" | ||||||
| env = Environment.get(workspace=ws, name=ray_environment_name) | env = Environment.get(workspace=ws, name=ray_environment_name) | ||||||
|  | aml_run_config = RunConfiguration(communicator="OpenMpi") | ||||||
|  | aml_run_config.target = compute_target | ||||||
|  | aml_run_config.docker = DockerConfiguration(use_docker=True) | ||||||
|  | aml_run_config.environment = env | ||||||
|  | aml_run_config.node_count = 2 | ||||||
| config = ScriptRunConfig( | config = ScriptRunConfig( | ||||||
|     source_directory='ray/', |     source_directory="ray/", | ||||||
|     command=command, |     command=command, | ||||||
|     compute_target=compute_target, |     run_config=aml_run_config, | ||||||
|     environment=env, |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| config.run_config.node_count = 2 | exp = Experiment(ws, "distribute-automl") | ||||||
| config.run_config.environment_variables["_AZUREML_CR_START_RAY"] = "true" |  | ||||||
| config.run_config.environment_variables["AZUREML_COMPUTE_USE_COMMON_RUNTIME"] = "true" |  | ||||||
| 
 |  | ||||||
| exp = Experiment(ws, 'distribute-automl') |  | ||||||
| run = exp.submit(config) | run = exp.submit(config) | ||||||
| 
 | 
 | ||||||
| print(run.get_portal_url())  # link to ml.azure.com | print(run.get_portal_url())  # link to ml.azure.com | ||||||
| run.wait_for_completion(show_output=True) | run.wait_for_completion(show_output=True) | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
| The line |  | ||||||
| ` |  | ||||||
| config.run_config.environment_variables["_AZUREML_CR_START_RAY"] = "true" |  | ||||||
| ` |  | ||||||
| tells AzureML to start ray on each node of the cluster. This ia a feature in preview and it is subject to change in future. It is applicable to dedicated VMs only. |  | ||||||
| 
 |  | ||||||
| #### Run distributed tune job | #### Run distributed tune job | ||||||
| 
 | 
 | ||||||
| Prepare a script like [ray/distribute_tune.py](https://github.com/microsoft/FLAML/blob/main/test/ray/distribute_tune.py). Replace the command in the above eample with: | Prepare a script like [ray/distribute_tune.py](https://github.com/microsoft/FLAML/blob/main/test/ray/distribute_tune.py). Replace the command in the above eample with: | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Chi Wang
						Chi Wang