mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-02 18:10:13 +00:00
parent
a2a37cb60f
commit
61d1263dfd
@ -292,7 +292,7 @@ class AutoML:
|
||||
classification), storing the best trained model.
|
||||
'''
|
||||
if self._trained_estimator:
|
||||
return self._trained_estimator.model
|
||||
return self._trained_estimator
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -309,7 +309,7 @@ class AutoML:
|
||||
if estimator_name in self._search_states:
|
||||
state = self._search_states[estimator_name]
|
||||
if hasattr(state, 'trained_estimator'):
|
||||
return state.trained_estimator.model
|
||||
return state.trained_estimator
|
||||
return None
|
||||
|
||||
@property
|
||||
@ -346,7 +346,7 @@ class AutoML:
|
||||
if self._label_transformer:
|
||||
return self._label_transformer.classes_.tolist()
|
||||
if self._trained_estimator:
|
||||
return self._trained_estimator.model.classes_.tolist()
|
||||
return self._trained_estimator.classes_.tolist()
|
||||
return None
|
||||
|
||||
def predict(self, X_test):
|
||||
@ -1094,7 +1094,7 @@ class AutoML:
|
||||
self._state.time_from_start)
|
||||
if self._save_model_history:
|
||||
self._model_history[
|
||||
self._track_iter] = search_state.trained_estimator.model
|
||||
self._track_iter] = search_state.trained_estimator
|
||||
elif self._trained_estimator:
|
||||
del self._trained_estimator
|
||||
self._trained_estimator = None
|
||||
@ -1217,6 +1217,8 @@ class AutoML:
|
||||
else:
|
||||
self._selected = self._trained_estimator = None
|
||||
self.modelcount = 0
|
||||
if self.model and mlflow is not None and mlflow.active_run():
|
||||
mlflow.sklearn.log_model(self.model, 'best_model')
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, '_trained_estimator') and self._trained_estimator \
|
||||
|
||||
@ -1 +1 @@
|
||||
__version__ = "0.4.1"
|
||||
__version__ = "0.4.2"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user