[Issue-801] - MLFeature & MLHyperParameters (#1071)

* Update model schema

* Update create model schema

* Disable additional params

* Update Model Resource & Repository

* Add tests for MLFeatures and MLHyperParameters
This commit is contained in:
Pere Miquel Brull 2021-11-08 07:44:15 +01:00 committed by GitHub
parent ff7fe1dd41
commit 7771e42e1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 270 additions and 22 deletions

View File

@ -44,9 +44,9 @@ import static org.openmetadata.catalog.exception.CatalogExceptionMessage.entityN
public class ModelRepository extends EntityRepository<Model> {
private static final Logger LOG = LoggerFactory.getLogger(ModelRepository.class);
private static final Fields MODEL_UPDATE_FIELDS = new Fields(ModelResource.FIELD_LIST,
"owner,dashboard,tags");
"owner,dashboard,mlHyperParameters,mlFeatures,tags");
private static final Fields MODEL_PATCH_FIELDS = new Fields(ModelResource.FIELD_LIST,
"owner,dashboard,tags");
"owner,dashboard,mlHyperParameters,mlFeatures,tags");
private final CollectionDAO dao;
public ModelRepository(CollectionDAO dao) {
@ -80,6 +80,8 @@ public class ModelRepository extends EntityRepository<Model> {
model.setDisplayName(model.getDisplayName());
model.setOwner(fields.contains("owner") ? getOwner(model) : null);
model.setDashboard(fields.contains("dashboard") ? getDashboard(model) : null);
model.setMlFeatures(fields.contains("mlFeatures") ? model.getMlFeatures(): null);
model.setMlHyperParameters(fields.contains("mlHyperParameters") ? model.getMlHyperParameters(): null);
model.setFollowers(fields.contains("followers") ? getFollowers(model) : null);
model.setTags(fields.contains("tags") ? getTags(model.getFullyQualifiedName()) : null);
model.setUsageSummary(fields.contains("usageSummary") ? EntityUtil.getLatestUsage(dao.usageDAO(),
@ -300,12 +302,22 @@ public class ModelRepository extends EntityRepository<Model> {
public void entitySpecificUpdate() throws IOException {
updateAlgorithm(original.getEntity(), updated.getEntity());
updateDashboard(original.getEntity(), updated.getEntity());
updateMlFeatures(original.getEntity(), updated.getEntity());
updateMlHyperParameters(original.getEntity(), updated.getEntity());
}
private void updateAlgorithm(Model origModel, Model updatedModel) throws JsonProcessingException {
recordChange("algorithm", origModel.getAlgorithm(), updatedModel.getAlgorithm());
}
private void updateMlFeatures(Model origModel, Model updatedModel) throws JsonProcessingException {
recordChange("mlFeatures", origModel.getMlFeatures(), updatedModel.getMlFeatures());
}
private void updateMlHyperParameters(Model origModel, Model updatedModel) throws JsonProcessingException {
recordChange("mlHyperParameters", origModel.getMlHyperParameters(), updatedModel.getMlHyperParameters());
}
private void updateDashboard(Model origModel, Model updatedModel) throws JsonProcessingException {
// Remove existing dashboards
removeDashboard(origModel);

View File

@ -118,7 +118,7 @@ public class ModelResource {
}
}
static final String FIELDS = "owner,dashboard,algorithm,followers,tags,usageSummary";
static final String FIELDS = "owner,dashboard,algorithm,mlFeatures,mlHyperParameters,followers,tags,usageSummary";
public static final List<String> FIELD_LIST = Arrays.asList(FIELDS.replaceAll(" ", "")
.split(","));
@ -288,7 +288,7 @@ public class ModelResource {
@DELETE
@Path("/{id}/followers/{userId}")
@Operation(summary = "Remove a follower", tags = "model",
@Operation(summary = "Remove a follower", tags = "models",
description = "Remove the user identified `userId` as a follower of the model.")
public Model deleteFollower(@Context UriInfo uriInfo,
@Context SecurityContext securityContext,
@ -306,7 +306,7 @@ public class ModelResource {
@DELETE
@Path("/{id}")
@Operation(summary = "Delete a Model", tags = "model",
@Operation(summary = "Delete a Model", tags = "models",
description = "Delete a model by `id`.",
responses = {
@ApiResponse(responseCode = "200", description = "OK"),
@ -321,8 +321,10 @@ public class ModelResource {
return new Model().withId(UUID.randomUUID()).withName(create.getName())
.withDisplayName(create.getDisplayName())
.withDescription(create.getDescription())
.withDashboard(create.getDashboard()) //ADDED
.withAlgorithm(create.getAlgorithm()) //ADDED
.withDashboard(create.getDashboard())
.withAlgorithm(create.getAlgorithm())
.withMlFeatures(create.getMlFeatures())
.withMlHyperParameters(create.getMlHyperParameters())
.withTags(create.getTags())
.withOwner(create.getOwner())
.withUpdatedBy(securityContext.getUserPrincipal().getName())

View File

@ -23,6 +23,22 @@
"description": "Algorithm used to train the model",
"type": "string"
},
"mlFeatures": {
"description": "Features used to train the ML Model.",
"type": "array",
"items": {
"$ref": "../../entity/data/model.json#/definitions/mlFeature"
},
"default" : null
},
"mlHyperParameters": {
"description": "Hyper Parameters used to train the ML Model.",
"type": "array",
"items": {
"$ref": "../../entity/data/model.json#/definitions/mlHyperParameter"
},
"default" : null
},
"dashboard" : {
"description": "Performance Dashboard URL to track metric evolution",
"$ref" : "../../type/entityReference.json"

View File

@ -4,7 +4,158 @@
"title": "Model",
"description": "This schema defines the Model entity. Models are algorithms trained on data to find patterns or make predictions.",
"type": "object",
"definitions": {
"featureType": {
"javaType": "org.openmetadata.catalog.type.MLFeatureDataType",
"description": "This enum defines the type of data stored in a ML Feature.",
"type": "string",
"enum": [
"numerical",
"categorical"
],
"javaEnums": [
{
"name": "Numerical"
},
{
"name": "Categorical"
}
]
},
"featureSourceDataType": {
"javaType": "org.openmetadata.catalog.type.FeatureSourceDataType",
"description": "This enum defines the type of data of a ML Feature source.",
"type": "string",
"enum": [
"integer",
"number",
"string",
"array",
"date",
"timestamp",
"object",
"boolean"
]
},
"featureName": {
"description": "Local name (not fully qualified name) of the ML Feature.",
"type": "string",
"minLength": 1,
"maxLength": 64,
"pattern": "^[^.]*$"
},
"featureSourceName": {
"description": "Local name (not fully qualified name) of a ML Feature source",
"type": "string",
"minLength": 1,
"maxLength": 64,
"pattern": "^[^.]*$"
},
"fullyQualifiedFeatureSourceName": {
"description": "Fully qualified name of the ML Feature Source that includes `serviceName.[databaseName].tableName/fileName/apiName.columnName[.nestedColumnName]`.",
"type": "string",
"minLength": 1,
"maxLength": 256
},
"fullyQualifiedFeatureName": {
"description": "Fully qualified name of the ML Feature that includes `modelName.featureName`.",
"type": "string",
"minLength": 1,
"maxLength": 256
},
"featureSource": {
"type": "object",
"javaType": "org.openmetadata.catalog.type.MLFeatureSource",
"description": "This schema defines the sources of a ML Feature.",
"additionalProperties": false,
"properties": {
"name": {
"$ref": "#/definitions/featureSourceName"
},
"dataType": {
"description": "Data type of the source (int, date etc.).",
"$ref": "#/definitions/featureSourceDataType"
},
"description": {
"description": "Description of the feature source.",
"type": "string"
},
"fullyQualifiedName": {
"$ref": "#/definitions/fullyQualifiedFeatureSourceName"
},
"tags": {
"description": "Tags associated with the feature source.",
"type": "array",
"items": {
"$ref": "../../type/tagLabel.json"
},
"default": null
}
}
},
"mlFeature": {
"type": "object",
"javaType": "org.openmetadata.catalog.type.MLFeature",
"description": "This schema defines the type for a ML Feature used in a Model.",
"additionalProperties": false,
"properties": {
"name": {
"$ref": "#/definitions/featureName"
},
"dataType": {
"description": "Data type of the column (numerical vs. categorical).",
"$ref": "#/definitions/featureType"
},
"description": {
"description": "Description of the ML Feature.",
"type": "string"
},
"fullyQualifiedName": {
"$ref": "#/definitions/fullyQualifiedFeatureName"
},
"featureSources": {
"description": "Columns used to create the ML Feature",
"type": "array",
"items": {
"$ref": "#/definitions/featureSource"
},
"default": null
},
"featureAlgorithm": {
"description": "Description of the algorithm used to compute the feature, e.g., PCA, bucketing...",
"type": "string"
},
"tags": {
"description": "Tags associated with the feature.",
"type": "array",
"items": {
"$ref": "../../type/tagLabel.json"
},
"default": null
}
}
},
"mlHyperParameter": {
"type": "object",
"javaType": "org.openmetadata.catalog.type.MLHyperParameter",
"description": "This schema defines the type for a ML HyperParameter used in a Model.",
"additionalProperties": false,
"properties": {
"name": {
"description": "Hyper parameter name",
"type": "string"
},
"value": {
"description": "Hyper parameter value",
"type": "string"
},
"description": {
"description": "Description of the Hyper Parameter.",
"type": "string"
}
}
}
},
"properties" : {
"id": {
"description": "Unique identifier of a model instance.",
@ -26,18 +177,6 @@
"description": "Display Name that identifies this model.",
"type": "string"
},
"version" : {
"description": "Metadata version of the entity.",
"$ref": "../../type/entityHistory.json#/definitions/entityVersion"
},
"updatedAt" : {
"description": "Last update time corresponding to the new version of the entity.",
"$ref": "../../type/basic.json#/definitions/dateTime"
},
"updatedBy" : {
"description": "User who made the update.",
"type": "string"
},
"description": {
"description": "Description of the model, what it is, and how to use it.",
"type": "string"
@ -46,6 +185,22 @@
"description": "Algorithm used to train the model.",
"type": "string"
},
"mlFeatures": {
"description": "Features used to train the ML Model.",
"type": "array",
"items": {
"$ref": "#/definitions/mlFeature"
},
"default": null
},
"mlHyperParameters": {
"description": "Hyper Parameters used to train the ML Model.",
"type": "array",
"items": {
"$ref": "#/definitions/mlHyperParameter"
},
"default": null
},
"dashboard" : {
"description": "Performance Dashboard URL to track metric evolution.",
"$ref" : "../../type/entityReference.json"
@ -75,6 +230,18 @@
"$ref": "../../type/usageDetails.json",
"default": null
},
"version" : {
"description": "Metadata version of the entity.",
"$ref": "../../type/entityHistory.json#/definitions/entityVersion"
},
"updatedAt" : {
"description": "Last update time corresponding to the new version of the entity.",
"$ref": "../../type/basic.json#/definitions/dateTime"
},
"updatedBy" : {
"description": "User who made the update.",
"type": "string"
},
"changeDescription": {
"description" : "Change that lead to this version of the entity.",
"$ref": "../../type/entityHistory.json#/definitions/changeDescription"

View File

@ -27,6 +27,11 @@ import org.openmetadata.catalog.api.services.CreateDashboardService;
import org.openmetadata.catalog.api.services.CreateDashboardService.DashboardServiceType;
import org.openmetadata.catalog.entity.data.Dashboard;
import org.openmetadata.catalog.entity.data.Model;
import org.openmetadata.catalog.type.FeatureSourceDataType;
import org.openmetadata.catalog.type.MLFeature;
import org.openmetadata.catalog.type.MLFeatureDataType;
import org.openmetadata.catalog.type.MLFeatureSource;
import org.openmetadata.catalog.type.MLHyperParameter;
import org.openmetadata.catalog.entity.services.DashboardService;
import org.openmetadata.catalog.entity.teams.Team;
import org.openmetadata.catalog.entity.teams.User;
@ -47,6 +52,8 @@ import org.slf4j.LoggerFactory;
import javax.ws.rs.client.WebTarget;
import javax.ws.rs.core.Response.Status;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.UUID;
@ -79,6 +86,39 @@ public class ModelResourceTest extends CatalogApplicationTest {
public static EntityReference SUPERSET_REFERENCE;
public static Dashboard DASHBOARD;
public static EntityReference DASHBOARD_REFERENCE;
public static List<MLFeature> ML_FEATURES = Arrays.asList(
new MLFeature()
.withName("age")
.withDataType(MLFeatureDataType.Numerical)
.withFeatureSources(
Collections.singletonList(
new MLFeatureSource()
.withName("age")
.withDataType(FeatureSourceDataType.INTEGER)
.withFullyQualifiedName("my_service.my_db.my_table.age")
)
),
new MLFeature()
.withName("persona")
.withDataType(MLFeatureDataType.Categorical)
.withFeatureSources(
Arrays.asList(
new MLFeatureSource()
.withName("age")
.withDataType(FeatureSourceDataType.INTEGER)
.withFullyQualifiedName("my_service.my_db.my_table.age"),
new MLFeatureSource()
.withName("education")
.withDataType(FeatureSourceDataType.STRING)
.withFullyQualifiedName("my_api.education")
)
)
.withFeatureAlgorithm("PCA")
);
public static List<MLHyperParameter> ML_HYPERPARAMS = Arrays.asList(
new MLHyperParameter().withName("regularisation").withValue("0.5"),
new MLHyperParameter().withName("random").withValue("hello")
);
@BeforeAll
@ -465,6 +505,15 @@ public class ModelResourceTest extends CatalogApplicationTest {
assertNotNull(model.getAlgorithm()); // Provided as default field
assertNull(model.getDashboard());
// .../models?fields=mlFeatures,mlHyperParameters
fields = "mlFeatures,mlHyperParameters";
model = byName ? getModelByName(model.getFullyQualifiedName(), fields, adminAuthHeaders()) :
getModel(model.getId(), fields, adminAuthHeaders());
assertNotNull(model.getAlgorithm()); // Provided as default field
assertNotNull(model.getMlFeatures());
assertNotNull(model.getMlHyperParameters());
assertNull(model.getDashboard());
// .../models?fields=owner,algorithm
fields = "owner,algorithm";
model = byName ? getModelByName(model.getFullyQualifiedName(), fields, adminAuthHeaders()) :
@ -577,11 +626,13 @@ public class ModelResourceTest extends CatalogApplicationTest {
}
public static CreateModel create(TestInfo test) {
return new CreateModel().withName(getModelName(test)).withAlgorithm(ALGORITHM);
return new CreateModel().withName(getModelName(test)).withAlgorithm(ALGORITHM)
.withMlFeatures(ML_FEATURES).withMlHyperParameters(ML_HYPERPARAMS);
}
public static CreateModel create(TestInfo test, int index) {
return new CreateModel().withName(getModelName(test, index)).withAlgorithm(ALGORITHM);
return new CreateModel().withName(getModelName(test, index)).withAlgorithm(ALGORITHM)
.withMlFeatures(ML_FEATURES).withMlHyperParameters(ML_HYPERPARAMS);
}
}