mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-08-11 02:26:49 +00:00
[Issue-801] - MLFeature & MLHyperParameters (#1071)
* Update model schema * Update create model schema * Disable additional params * Update Model Resource & Repository * Add tests for MLFeatures and MLHyperParameters
This commit is contained in:
parent
ff7fe1dd41
commit
7771e42e1b
@ -44,9 +44,9 @@ import static org.openmetadata.catalog.exception.CatalogExceptionMessage.entityN
|
||||
public class ModelRepository extends EntityRepository<Model> {
|
||||
private static final Logger LOG = LoggerFactory.getLogger(ModelRepository.class);
|
||||
private static final Fields MODEL_UPDATE_FIELDS = new Fields(ModelResource.FIELD_LIST,
|
||||
"owner,dashboard,tags");
|
||||
"owner,dashboard,mlHyperParameters,mlFeatures,tags");
|
||||
private static final Fields MODEL_PATCH_FIELDS = new Fields(ModelResource.FIELD_LIST,
|
||||
"owner,dashboard,tags");
|
||||
"owner,dashboard,mlHyperParameters,mlFeatures,tags");
|
||||
private final CollectionDAO dao;
|
||||
|
||||
public ModelRepository(CollectionDAO dao) {
|
||||
@ -80,6 +80,8 @@ public class ModelRepository extends EntityRepository<Model> {
|
||||
model.setDisplayName(model.getDisplayName());
|
||||
model.setOwner(fields.contains("owner") ? getOwner(model) : null);
|
||||
model.setDashboard(fields.contains("dashboard") ? getDashboard(model) : null);
|
||||
model.setMlFeatures(fields.contains("mlFeatures") ? model.getMlFeatures(): null);
|
||||
model.setMlHyperParameters(fields.contains("mlHyperParameters") ? model.getMlHyperParameters(): null);
|
||||
model.setFollowers(fields.contains("followers") ? getFollowers(model) : null);
|
||||
model.setTags(fields.contains("tags") ? getTags(model.getFullyQualifiedName()) : null);
|
||||
model.setUsageSummary(fields.contains("usageSummary") ? EntityUtil.getLatestUsage(dao.usageDAO(),
|
||||
@ -300,12 +302,22 @@ public class ModelRepository extends EntityRepository<Model> {
|
||||
public void entitySpecificUpdate() throws IOException {
|
||||
updateAlgorithm(original.getEntity(), updated.getEntity());
|
||||
updateDashboard(original.getEntity(), updated.getEntity());
|
||||
updateMlFeatures(original.getEntity(), updated.getEntity());
|
||||
updateMlHyperParameters(original.getEntity(), updated.getEntity());
|
||||
}
|
||||
|
||||
private void updateAlgorithm(Model origModel, Model updatedModel) throws JsonProcessingException {
|
||||
recordChange("algorithm", origModel.getAlgorithm(), updatedModel.getAlgorithm());
|
||||
}
|
||||
|
||||
private void updateMlFeatures(Model origModel, Model updatedModel) throws JsonProcessingException {
|
||||
recordChange("mlFeatures", origModel.getMlFeatures(), updatedModel.getMlFeatures());
|
||||
}
|
||||
|
||||
private void updateMlHyperParameters(Model origModel, Model updatedModel) throws JsonProcessingException {
|
||||
recordChange("mlHyperParameters", origModel.getMlHyperParameters(), updatedModel.getMlHyperParameters());
|
||||
}
|
||||
|
||||
private void updateDashboard(Model origModel, Model updatedModel) throws JsonProcessingException {
|
||||
// Remove existing dashboards
|
||||
removeDashboard(origModel);
|
||||
|
@ -118,7 +118,7 @@ public class ModelResource {
|
||||
}
|
||||
}
|
||||
|
||||
static final String FIELDS = "owner,dashboard,algorithm,followers,tags,usageSummary";
|
||||
static final String FIELDS = "owner,dashboard,algorithm,mlFeatures,mlHyperParameters,followers,tags,usageSummary";
|
||||
public static final List<String> FIELD_LIST = Arrays.asList(FIELDS.replaceAll(" ", "")
|
||||
.split(","));
|
||||
|
||||
@ -288,7 +288,7 @@ public class ModelResource {
|
||||
|
||||
@DELETE
|
||||
@Path("/{id}/followers/{userId}")
|
||||
@Operation(summary = "Remove a follower", tags = "model",
|
||||
@Operation(summary = "Remove a follower", tags = "models",
|
||||
description = "Remove the user identified `userId` as a follower of the model.")
|
||||
public Model deleteFollower(@Context UriInfo uriInfo,
|
||||
@Context SecurityContext securityContext,
|
||||
@ -306,7 +306,7 @@ public class ModelResource {
|
||||
|
||||
@DELETE
|
||||
@Path("/{id}")
|
||||
@Operation(summary = "Delete a Model", tags = "model",
|
||||
@Operation(summary = "Delete a Model", tags = "models",
|
||||
description = "Delete a model by `id`.",
|
||||
responses = {
|
||||
@ApiResponse(responseCode = "200", description = "OK"),
|
||||
@ -321,8 +321,10 @@ public class ModelResource {
|
||||
return new Model().withId(UUID.randomUUID()).withName(create.getName())
|
||||
.withDisplayName(create.getDisplayName())
|
||||
.withDescription(create.getDescription())
|
||||
.withDashboard(create.getDashboard()) //ADDED
|
||||
.withAlgorithm(create.getAlgorithm()) //ADDED
|
||||
.withDashboard(create.getDashboard())
|
||||
.withAlgorithm(create.getAlgorithm())
|
||||
.withMlFeatures(create.getMlFeatures())
|
||||
.withMlHyperParameters(create.getMlHyperParameters())
|
||||
.withTags(create.getTags())
|
||||
.withOwner(create.getOwner())
|
||||
.withUpdatedBy(securityContext.getUserPrincipal().getName())
|
||||
|
@ -23,6 +23,22 @@
|
||||
"description": "Algorithm used to train the model",
|
||||
"type": "string"
|
||||
},
|
||||
"mlFeatures": {
|
||||
"description": "Features used to train the ML Model.",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "../../entity/data/model.json#/definitions/mlFeature"
|
||||
},
|
||||
"default" : null
|
||||
},
|
||||
"mlHyperParameters": {
|
||||
"description": "Hyper Parameters used to train the ML Model.",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "../../entity/data/model.json#/definitions/mlHyperParameter"
|
||||
},
|
||||
"default" : null
|
||||
},
|
||||
"dashboard" : {
|
||||
"description": "Performance Dashboard URL to track metric evolution",
|
||||
"$ref" : "../../type/entityReference.json"
|
||||
|
@ -4,7 +4,158 @@
|
||||
"title": "Model",
|
||||
"description": "This schema defines the Model entity. Models are algorithms trained on data to find patterns or make predictions.",
|
||||
"type": "object",
|
||||
|
||||
"definitions": {
|
||||
"featureType": {
|
||||
"javaType": "org.openmetadata.catalog.type.MLFeatureDataType",
|
||||
"description": "This enum defines the type of data stored in a ML Feature.",
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"numerical",
|
||||
"categorical"
|
||||
],
|
||||
"javaEnums": [
|
||||
{
|
||||
"name": "Numerical"
|
||||
},
|
||||
{
|
||||
"name": "Categorical"
|
||||
}
|
||||
]
|
||||
},
|
||||
"featureSourceDataType": {
|
||||
"javaType": "org.openmetadata.catalog.type.FeatureSourceDataType",
|
||||
"description": "This enum defines the type of data of a ML Feature source.",
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"integer",
|
||||
"number",
|
||||
"string",
|
||||
"array",
|
||||
"date",
|
||||
"timestamp",
|
||||
"object",
|
||||
"boolean"
|
||||
]
|
||||
},
|
||||
"featureName": {
|
||||
"description": "Local name (not fully qualified name) of the ML Feature.",
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"maxLength": 64,
|
||||
"pattern": "^[^.]*$"
|
||||
},
|
||||
"featureSourceName": {
|
||||
"description": "Local name (not fully qualified name) of a ML Feature source",
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"maxLength": 64,
|
||||
"pattern": "^[^.]*$"
|
||||
},
|
||||
"fullyQualifiedFeatureSourceName": {
|
||||
"description": "Fully qualified name of the ML Feature Source that includes `serviceName.[databaseName].tableName/fileName/apiName.columnName[.nestedColumnName]`.",
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"maxLength": 256
|
||||
},
|
||||
"fullyQualifiedFeatureName": {
|
||||
"description": "Fully qualified name of the ML Feature that includes `modelName.featureName`.",
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"maxLength": 256
|
||||
},
|
||||
"featureSource": {
|
||||
"type": "object",
|
||||
"javaType": "org.openmetadata.catalog.type.MLFeatureSource",
|
||||
"description": "This schema defines the sources of a ML Feature.",
|
||||
"additionalProperties": false,
|
||||
"properties": {
|
||||
"name": {
|
||||
"$ref": "#/definitions/featureSourceName"
|
||||
},
|
||||
"dataType": {
|
||||
"description": "Data type of the source (int, date etc.).",
|
||||
"$ref": "#/definitions/featureSourceDataType"
|
||||
},
|
||||
"description": {
|
||||
"description": "Description of the feature source.",
|
||||
"type": "string"
|
||||
},
|
||||
"fullyQualifiedName": {
|
||||
"$ref": "#/definitions/fullyQualifiedFeatureSourceName"
|
||||
},
|
||||
"tags": {
|
||||
"description": "Tags associated with the feature source.",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "../../type/tagLabel.json"
|
||||
},
|
||||
"default": null
|
||||
}
|
||||
}
|
||||
},
|
||||
"mlFeature": {
|
||||
"type": "object",
|
||||
"javaType": "org.openmetadata.catalog.type.MLFeature",
|
||||
"description": "This schema defines the type for a ML Feature used in a Model.",
|
||||
"additionalProperties": false,
|
||||
"properties": {
|
||||
"name": {
|
||||
"$ref": "#/definitions/featureName"
|
||||
},
|
||||
"dataType": {
|
||||
"description": "Data type of the column (numerical vs. categorical).",
|
||||
"$ref": "#/definitions/featureType"
|
||||
},
|
||||
"description": {
|
||||
"description": "Description of the ML Feature.",
|
||||
"type": "string"
|
||||
},
|
||||
"fullyQualifiedName": {
|
||||
"$ref": "#/definitions/fullyQualifiedFeatureName"
|
||||
},
|
||||
"featureSources": {
|
||||
"description": "Columns used to create the ML Feature",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/featureSource"
|
||||
},
|
||||
"default": null
|
||||
},
|
||||
"featureAlgorithm": {
|
||||
"description": "Description of the algorithm used to compute the feature, e.g., PCA, bucketing...",
|
||||
"type": "string"
|
||||
},
|
||||
"tags": {
|
||||
"description": "Tags associated with the feature.",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "../../type/tagLabel.json"
|
||||
},
|
||||
"default": null
|
||||
}
|
||||
}
|
||||
},
|
||||
"mlHyperParameter": {
|
||||
"type": "object",
|
||||
"javaType": "org.openmetadata.catalog.type.MLHyperParameter",
|
||||
"description": "This schema defines the type for a ML HyperParameter used in a Model.",
|
||||
"additionalProperties": false,
|
||||
"properties": {
|
||||
"name": {
|
||||
"description": "Hyper parameter name",
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"description": "Hyper parameter value",
|
||||
"type": "string"
|
||||
},
|
||||
"description": {
|
||||
"description": "Description of the Hyper Parameter.",
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"properties" : {
|
||||
"id": {
|
||||
"description": "Unique identifier of a model instance.",
|
||||
@ -26,18 +177,6 @@
|
||||
"description": "Display Name that identifies this model.",
|
||||
"type": "string"
|
||||
},
|
||||
"version" : {
|
||||
"description": "Metadata version of the entity.",
|
||||
"$ref": "../../type/entityHistory.json#/definitions/entityVersion"
|
||||
},
|
||||
"updatedAt" : {
|
||||
"description": "Last update time corresponding to the new version of the entity.",
|
||||
"$ref": "../../type/basic.json#/definitions/dateTime"
|
||||
},
|
||||
"updatedBy" : {
|
||||
"description": "User who made the update.",
|
||||
"type": "string"
|
||||
},
|
||||
"description": {
|
||||
"description": "Description of the model, what it is, and how to use it.",
|
||||
"type": "string"
|
||||
@ -46,6 +185,22 @@
|
||||
"description": "Algorithm used to train the model.",
|
||||
"type": "string"
|
||||
},
|
||||
"mlFeatures": {
|
||||
"description": "Features used to train the ML Model.",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/mlFeature"
|
||||
},
|
||||
"default": null
|
||||
},
|
||||
"mlHyperParameters": {
|
||||
"description": "Hyper Parameters used to train the ML Model.",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/mlHyperParameter"
|
||||
},
|
||||
"default": null
|
||||
},
|
||||
"dashboard" : {
|
||||
"description": "Performance Dashboard URL to track metric evolution.",
|
||||
"$ref" : "../../type/entityReference.json"
|
||||
@ -75,6 +230,18 @@
|
||||
"$ref": "../../type/usageDetails.json",
|
||||
"default": null
|
||||
},
|
||||
"version" : {
|
||||
"description": "Metadata version of the entity.",
|
||||
"$ref": "../../type/entityHistory.json#/definitions/entityVersion"
|
||||
},
|
||||
"updatedAt" : {
|
||||
"description": "Last update time corresponding to the new version of the entity.",
|
||||
"$ref": "../../type/basic.json#/definitions/dateTime"
|
||||
},
|
||||
"updatedBy" : {
|
||||
"description": "User who made the update.",
|
||||
"type": "string"
|
||||
},
|
||||
"changeDescription": {
|
||||
"description" : "Change that lead to this version of the entity.",
|
||||
"$ref": "../../type/entityHistory.json#/definitions/changeDescription"
|
||||
|
@ -27,6 +27,11 @@ import org.openmetadata.catalog.api.services.CreateDashboardService;
|
||||
import org.openmetadata.catalog.api.services.CreateDashboardService.DashboardServiceType;
|
||||
import org.openmetadata.catalog.entity.data.Dashboard;
|
||||
import org.openmetadata.catalog.entity.data.Model;
|
||||
import org.openmetadata.catalog.type.FeatureSourceDataType;
|
||||
import org.openmetadata.catalog.type.MLFeature;
|
||||
import org.openmetadata.catalog.type.MLFeatureDataType;
|
||||
import org.openmetadata.catalog.type.MLFeatureSource;
|
||||
import org.openmetadata.catalog.type.MLHyperParameter;
|
||||
import org.openmetadata.catalog.entity.services.DashboardService;
|
||||
import org.openmetadata.catalog.entity.teams.Team;
|
||||
import org.openmetadata.catalog.entity.teams.User;
|
||||
@ -47,6 +52,8 @@ import org.slf4j.LoggerFactory;
|
||||
|
||||
import javax.ws.rs.client.WebTarget;
|
||||
import javax.ws.rs.core.Response.Status;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
@ -79,6 +86,39 @@ public class ModelResourceTest extends CatalogApplicationTest {
|
||||
public static EntityReference SUPERSET_REFERENCE;
|
||||
public static Dashboard DASHBOARD;
|
||||
public static EntityReference DASHBOARD_REFERENCE;
|
||||
public static List<MLFeature> ML_FEATURES = Arrays.asList(
|
||||
new MLFeature()
|
||||
.withName("age")
|
||||
.withDataType(MLFeatureDataType.Numerical)
|
||||
.withFeatureSources(
|
||||
Collections.singletonList(
|
||||
new MLFeatureSource()
|
||||
.withName("age")
|
||||
.withDataType(FeatureSourceDataType.INTEGER)
|
||||
.withFullyQualifiedName("my_service.my_db.my_table.age")
|
||||
)
|
||||
),
|
||||
new MLFeature()
|
||||
.withName("persona")
|
||||
.withDataType(MLFeatureDataType.Categorical)
|
||||
.withFeatureSources(
|
||||
Arrays.asList(
|
||||
new MLFeatureSource()
|
||||
.withName("age")
|
||||
.withDataType(FeatureSourceDataType.INTEGER)
|
||||
.withFullyQualifiedName("my_service.my_db.my_table.age"),
|
||||
new MLFeatureSource()
|
||||
.withName("education")
|
||||
.withDataType(FeatureSourceDataType.STRING)
|
||||
.withFullyQualifiedName("my_api.education")
|
||||
)
|
||||
)
|
||||
.withFeatureAlgorithm("PCA")
|
||||
);
|
||||
public static List<MLHyperParameter> ML_HYPERPARAMS = Arrays.asList(
|
||||
new MLHyperParameter().withName("regularisation").withValue("0.5"),
|
||||
new MLHyperParameter().withName("random").withValue("hello")
|
||||
);
|
||||
|
||||
|
||||
@BeforeAll
|
||||
@ -465,6 +505,15 @@ public class ModelResourceTest extends CatalogApplicationTest {
|
||||
assertNotNull(model.getAlgorithm()); // Provided as default field
|
||||
assertNull(model.getDashboard());
|
||||
|
||||
// .../models?fields=mlFeatures,mlHyperParameters
|
||||
fields = "mlFeatures,mlHyperParameters";
|
||||
model = byName ? getModelByName(model.getFullyQualifiedName(), fields, adminAuthHeaders()) :
|
||||
getModel(model.getId(), fields, adminAuthHeaders());
|
||||
assertNotNull(model.getAlgorithm()); // Provided as default field
|
||||
assertNotNull(model.getMlFeatures());
|
||||
assertNotNull(model.getMlHyperParameters());
|
||||
assertNull(model.getDashboard());
|
||||
|
||||
// .../models?fields=owner,algorithm
|
||||
fields = "owner,algorithm";
|
||||
model = byName ? getModelByName(model.getFullyQualifiedName(), fields, adminAuthHeaders()) :
|
||||
@ -577,11 +626,13 @@ public class ModelResourceTest extends CatalogApplicationTest {
|
||||
}
|
||||
|
||||
public static CreateModel create(TestInfo test) {
|
||||
return new CreateModel().withName(getModelName(test)).withAlgorithm(ALGORITHM);
|
||||
return new CreateModel().withName(getModelName(test)).withAlgorithm(ALGORITHM)
|
||||
.withMlFeatures(ML_FEATURES).withMlHyperParameters(ML_HYPERPARAMS);
|
||||
}
|
||||
|
||||
public static CreateModel create(TestInfo test, int index) {
|
||||
return new CreateModel().withName(getModelName(test, index)).withAlgorithm(ALGORITHM);
|
||||
return new CreateModel().withName(getModelName(test, index)).withAlgorithm(ALGORITHM)
|
||||
.withMlFeatures(ML_FEATURES).withMlHyperParameters(ML_HYPERPARAMS);
|
||||
}
|
||||
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user