log best model (#96)

* log best model
This commit is contained in:
Chi Wang 2021-06-02 13:11:41 -07:00 committed by GitHub
parent a2a37cb60f
commit 61d1263dfd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 5 deletions

View File

@ -292,7 +292,7 @@ class AutoML:
classification), storing the best trained model.
'''
if self._trained_estimator:
return self._trained_estimator.model
return self._trained_estimator
else:
return None
@ -309,7 +309,7 @@ class AutoML:
if estimator_name in self._search_states:
state = self._search_states[estimator_name]
if hasattr(state, 'trained_estimator'):
return state.trained_estimator.model
return state.trained_estimator
return None
@property
@ -346,7 +346,7 @@ class AutoML:
if self._label_transformer:
return self._label_transformer.classes_.tolist()
if self._trained_estimator:
return self._trained_estimator.model.classes_.tolist()
return self._trained_estimator.classes_.tolist()
return None
def predict(self, X_test):
@ -1094,7 +1094,7 @@ class AutoML:
self._state.time_from_start)
if self._save_model_history:
self._model_history[
self._track_iter] = search_state.trained_estimator.model
self._track_iter] = search_state.trained_estimator
elif self._trained_estimator:
del self._trained_estimator
self._trained_estimator = None
@ -1217,6 +1217,8 @@ class AutoML:
else:
self._selected = self._trained_estimator = None
self.modelcount = 0
if self.model and mlflow is not None and mlflow.active_run():
mlflow.sklearn.log_model(self.model, 'best_model')
def __del__(self):
if hasattr(self, '_trained_estimator') and self._trained_estimator \

View File

@ -1 +1 @@
__version__ = "0.4.1"
__version__ = "0.4.2"