2020-12-04 09:40:27 -08:00
|
|
|
import unittest
|
|
|
|
|
|
|
|
from sklearn.datasets import fetch_openml
|
|
|
|
from flaml.automl import AutoML
|
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
from sklearn.metrics import accuracy_score
|
|
|
|
|
|
|
|
|
2021-06-15 18:52:57 -07:00
|
|
|
dataset = "credit"
|
2020-12-04 09:40:27 -08:00
|
|
|
|
|
|
|
|
|
|
|
def _test(split_type):
|
|
|
|
automl = AutoML()
|
|
|
|
|
|
|
|
automl_settings = {
|
|
|
|
"time_budget": 2,
|
|
|
|
# "metric": 'accuracy',
|
|
|
|
"task": 'classification',
|
|
|
|
"log_file_name": "test/{}.log".format(dataset),
|
|
|
|
"model_history": True,
|
|
|
|
"log_training_metric": True,
|
|
|
|
"split_type": split_type,
|
|
|
|
}
|
|
|
|
|
|
|
|
X, y = fetch_openml(name=dataset, return_X_y=True)
|
|
|
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33,
|
2021-04-08 09:29:55 -07:00
|
|
|
random_state=42)
|
2020-12-04 09:40:27 -08:00
|
|
|
automl.fit(X_train=X_train, y_train=y_train, **automl_settings)
|
|
|
|
|
|
|
|
pred = automl.predict(X_test)
|
|
|
|
acc = accuracy_score(y_test, pred)
|
|
|
|
|
|
|
|
print(acc)
|
|
|
|
|
2021-04-08 09:29:55 -07:00
|
|
|
|
2020-12-14 23:10:03 -08:00
|
|
|
def _test_uniform():
|
2020-12-04 09:40:27 -08:00
|
|
|
_test(split_type="uniform")
|
|
|
|
|
|
|
|
|
2021-06-15 18:52:57 -07:00
|
|
|
def test_groups():
|
|
|
|
from sklearn.externals._arff import ArffException
|
|
|
|
try:
|
|
|
|
X, y = fetch_openml(name=dataset, return_X_y=True)
|
|
|
|
except (ArffException, ValueError):
|
|
|
|
from sklearn.datasets import load_wine
|
|
|
|
X, y = load_wine(return_X_y=True)
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
automl = AutoML()
|
|
|
|
automl_settings = {
|
|
|
|
"time_budget": 2,
|
|
|
|
"task": 'classification',
|
|
|
|
"log_file_name": "test/{}.log".format(dataset),
|
|
|
|
"model_history": True,
|
|
|
|
"eval_method": "cv",
|
|
|
|
"groups": np.random.randint(low=0, high=10, size=len(y)),
|
|
|
|
}
|
|
|
|
automl.fit(X, y, **automl_settings)
|
|
|
|
|
|
|
|
|
2020-12-04 09:40:27 -08:00
|
|
|
if __name__ == "__main__":
|
|
|
|
unittest.main()
|