diff --git a/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/ModelRepository.java b/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/ModelRepository.java index 3d47057ac17..efa800278d2 100644 --- a/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/ModelRepository.java +++ b/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/ModelRepository.java @@ -44,9 +44,9 @@ import static org.openmetadata.catalog.exception.CatalogExceptionMessage.entityN public class ModelRepository extends EntityRepository { private static final Logger LOG = LoggerFactory.getLogger(ModelRepository.class); private static final Fields MODEL_UPDATE_FIELDS = new Fields(ModelResource.FIELD_LIST, - "owner,dashboard,tags"); + "owner,dashboard,mlHyperParameters,mlFeatures,tags"); private static final Fields MODEL_PATCH_FIELDS = new Fields(ModelResource.FIELD_LIST, - "owner,dashboard,tags"); + "owner,dashboard,mlHyperParameters,mlFeatures,tags"); private final CollectionDAO dao; public ModelRepository(CollectionDAO dao) { @@ -80,6 +80,8 @@ public class ModelRepository extends EntityRepository { model.setDisplayName(model.getDisplayName()); model.setOwner(fields.contains("owner") ? getOwner(model) : null); model.setDashboard(fields.contains("dashboard") ? getDashboard(model) : null); + model.setMlFeatures(fields.contains("mlFeatures") ? model.getMlFeatures(): null); + model.setMlHyperParameters(fields.contains("mlHyperParameters") ? model.getMlHyperParameters(): null); model.setFollowers(fields.contains("followers") ? getFollowers(model) : null); model.setTags(fields.contains("tags") ? getTags(model.getFullyQualifiedName()) : null); model.setUsageSummary(fields.contains("usageSummary") ? EntityUtil.getLatestUsage(dao.usageDAO(), @@ -300,12 +302,22 @@ public class ModelRepository extends EntityRepository { public void entitySpecificUpdate() throws IOException { updateAlgorithm(original.getEntity(), updated.getEntity()); updateDashboard(original.getEntity(), updated.getEntity()); + updateMlFeatures(original.getEntity(), updated.getEntity()); + updateMlHyperParameters(original.getEntity(), updated.getEntity()); } private void updateAlgorithm(Model origModel, Model updatedModel) throws JsonProcessingException { recordChange("algorithm", origModel.getAlgorithm(), updatedModel.getAlgorithm()); } + private void updateMlFeatures(Model origModel, Model updatedModel) throws JsonProcessingException { + recordChange("mlFeatures", origModel.getMlFeatures(), updatedModel.getMlFeatures()); + } + + private void updateMlHyperParameters(Model origModel, Model updatedModel) throws JsonProcessingException { + recordChange("mlHyperParameters", origModel.getMlHyperParameters(), updatedModel.getMlHyperParameters()); + } + private void updateDashboard(Model origModel, Model updatedModel) throws JsonProcessingException { // Remove existing dashboards removeDashboard(origModel); diff --git a/catalog-rest-service/src/main/java/org/openmetadata/catalog/resources/models/ModelResource.java b/catalog-rest-service/src/main/java/org/openmetadata/catalog/resources/models/ModelResource.java index a4fbab556f6..fcef18dd41f 100644 --- a/catalog-rest-service/src/main/java/org/openmetadata/catalog/resources/models/ModelResource.java +++ b/catalog-rest-service/src/main/java/org/openmetadata/catalog/resources/models/ModelResource.java @@ -118,7 +118,7 @@ public class ModelResource { } } - static final String FIELDS = "owner,dashboard,algorithm,followers,tags,usageSummary"; + static final String FIELDS = "owner,dashboard,algorithm,mlFeatures,mlHyperParameters,followers,tags,usageSummary"; public static final List FIELD_LIST = Arrays.asList(FIELDS.replaceAll(" ", "") .split(",")); @@ -288,7 +288,7 @@ public class ModelResource { @DELETE @Path("/{id}/followers/{userId}") - @Operation(summary = "Remove a follower", tags = "model", + @Operation(summary = "Remove a follower", tags = "models", description = "Remove the user identified `userId` as a follower of the model.") public Model deleteFollower(@Context UriInfo uriInfo, @Context SecurityContext securityContext, @@ -306,7 +306,7 @@ public class ModelResource { @DELETE @Path("/{id}") - @Operation(summary = "Delete a Model", tags = "model", + @Operation(summary = "Delete a Model", tags = "models", description = "Delete a model by `id`.", responses = { @ApiResponse(responseCode = "200", description = "OK"), @@ -321,8 +321,10 @@ public class ModelResource { return new Model().withId(UUID.randomUUID()).withName(create.getName()) .withDisplayName(create.getDisplayName()) .withDescription(create.getDescription()) - .withDashboard(create.getDashboard()) //ADDED - .withAlgorithm(create.getAlgorithm()) //ADDED + .withDashboard(create.getDashboard()) + .withAlgorithm(create.getAlgorithm()) + .withMlFeatures(create.getMlFeatures()) + .withMlHyperParameters(create.getMlHyperParameters()) .withTags(create.getTags()) .withOwner(create.getOwner()) .withUpdatedBy(securityContext.getUserPrincipal().getName()) diff --git a/catalog-rest-service/src/main/resources/json/schema/api/data/createModel.json b/catalog-rest-service/src/main/resources/json/schema/api/data/createModel.json index c1d04af428a..46b4dfe1454 100644 --- a/catalog-rest-service/src/main/resources/json/schema/api/data/createModel.json +++ b/catalog-rest-service/src/main/resources/json/schema/api/data/createModel.json @@ -23,6 +23,22 @@ "description": "Algorithm used to train the model", "type": "string" }, + "mlFeatures": { + "description": "Features used to train the ML Model.", + "type": "array", + "items": { + "$ref": "../../entity/data/model.json#/definitions/mlFeature" + }, + "default" : null + }, + "mlHyperParameters": { + "description": "Hyper Parameters used to train the ML Model.", + "type": "array", + "items": { + "$ref": "../../entity/data/model.json#/definitions/mlHyperParameter" + }, + "default" : null + }, "dashboard" : { "description": "Performance Dashboard URL to track metric evolution", "$ref" : "../../type/entityReference.json" diff --git a/catalog-rest-service/src/main/resources/json/schema/entity/data/model.json b/catalog-rest-service/src/main/resources/json/schema/entity/data/model.json index 09a06830cfb..b87d77a0955 100644 --- a/catalog-rest-service/src/main/resources/json/schema/entity/data/model.json +++ b/catalog-rest-service/src/main/resources/json/schema/entity/data/model.json @@ -4,7 +4,158 @@ "title": "Model", "description": "This schema defines the Model entity. Models are algorithms trained on data to find patterns or make predictions.", "type": "object", - + "definitions": { + "featureType": { + "javaType": "org.openmetadata.catalog.type.MLFeatureDataType", + "description": "This enum defines the type of data stored in a ML Feature.", + "type": "string", + "enum": [ + "numerical", + "categorical" + ], + "javaEnums": [ + { + "name": "Numerical" + }, + { + "name": "Categorical" + } + ] + }, + "featureSourceDataType": { + "javaType": "org.openmetadata.catalog.type.FeatureSourceDataType", + "description": "This enum defines the type of data of a ML Feature source.", + "type": "string", + "enum": [ + "integer", + "number", + "string", + "array", + "date", + "timestamp", + "object", + "boolean" + ] + }, + "featureName": { + "description": "Local name (not fully qualified name) of the ML Feature.", + "type": "string", + "minLength": 1, + "maxLength": 64, + "pattern": "^[^.]*$" + }, + "featureSourceName": { + "description": "Local name (not fully qualified name) of a ML Feature source", + "type": "string", + "minLength": 1, + "maxLength": 64, + "pattern": "^[^.]*$" + }, + "fullyQualifiedFeatureSourceName": { + "description": "Fully qualified name of the ML Feature Source that includes `serviceName.[databaseName].tableName/fileName/apiName.columnName[.nestedColumnName]`.", + "type": "string", + "minLength": 1, + "maxLength": 256 + }, + "fullyQualifiedFeatureName": { + "description": "Fully qualified name of the ML Feature that includes `modelName.featureName`.", + "type": "string", + "minLength": 1, + "maxLength": 256 + }, + "featureSource": { + "type": "object", + "javaType": "org.openmetadata.catalog.type.MLFeatureSource", + "description": "This schema defines the sources of a ML Feature.", + "additionalProperties": false, + "properties": { + "name": { + "$ref": "#/definitions/featureSourceName" + }, + "dataType": { + "description": "Data type of the source (int, date etc.).", + "$ref": "#/definitions/featureSourceDataType" + }, + "description": { + "description": "Description of the feature source.", + "type": "string" + }, + "fullyQualifiedName": { + "$ref": "#/definitions/fullyQualifiedFeatureSourceName" + }, + "tags": { + "description": "Tags associated with the feature source.", + "type": "array", + "items": { + "$ref": "../../type/tagLabel.json" + }, + "default": null + } + } + }, + "mlFeature": { + "type": "object", + "javaType": "org.openmetadata.catalog.type.MLFeature", + "description": "This schema defines the type for a ML Feature used in a Model.", + "additionalProperties": false, + "properties": { + "name": { + "$ref": "#/definitions/featureName" + }, + "dataType": { + "description": "Data type of the column (numerical vs. categorical).", + "$ref": "#/definitions/featureType" + }, + "description": { + "description": "Description of the ML Feature.", + "type": "string" + }, + "fullyQualifiedName": { + "$ref": "#/definitions/fullyQualifiedFeatureName" + }, + "featureSources": { + "description": "Columns used to create the ML Feature", + "type": "array", + "items": { + "$ref": "#/definitions/featureSource" + }, + "default": null + }, + "featureAlgorithm": { + "description": "Description of the algorithm used to compute the feature, e.g., PCA, bucketing...", + "type": "string" + }, + "tags": { + "description": "Tags associated with the feature.", + "type": "array", + "items": { + "$ref": "../../type/tagLabel.json" + }, + "default": null + } + } + }, + "mlHyperParameter": { + "type": "object", + "javaType": "org.openmetadata.catalog.type.MLHyperParameter", + "description": "This schema defines the type for a ML HyperParameter used in a Model.", + "additionalProperties": false, + "properties": { + "name": { + "description": "Hyper parameter name", + "type": "string" + }, + "value": { + "description": "Hyper parameter value", + "type": "string" + }, + "description": { + "description": "Description of the Hyper Parameter.", + "type": "string" + } + } + } + }, "properties" : { "id": { "description": "Unique identifier of a model instance.", @@ -26,18 +177,6 @@ "description": "Display Name that identifies this model.", "type": "string" }, - "version" : { - "description": "Metadata version of the entity.", - "$ref": "../../type/entityHistory.json#/definitions/entityVersion" - }, - "updatedAt" : { - "description": "Last update time corresponding to the new version of the entity.", - "$ref": "../../type/basic.json#/definitions/dateTime" - }, - "updatedBy" : { - "description": "User who made the update.", - "type": "string" - }, "description": { "description": "Description of the model, what it is, and how to use it.", "type": "string" @@ -46,6 +185,22 @@ "description": "Algorithm used to train the model.", "type": "string" }, + "mlFeatures": { + "description": "Features used to train the ML Model.", + "type": "array", + "items": { + "$ref": "#/definitions/mlFeature" + }, + "default": null + }, + "mlHyperParameters": { + "description": "Hyper Parameters used to train the ML Model.", + "type": "array", + "items": { + "$ref": "#/definitions/mlHyperParameter" + }, + "default": null + }, "dashboard" : { "description": "Performance Dashboard URL to track metric evolution.", "$ref" : "../../type/entityReference.json" @@ -75,6 +230,18 @@ "$ref": "../../type/usageDetails.json", "default": null }, + "version" : { + "description": "Metadata version of the entity.", + "$ref": "../../type/entityHistory.json#/definitions/entityVersion" + }, + "updatedAt" : { + "description": "Last update time corresponding to the new version of the entity.", + "$ref": "../../type/basic.json#/definitions/dateTime" + }, + "updatedBy" : { + "description": "User who made the update.", + "type": "string" + }, "changeDescription": { "description" : "Change that lead to this version of the entity.", "$ref": "../../type/entityHistory.json#/definitions/changeDescription" diff --git a/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/models/ModelResourceTest.java b/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/models/ModelResourceTest.java index 52120628182..9d1ba6ef586 100644 --- a/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/models/ModelResourceTest.java +++ b/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/models/ModelResourceTest.java @@ -27,6 +27,11 @@ import org.openmetadata.catalog.api.services.CreateDashboardService; import org.openmetadata.catalog.api.services.CreateDashboardService.DashboardServiceType; import org.openmetadata.catalog.entity.data.Dashboard; import org.openmetadata.catalog.entity.data.Model; +import org.openmetadata.catalog.type.FeatureSourceDataType; +import org.openmetadata.catalog.type.MLFeature; +import org.openmetadata.catalog.type.MLFeatureDataType; +import org.openmetadata.catalog.type.MLFeatureSource; +import org.openmetadata.catalog.type.MLHyperParameter; import org.openmetadata.catalog.entity.services.DashboardService; import org.openmetadata.catalog.entity.teams.Team; import org.openmetadata.catalog.entity.teams.User; @@ -47,6 +52,8 @@ import org.slf4j.LoggerFactory; import javax.ws.rs.client.WebTarget; import javax.ws.rs.core.Response.Status; +import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.UUID; @@ -79,6 +86,39 @@ public class ModelResourceTest extends CatalogApplicationTest { public static EntityReference SUPERSET_REFERENCE; public static Dashboard DASHBOARD; public static EntityReference DASHBOARD_REFERENCE; + public static List ML_FEATURES = Arrays.asList( + new MLFeature() + .withName("age") + .withDataType(MLFeatureDataType.Numerical) + .withFeatureSources( + Collections.singletonList( + new MLFeatureSource() + .withName("age") + .withDataType(FeatureSourceDataType.INTEGER) + .withFullyQualifiedName("my_service.my_db.my_table.age") + ) + ), + new MLFeature() + .withName("persona") + .withDataType(MLFeatureDataType.Categorical) + .withFeatureSources( + Arrays.asList( + new MLFeatureSource() + .withName("age") + .withDataType(FeatureSourceDataType.INTEGER) + .withFullyQualifiedName("my_service.my_db.my_table.age"), + new MLFeatureSource() + .withName("education") + .withDataType(FeatureSourceDataType.STRING) + .withFullyQualifiedName("my_api.education") + ) + ) + .withFeatureAlgorithm("PCA") + ); + public static List ML_HYPERPARAMS = Arrays.asList( + new MLHyperParameter().withName("regularisation").withValue("0.5"), + new MLHyperParameter().withName("random").withValue("hello") + ); @BeforeAll @@ -465,6 +505,15 @@ public class ModelResourceTest extends CatalogApplicationTest { assertNotNull(model.getAlgorithm()); // Provided as default field assertNull(model.getDashboard()); + // .../models?fields=mlFeatures,mlHyperParameters + fields = "mlFeatures,mlHyperParameters"; + model = byName ? getModelByName(model.getFullyQualifiedName(), fields, adminAuthHeaders()) : + getModel(model.getId(), fields, adminAuthHeaders()); + assertNotNull(model.getAlgorithm()); // Provided as default field + assertNotNull(model.getMlFeatures()); + assertNotNull(model.getMlHyperParameters()); + assertNull(model.getDashboard()); + // .../models?fields=owner,algorithm fields = "owner,algorithm"; model = byName ? getModelByName(model.getFullyQualifiedName(), fields, adminAuthHeaders()) : @@ -577,11 +626,13 @@ public class ModelResourceTest extends CatalogApplicationTest { } public static CreateModel create(TestInfo test) { - return new CreateModel().withName(getModelName(test)).withAlgorithm(ALGORITHM); + return new CreateModel().withName(getModelName(test)).withAlgorithm(ALGORITHM) + .withMlFeatures(ML_FEATURES).withMlHyperParameters(ML_HYPERPARAMS); } public static CreateModel create(TestInfo test, int index) { - return new CreateModel().withName(getModelName(test, index)).withAlgorithm(ALGORITHM); + return new CreateModel().withName(getModelName(test, index)).withAlgorithm(ALGORITHM) + .withMlFeatures(ML_FEATURES).withMlHyperParameters(ML_HYPERPARAMS); } }