| 
									
										
										
										
											2022-10-12 20:04:42 -04:00
										 |  |  | from utils import ( | 
					
						
							|  |  |  |     get_toy_data_regression, | 
					
						
							|  |  |  |     get_toy_data_binclassification, | 
					
						
							|  |  |  |     get_toy_data_multiclassclassification, | 
					
						
							|  |  |  |     get_automl_settings, | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | import sys | 
					
						
							|  |  |  | import pytest | 
					
						
							|  |  |  | import os | 
					
						
							|  |  |  | import shutil | 
					
						
							| 
									
										
										
										
											2022-04-28 14:06:29 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-12 20:04:42 -04:00
										 |  |  | data_list = [ | 
					
						
							|  |  |  |     "get_toy_data_regression", | 
					
						
							|  |  |  |     "get_toy_data_binclassification", | 
					
						
							|  |  |  |     "get_toy_data_multiclassclassification", | 
					
						
							|  |  |  | ] | 
					
						
							|  |  |  | model_path_list = [ | 
					
						
							|  |  |  |     "textattack/bert-base-uncased-STS-B", | 
					
						
							|  |  |  |     "textattack/bert-base-uncased-SST-2", | 
					
						
							|  |  |  |     "textattack/bert-base-uncased-MNLI", | 
					
						
							|  |  |  | ] | 
					
						
							| 
									
										
										
										
											2022-04-28 14:06:29 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-12 20:04:42 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | def test_switch_1_1(): | 
					
						
							|  |  |  |     data_idx, model_path_idx = 0, 0 | 
					
						
							| 
									
										
										
										
											2023-04-10 21:50:40 +02:00
										 |  |  |     _test_switch_classificationhead(data_list[data_idx], model_path_list[model_path_idx]) | 
					
						
							| 
									
										
										
										
											2022-10-12 20:04:42 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_switch_1_2(): | 
					
						
							|  |  |  |     data_idx, model_path_idx = 0, 1 | 
					
						
							| 
									
										
										
										
											2023-04-10 21:50:40 +02:00
										 |  |  |     _test_switch_classificationhead(data_list[data_idx], model_path_list[model_path_idx]) | 
					
						
							| 
									
										
										
										
											2022-10-12 20:04:42 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_switch_1_3(): | 
					
						
							|  |  |  |     data_idx, model_path_idx = 0, 2 | 
					
						
							| 
									
										
										
										
											2023-04-10 21:50:40 +02:00
										 |  |  |     _test_switch_classificationhead(data_list[data_idx], model_path_list[model_path_idx]) | 
					
						
							| 
									
										
										
										
											2022-10-12 20:04:42 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_switch_2_1(): | 
					
						
							|  |  |  |     data_idx, model_path_idx = 1, 0 | 
					
						
							| 
									
										
										
										
											2023-04-10 21:50:40 +02:00
										 |  |  |     _test_switch_classificationhead(data_list[data_idx], model_path_list[model_path_idx]) | 
					
						
							| 
									
										
										
										
											2022-10-12 20:04:42 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_switch_2_2(): | 
					
						
							|  |  |  |     data_idx, model_path_idx = 1, 1 | 
					
						
							| 
									
										
										
										
											2023-04-10 21:50:40 +02:00
										 |  |  |     _test_switch_classificationhead(data_list[data_idx], model_path_list[model_path_idx]) | 
					
						
							| 
									
										
										
										
											2022-10-12 20:04:42 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_switch_2_3(): | 
					
						
							|  |  |  |     data_idx, model_path_idx = 1, 2 | 
					
						
							| 
									
										
										
										
											2023-04-10 21:50:40 +02:00
										 |  |  |     _test_switch_classificationhead(data_list[data_idx], model_path_list[model_path_idx]) | 
					
						
							| 
									
										
										
										
											2022-10-12 20:04:42 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_switch_3_1(): | 
					
						
							|  |  |  |     data_idx, model_path_idx = 2, 0 | 
					
						
							| 
									
										
										
										
											2023-04-10 21:50:40 +02:00
										 |  |  |     _test_switch_classificationhead(data_list[data_idx], model_path_list[model_path_idx]) | 
					
						
							| 
									
										
										
										
											2022-10-12 20:04:42 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_switch_3_2(): | 
					
						
							|  |  |  |     data_idx, model_path_idx = 2, 1 | 
					
						
							| 
									
										
										
										
											2023-04-10 21:50:40 +02:00
										 |  |  |     _test_switch_classificationhead(data_list[data_idx], model_path_list[model_path_idx]) | 
					
						
							| 
									
										
										
										
											2022-10-12 20:04:42 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_switch_3_3(): | 
					
						
							|  |  |  |     data_idx, model_path_idx = 2, 2 | 
					
						
							| 
									
										
										
										
											2023-04-10 21:50:40 +02:00
										 |  |  |     _test_switch_classificationhead(data_list[data_idx], model_path_list[model_path_idx]) | 
					
						
							| 
									
										
										
										
											2022-10-12 20:04:42 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _test_switch_classificationhead(each_data, each_model_path): | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |     from flaml import AutoML | 
					
						
							| 
									
										
										
										
											2022-01-30 01:53:32 -05:00
										 |  |  |     import requests | 
					
						
							| 
									
										
										
										
											2022-01-24 17:24:14 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  |     automl = AutoML() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-12 20:04:42 -04:00
										 |  |  |     X_train, y_train, X_val, y_val = globals()[each_data]() | 
					
						
							| 
									
										
										
										
											2022-04-28 14:06:29 -04:00
										 |  |  |     automl_settings = get_automl_settings() | 
					
						
							| 
									
										
										
										
											2022-10-12 20:04:42 -04:00
										 |  |  |     automl_settings["model_path"] = each_model_path | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if each_data == "get_toy_data_regression": | 
					
						
							|  |  |  |         automl_settings["task"] = "seq-regression" | 
					
						
							|  |  |  |         automl_settings["metric"] = "pearsonr" | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         automl_settings["task"] = "seq-classification" | 
					
						
							|  |  |  |         automl_settings["metric"] = "accuracy" | 
					
						
							| 
									
										
										
										
											2021-11-16 14:06:20 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-01-30 01:53:32 -05:00
										 |  |  |     try: | 
					
						
							| 
									
										
										
										
											2023-04-10 21:50:40 +02:00
										 |  |  |         automl.fit(X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val, **automl_settings) | 
					
						
							| 
									
										
										
										
											2022-01-30 01:53:32 -05:00
										 |  |  |     except requests.exceptions.HTTPError: | 
					
						
							|  |  |  |         return | 
					
						
							| 
									
										
										
										
											2022-03-26 14:08:51 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-12 20:04:42 -04:00
										 |  |  |     if os.path.exists("test/data/output/"): | 
					
						
							| 
									
										
										
										
											2022-11-27 11:22:54 -08:00
										 |  |  |         try: | 
					
						
							|  |  |  |             shutil.rmtree("test/data/output/") | 
					
						
							|  |  |  |         except PermissionError: | 
					
						
							|  |  |  |             print("PermissionError when deleting test/data/output/") | 
					
						
							| 
									
										
										
										
											2022-10-12 20:04:42 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-03-26 14:08:51 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							| 
									
										
										
										
											2022-10-12 20:04:42 -04:00
										 |  |  |     _test_switch_classificationhead(data_list[0], model_path_list[0]) |