mirror of
				https://github.com/open-metadata/OpenMetadata.git
				synced 2025-10-31 10:39:30 +00:00 
			
		
		
		
	 48ebcffbd0
			
		
	
	
		48ebcffbd0
		
			
		
	
	
	
	
		
			
			* Update path * Prepare sonar properties * Prepare coverage recipes * Add coverage * Simplify pytest * Organise integration tests * Update path * Use setup instead of reqs * Update recipes * Fix PR event to target * Update event_name * Prepare sonar * Run tests & sonar * Use sonarcloud host * Fix compose * Use ingestion token
		
			
				
	
	
		
			104 lines
		
	
	
		
			3.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			104 lines
		
	
	
		
			3.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| Example extracted from https://www.mlflow.org/docs/latest/tutorials-and-examples/tutorial.html
 | |
| 
 | |
| To run this you need to have installed `sklearn` in your environment.
 | |
| """
 | |
| 
 | |
| import logging
 | |
| import os
 | |
| import sys
 | |
| import warnings
 | |
| from urllib.parse import urlparse
 | |
| 
 | |
| import mlflow.sklearn
 | |
| import numpy as np
 | |
| import pandas as pd
 | |
| from mlflow.models import infer_signature
 | |
| from sklearn.linear_model import ElasticNet
 | |
| from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
 | |
| from sklearn.model_selection import train_test_split
 | |
| 
 | |
| logging.basicConfig(level=logging.WARN)
 | |
| logger = logging.getLogger(__name__)
 | |
| 
 | |
| 
 | |
| def eval_metrics(actual, pred):
 | |
|     rmse = np.sqrt(mean_squared_error(actual, pred))
 | |
|     mae = mean_absolute_error(actual, pred)
 | |
|     r2 = r2_score(actual, pred)
 | |
|     return rmse, mae, r2
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
| 
 | |
|     mlflow_uri = "http://localhost:5000"
 | |
|     mlflow.set_tracking_uri(mlflow_uri)
 | |
| 
 | |
|     os.environ["AWS_ACCESS_KEY_ID"] = "minio"
 | |
|     os.environ["AWS_SECRET_ACCESS_KEY"] = "password"
 | |
|     os.environ["MLFLOW_S3_ENDPOINT_URL"] = "http://localhost:9001"
 | |
| 
 | |
|     warnings.filterwarnings("ignore")
 | |
|     np.random.seed(40)
 | |
| 
 | |
|     # Read the wine-quality csv file from the URL
 | |
|     csv_url = "http://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv"
 | |
|     try:
 | |
|         data = pd.read_csv(csv_url, sep=";")
 | |
|     except Exception as e:
 | |
|         logger.exception(
 | |
|             "Unable to download training & test CSV, check your internet connection. Error: %s",
 | |
|             e,
 | |
|         )
 | |
| 
 | |
|     # Split the data into training and test sets. (0.75, 0.25) split.
 | |
|     train, test = train_test_split(data)
 | |
| 
 | |
|     # The predicted column is "quality" which is a scalar from [3, 9]
 | |
|     train_x = train.drop(["quality"], axis=1)
 | |
|     test_x = test.drop(["quality"], axis=1)
 | |
|     train_y = train[["quality"]]
 | |
|     test_y = test[["quality"]]
 | |
| 
 | |
|     alpha = float(sys.argv[1]) if len(sys.argv) > 1 else 0.5
 | |
|     l1_ratio = float(sys.argv[2]) if len(sys.argv) > 2 else 0.5
 | |
| 
 | |
|     with mlflow.start_run():
 | |
|         lr = ElasticNet(alpha=alpha, l1_ratio=l1_ratio, random_state=42)
 | |
|         lr.fit(train_x, train_y)
 | |
| 
 | |
|         signature = infer_signature(train_x, lr.predict(train_x))
 | |
| 
 | |
|         predicted_qualities = lr.predict(test_x)
 | |
| 
 | |
|         (rmse, mae, r2) = eval_metrics(test_y, predicted_qualities)
 | |
| 
 | |
|         print("Elasticnet model (alpha=%f, l1_ratio=%f):" % (alpha, l1_ratio))
 | |
|         print("  RMSE: %s" % rmse)
 | |
|         print("  MAE: %s" % mae)
 | |
|         print("  R2: %s" % r2)
 | |
| 
 | |
|         mlflow.log_param("alpha", alpha)
 | |
|         mlflow.log_param("l1_ratio", l1_ratio)
 | |
|         mlflow.log_metric("rmse", rmse)
 | |
|         mlflow.log_metric("r2", r2)
 | |
|         mlflow.log_metric("mae", mae)
 | |
| 
 | |
|         tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme
 | |
| 
 | |
|         # Model registry does not work with file store
 | |
|         if tracking_url_type_store != "file":
 | |
| 
 | |
|             # Register the model
 | |
|             # There are other ways to use the Model Registry, which depends on the use case,
 | |
|             # please refer to the doc for more information:
 | |
|             # https://mlflow.org/docs/latest/model-registry.html#api-workflow
 | |
|             mlflow.sklearn.log_model(
 | |
|                 lr,
 | |
|                 "model",
 | |
|                 registered_model_name="ElasticnetWineModel",
 | |
|                 signature=signature,
 | |
|             )
 | |
|         else:
 | |
|             mlflow.sklearn.log_model(lr, "model")
 |