mirror of
https://github.com/microsoft/autogen.git
synced 2025-09-03 21:37:17 +00:00
install editable package in codespace (#826)
* install editable package in codespace * fix test error in test_forecast * fix test error in test_space * openml version * break tests; pre-commit * skip on py10+win32 * install mlflow in test * install mlflow in [test] * skip test in windows * import * handle PermissionError * skip test in windows * skip test in windows * skip test in windows * skip test in windows * remove ts_forecast_panel from doc
This commit is contained in:
parent
586afe0d6b
commit
595af7a04f
@ -17,10 +17,7 @@ RUN apt-get update \
|
|||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
ENV DEBIAN_FRONTEND=dialog
|
ENV DEBIAN_FRONTEND=dialog
|
||||||
|
|
||||||
#
|
RUN pip3 --disable-pip-version-check --no-cache-dir install flaml
|
||||||
# Install extras for development
|
|
||||||
#
|
|
||||||
RUN pip3 --disable-pip-version-check --no-cache-dir install flaml[test,notebook]
|
|
||||||
# For docs
|
# For docs
|
||||||
RUN npm install --global yarn
|
RUN npm install --global yarn
|
||||||
RUN pip install pydoc-markdown==4.5.0
|
RUN pip install pydoc-markdown==4.5.0
|
||||||
|
@ -8,5 +8,6 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"terminal.integrated.defaultProfile.linux": "bash"
|
"terminal.integrated.defaultProfile.linux": "bash"
|
||||||
}
|
},
|
||||||
|
"updateContentCommand": "pip install -e .[test,notebook] && pre-commit install"
|
||||||
}
|
}
|
@ -2213,7 +2213,7 @@ class AutoML(BaseEstimator):
|
|||||||
```
|
```
|
||||||
task: A string of the task type, e.g.,
|
task: A string of the task type, e.g.,
|
||||||
'classification', 'regression', 'ts_forecast_regression',
|
'classification', 'regression', 'ts_forecast_regression',
|
||||||
'ts_forecast_classification', 'ts_forecast_panel', 'rank', 'seq-classification',
|
'ts_forecast_classification', 'rank', 'seq-classification',
|
||||||
'seq-regression', 'summarization'.
|
'seq-regression', 'summarization'.
|
||||||
n_jobs: An integer of the number of threads for training | default=-1.
|
n_jobs: An integer of the number of threads for training | default=-1.
|
||||||
Use all available resources when n_jobs == -1.
|
Use all available resources when n_jobs == -1.
|
||||||
|
@ -2266,18 +2266,13 @@ class TemporalFusionTransformerEstimator(SKLearnEstimator):
|
|||||||
return training, train_dataloader, val_dataloader
|
return training, train_dataloader, val_dataloader
|
||||||
|
|
||||||
def fit(self, X_train, y_train, budget=None, **kwargs):
|
def fit(self, X_train, y_train, budget=None, **kwargs):
|
||||||
import copy
|
|
||||||
from pathlib import Path
|
|
||||||
import warnings
|
import warnings
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
|
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
|
||||||
from pytorch_lightning.loggers import TensorBoardLogger
|
from pytorch_lightning.loggers import TensorBoardLogger
|
||||||
import torch
|
import torch
|
||||||
from pytorch_forecasting import TemporalFusionTransformer
|
from pytorch_forecasting import TemporalFusionTransformer
|
||||||
from pytorch_forecasting.metrics import QuantileLoss
|
from pytorch_forecasting.metrics import QuantileLoss
|
||||||
import tensorboard as tb
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
@ -18,7 +18,6 @@
|
|||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .trial import Trial
|
from .trial import Trial
|
||||||
from collections import defaultdict
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -428,7 +428,12 @@ class Categorical(Domain):
|
|||||||
):
|
):
|
||||||
if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
|
if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
|
||||||
random_state = _BackwardsCompatibleNumpyRng(random_state)
|
random_state = _BackwardsCompatibleNumpyRng(random_state)
|
||||||
items = random_state.choice(domain.categories, size=size).tolist()
|
# do not use .choice() directly on domain.categories
|
||||||
|
# as that will coerce them to a single dtype
|
||||||
|
indices = random_state.choice(
|
||||||
|
np.arange(0, len(domain.categories)), size=size
|
||||||
|
)
|
||||||
|
items = [domain.categories[index] for index in indices]
|
||||||
return items if len(items) > 1 else domain.cast(items[0])
|
return items if len(items) > 1 else domain.cast(items[0])
|
||||||
|
|
||||||
default_sampler_cls = _Uniform
|
default_sampler_cls = _Uniform
|
||||||
@ -479,8 +484,18 @@ class Quantized(Sampler):
|
|||||||
):
|
):
|
||||||
if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
|
if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
|
||||||
random_state = _BackwardsCompatibleNumpyRng(random_state)
|
random_state = _BackwardsCompatibleNumpyRng(random_state)
|
||||||
values = self.sampler.sample(domain, spec, size, random_state=random_state)
|
|
||||||
|
if self.q == 1:
|
||||||
|
return self.sampler.sample(domain, spec, size, random_state=random_state)
|
||||||
|
|
||||||
|
quantized_domain = copy(domain)
|
||||||
|
quantized_domain.lower = np.ceil(domain.lower / self.q) * self.q
|
||||||
|
quantized_domain.upper = np.floor(domain.upper / self.q) * self.q
|
||||||
|
values = self.sampler.sample(
|
||||||
|
quantized_domain, spec, size, random_state=random_state
|
||||||
|
)
|
||||||
quantized = np.round(np.divide(values, self.q)) * self.q
|
quantized = np.round(np.divide(values, self.q)) * self.q
|
||||||
|
|
||||||
if not isinstance(quantized, np.ndarray):
|
if not isinstance(quantized, np.ndarray):
|
||||||
return domain.cast(quantized)
|
return domain.cast(quantized)
|
||||||
return list(quantized)
|
return list(quantized)
|
||||||
@ -586,7 +601,9 @@ def lograndint(lower: int, upper: int, base: float = 10):
|
|||||||
|
|
||||||
def qrandint(lower: int, upper: int, q: int = 1):
|
def qrandint(lower: int, upper: int, q: int = 1):
|
||||||
"""Sample an integer value uniformly between ``lower`` and ``upper``.
|
"""Sample an integer value uniformly between ``lower`` and ``upper``.
|
||||||
|
|
||||||
``lower`` is inclusive, ``upper`` is also inclusive (!).
|
``lower`` is inclusive, ``upper`` is also inclusive (!).
|
||||||
|
|
||||||
The value will be quantized, i.e. rounded to an integer increment of ``q``.
|
The value will be quantized, i.e. rounded to an integer increment of ``q``.
|
||||||
Quantization makes the upper bound inclusive.
|
Quantization makes the upper bound inclusive.
|
||||||
"""
|
"""
|
||||||
@ -614,12 +631,15 @@ def randn(mean: float = 0.0, sd: float = 1.0):
|
|||||||
|
|
||||||
def qrandn(mean: float, sd: float, q: float):
|
def qrandn(mean: float, sd: float, q: float):
|
||||||
"""Sample a float value normally with ``mean`` and ``sd``.
|
"""Sample a float value normally with ``mean`` and ``sd``.
|
||||||
|
|
||||||
The value will be quantized, i.e. rounded to an integer increment of ``q``.
|
The value will be quantized, i.e. rounded to an integer increment of ``q``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mean (float): Mean of the normal distribution.
|
mean: Mean of the normal distribution.
|
||||||
sd (float): SD of the normal distribution.
|
sd: SD of the normal distribution.
|
||||||
q (float): Quantization number. The result will be rounded to an
|
q: Quantization number. The result will be rounded to an
|
||||||
integer increment of this value.
|
integer increment of this value.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return Float(None, None).normal(mean, sd).quantized(q)
|
return Float(None, None).normal(mean, sd).quantized(q)
|
||||||
|
|
||||||
|
@ -38,10 +38,10 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"%pip install flaml[notebook]\n",
|
"%pip install flaml[notebook] openml==0.10.2\n",
|
||||||
"# from v0.6.6, catboost is made an optional dependency to build conda package.\n",
|
"# From v0.6.6, catboost is made an optional dependency to build conda package.\n",
|
||||||
"# to install catboost without installing the notebook option, you can run:\n",
|
"# To install catboost, you can run:\n",
|
||||||
"# %pip install flaml[catboost]"
|
"%pip install flaml[catboost]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -836,6 +836,15 @@
|
|||||||
"In this example, the above information for RGF is wrapped in a python class called *MyRegularizedGreedyForest* that exposes the hyperparameters."
|
"In this example, the above information for RGF is wrapped in a python class called *MyRegularizedGreedyForest* that exposes the hyperparameters."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%pip install rgf-python"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 19,
|
"execution_count": 19,
|
||||||
@ -1259,11 +1268,8 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"interpreter": {
|
|
||||||
"hash": "5432eb6463ddd46aaa76ccf859b1fa421ab98224a755661a6688060ed6e23d59"
|
|
||||||
},
|
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "ds440flaml",
|
"display_name": "Python 3.9.15 64-bit",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
@ -1277,7 +1283,12 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.8.12"
|
"version": "3.9.15"
|
||||||
|
},
|
||||||
|
"vscode": {
|
||||||
|
"interpreter": {
|
||||||
|
"hash": "949777d72b0d2535278d3dc13498b2535136f6dfe0678499012e853ee9abcab1"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
8
setup.py
8
setup.py
@ -40,11 +40,8 @@ setuptools.setup(
|
|||||||
install_requires=install_requires,
|
install_requires=install_requires,
|
||||||
extras_require={
|
extras_require={
|
||||||
"notebook": [
|
"notebook": [
|
||||||
"openml==0.10.2",
|
|
||||||
"jupyter",
|
"jupyter",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"rgf-python",
|
|
||||||
"catboost>=0.26",
|
|
||||||
],
|
],
|
||||||
"test": [
|
"test": [
|
||||||
"flake8>=3.8.4",
|
"flake8>=3.8.4",
|
||||||
@ -57,7 +54,7 @@ setuptools.setup(
|
|||||||
"catboost>=0.26",
|
"catboost>=0.26",
|
||||||
"rgf-python",
|
"rgf-python",
|
||||||
"optuna==2.8.0",
|
"optuna==2.8.0",
|
||||||
"openml",
|
"openml==0.10.2",
|
||||||
"statsmodels>=0.12.2",
|
"statsmodels>=0.12.2",
|
||||||
"psutil==5.8.0",
|
"psutil==5.8.0",
|
||||||
"dataclasses",
|
"dataclasses",
|
||||||
@ -67,7 +64,8 @@ setuptools.setup(
|
|||||||
"rouge_score",
|
"rouge_score",
|
||||||
"hcrystalball==0.1.10",
|
"hcrystalball==0.1.10",
|
||||||
"seqeval",
|
"seqeval",
|
||||||
"pytorch-forecasting>=0.9.0",
|
"pytorch-forecasting>=0.9.0,<=0.10.1",
|
||||||
|
"mlflow",
|
||||||
],
|
],
|
||||||
"catboost": ["catboost>=0.26"],
|
"catboost": ["catboost>=0.26"],
|
||||||
"blendsearch": ["optuna==2.8.0"],
|
"blendsearch": ["optuna==2.8.0"],
|
||||||
|
@ -108,10 +108,7 @@ def _test_nobudget():
|
|||||||
|
|
||||||
|
|
||||||
def test_mlflow():
|
def test_mlflow():
|
||||||
import subprocess
|
# subprocess.check_call([sys.executable, "-m", "pip", "install", "mlflow"])
|
||||||
import sys
|
|
||||||
|
|
||||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "mlflow"])
|
|
||||||
import mlflow
|
import mlflow
|
||||||
from flaml.data import load_openml_task
|
from flaml.data import load_openml_task
|
||||||
|
|
||||||
@ -152,9 +149,12 @@ def test_mlflow():
|
|||||||
print(automl.predict_proba(X_test))
|
print(automl.predict_proba(X_test))
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
# subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "mlflow"])
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_mlflow_iris():
|
||||||
from sklearn.datasets import load_iris
|
from sklearn.datasets import load_iris
|
||||||
|
import mlflow
|
||||||
|
from flaml import AutoML
|
||||||
|
|
||||||
with mlflow.start_run():
|
with mlflow.start_run():
|
||||||
automl = AutoML()
|
automl = AutoML()
|
||||||
@ -167,6 +167,8 @@ def test_mlflow():
|
|||||||
X_train, y_train = load_iris(return_X_y=True)
|
X_train, y_train = load_iris(return_X_y=True)
|
||||||
automl.fit(X_train=X_train, y_train=y_train, **automl_settings)
|
automl.fit(X_train=X_train, y_train=y_train, **automl_settings)
|
||||||
|
|
||||||
|
# subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "mlflow"])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_automl(600)
|
test_automl(600)
|
||||||
|
@ -74,7 +74,10 @@ def test_hf_data():
|
|||||||
del automl
|
del automl
|
||||||
|
|
||||||
if os.path.exists("test/data/output/"):
|
if os.path.exists("test/data/output/"):
|
||||||
shutil.rmtree("test/data/output/")
|
try:
|
||||||
|
shutil.rmtree("test/data/output/")
|
||||||
|
except PermissionError:
|
||||||
|
print("PermissionError when deleting test/data/output/")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -113,7 +113,10 @@ def _test_switch_classificationhead(each_data, each_model_path):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if os.path.exists("test/data/output/"):
|
if os.path.exists("test/data/output/"):
|
||||||
shutil.rmtree("test/data/output/")
|
try:
|
||||||
|
shutil.rmtree("test/data/output/")
|
||||||
|
except PermissionError:
|
||||||
|
print("PermissionError when deleting test/data/output/")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -84,7 +84,10 @@ def test_custom_metric():
|
|||||||
del automl
|
del automl
|
||||||
|
|
||||||
if os.path.exists("test/data/output/"):
|
if os.path.exists("test/data/output/"):
|
||||||
shutil.rmtree("test/data/output/")
|
try:
|
||||||
|
shutil.rmtree("test/data/output/")
|
||||||
|
except PermissionError:
|
||||||
|
print("PermissionError when deleting test/data/output/")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -5,7 +5,9 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(sys.platform == "darwin", reason="do not run on mac os")
|
@pytest.mark.skipif(
|
||||||
|
sys.platform in ["darwin", "win32"], reason="do not run on mac os or windows"
|
||||||
|
)
|
||||||
def test_cv():
|
def test_cv():
|
||||||
from flaml import AutoML
|
from flaml import AutoML
|
||||||
import requests
|
import requests
|
||||||
@ -22,7 +24,10 @@ def test_cv():
|
|||||||
return
|
return
|
||||||
|
|
||||||
if os.path.exists("test/data/output/"):
|
if os.path.exists("test/data/output/"):
|
||||||
shutil.rmtree("test/data/output/")
|
try:
|
||||||
|
shutil.rmtree("test/data/output/")
|
||||||
|
except PermissionError:
|
||||||
|
print("PermissionError when deleting test/data/output/")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -5,7 +5,9 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(sys.platform == "darwin", reason="do not run on mac os")
|
@pytest.mark.skipif(
|
||||||
|
sys.platform in ["darwin", "win32"], reason="do not run on mac os or windows"
|
||||||
|
)
|
||||||
def test_mcc():
|
def test_mcc():
|
||||||
from flaml import AutoML
|
from flaml import AutoML
|
||||||
import requests
|
import requests
|
||||||
@ -49,7 +51,10 @@ def test_mcc():
|
|||||||
print("Accuracy: " + str(accuracy))
|
print("Accuracy: " + str(accuracy))
|
||||||
|
|
||||||
if os.path.exists("test/data/output/"):
|
if os.path.exists("test/data/output/"):
|
||||||
shutil.rmtree("test/data/output/")
|
try:
|
||||||
|
shutil.rmtree("test/data/output/")
|
||||||
|
except PermissionError:
|
||||||
|
print("PermissionError when deleting test/data/output/")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -35,7 +35,10 @@ def test_regression():
|
|||||||
automl.predict(X_val)
|
automl.predict(X_val)
|
||||||
|
|
||||||
if os.path.exists("test/data/output/"):
|
if os.path.exists("test/data/output/"):
|
||||||
shutil.rmtree("test/data/output/")
|
try:
|
||||||
|
shutil.rmtree("test/data/output/")
|
||||||
|
except PermissionError:
|
||||||
|
print("PermissionError when deleting test/data/output/")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -7,8 +7,8 @@ import shutil
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
sys.platform == "darwin" or sys.version < "3.7",
|
sys.platform in ["darwin", "win32"] or sys.version < "3.7",
|
||||||
reason="do not run on mac os or py3.6",
|
reason="do not run on mac os, windows or py3.6",
|
||||||
)
|
)
|
||||||
def test_summarization():
|
def test_summarization():
|
||||||
# TODO: manual test for how effective postprocess_seq2seq_prediction_label is
|
# TODO: manual test for how effective postprocess_seq2seq_prediction_label is
|
||||||
@ -51,7 +51,10 @@ def test_summarization():
|
|||||||
automl.predict(X_test)
|
automl.predict(X_test)
|
||||||
|
|
||||||
if os.path.exists("test/data/output/"):
|
if os.path.exists("test/data/output/"):
|
||||||
shutil.rmtree("test/data/output/")
|
try:
|
||||||
|
shutil.rmtree("test/data/output/")
|
||||||
|
except PermissionError:
|
||||||
|
print("PermissionError when deleting test/data/output/")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -11,8 +11,8 @@ from utils import (
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
sys.platform == "darwin" or sys.version < "3.7",
|
sys.platform in ["darwin", "win32"] or sys.version < "3.7",
|
||||||
reason="do not run on mac os or py<3.7",
|
reason="do not run on mac os, windows or py<3.7",
|
||||||
)
|
)
|
||||||
def test_tokenclassification_idlabel():
|
def test_tokenclassification_idlabel():
|
||||||
from flaml import AutoML
|
from flaml import AutoML
|
||||||
@ -65,12 +65,15 @@ def test_tokenclassification_idlabel():
|
|||||||
assert val_loss == min_inter_result
|
assert val_loss == min_inter_result
|
||||||
|
|
||||||
if os.path.exists("test/data/output/"):
|
if os.path.exists("test/data/output/"):
|
||||||
shutil.rmtree("test/data/output/")
|
try:
|
||||||
|
shutil.rmtree("test/data/output/")
|
||||||
|
except PermissionError:
|
||||||
|
print("PermissionError when deleting test/data/output/")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
sys.platform == "darwin" or sys.version < "3.7",
|
sys.platform in ["darwin", "win32"] or sys.version < "3.7",
|
||||||
reason="do not run on mac os or py<3.7",
|
reason="do not run on mac os, windows or py<3.7",
|
||||||
)
|
)
|
||||||
def test_tokenclassification_tokenlabel():
|
def test_tokenclassification_tokenlabel():
|
||||||
from flaml import AutoML
|
from flaml import AutoML
|
||||||
@ -112,7 +115,10 @@ def test_tokenclassification_tokenlabel():
|
|||||||
assert val_loss == min_inter_result
|
assert val_loss == min_inter_result
|
||||||
|
|
||||||
if os.path.exists("test/data/output/"):
|
if os.path.exists("test/data/output/"):
|
||||||
shutil.rmtree("test/data/output/")
|
try:
|
||||||
|
shutil.rmtree("test/data/output/")
|
||||||
|
except PermissionError:
|
||||||
|
print("PermissionError when deleting test/data/output/")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -3,6 +3,7 @@ import sys
|
|||||||
from flaml.default import portfolio
|
from flaml.default import portfolio
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
def pop_args(fit_kwargs):
|
def pop_args(fit_kwargs):
|
||||||
@ -18,6 +19,7 @@ def test_build_portfolio(path="./test/nlp/default", strategy="greedy"):
|
|||||||
portfolio.main()
|
portfolio.main()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.platform == "win32", reason="do not run on windows")
|
||||||
def test_starting_point_not_in_search_space():
|
def test_starting_point_not_in_search_space():
|
||||||
from flaml import AutoML
|
from flaml import AutoML
|
||||||
|
|
||||||
@ -84,9 +86,13 @@ def test_starting_point_not_in_search_space():
|
|||||||
)
|
)
|
||||||
|
|
||||||
if os.path.exists("test/data/output/"):
|
if os.path.exists("test/data/output/"):
|
||||||
shutil.rmtree("test/data/output/")
|
try:
|
||||||
|
shutil.rmtree("test/data/output/")
|
||||||
|
except PermissionError:
|
||||||
|
print("PermissionError when deleting test/data/output/")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.platform == "win32", reason="do not run on windows")
|
||||||
def test_points_to_evaluate():
|
def test_points_to_evaluate():
|
||||||
from flaml import AutoML
|
from flaml import AutoML
|
||||||
|
|
||||||
@ -106,10 +112,14 @@ def test_points_to_evaluate():
|
|||||||
automl.fit(X_train, y_train, **automl_settings)
|
automl.fit(X_train, y_train, **automl_settings)
|
||||||
|
|
||||||
if os.path.exists("test/data/output/"):
|
if os.path.exists("test/data/output/"):
|
||||||
shutil.rmtree("test/data/output/")
|
try:
|
||||||
|
shutil.rmtree("test/data/output/")
|
||||||
|
except PermissionError:
|
||||||
|
print("PermissionError when deleting test/data/output/")
|
||||||
|
|
||||||
|
|
||||||
# TODO: implement _test_zero_shot_model
|
# TODO: implement _test_zero_shot_model
|
||||||
|
@pytest.mark.skipif(sys.platform == "win32", reason="do not run on windows")
|
||||||
def test_zero_shot_nomodel():
|
def test_zero_shot_nomodel():
|
||||||
from flaml.default import preprocess_and_suggest_hyperparams
|
from flaml.default import preprocess_and_suggest_hyperparams
|
||||||
|
|
||||||
@ -141,7 +151,10 @@ def test_zero_shot_nomodel():
|
|||||||
model.fit(X_train, y_train, **fit_kwargs)
|
model.fit(X_train, y_train, **fit_kwargs)
|
||||||
|
|
||||||
if os.path.exists("test/data/output/"):
|
if os.path.exists("test/data/output/"):
|
||||||
shutil.rmtree("test/data/output/")
|
try:
|
||||||
|
shutil.rmtree("test/data/output/")
|
||||||
|
except PermissionError:
|
||||||
|
print("PermissionError when deleting test/data/output/")
|
||||||
|
|
||||||
|
|
||||||
def test_build_error_portfolio(path="./test/nlp/default", strategy="greedy"):
|
def test_build_error_portfolio(path="./test/nlp/default", strategy="greedy"):
|
||||||
@ -176,4 +189,7 @@ def test_build_error_portfolio(path="./test/nlp/default", strategy="greedy"):
|
|||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
if os.path.exists("test/data/output/"):
|
if os.path.exists("test/data/output/"):
|
||||||
shutil.rmtree("test/data/output/")
|
try:
|
||||||
|
shutil.rmtree("test/data/output/")
|
||||||
|
except PermissionError:
|
||||||
|
print("PermissionError when deleting test/data/output/")
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
- 'regression': regression with tabular data.
|
- 'regression': regression with tabular data.
|
||||||
- 'ts_forecast': time series forecasting.
|
- 'ts_forecast': time series forecasting.
|
||||||
- 'ts_forecast_classification': time series forecasting for classification.
|
- 'ts_forecast_classification': time series forecasting for classification.
|
||||||
- 'ts_forecast_panel': time series forecasting for panel datasets (multiple time series).
|
<!-- - 'ts_forecast_panel': time series forecasting for panel datasets (multiple time series). -->
|
||||||
- 'rank': learning to rank.
|
- 'rank': learning to rank.
|
||||||
- 'seq-classification': sequence classification.
|
- 'seq-classification': sequence classification.
|
||||||
- 'seq-regression': sequence regression.
|
- 'seq-regression': sequence regression.
|
||||||
@ -120,7 +120,7 @@ The estimator list can contain one or more estimator names, each corresponding t
|
|||||||
- 'arima': ARIMA for task "ts_forecast". Hyperparameters: p, d, q.
|
- 'arima': ARIMA for task "ts_forecast". Hyperparameters: p, d, q.
|
||||||
- 'sarimax': SARIMAX for task "ts_forecast". Hyperparameters: p, d, q, P, D, Q, s.
|
- 'sarimax': SARIMAX for task "ts_forecast". Hyperparameters: p, d, q, P, D, Q, s.
|
||||||
- 'transformer': Huggingface transformer models for task "seq-classification", "seq-regression", "multichoice-classification", "token-classification" and "summarization". Hyperparameters: learning_rate, num_train_epochs, per_device_train_batch_size, warmup_ratio, weight_decay, adam_epsilon, seed.
|
- 'transformer': Huggingface transformer models for task "seq-classification", "seq-regression", "multichoice-classification", "token-classification" and "summarization". Hyperparameters: learning_rate, num_train_epochs, per_device_train_batch_size, warmup_ratio, weight_decay, adam_epsilon, seed.
|
||||||
- 'temporal_fusion_transform': TemporalFusionTransformerEstimator for task "ts_forecast_panel". Hyperparameters: gradient_clip_val, hidden_size, hidden_continuous_size, attention_head_size, dropout, learning_rate.
|
<!-- - 'temporal_fusion_transform': TemporalFusionTransformerEstimator for task "ts_forecast_panel". Hyperparameters: gradient_clip_val, hidden_size, hidden_continuous_size, attention_head_size, dropout, learning_rate. -->
|
||||||
* Custom estimator. Use custom estimator for:
|
* Custom estimator. Use custom estimator for:
|
||||||
- tuning an estimator that is not built-in;
|
- tuning an estimator that is not built-in;
|
||||||
- customizing search space for a built-in estimator.
|
- customizing search space for a built-in estimator.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user