Support spark dataframe as input dataset and spark models as estimators (#934)
* add basic support to Spark dataframe
add support to SynapseML LightGBM model
update to pyspark>=3.2.0 to leverage pandas_on_Spark API
* clean code, add TODOs
* add sample_train_data for pyspark.pandas dataframe, fix bugs
* improve some functions, fix bugs
* fix dict change size during iteration
* update model predict
* update LightGBM model, update test
* update SynapseML LightGBM params
* update synapseML and tests
* update TODOs
* Added support to roc_auc for spark models
* Added support to score of spark estimator
* Added test for automl score of spark estimator
* Added cv support to pyspark.pandas dataframe
* Update test, fix bugs
* Added tests
* Updated docs, tests, added a notebook
* Fix bugs in non-spark env
* Fix bugs and improve tests
* Fix uninstall pyspark
* Fix tests error
* Fix java.lang.OutOfMemoryError: Java heap space
* Fix test_performance
* Update test_sparkml to test_0sparkml to use the expected spark conf
* Remove unnecessary widgets in notebook
* Fix iloc java.lang.StackOverflowError
* fix pre-commit
* Added params check for spark dataframes
* Refactor code for train_test_split to a function
* Update train_test_split_pyspark
* Refactor if-else, remove unnecessary code
* Remove y from predict, remove mem control from n_iter compute
* Update workflow
* Improve _split_pyspark
* Fix test failure of too short training time
* Fix typos, improve docstrings
* Fix index errors of pandas_on_spark, add spark loss metric
* Fix typo of ndcgAtK
* Update NDCG metrics and tests
* Remove unuseful logger
* Use cache and count to ensure consistent indexes
* refactor for merge maain
* fix errors of refactor
* Updated SparkLightGBMEstimator and cache
* Updated config2params
* Remove unused import
* Fix unknown parameters
* Update default_estimator_list
* Add unit tests for spark metrics
2023-03-26 03:59:46 +08:00
|
|
|
import numpy as np
|
|
|
|
from typing import Union
|
2023-05-24 16:55:04 -07:00
|
|
|
from flaml.automl.spark import psSeries, F
|
|
|
|
from pyspark.ml.evaluation import (
|
|
|
|
BinaryClassificationEvaluator,
|
|
|
|
RegressionEvaluator,
|
|
|
|
MulticlassClassificationEvaluator,
|
|
|
|
MultilabelClassificationEvaluator,
|
|
|
|
RankingEvaluator,
|
|
|
|
)
|
Support spark dataframe as input dataset and spark models as estimators (#934)
* add basic support to Spark dataframe
add support to SynapseML LightGBM model
update to pyspark>=3.2.0 to leverage pandas_on_Spark API
* clean code, add TODOs
* add sample_train_data for pyspark.pandas dataframe, fix bugs
* improve some functions, fix bugs
* fix dict change size during iteration
* update model predict
* update LightGBM model, update test
* update SynapseML LightGBM params
* update synapseML and tests
* update TODOs
* Added support to roc_auc for spark models
* Added support to score of spark estimator
* Added test for automl score of spark estimator
* Added cv support to pyspark.pandas dataframe
* Update test, fix bugs
* Added tests
* Updated docs, tests, added a notebook
* Fix bugs in non-spark env
* Fix bugs and improve tests
* Fix uninstall pyspark
* Fix tests error
* Fix java.lang.OutOfMemoryError: Java heap space
* Fix test_performance
* Update test_sparkml to test_0sparkml to use the expected spark conf
* Remove unnecessary widgets in notebook
* Fix iloc java.lang.StackOverflowError
* fix pre-commit
* Added params check for spark dataframes
* Refactor code for train_test_split to a function
* Update train_test_split_pyspark
* Refactor if-else, remove unnecessary code
* Remove y from predict, remove mem control from n_iter compute
* Update workflow
* Improve _split_pyspark
* Fix test failure of too short training time
* Fix typos, improve docstrings
* Fix index errors of pandas_on_spark, add spark loss metric
* Fix typo of ndcgAtK
* Update NDCG metrics and tests
* Remove unuseful logger
* Use cache and count to ensure consistent indexes
* refactor for merge maain
* fix errors of refactor
* Updated SparkLightGBMEstimator and cache
* Updated config2params
* Remove unused import
* Fix unknown parameters
* Update default_estimator_list
* Add unit tests for spark metrics
2023-03-26 03:59:46 +08:00
|
|
|
|
|
|
|
|
2023-05-24 16:55:04 -07:00
|
|
|
def ps_group_counts(groups: Union[psSeries, np.ndarray]) -> np.ndarray:
|
Support spark dataframe as input dataset and spark models as estimators (#934)
* add basic support to Spark dataframe
add support to SynapseML LightGBM model
update to pyspark>=3.2.0 to leverage pandas_on_Spark API
* clean code, add TODOs
* add sample_train_data for pyspark.pandas dataframe, fix bugs
* improve some functions, fix bugs
* fix dict change size during iteration
* update model predict
* update LightGBM model, update test
* update SynapseML LightGBM params
* update synapseML and tests
* update TODOs
* Added support to roc_auc for spark models
* Added support to score of spark estimator
* Added test for automl score of spark estimator
* Added cv support to pyspark.pandas dataframe
* Update test, fix bugs
* Added tests
* Updated docs, tests, added a notebook
* Fix bugs in non-spark env
* Fix bugs and improve tests
* Fix uninstall pyspark
* Fix tests error
* Fix java.lang.OutOfMemoryError: Java heap space
* Fix test_performance
* Update test_sparkml to test_0sparkml to use the expected spark conf
* Remove unnecessary widgets in notebook
* Fix iloc java.lang.StackOverflowError
* fix pre-commit
* Added params check for spark dataframes
* Refactor code for train_test_split to a function
* Update train_test_split_pyspark
* Refactor if-else, remove unnecessary code
* Remove y from predict, remove mem control from n_iter compute
* Update workflow
* Improve _split_pyspark
* Fix test failure of too short training time
* Fix typos, improve docstrings
* Fix index errors of pandas_on_spark, add spark loss metric
* Fix typo of ndcgAtK
* Update NDCG metrics and tests
* Remove unuseful logger
* Use cache and count to ensure consistent indexes
* refactor for merge maain
* fix errors of refactor
* Updated SparkLightGBMEstimator and cache
* Updated config2params
* Remove unused import
* Fix unknown parameters
* Update default_estimator_list
* Add unit tests for spark metrics
2023-03-26 03:59:46 +08:00
|
|
|
if isinstance(groups, np.ndarray):
|
|
|
|
_, i, c = np.unique(groups, return_counts=True, return_index=True)
|
|
|
|
else:
|
|
|
|
i = groups.drop_duplicates().index.values
|
|
|
|
c = groups.value_counts().sort_index().to_numpy()
|
|
|
|
return c[np.argsort(i)].tolist()
|
|
|
|
|
|
|
|
|
|
|
|
def _process_df(df, label_col, prediction_col):
|
|
|
|
df = df.withColumn(label_col, F.array([df[label_col]]))
|
|
|
|
df = df.withColumn(prediction_col, F.array([df[prediction_col]]))
|
|
|
|
return df
|
|
|
|
|
|
|
|
|
|
|
|
def _compute_label_from_probability(df, probability_col, prediction_col):
|
|
|
|
# array_max finds the maximum value in the 'probability' array
|
|
|
|
# array_position finds the index of the maximum value in the 'probability' array
|
2023-04-10 21:50:40 +02:00
|
|
|
max_index_expr = F.expr(f"array_position({probability_col}, array_max({probability_col}))-1")
|
Support spark dataframe as input dataset and spark models as estimators (#934)
* add basic support to Spark dataframe
add support to SynapseML LightGBM model
update to pyspark>=3.2.0 to leverage pandas_on_Spark API
* clean code, add TODOs
* add sample_train_data for pyspark.pandas dataframe, fix bugs
* improve some functions, fix bugs
* fix dict change size during iteration
* update model predict
* update LightGBM model, update test
* update SynapseML LightGBM params
* update synapseML and tests
* update TODOs
* Added support to roc_auc for spark models
* Added support to score of spark estimator
* Added test for automl score of spark estimator
* Added cv support to pyspark.pandas dataframe
* Update test, fix bugs
* Added tests
* Updated docs, tests, added a notebook
* Fix bugs in non-spark env
* Fix bugs and improve tests
* Fix uninstall pyspark
* Fix tests error
* Fix java.lang.OutOfMemoryError: Java heap space
* Fix test_performance
* Update test_sparkml to test_0sparkml to use the expected spark conf
* Remove unnecessary widgets in notebook
* Fix iloc java.lang.StackOverflowError
* fix pre-commit
* Added params check for spark dataframes
* Refactor code for train_test_split to a function
* Update train_test_split_pyspark
* Refactor if-else, remove unnecessary code
* Remove y from predict, remove mem control from n_iter compute
* Update workflow
* Improve _split_pyspark
* Fix test failure of too short training time
* Fix typos, improve docstrings
* Fix index errors of pandas_on_spark, add spark loss metric
* Fix typo of ndcgAtK
* Update NDCG metrics and tests
* Remove unuseful logger
* Use cache and count to ensure consistent indexes
* refactor for merge maain
* fix errors of refactor
* Updated SparkLightGBMEstimator and cache
* Updated config2params
* Remove unused import
* Fix unknown parameters
* Update default_estimator_list
* Add unit tests for spark metrics
2023-03-26 03:59:46 +08:00
|
|
|
# Create a new column 'prediction' based on the maximum probability value
|
|
|
|
df = df.withColumn(prediction_col, max_index_expr.cast("double"))
|
|
|
|
return df
|
|
|
|
|
|
|
|
|
|
|
|
def spark_metric_loss_score(
|
|
|
|
metric_name: str,
|
2023-05-24 16:55:04 -07:00
|
|
|
y_predict: psSeries,
|
|
|
|
y_true: psSeries,
|
|
|
|
sample_weight: psSeries = None,
|
|
|
|
groups: psSeries = None,
|
Support spark dataframe as input dataset and spark models as estimators (#934)
* add basic support to Spark dataframe
add support to SynapseML LightGBM model
update to pyspark>=3.2.0 to leverage pandas_on_Spark API
* clean code, add TODOs
* add sample_train_data for pyspark.pandas dataframe, fix bugs
* improve some functions, fix bugs
* fix dict change size during iteration
* update model predict
* update LightGBM model, update test
* update SynapseML LightGBM params
* update synapseML and tests
* update TODOs
* Added support to roc_auc for spark models
* Added support to score of spark estimator
* Added test for automl score of spark estimator
* Added cv support to pyspark.pandas dataframe
* Update test, fix bugs
* Added tests
* Updated docs, tests, added a notebook
* Fix bugs in non-spark env
* Fix bugs and improve tests
* Fix uninstall pyspark
* Fix tests error
* Fix java.lang.OutOfMemoryError: Java heap space
* Fix test_performance
* Update test_sparkml to test_0sparkml to use the expected spark conf
* Remove unnecessary widgets in notebook
* Fix iloc java.lang.StackOverflowError
* fix pre-commit
* Added params check for spark dataframes
* Refactor code for train_test_split to a function
* Update train_test_split_pyspark
* Refactor if-else, remove unnecessary code
* Remove y from predict, remove mem control from n_iter compute
* Update workflow
* Improve _split_pyspark
* Fix test failure of too short training time
* Fix typos, improve docstrings
* Fix index errors of pandas_on_spark, add spark loss metric
* Fix typo of ndcgAtK
* Update NDCG metrics and tests
* Remove unuseful logger
* Use cache and count to ensure consistent indexes
* refactor for merge maain
* fix errors of refactor
* Updated SparkLightGBMEstimator and cache
* Updated config2params
* Remove unused import
* Fix unknown parameters
* Update default_estimator_list
* Add unit tests for spark metrics
2023-03-26 03:59:46 +08:00
|
|
|
) -> float:
|
|
|
|
"""
|
|
|
|
Compute the loss score of a metric for spark models.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
metric_name: str | the name of the metric.
|
2023-05-24 16:55:04 -07:00
|
|
|
y_predict: psSeries | the predicted values.
|
|
|
|
y_true: psSeries | the true values.
|
|
|
|
sample_weight: psSeries | the sample weights. Default: None.
|
|
|
|
groups: psSeries | the group of each row. Default: None.
|
Support spark dataframe as input dataset and spark models as estimators (#934)
* add basic support to Spark dataframe
add support to SynapseML LightGBM model
update to pyspark>=3.2.0 to leverage pandas_on_Spark API
* clean code, add TODOs
* add sample_train_data for pyspark.pandas dataframe, fix bugs
* improve some functions, fix bugs
* fix dict change size during iteration
* update model predict
* update LightGBM model, update test
* update SynapseML LightGBM params
* update synapseML and tests
* update TODOs
* Added support to roc_auc for spark models
* Added support to score of spark estimator
* Added test for automl score of spark estimator
* Added cv support to pyspark.pandas dataframe
* Update test, fix bugs
* Added tests
* Updated docs, tests, added a notebook
* Fix bugs in non-spark env
* Fix bugs and improve tests
* Fix uninstall pyspark
* Fix tests error
* Fix java.lang.OutOfMemoryError: Java heap space
* Fix test_performance
* Update test_sparkml to test_0sparkml to use the expected spark conf
* Remove unnecessary widgets in notebook
* Fix iloc java.lang.StackOverflowError
* fix pre-commit
* Added params check for spark dataframes
* Refactor code for train_test_split to a function
* Update train_test_split_pyspark
* Refactor if-else, remove unnecessary code
* Remove y from predict, remove mem control from n_iter compute
* Update workflow
* Improve _split_pyspark
* Fix test failure of too short training time
* Fix typos, improve docstrings
* Fix index errors of pandas_on_spark, add spark loss metric
* Fix typo of ndcgAtK
* Update NDCG metrics and tests
* Remove unuseful logger
* Use cache and count to ensure consistent indexes
* refactor for merge maain
* fix errors of refactor
* Updated SparkLightGBMEstimator and cache
* Updated config2params
* Remove unused import
* Fix unknown parameters
* Update default_estimator_list
* Add unit tests for spark metrics
2023-03-26 03:59:46 +08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
float | the loss score. A lower value indicates a better model.
|
|
|
|
"""
|
|
|
|
label_col = "label"
|
|
|
|
prediction_col = "prediction"
|
|
|
|
kwargs = {}
|
|
|
|
|
|
|
|
y_predict.name = prediction_col
|
|
|
|
y_true.name = label_col
|
|
|
|
df = y_predict.to_frame().join(y_true)
|
|
|
|
if sample_weight is not None:
|
|
|
|
sample_weight.name = "weight"
|
|
|
|
df = df.join(sample_weight)
|
|
|
|
kwargs = {"weightCol": "weight"}
|
|
|
|
|
|
|
|
df = df.to_spark()
|
|
|
|
|
|
|
|
metric_name = metric_name.lower()
|
|
|
|
min_mode_metrics = ["log_loss", "rmse", "mse", "mae"]
|
|
|
|
|
|
|
|
if metric_name == "rmse":
|
|
|
|
evaluator = RegressionEvaluator(
|
|
|
|
metricName="rmse",
|
|
|
|
labelCol=label_col,
|
|
|
|
predictionCol=prediction_col,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
elif metric_name == "mse":
|
|
|
|
evaluator = RegressionEvaluator(
|
|
|
|
metricName="mse",
|
|
|
|
labelCol=label_col,
|
|
|
|
predictionCol=prediction_col,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
elif metric_name == "mae":
|
|
|
|
evaluator = RegressionEvaluator(
|
|
|
|
metricName="mae",
|
|
|
|
labelCol=label_col,
|
|
|
|
predictionCol=prediction_col,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
elif metric_name == "r2":
|
|
|
|
evaluator = RegressionEvaluator(
|
|
|
|
metricName="r2",
|
|
|
|
labelCol=label_col,
|
|
|
|
predictionCol=prediction_col,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
elif metric_name == "var":
|
|
|
|
evaluator = RegressionEvaluator(
|
|
|
|
metricName="var",
|
|
|
|
labelCol=label_col,
|
|
|
|
predictionCol=prediction_col,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
elif metric_name == "roc_auc":
|
|
|
|
evaluator = BinaryClassificationEvaluator(
|
|
|
|
metricName="areaUnderROC",
|
|
|
|
labelCol=label_col,
|
|
|
|
rawPredictionCol=prediction_col,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
elif metric_name == "pr_auc":
|
|
|
|
evaluator = BinaryClassificationEvaluator(
|
|
|
|
metricName="areaUnderPR",
|
|
|
|
labelCol=label_col,
|
|
|
|
rawPredictionCol=prediction_col,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
elif metric_name == "accuracy":
|
|
|
|
evaluator = MulticlassClassificationEvaluator(
|
|
|
|
metricName="accuracy",
|
|
|
|
labelCol=label_col,
|
|
|
|
predictionCol=prediction_col,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
elif metric_name == "log_loss":
|
|
|
|
# For log_loss, prediction_col should be probability, and we need to convert it to label
|
2023-04-10 21:50:40 +02:00
|
|
|
df = _compute_label_from_probability(df, prediction_col, prediction_col + "_label")
|
Support spark dataframe as input dataset and spark models as estimators (#934)
* add basic support to Spark dataframe
add support to SynapseML LightGBM model
update to pyspark>=3.2.0 to leverage pandas_on_Spark API
* clean code, add TODOs
* add sample_train_data for pyspark.pandas dataframe, fix bugs
* improve some functions, fix bugs
* fix dict change size during iteration
* update model predict
* update LightGBM model, update test
* update SynapseML LightGBM params
* update synapseML and tests
* update TODOs
* Added support to roc_auc for spark models
* Added support to score of spark estimator
* Added test for automl score of spark estimator
* Added cv support to pyspark.pandas dataframe
* Update test, fix bugs
* Added tests
* Updated docs, tests, added a notebook
* Fix bugs in non-spark env
* Fix bugs and improve tests
* Fix uninstall pyspark
* Fix tests error
* Fix java.lang.OutOfMemoryError: Java heap space
* Fix test_performance
* Update test_sparkml to test_0sparkml to use the expected spark conf
* Remove unnecessary widgets in notebook
* Fix iloc java.lang.StackOverflowError
* fix pre-commit
* Added params check for spark dataframes
* Refactor code for train_test_split to a function
* Update train_test_split_pyspark
* Refactor if-else, remove unnecessary code
* Remove y from predict, remove mem control from n_iter compute
* Update workflow
* Improve _split_pyspark
* Fix test failure of too short training time
* Fix typos, improve docstrings
* Fix index errors of pandas_on_spark, add spark loss metric
* Fix typo of ndcgAtK
* Update NDCG metrics and tests
* Remove unuseful logger
* Use cache and count to ensure consistent indexes
* refactor for merge maain
* fix errors of refactor
* Updated SparkLightGBMEstimator and cache
* Updated config2params
* Remove unused import
* Fix unknown parameters
* Update default_estimator_list
* Add unit tests for spark metrics
2023-03-26 03:59:46 +08:00
|
|
|
evaluator = MulticlassClassificationEvaluator(
|
|
|
|
metricName="logLoss",
|
|
|
|
labelCol=label_col,
|
|
|
|
predictionCol=prediction_col + "_label",
|
|
|
|
probabilityCol=prediction_col,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
elif metric_name == "f1":
|
|
|
|
evaluator = MulticlassClassificationEvaluator(
|
|
|
|
metricName="f1",
|
|
|
|
labelCol=label_col,
|
|
|
|
predictionCol=prediction_col,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
elif metric_name == "micro_f1":
|
|
|
|
evaluator = MultilabelClassificationEvaluator(
|
|
|
|
metricName="microF1Measure",
|
|
|
|
labelCol=label_col,
|
|
|
|
predictionCol=prediction_col,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
elif metric_name == "macro_f1":
|
|
|
|
evaluator = MultilabelClassificationEvaluator(
|
|
|
|
metricName="f1MeasureByLabel",
|
|
|
|
labelCol=label_col,
|
|
|
|
predictionCol=prediction_col,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
elif metric_name == "ap":
|
|
|
|
evaluator = RankingEvaluator(
|
|
|
|
metricName="meanAveragePrecision",
|
|
|
|
labelCol=label_col,
|
|
|
|
predictionCol=prediction_col,
|
|
|
|
)
|
|
|
|
elif "ndcg" in metric_name:
|
|
|
|
# TODO: check if spark.ml ranker has the same format with
|
|
|
|
# synapseML ranker, may need to adjust the format of df
|
|
|
|
if "@" in metric_name:
|
|
|
|
k = int(metric_name.split("@", 1)[-1])
|
|
|
|
if groups is None:
|
|
|
|
evaluator = RankingEvaluator(
|
|
|
|
metricName="ndcgAtK",
|
|
|
|
labelCol=label_col,
|
|
|
|
predictionCol=prediction_col,
|
|
|
|
k=k,
|
|
|
|
)
|
|
|
|
df = _process_df(df, label_col, prediction_col)
|
|
|
|
score = 1 - evaluator.evaluate(df)
|
|
|
|
else:
|
|
|
|
counts = ps_group_counts(groups)
|
|
|
|
score = 0
|
|
|
|
psum = 0
|
|
|
|
for c in counts:
|
|
|
|
y_true_ = y_true[psum : psum + c]
|
|
|
|
y_predict_ = y_predict[psum : psum + c]
|
|
|
|
df = y_true_.to_frame().join(y_predict_).to_spark()
|
|
|
|
df = _process_df(df, label_col, prediction_col)
|
|
|
|
evaluator = RankingEvaluator(
|
|
|
|
metricName="ndcgAtK",
|
|
|
|
labelCol=label_col,
|
|
|
|
predictionCol=prediction_col,
|
|
|
|
k=k,
|
|
|
|
)
|
|
|
|
score -= evaluator.evaluate(df)
|
|
|
|
psum += c
|
|
|
|
score /= len(counts)
|
|
|
|
score += 1
|
|
|
|
else:
|
2023-04-10 21:50:40 +02:00
|
|
|
evaluator = RankingEvaluator(metricName="ndcgAtK", labelCol=label_col, predictionCol=prediction_col)
|
Support spark dataframe as input dataset and spark models as estimators (#934)
* add basic support to Spark dataframe
add support to SynapseML LightGBM model
update to pyspark>=3.2.0 to leverage pandas_on_Spark API
* clean code, add TODOs
* add sample_train_data for pyspark.pandas dataframe, fix bugs
* improve some functions, fix bugs
* fix dict change size during iteration
* update model predict
* update LightGBM model, update test
* update SynapseML LightGBM params
* update synapseML and tests
* update TODOs
* Added support to roc_auc for spark models
* Added support to score of spark estimator
* Added test for automl score of spark estimator
* Added cv support to pyspark.pandas dataframe
* Update test, fix bugs
* Added tests
* Updated docs, tests, added a notebook
* Fix bugs in non-spark env
* Fix bugs and improve tests
* Fix uninstall pyspark
* Fix tests error
* Fix java.lang.OutOfMemoryError: Java heap space
* Fix test_performance
* Update test_sparkml to test_0sparkml to use the expected spark conf
* Remove unnecessary widgets in notebook
* Fix iloc java.lang.StackOverflowError
* fix pre-commit
* Added params check for spark dataframes
* Refactor code for train_test_split to a function
* Update train_test_split_pyspark
* Refactor if-else, remove unnecessary code
* Remove y from predict, remove mem control from n_iter compute
* Update workflow
* Improve _split_pyspark
* Fix test failure of too short training time
* Fix typos, improve docstrings
* Fix index errors of pandas_on_spark, add spark loss metric
* Fix typo of ndcgAtK
* Update NDCG metrics and tests
* Remove unuseful logger
* Use cache and count to ensure consistent indexes
* refactor for merge maain
* fix errors of refactor
* Updated SparkLightGBMEstimator and cache
* Updated config2params
* Remove unused import
* Fix unknown parameters
* Update default_estimator_list
* Add unit tests for spark metrics
2023-03-26 03:59:46 +08:00
|
|
|
df = _process_df(df, label_col, prediction_col)
|
|
|
|
score = 1 - evaluator.evaluate(df)
|
|
|
|
return score
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Unknown metric name: {metric_name} for spark models.")
|
|
|
|
|
2023-04-10 21:50:40 +02:00
|
|
|
return evaluator.evaluate(df) if metric_name in min_mode_metrics else 1 - evaluator.evaluate(df)
|