From 7771e42e1be7aaaf7a16af471fc8808114d6591c Mon Sep 17 00:00:00 2001
From: Pere Miquel Brull
Date: Mon, 8 Nov 2021 07:44:15 +0100
Subject: [PATCH] [Issue-801] - MLFeature & MLHyperParameters (#1071)
* Update model schema
* Update create model schema
* Disable additional params
* Update Model Resource & Repository
* Add tests for MLFeatures and MLHyperParameters
---
.../catalog/jdbi3/ModelRepository.java | 16 +-
.../resources/models/ModelResource.java | 12 +-
.../json/schema/api/data/createModel.json | 16 ++
.../json/schema/entity/data/model.json | 193 ++++++++++++++++--
.../resources/models/ModelResourceTest.java | 55 ++++-
5 files changed, 270 insertions(+), 22 deletions(-)
diff --git a/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/ModelRepository.java b/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/ModelRepository.java
index 3d47057ac17..efa800278d2 100644
--- a/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/ModelRepository.java
+++ b/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/ModelRepository.java
@@ -44,9 +44,9 @@ import static org.openmetadata.catalog.exception.CatalogExceptionMessage.entityN
public class ModelRepository extends EntityRepository {
private static final Logger LOG = LoggerFactory.getLogger(ModelRepository.class);
private static final Fields MODEL_UPDATE_FIELDS = new Fields(ModelResource.FIELD_LIST,
- "owner,dashboard,tags");
+ "owner,dashboard,mlHyperParameters,mlFeatures,tags");
private static final Fields MODEL_PATCH_FIELDS = new Fields(ModelResource.FIELD_LIST,
- "owner,dashboard,tags");
+ "owner,dashboard,mlHyperParameters,mlFeatures,tags");
private final CollectionDAO dao;
public ModelRepository(CollectionDAO dao) {
@@ -80,6 +80,8 @@ public class ModelRepository extends EntityRepository {
model.setDisplayName(model.getDisplayName());
model.setOwner(fields.contains("owner") ? getOwner(model) : null);
model.setDashboard(fields.contains("dashboard") ? getDashboard(model) : null);
+ model.setMlFeatures(fields.contains("mlFeatures") ? model.getMlFeatures(): null);
+ model.setMlHyperParameters(fields.contains("mlHyperParameters") ? model.getMlHyperParameters(): null);
model.setFollowers(fields.contains("followers") ? getFollowers(model) : null);
model.setTags(fields.contains("tags") ? getTags(model.getFullyQualifiedName()) : null);
model.setUsageSummary(fields.contains("usageSummary") ? EntityUtil.getLatestUsage(dao.usageDAO(),
@@ -300,12 +302,22 @@ public class ModelRepository extends EntityRepository {
public void entitySpecificUpdate() throws IOException {
updateAlgorithm(original.getEntity(), updated.getEntity());
updateDashboard(original.getEntity(), updated.getEntity());
+ updateMlFeatures(original.getEntity(), updated.getEntity());
+ updateMlHyperParameters(original.getEntity(), updated.getEntity());
}
private void updateAlgorithm(Model origModel, Model updatedModel) throws JsonProcessingException {
recordChange("algorithm", origModel.getAlgorithm(), updatedModel.getAlgorithm());
}
+ private void updateMlFeatures(Model origModel, Model updatedModel) throws JsonProcessingException {
+ recordChange("mlFeatures", origModel.getMlFeatures(), updatedModel.getMlFeatures());
+ }
+
+ private void updateMlHyperParameters(Model origModel, Model updatedModel) throws JsonProcessingException {
+ recordChange("mlHyperParameters", origModel.getMlHyperParameters(), updatedModel.getMlHyperParameters());
+ }
+
private void updateDashboard(Model origModel, Model updatedModel) throws JsonProcessingException {
// Remove existing dashboards
removeDashboard(origModel);
diff --git a/catalog-rest-service/src/main/java/org/openmetadata/catalog/resources/models/ModelResource.java b/catalog-rest-service/src/main/java/org/openmetadata/catalog/resources/models/ModelResource.java
index a4fbab556f6..fcef18dd41f 100644
--- a/catalog-rest-service/src/main/java/org/openmetadata/catalog/resources/models/ModelResource.java
+++ b/catalog-rest-service/src/main/java/org/openmetadata/catalog/resources/models/ModelResource.java
@@ -118,7 +118,7 @@ public class ModelResource {
}
}
- static final String FIELDS = "owner,dashboard,algorithm,followers,tags,usageSummary";
+ static final String FIELDS = "owner,dashboard,algorithm,mlFeatures,mlHyperParameters,followers,tags,usageSummary";
public static final List FIELD_LIST = Arrays.asList(FIELDS.replaceAll(" ", "")
.split(","));
@@ -288,7 +288,7 @@ public class ModelResource {
@DELETE
@Path("/{id}/followers/{userId}")
- @Operation(summary = "Remove a follower", tags = "model",
+ @Operation(summary = "Remove a follower", tags = "models",
description = "Remove the user identified `userId` as a follower of the model.")
public Model deleteFollower(@Context UriInfo uriInfo,
@Context SecurityContext securityContext,
@@ -306,7 +306,7 @@ public class ModelResource {
@DELETE
@Path("/{id}")
- @Operation(summary = "Delete a Model", tags = "model",
+ @Operation(summary = "Delete a Model", tags = "models",
description = "Delete a model by `id`.",
responses = {
@ApiResponse(responseCode = "200", description = "OK"),
@@ -321,8 +321,10 @@ public class ModelResource {
return new Model().withId(UUID.randomUUID()).withName(create.getName())
.withDisplayName(create.getDisplayName())
.withDescription(create.getDescription())
- .withDashboard(create.getDashboard()) //ADDED
- .withAlgorithm(create.getAlgorithm()) //ADDED
+ .withDashboard(create.getDashboard())
+ .withAlgorithm(create.getAlgorithm())
+ .withMlFeatures(create.getMlFeatures())
+ .withMlHyperParameters(create.getMlHyperParameters())
.withTags(create.getTags())
.withOwner(create.getOwner())
.withUpdatedBy(securityContext.getUserPrincipal().getName())
diff --git a/catalog-rest-service/src/main/resources/json/schema/api/data/createModel.json b/catalog-rest-service/src/main/resources/json/schema/api/data/createModel.json
index c1d04af428a..46b4dfe1454 100644
--- a/catalog-rest-service/src/main/resources/json/schema/api/data/createModel.json
+++ b/catalog-rest-service/src/main/resources/json/schema/api/data/createModel.json
@@ -23,6 +23,22 @@
"description": "Algorithm used to train the model",
"type": "string"
},
+ "mlFeatures": {
+ "description": "Features used to train the ML Model.",
+ "type": "array",
+ "items": {
+ "$ref": "../../entity/data/model.json#/definitions/mlFeature"
+ },
+ "default" : null
+ },
+ "mlHyperParameters": {
+ "description": "Hyper Parameters used to train the ML Model.",
+ "type": "array",
+ "items": {
+ "$ref": "../../entity/data/model.json#/definitions/mlHyperParameter"
+ },
+ "default" : null
+ },
"dashboard" : {
"description": "Performance Dashboard URL to track metric evolution",
"$ref" : "../../type/entityReference.json"
diff --git a/catalog-rest-service/src/main/resources/json/schema/entity/data/model.json b/catalog-rest-service/src/main/resources/json/schema/entity/data/model.json
index 09a06830cfb..b87d77a0955 100644
--- a/catalog-rest-service/src/main/resources/json/schema/entity/data/model.json
+++ b/catalog-rest-service/src/main/resources/json/schema/entity/data/model.json
@@ -4,7 +4,158 @@
"title": "Model",
"description": "This schema defines the Model entity. Models are algorithms trained on data to find patterns or make predictions.",
"type": "object",
-
+ "definitions": {
+ "featureType": {
+ "javaType": "org.openmetadata.catalog.type.MLFeatureDataType",
+ "description": "This enum defines the type of data stored in a ML Feature.",
+ "type": "string",
+ "enum": [
+ "numerical",
+ "categorical"
+ ],
+ "javaEnums": [
+ {
+ "name": "Numerical"
+ },
+ {
+ "name": "Categorical"
+ }
+ ]
+ },
+ "featureSourceDataType": {
+ "javaType": "org.openmetadata.catalog.type.FeatureSourceDataType",
+ "description": "This enum defines the type of data of a ML Feature source.",
+ "type": "string",
+ "enum": [
+ "integer",
+ "number",
+ "string",
+ "array",
+ "date",
+ "timestamp",
+ "object",
+ "boolean"
+ ]
+ },
+ "featureName": {
+ "description": "Local name (not fully qualified name) of the ML Feature.",
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 64,
+ "pattern": "^[^.]*$"
+ },
+ "featureSourceName": {
+ "description": "Local name (not fully qualified name) of a ML Feature source",
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 64,
+ "pattern": "^[^.]*$"
+ },
+ "fullyQualifiedFeatureSourceName": {
+ "description": "Fully qualified name of the ML Feature Source that includes `serviceName.[databaseName].tableName/fileName/apiName.columnName[.nestedColumnName]`.",
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 256
+ },
+ "fullyQualifiedFeatureName": {
+ "description": "Fully qualified name of the ML Feature that includes `modelName.featureName`.",
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 256
+ },
+ "featureSource": {
+ "type": "object",
+ "javaType": "org.openmetadata.catalog.type.MLFeatureSource",
+ "description": "This schema defines the sources of a ML Feature.",
+ "additionalProperties": false,
+ "properties": {
+ "name": {
+ "$ref": "#/definitions/featureSourceName"
+ },
+ "dataType": {
+ "description": "Data type of the source (int, date etc.).",
+ "$ref": "#/definitions/featureSourceDataType"
+ },
+ "description": {
+ "description": "Description of the feature source.",
+ "type": "string"
+ },
+ "fullyQualifiedName": {
+ "$ref": "#/definitions/fullyQualifiedFeatureSourceName"
+ },
+ "tags": {
+ "description": "Tags associated with the feature source.",
+ "type": "array",
+ "items": {
+ "$ref": "../../type/tagLabel.json"
+ },
+ "default": null
+ }
+ }
+ },
+ "mlFeature": {
+ "type": "object",
+ "javaType": "org.openmetadata.catalog.type.MLFeature",
+ "description": "This schema defines the type for a ML Feature used in a Model.",
+ "additionalProperties": false,
+ "properties": {
+ "name": {
+ "$ref": "#/definitions/featureName"
+ },
+ "dataType": {
+ "description": "Data type of the column (numerical vs. categorical).",
+ "$ref": "#/definitions/featureType"
+ },
+ "description": {
+ "description": "Description of the ML Feature.",
+ "type": "string"
+ },
+ "fullyQualifiedName": {
+ "$ref": "#/definitions/fullyQualifiedFeatureName"
+ },
+ "featureSources": {
+ "description": "Columns used to create the ML Feature",
+ "type": "array",
+ "items": {
+ "$ref": "#/definitions/featureSource"
+ },
+ "default": null
+ },
+ "featureAlgorithm": {
+ "description": "Description of the algorithm used to compute the feature, e.g., PCA, bucketing...",
+ "type": "string"
+ },
+ "tags": {
+ "description": "Tags associated with the feature.",
+ "type": "array",
+ "items": {
+ "$ref": "../../type/tagLabel.json"
+ },
+ "default": null
+ }
+ }
+ },
+ "mlHyperParameter": {
+ "type": "object",
+ "javaType": "org.openmetadata.catalog.type.MLHyperParameter",
+ "description": "This schema defines the type for a ML HyperParameter used in a Model.",
+ "additionalProperties": false,
+ "properties": {
+ "name": {
+ "description": "Hyper parameter name",
+ "type": "string"
+ },
+ "value": {
+ "description": "Hyper parameter value",
+ "type": "string"
+ },
+ "description": {
+ "description": "Description of the Hyper Parameter.",
+ "type": "string"
+ }
+ }
+ }
+ },
"properties" : {
"id": {
"description": "Unique identifier of a model instance.",
@@ -26,18 +177,6 @@
"description": "Display Name that identifies this model.",
"type": "string"
},
- "version" : {
- "description": "Metadata version of the entity.",
- "$ref": "../../type/entityHistory.json#/definitions/entityVersion"
- },
- "updatedAt" : {
- "description": "Last update time corresponding to the new version of the entity.",
- "$ref": "../../type/basic.json#/definitions/dateTime"
- },
- "updatedBy" : {
- "description": "User who made the update.",
- "type": "string"
- },
"description": {
"description": "Description of the model, what it is, and how to use it.",
"type": "string"
@@ -46,6 +185,22 @@
"description": "Algorithm used to train the model.",
"type": "string"
},
+ "mlFeatures": {
+ "description": "Features used to train the ML Model.",
+ "type": "array",
+ "items": {
+ "$ref": "#/definitions/mlFeature"
+ },
+ "default": null
+ },
+ "mlHyperParameters": {
+ "description": "Hyper Parameters used to train the ML Model.",
+ "type": "array",
+ "items": {
+ "$ref": "#/definitions/mlHyperParameter"
+ },
+ "default": null
+ },
"dashboard" : {
"description": "Performance Dashboard URL to track metric evolution.",
"$ref" : "../../type/entityReference.json"
@@ -75,6 +230,18 @@
"$ref": "../../type/usageDetails.json",
"default": null
},
+ "version" : {
+ "description": "Metadata version of the entity.",
+ "$ref": "../../type/entityHistory.json#/definitions/entityVersion"
+ },
+ "updatedAt" : {
+ "description": "Last update time corresponding to the new version of the entity.",
+ "$ref": "../../type/basic.json#/definitions/dateTime"
+ },
+ "updatedBy" : {
+ "description": "User who made the update.",
+ "type": "string"
+ },
"changeDescription": {
"description" : "Change that lead to this version of the entity.",
"$ref": "../../type/entityHistory.json#/definitions/changeDescription"
diff --git a/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/models/ModelResourceTest.java b/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/models/ModelResourceTest.java
index 52120628182..9d1ba6ef586 100644
--- a/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/models/ModelResourceTest.java
+++ b/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/models/ModelResourceTest.java
@@ -27,6 +27,11 @@ import org.openmetadata.catalog.api.services.CreateDashboardService;
import org.openmetadata.catalog.api.services.CreateDashboardService.DashboardServiceType;
import org.openmetadata.catalog.entity.data.Dashboard;
import org.openmetadata.catalog.entity.data.Model;
+import org.openmetadata.catalog.type.FeatureSourceDataType;
+import org.openmetadata.catalog.type.MLFeature;
+import org.openmetadata.catalog.type.MLFeatureDataType;
+import org.openmetadata.catalog.type.MLFeatureSource;
+import org.openmetadata.catalog.type.MLHyperParameter;
import org.openmetadata.catalog.entity.services.DashboardService;
import org.openmetadata.catalog.entity.teams.Team;
import org.openmetadata.catalog.entity.teams.User;
@@ -47,6 +52,8 @@ import org.slf4j.LoggerFactory;
import javax.ws.rs.client.WebTarget;
import javax.ws.rs.core.Response.Status;
+import java.util.Arrays;
+import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.UUID;
@@ -79,6 +86,39 @@ public class ModelResourceTest extends CatalogApplicationTest {
public static EntityReference SUPERSET_REFERENCE;
public static Dashboard DASHBOARD;
public static EntityReference DASHBOARD_REFERENCE;
+ public static List ML_FEATURES = Arrays.asList(
+ new MLFeature()
+ .withName("age")
+ .withDataType(MLFeatureDataType.Numerical)
+ .withFeatureSources(
+ Collections.singletonList(
+ new MLFeatureSource()
+ .withName("age")
+ .withDataType(FeatureSourceDataType.INTEGER)
+ .withFullyQualifiedName("my_service.my_db.my_table.age")
+ )
+ ),
+ new MLFeature()
+ .withName("persona")
+ .withDataType(MLFeatureDataType.Categorical)
+ .withFeatureSources(
+ Arrays.asList(
+ new MLFeatureSource()
+ .withName("age")
+ .withDataType(FeatureSourceDataType.INTEGER)
+ .withFullyQualifiedName("my_service.my_db.my_table.age"),
+ new MLFeatureSource()
+ .withName("education")
+ .withDataType(FeatureSourceDataType.STRING)
+ .withFullyQualifiedName("my_api.education")
+ )
+ )
+ .withFeatureAlgorithm("PCA")
+ );
+ public static List ML_HYPERPARAMS = Arrays.asList(
+ new MLHyperParameter().withName("regularisation").withValue("0.5"),
+ new MLHyperParameter().withName("random").withValue("hello")
+ );
@BeforeAll
@@ -465,6 +505,15 @@ public class ModelResourceTest extends CatalogApplicationTest {
assertNotNull(model.getAlgorithm()); // Provided as default field
assertNull(model.getDashboard());
+ // .../models?fields=mlFeatures,mlHyperParameters
+ fields = "mlFeatures,mlHyperParameters";
+ model = byName ? getModelByName(model.getFullyQualifiedName(), fields, adminAuthHeaders()) :
+ getModel(model.getId(), fields, adminAuthHeaders());
+ assertNotNull(model.getAlgorithm()); // Provided as default field
+ assertNotNull(model.getMlFeatures());
+ assertNotNull(model.getMlHyperParameters());
+ assertNull(model.getDashboard());
+
// .../models?fields=owner,algorithm
fields = "owner,algorithm";
model = byName ? getModelByName(model.getFullyQualifiedName(), fields, adminAuthHeaders()) :
@@ -577,11 +626,13 @@ public class ModelResourceTest extends CatalogApplicationTest {
}
public static CreateModel create(TestInfo test) {
- return new CreateModel().withName(getModelName(test)).withAlgorithm(ALGORITHM);
+ return new CreateModel().withName(getModelName(test)).withAlgorithm(ALGORITHM)
+ .withMlFeatures(ML_FEATURES).withMlHyperParameters(ML_HYPERPARAMS);
}
public static CreateModel create(TestInfo test, int index) {
- return new CreateModel().withName(getModelName(test, index)).withAlgorithm(ALGORITHM);
+ return new CreateModel().withName(getModelName(test, index)).withAlgorithm(ALGORITHM)
+ .withMlFeatures(ML_FEATURES).withMlHyperParameters(ML_HYPERPARAMS);
}
}