| 
									
										
										
										
											2022-01-03 13:44:10 -05:00
										 |  |  | import sys | 
					
						
							|  |  |  | import pytest | 
					
						
							| 
									
										
										
										
											2022-01-30 01:53:32 -05:00
										 |  |  | import requests | 
					
						
							| 
									
										
										
										
											2022-10-12 20:04:42 -04:00
										 |  |  | import os | 
					
						
							|  |  |  | import shutil | 
					
						
							| 
									
										
										
										
											2022-07-05 13:38:21 -04:00
										 |  |  | from utils import ( | 
					
						
							|  |  |  |     get_toy_data_tokenclassification_idlabel, | 
					
						
							|  |  |  |     get_toy_data_tokenclassification_tokenlabel, | 
					
						
							|  |  |  |     get_automl_settings, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2022-01-03 13:44:10 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-06-26 08:32:28 -07:00
										 |  |  | @pytest.mark.skipif( | 
					
						
							| 
									
										
										
										
											2022-11-27 11:22:54 -08:00
										 |  |  |     sys.platform in ["darwin", "win32"] or sys.version < "3.7", | 
					
						
							|  |  |  |     reason="do not run on mac os, windows or py<3.7", | 
					
						
							| 
									
										
										
										
											2022-06-26 08:32:28 -07:00
										 |  |  | ) | 
					
						
							| 
									
										
										
										
											2022-07-05 13:38:21 -04:00
										 |  |  | def test_tokenclassification_idlabel(): | 
					
						
							| 
									
										
										
										
											2022-01-03 13:44:10 -05:00
										 |  |  |     from flaml import AutoML | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-07-05 13:38:21 -04:00
										 |  |  |     X_train, y_train, X_val, y_val = get_toy_data_tokenclassification_idlabel() | 
					
						
							| 
									
										
										
										
											2022-01-03 13:44:10 -05:00
										 |  |  |     automl = AutoML() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-28 14:06:29 -04:00
										 |  |  |     automl_settings = get_automl_settings() | 
					
						
							|  |  |  |     automl_settings["task"] = "token-classification" | 
					
						
							| 
									
										
										
										
											2023-04-10 21:50:40 +02:00
										 |  |  |     automl_settings["metric"] = "seqeval:overall_f1"  # evaluating based on the overall_f1 of seqeval | 
					
						
							| 
									
										
										
										
											2022-05-10 17:22:57 -04:00
										 |  |  |     automl_settings["fit_kwargs_by_estimator"]["transformer"]["label_list"] = [ | 
					
						
							|  |  |  |         "O", | 
					
						
							|  |  |  |         "B-PER", | 
					
						
							|  |  |  |         "I-PER", | 
					
						
							|  |  |  |         "B-ORG", | 
					
						
							|  |  |  |         "I-ORG", | 
					
						
							|  |  |  |         "B-LOC", | 
					
						
							|  |  |  |         "I-LOC", | 
					
						
							|  |  |  |         "B-MISC", | 
					
						
							|  |  |  |         "I-MISC", | 
					
						
							|  |  |  |     ] | 
					
						
							| 
									
										
										
										
											2022-01-03 13:44:10 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-01-30 01:53:32 -05:00
										 |  |  |     try: | 
					
						
							| 
									
										
										
										
											2023-04-10 21:50:40 +02:00
										 |  |  |         automl.fit(X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val, **automl_settings) | 
					
						
							| 
									
										
										
										
											2022-01-30 01:53:32 -05:00
										 |  |  |     except requests.exceptions.HTTPError: | 
					
						
							|  |  |  |         return | 
					
						
							| 
									
										
										
										
											2022-01-03 13:44:10 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-07-05 13:38:21 -04:00
										 |  |  |     # perf test | 
					
						
							|  |  |  |     import json | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     with open("seqclass.log", "r") as fin: | 
					
						
							|  |  |  |         for line in fin: | 
					
						
							|  |  |  |             each_log = json.loads(line.strip("\n")) | 
					
						
							|  |  |  |             if "validation_loss" in each_log: | 
					
						
							|  |  |  |                 val_loss = each_log["validation_loss"] | 
					
						
							|  |  |  |                 min_inter_result = min( | 
					
						
							|  |  |  |                     each_dict.get("eval_automl_metric", sys.maxsize) | 
					
						
							|  |  |  |                     for each_dict in each_log["logged_metric"]["intermediate_results"] | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if min_inter_result != sys.maxsize: | 
					
						
							|  |  |  |                     assert val_loss == min_inter_result | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-12 20:04:42 -04:00
										 |  |  |     if os.path.exists("test/data/output/"): | 
					
						
							| 
									
										
										
										
											2022-11-27 11:22:54 -08:00
										 |  |  |         try: | 
					
						
							|  |  |  |             shutil.rmtree("test/data/output/") | 
					
						
							|  |  |  |         except PermissionError: | 
					
						
							|  |  |  |             print("PermissionError when deleting test/data/output/") | 
					
						
							| 
									
										
										
										
											2022-10-12 20:04:42 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-07-05 13:38:21 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | @pytest.mark.skipif( | 
					
						
							| 
									
										
										
										
											2022-11-27 11:22:54 -08:00
										 |  |  |     sys.platform in ["darwin", "win32"] or sys.version < "3.7", | 
					
						
							|  |  |  |     reason="do not run on mac os, windows or py<3.7", | 
					
						
							| 
									
										
										
										
											2022-07-05 13:38:21 -04:00
										 |  |  | ) | 
					
						
							|  |  |  | def test_tokenclassification_tokenlabel(): | 
					
						
							|  |  |  |     from flaml import AutoML | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     X_train, y_train, X_val, y_val = get_toy_data_tokenclassification_tokenlabel() | 
					
						
							|  |  |  |     automl = AutoML() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     automl_settings = get_automl_settings() | 
					
						
							|  |  |  |     automl_settings["task"] = "token-classification" | 
					
						
							| 
									
										
										
										
											2023-04-10 21:50:40 +02:00
										 |  |  |     automl_settings["metric"] = "seqeval:overall_f1"  # evaluating based on the overall_f1 of seqeval | 
					
						
							| 
									
										
										
										
											2022-07-05 13:38:21 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							| 
									
										
										
										
											2023-04-10 21:50:40 +02:00
										 |  |  |         automl.fit(X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val, **automl_settings) | 
					
						
							| 
									
										
										
										
											2022-07-05 13:38:21 -04:00
										 |  |  |     except requests.exceptions.HTTPError: | 
					
						
							|  |  |  |         return | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # perf test | 
					
						
							|  |  |  |     import json | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     with open("seqclass.log", "r") as fin: | 
					
						
							|  |  |  |         for line in fin: | 
					
						
							|  |  |  |             each_log = json.loads(line.strip("\n")) | 
					
						
							|  |  |  |             if "validation_loss" in each_log: | 
					
						
							|  |  |  |                 val_loss = each_log["validation_loss"] | 
					
						
							|  |  |  |                 min_inter_result = min( | 
					
						
							|  |  |  |                     each_dict.get("eval_automl_metric", sys.maxsize) | 
					
						
							|  |  |  |                     for each_dict in each_log["logged_metric"]["intermediate_results"] | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if min_inter_result != sys.maxsize: | 
					
						
							|  |  |  |                     assert val_loss == min_inter_result | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-12 20:04:42 -04:00
										 |  |  |     if os.path.exists("test/data/output/"): | 
					
						
							| 
									
										
										
										
											2022-11-27 11:22:54 -08:00
										 |  |  |         try: | 
					
						
							|  |  |  |             shutil.rmtree("test/data/output/") | 
					
						
							|  |  |  |         except PermissionError: | 
					
						
							|  |  |  |             print("PermissionError when deleting test/data/output/") | 
					
						
							| 
									
										
										
										
											2022-10-12 20:04:42 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-01-03 13:44:10 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							| 
									
										
										
										
											2022-07-05 13:38:21 -04:00
										 |  |  |     test_tokenclassification_idlabel() |