From c2efcb110762efc0b69ea9da66064d34970a53e3 Mon Sep 17 00:00:00 2001 From: Pere Miquel Brull Date: Mon, 29 Nov 2021 19:08:00 +0100 Subject: [PATCH] [issue-1284] - Add MlStore and Server to MlModel (#1439) * Add mlStore and server to MlModel entity * Add mlStore and server to Create MlModel entity * Add MlStore and Server to MlModel properties --- .../catalog/jdbi3/MlModelRepository.java | 14 ++++- .../resources/mlmodels/MlModelResource.java | 5 +- .../json/schema/api/data/createMlModel.json | 8 +++ .../json/schema/entity/data/mlmodel.json | 24 ++++++++ .../mlmodels/MlModelResourceTest.java | 61 +++++++++++++++++-- 5 files changed, 105 insertions(+), 7 deletions(-) diff --git a/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/MlModelRepository.java b/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/MlModelRepository.java index 0d8ac7156c6..9769617151c 100644 --- a/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/MlModelRepository.java +++ b/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/MlModelRepository.java @@ -43,7 +43,7 @@ import java.util.UUID; public class MlModelRepository extends EntityRepository { private static final Logger LOG = LoggerFactory.getLogger(MlModelRepository.class); private static final Fields MODEL_UPDATE_FIELDS = new Fields(MlModelResource.FIELD_LIST, - "owner,algorithm,dashboard,mlHyperParameters,mlFeatures,tags"); + "owner,algorithm,dashboard,mlHyperParameters,mlFeatures,mlStore,server,tags"); private static final Fields MODEL_PATCH_FIELDS = new Fields(MlModelResource.FIELD_LIST, "owner,algorithm,dashboard,mlHyperParameters,mlFeatures,tags"); private final CollectionDAO dao; @@ -77,6 +77,8 @@ public class MlModelRepository extends EntityRepository { mlModel.setDashboard(fields.contains("dashboard") ? getDashboard(mlModel) : null); mlModel.setMlFeatures(fields.contains("mlFeatures") ? mlModel.getMlFeatures(): null); mlModel.setMlHyperParameters(fields.contains("mlHyperParameters") ? mlModel.getMlHyperParameters(): null); + mlModel.setMlStore(fields.contains("mlStore") ? mlModel.getMlStore(): null); + mlModel.setServer(fields.contains("server") ? mlModel.getServer(): null); mlModel.setFollowers(fields.contains("followers") ? getFollowers(mlModel) : null); mlModel.setTags(fields.contains("tags") ? getTags(mlModel.getFullyQualifiedName()) : null); mlModel.setUsageSummary(fields.contains("usageSummary") ? EntityUtil.getLatestUsage(dao.usageDAO(), @@ -338,6 +340,8 @@ public class MlModelRepository extends EntityRepository { updateDashboard(origMlModel, updatedMlModel); updateMlFeatures(origMlModel, updatedMlModel); updateMlHyperParameters(origMlModel, updatedMlModel); + updateMlStore(origMlModel, updatedMlModel); + updateServer(origMlModel, updatedMlModel); } private void updateAlgorithm(MlModel origModel, MlModel updatedModel) throws JsonProcessingException { @@ -352,6 +356,14 @@ public class MlModelRepository extends EntityRepository { recordChange("mlHyperParameters", origModel.getMlHyperParameters(), updatedModel.getMlHyperParameters()); } + private void updateMlStore(MlModel origModel, MlModel updatedModel) throws JsonProcessingException { + recordChange("mlStore", origModel.getMlStore(), updatedModel.getMlStore(), true); + } + + private void updateServer(MlModel origModel, MlModel updatedModel) throws JsonProcessingException { + recordChange("server", origModel.getServer(), updatedModel.getServer()); + } + private void updateDashboard(MlModel origModel, MlModel updatedModel) throws JsonProcessingException { String modelId = updatedModel.getId().toString(); diff --git a/catalog-rest-service/src/main/java/org/openmetadata/catalog/resources/mlmodels/MlModelResource.java b/catalog-rest-service/src/main/java/org/openmetadata/catalog/resources/mlmodels/MlModelResource.java index eaa1917c4ca..6731802eb9f 100644 --- a/catalog-rest-service/src/main/java/org/openmetadata/catalog/resources/mlmodels/MlModelResource.java +++ b/catalog-rest-service/src/main/java/org/openmetadata/catalog/resources/mlmodels/MlModelResource.java @@ -110,7 +110,8 @@ public class MlModelResource { } } - static final String FIELDS = "owner,dashboard,algorithm,mlFeatures,mlHyperParameters,followers,tags,usageSummary"; + static final String FIELDS = "owner,dashboard,algorithm,mlFeatures,mlHyperParameters,mlStore,server," + + "followers,tags,usageSummary"; public static final List FIELD_LIST = Arrays.asList(FIELDS.replaceAll(" ", "") .split(",")); @@ -357,6 +358,8 @@ public class MlModelResource { .withAlgorithm(create.getAlgorithm()) .withMlFeatures(create.getMlFeatures()) .withMlHyperParameters(create.getMlHyperParameters()) + .withMlStore(create.getMlStore()) + .withServer(create.getServer()) .withTags(create.getTags()) .withOwner(create.getOwner()) .withUpdatedBy(securityContext.getUserPrincipal().getName()) diff --git a/catalog-rest-service/src/main/resources/json/schema/api/data/createMlModel.json b/catalog-rest-service/src/main/resources/json/schema/api/data/createMlModel.json index 20c5f095837..7c89c712bb7 100644 --- a/catalog-rest-service/src/main/resources/json/schema/api/data/createMlModel.json +++ b/catalog-rest-service/src/main/resources/json/schema/api/data/createMlModel.json @@ -43,6 +43,14 @@ "description": "Performance Dashboard URL to track metric evolution", "$ref" : "../../type/entityReference.json" }, + "mlStore": { + "description": "Location containing the ML Model. It can be a storage layer and/or a container repository.", + "$ref": "../../entity/data/mlmodel.json#/definitions/mlStore" + }, + "server": { + "description": "Endpoint that makes the ML Model available, e.g,. a REST API serving the data or computing predictions.", + "$ref": "../../type/basic.json#/definitions/href" + }, "tags": { "description": "Tags for this ML Model", "type": "array", diff --git a/catalog-rest-service/src/main/resources/json/schema/entity/data/mlmodel.json b/catalog-rest-service/src/main/resources/json/schema/entity/data/mlmodel.json index b6563df3f83..ff6a26ba0f4 100644 --- a/catalog-rest-service/src/main/resources/json/schema/entity/data/mlmodel.json +++ b/catalog-rest-service/src/main/resources/json/schema/entity/data/mlmodel.json @@ -155,6 +155,22 @@ "type": "string" } } + }, + "mlStore": { + "type": "object", + "javaType": "org.openmetadata.catalog.type.MlStore", + "description": "Location containing the ML Model. It can be a storage layer and/or a container repository.", + "additionalProperties": false, + "properties": { + "storage": { + "description": "Storage Layer containing the ML Model data.", + "$ref": "../../type/basic.json#/definitions/href" + }, + "imageRepository": { + "description": "Container Repository with the ML Model image.", + "$ref": "../../type/basic.json#/definitions/href" + } + } } }, "properties" : { @@ -206,6 +222,14 @@ "description": "Performance Dashboard URL to track metric evolution.", "$ref" : "../../type/entityReference.json" }, + "mlStore": { + "description": "Location containing the ML Model. It can be a storage layer and/or a container repository.", + "$ref": "#/definitions/mlStore" + }, + "server": { + "description": "Endpoint that makes the ML Model available, e.g,. a REST API serving the data or computing predictions.", + "$ref": "../../type/basic.json#/definitions/href" + }, "href": { "description": "Link to the resource corresponding to this entity.", "$ref": "../../type/basic.json#/definitions/href" diff --git a/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/mlmodels/MlModelResourceTest.java b/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/mlmodels/MlModelResourceTest.java index 8bc4db626dc..c3cc42313f8 100644 --- a/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/mlmodels/MlModelResourceTest.java +++ b/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/mlmodels/MlModelResourceTest.java @@ -36,6 +36,7 @@ import org.openmetadata.catalog.type.FeatureSourceDataType; import org.openmetadata.catalog.type.FieldChange; import org.openmetadata.catalog.type.MlFeature; import org.openmetadata.catalog.type.MlFeatureDataType; +import org.openmetadata.catalog.type.MlStore; import org.openmetadata.catalog.type.MlFeatureSource; import org.openmetadata.catalog.type.MlHyperParameter; import org.openmetadata.catalog.entity.services.DashboardService; @@ -52,6 +53,7 @@ import org.openmetadata.catalog.util.JsonUtils; import javax.ws.rs.client.WebTarget; import javax.ws.rs.core.Response.Status; import java.io.IOException; +import java.net.URI; import java.net.URISyntaxException; import java.util.Arrays; import java.util.Collections; @@ -83,6 +85,12 @@ public class MlModelResourceTest extends EntityResourceTest { public static String ALGORITHM = "regression"; public static Dashboard DASHBOARD; public static EntityReference DASHBOARD_REFERENCE; + + public static URI SERVER = URI.create("http://localhost.com/mlModel"); + public static MlStore ML_STORE = new MlStore() + .withStorage(URI.create("s3://my-bucket.com/mlModel")) + .withImageRepository(URI.create("https://12345.dkr.ecr.region.amazonaws.com")); + public static List ML_FEATURES = Arrays.asList( new MlFeature() .withName("age") @@ -196,6 +204,18 @@ public class MlModelResourceTest extends EntityResourceTest { createAndCheckEntity(create, adminAuthHeaders()); } + @Test + public void post_MlModelWitMlStore_200_ok(TestInfo test) throws IOException { + CreateMlModel create = create(test).withMlStore(ML_STORE); + createAndCheckEntity(create, adminAuthHeaders()); + } + + @Test + public void post_MlModelWitServer_200_ok(TestInfo test) throws IOException { + CreateMlModel create = create(test).withServer(SERVER); + createAndCheckEntity(create, adminAuthHeaders()); + } + @Test public void post_MlModel_as_non_admin_401(TestInfo test) { CreateMlModel create = create(test); @@ -257,6 +277,30 @@ public class MlModelResourceTest extends EntityResourceTest { ); } + @Test + public void put_MlModelAddServer_200(TestInfo test) throws IOException { + CreateMlModel request = create(test); + MlModel model = createAndCheckEntity(request, adminAuthHeaders()); + ChangeDescription change = getChangeDescription(model.getVersion()); + change.getFieldsAdded().add(new FieldChange().withName("server").withNewValue(SERVER)); + + updateAndCheckEntity( + request.withServer(SERVER), Status.OK, adminAuthHeaders(), MINOR_UPDATE, change + ); + } + + @Test + public void put_MlModelAddMlStore_200(TestInfo test) throws IOException { + CreateMlModel request = create(test); + MlModel model = createAndCheckEntity(request, adminAuthHeaders()); + ChangeDescription change = getChangeDescription(model.getVersion()); + change.getFieldsAdded().add(new FieldChange().withName("mlStore").withNewValue(ML_STORE)); + + updateAndCheckEntity( + request.withMlStore(ML_STORE), Status.OK, adminAuthHeaders(), MINOR_UPDATE, change + ); + } + @Test public void get_nonExistentMlModel_404_notFound() { HttpResponseException exception = assertThrows(HttpResponseException.class, () -> @@ -469,23 +513,30 @@ public class MlModelResourceTest extends EntityResourceTest { if (expected == actual) { return; } - if (fieldName.contains("mlFeatures") && !fieldName.endsWith("tags") && !fieldName.endsWith("description")) { + if (fieldName.contains("mlFeatures")) { List expectedFeatures = (List) expected; List actualFeatures = JsonUtils.readObjects(actual.toString(), MlFeature.class); assertEquals(expectedFeatures, actualFeatures); - } else if (fieldName.contains("mlHyperParameters") && !fieldName.endsWith("tags") - && !fieldName.endsWith("description")) { + } else if (fieldName.contains("mlHyperParameters")) { List expectedConstraints = (List) expected; List actualConstraints = JsonUtils.readObjects(actual.toString(), MlHyperParameter.class); assertEquals(expectedConstraints, actualConstraints); - } else if (fieldName.endsWith("algorithm")) { + } else if (fieldName.contains("algorithm")) { String expectedAlgorithm = (String) expected; String actualAlgorithm = actual.toString(); assertEquals(expectedAlgorithm, actualAlgorithm); - } else if (fieldName.endsWith("dashboard")) { + } else if (fieldName.contains("dashboard")) { EntityReference expectedDashboard = (EntityReference) expected; EntityReference actualDashboard = JsonUtils.readValue(actual.toString(), EntityReference.class); assertEquals(expectedDashboard, actualDashboard); + } else if (fieldName.contains("server")) { + URI expectedServer = (URI) expected; + URI actualServer = URI.create(actual.toString()); + assertEquals(expectedServer, actualServer); + } else if (fieldName.contains("mlStore")) { + MlStore expectedMlStore = (MlStore) expected; + MlStore actualMlStore = JsonUtils.readValue(actual.toString(), MlStore.class); + assertEquals(expectedMlStore, actualMlStore); } else { assertCommonFieldChange(fieldName, expected, actual); }