From c2efcb110762efc0b69ea9da66064d34970a53e3 Mon Sep 17 00:00:00 2001
From: Pere Miquel Brull
Date: Mon, 29 Nov 2021 19:08:00 +0100
Subject: [PATCH] [issue-1284] - Add MlStore and Server to MlModel (#1439)
* Add mlStore and server to MlModel entity
* Add mlStore and server to Create MlModel entity
* Add MlStore and Server to MlModel properties
---
.../catalog/jdbi3/MlModelRepository.java | 14 ++++-
.../resources/mlmodels/MlModelResource.java | 5 +-
.../json/schema/api/data/createMlModel.json | 8 +++
.../json/schema/entity/data/mlmodel.json | 24 ++++++++
.../mlmodels/MlModelResourceTest.java | 61 +++++++++++++++++--
5 files changed, 105 insertions(+), 7 deletions(-)
diff --git a/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/MlModelRepository.java b/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/MlModelRepository.java
index 0d8ac7156c6..9769617151c 100644
--- a/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/MlModelRepository.java
+++ b/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/MlModelRepository.java
@@ -43,7 +43,7 @@ import java.util.UUID;
public class MlModelRepository extends EntityRepository {
private static final Logger LOG = LoggerFactory.getLogger(MlModelRepository.class);
private static final Fields MODEL_UPDATE_FIELDS = new Fields(MlModelResource.FIELD_LIST,
- "owner,algorithm,dashboard,mlHyperParameters,mlFeatures,tags");
+ "owner,algorithm,dashboard,mlHyperParameters,mlFeatures,mlStore,server,tags");
private static final Fields MODEL_PATCH_FIELDS = new Fields(MlModelResource.FIELD_LIST,
"owner,algorithm,dashboard,mlHyperParameters,mlFeatures,tags");
private final CollectionDAO dao;
@@ -77,6 +77,8 @@ public class MlModelRepository extends EntityRepository {
mlModel.setDashboard(fields.contains("dashboard") ? getDashboard(mlModel) : null);
mlModel.setMlFeatures(fields.contains("mlFeatures") ? mlModel.getMlFeatures(): null);
mlModel.setMlHyperParameters(fields.contains("mlHyperParameters") ? mlModel.getMlHyperParameters(): null);
+ mlModel.setMlStore(fields.contains("mlStore") ? mlModel.getMlStore(): null);
+ mlModel.setServer(fields.contains("server") ? mlModel.getServer(): null);
mlModel.setFollowers(fields.contains("followers") ? getFollowers(mlModel) : null);
mlModel.setTags(fields.contains("tags") ? getTags(mlModel.getFullyQualifiedName()) : null);
mlModel.setUsageSummary(fields.contains("usageSummary") ? EntityUtil.getLatestUsage(dao.usageDAO(),
@@ -338,6 +340,8 @@ public class MlModelRepository extends EntityRepository {
updateDashboard(origMlModel, updatedMlModel);
updateMlFeatures(origMlModel, updatedMlModel);
updateMlHyperParameters(origMlModel, updatedMlModel);
+ updateMlStore(origMlModel, updatedMlModel);
+ updateServer(origMlModel, updatedMlModel);
}
private void updateAlgorithm(MlModel origModel, MlModel updatedModel) throws JsonProcessingException {
@@ -352,6 +356,14 @@ public class MlModelRepository extends EntityRepository {
recordChange("mlHyperParameters", origModel.getMlHyperParameters(), updatedModel.getMlHyperParameters());
}
+ private void updateMlStore(MlModel origModel, MlModel updatedModel) throws JsonProcessingException {
+ recordChange("mlStore", origModel.getMlStore(), updatedModel.getMlStore(), true);
+ }
+
+ private void updateServer(MlModel origModel, MlModel updatedModel) throws JsonProcessingException {
+ recordChange("server", origModel.getServer(), updatedModel.getServer());
+ }
+
private void updateDashboard(MlModel origModel, MlModel updatedModel) throws JsonProcessingException {
String modelId = updatedModel.getId().toString();
diff --git a/catalog-rest-service/src/main/java/org/openmetadata/catalog/resources/mlmodels/MlModelResource.java b/catalog-rest-service/src/main/java/org/openmetadata/catalog/resources/mlmodels/MlModelResource.java
index eaa1917c4ca..6731802eb9f 100644
--- a/catalog-rest-service/src/main/java/org/openmetadata/catalog/resources/mlmodels/MlModelResource.java
+++ b/catalog-rest-service/src/main/java/org/openmetadata/catalog/resources/mlmodels/MlModelResource.java
@@ -110,7 +110,8 @@ public class MlModelResource {
}
}
- static final String FIELDS = "owner,dashboard,algorithm,mlFeatures,mlHyperParameters,followers,tags,usageSummary";
+ static final String FIELDS = "owner,dashboard,algorithm,mlFeatures,mlHyperParameters,mlStore,server," +
+ "followers,tags,usageSummary";
public static final List FIELD_LIST = Arrays.asList(FIELDS.replaceAll(" ", "")
.split(","));
@@ -357,6 +358,8 @@ public class MlModelResource {
.withAlgorithm(create.getAlgorithm())
.withMlFeatures(create.getMlFeatures())
.withMlHyperParameters(create.getMlHyperParameters())
+ .withMlStore(create.getMlStore())
+ .withServer(create.getServer())
.withTags(create.getTags())
.withOwner(create.getOwner())
.withUpdatedBy(securityContext.getUserPrincipal().getName())
diff --git a/catalog-rest-service/src/main/resources/json/schema/api/data/createMlModel.json b/catalog-rest-service/src/main/resources/json/schema/api/data/createMlModel.json
index 20c5f095837..7c89c712bb7 100644
--- a/catalog-rest-service/src/main/resources/json/schema/api/data/createMlModel.json
+++ b/catalog-rest-service/src/main/resources/json/schema/api/data/createMlModel.json
@@ -43,6 +43,14 @@
"description": "Performance Dashboard URL to track metric evolution",
"$ref" : "../../type/entityReference.json"
},
+ "mlStore": {
+ "description": "Location containing the ML Model. It can be a storage layer and/or a container repository.",
+ "$ref": "../../entity/data/mlmodel.json#/definitions/mlStore"
+ },
+ "server": {
+ "description": "Endpoint that makes the ML Model available, e.g,. a REST API serving the data or computing predictions.",
+ "$ref": "../../type/basic.json#/definitions/href"
+ },
"tags": {
"description": "Tags for this ML Model",
"type": "array",
diff --git a/catalog-rest-service/src/main/resources/json/schema/entity/data/mlmodel.json b/catalog-rest-service/src/main/resources/json/schema/entity/data/mlmodel.json
index b6563df3f83..ff6a26ba0f4 100644
--- a/catalog-rest-service/src/main/resources/json/schema/entity/data/mlmodel.json
+++ b/catalog-rest-service/src/main/resources/json/schema/entity/data/mlmodel.json
@@ -155,6 +155,22 @@
"type": "string"
}
}
+ },
+ "mlStore": {
+ "type": "object",
+ "javaType": "org.openmetadata.catalog.type.MlStore",
+ "description": "Location containing the ML Model. It can be a storage layer and/or a container repository.",
+ "additionalProperties": false,
+ "properties": {
+ "storage": {
+ "description": "Storage Layer containing the ML Model data.",
+ "$ref": "../../type/basic.json#/definitions/href"
+ },
+ "imageRepository": {
+ "description": "Container Repository with the ML Model image.",
+ "$ref": "../../type/basic.json#/definitions/href"
+ }
+ }
}
},
"properties" : {
@@ -206,6 +222,14 @@
"description": "Performance Dashboard URL to track metric evolution.",
"$ref" : "../../type/entityReference.json"
},
+ "mlStore": {
+ "description": "Location containing the ML Model. It can be a storage layer and/or a container repository.",
+ "$ref": "#/definitions/mlStore"
+ },
+ "server": {
+ "description": "Endpoint that makes the ML Model available, e.g,. a REST API serving the data or computing predictions.",
+ "$ref": "../../type/basic.json#/definitions/href"
+ },
"href": {
"description": "Link to the resource corresponding to this entity.",
"$ref": "../../type/basic.json#/definitions/href"
diff --git a/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/mlmodels/MlModelResourceTest.java b/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/mlmodels/MlModelResourceTest.java
index 8bc4db626dc..c3cc42313f8 100644
--- a/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/mlmodels/MlModelResourceTest.java
+++ b/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/mlmodels/MlModelResourceTest.java
@@ -36,6 +36,7 @@ import org.openmetadata.catalog.type.FeatureSourceDataType;
import org.openmetadata.catalog.type.FieldChange;
import org.openmetadata.catalog.type.MlFeature;
import org.openmetadata.catalog.type.MlFeatureDataType;
+import org.openmetadata.catalog.type.MlStore;
import org.openmetadata.catalog.type.MlFeatureSource;
import org.openmetadata.catalog.type.MlHyperParameter;
import org.openmetadata.catalog.entity.services.DashboardService;
@@ -52,6 +53,7 @@ import org.openmetadata.catalog.util.JsonUtils;
import javax.ws.rs.client.WebTarget;
import javax.ws.rs.core.Response.Status;
import java.io.IOException;
+import java.net.URI;
import java.net.URISyntaxException;
import java.util.Arrays;
import java.util.Collections;
@@ -83,6 +85,12 @@ public class MlModelResourceTest extends EntityResourceTest {
public static String ALGORITHM = "regression";
public static Dashboard DASHBOARD;
public static EntityReference DASHBOARD_REFERENCE;
+
+ public static URI SERVER = URI.create("http://localhost.com/mlModel");
+ public static MlStore ML_STORE = new MlStore()
+ .withStorage(URI.create("s3://my-bucket.com/mlModel"))
+ .withImageRepository(URI.create("https://12345.dkr.ecr.region.amazonaws.com"));
+
public static List ML_FEATURES = Arrays.asList(
new MlFeature()
.withName("age")
@@ -196,6 +204,18 @@ public class MlModelResourceTest extends EntityResourceTest {
createAndCheckEntity(create, adminAuthHeaders());
}
+ @Test
+ public void post_MlModelWitMlStore_200_ok(TestInfo test) throws IOException {
+ CreateMlModel create = create(test).withMlStore(ML_STORE);
+ createAndCheckEntity(create, adminAuthHeaders());
+ }
+
+ @Test
+ public void post_MlModelWitServer_200_ok(TestInfo test) throws IOException {
+ CreateMlModel create = create(test).withServer(SERVER);
+ createAndCheckEntity(create, adminAuthHeaders());
+ }
+
@Test
public void post_MlModel_as_non_admin_401(TestInfo test) {
CreateMlModel create = create(test);
@@ -257,6 +277,30 @@ public class MlModelResourceTest extends EntityResourceTest {
);
}
+ @Test
+ public void put_MlModelAddServer_200(TestInfo test) throws IOException {
+ CreateMlModel request = create(test);
+ MlModel model = createAndCheckEntity(request, adminAuthHeaders());
+ ChangeDescription change = getChangeDescription(model.getVersion());
+ change.getFieldsAdded().add(new FieldChange().withName("server").withNewValue(SERVER));
+
+ updateAndCheckEntity(
+ request.withServer(SERVER), Status.OK, adminAuthHeaders(), MINOR_UPDATE, change
+ );
+ }
+
+ @Test
+ public void put_MlModelAddMlStore_200(TestInfo test) throws IOException {
+ CreateMlModel request = create(test);
+ MlModel model = createAndCheckEntity(request, adminAuthHeaders());
+ ChangeDescription change = getChangeDescription(model.getVersion());
+ change.getFieldsAdded().add(new FieldChange().withName("mlStore").withNewValue(ML_STORE));
+
+ updateAndCheckEntity(
+ request.withMlStore(ML_STORE), Status.OK, adminAuthHeaders(), MINOR_UPDATE, change
+ );
+ }
+
@Test
public void get_nonExistentMlModel_404_notFound() {
HttpResponseException exception = assertThrows(HttpResponseException.class, () ->
@@ -469,23 +513,30 @@ public class MlModelResourceTest extends EntityResourceTest {
if (expected == actual) {
return;
}
- if (fieldName.contains("mlFeatures") && !fieldName.endsWith("tags") && !fieldName.endsWith("description")) {
+ if (fieldName.contains("mlFeatures")) {
List expectedFeatures = (List) expected;
List actualFeatures = JsonUtils.readObjects(actual.toString(), MlFeature.class);
assertEquals(expectedFeatures, actualFeatures);
- } else if (fieldName.contains("mlHyperParameters") && !fieldName.endsWith("tags")
- && !fieldName.endsWith("description")) {
+ } else if (fieldName.contains("mlHyperParameters")) {
List expectedConstraints = (List) expected;
List actualConstraints = JsonUtils.readObjects(actual.toString(), MlHyperParameter.class);
assertEquals(expectedConstraints, actualConstraints);
- } else if (fieldName.endsWith("algorithm")) {
+ } else if (fieldName.contains("algorithm")) {
String expectedAlgorithm = (String) expected;
String actualAlgorithm = actual.toString();
assertEquals(expectedAlgorithm, actualAlgorithm);
- } else if (fieldName.endsWith("dashboard")) {
+ } else if (fieldName.contains("dashboard")) {
EntityReference expectedDashboard = (EntityReference) expected;
EntityReference actualDashboard = JsonUtils.readValue(actual.toString(), EntityReference.class);
assertEquals(expectedDashboard, actualDashboard);
+ } else if (fieldName.contains("server")) {
+ URI expectedServer = (URI) expected;
+ URI actualServer = URI.create(actual.toString());
+ assertEquals(expectedServer, actualServer);
+ } else if (fieldName.contains("mlStore")) {
+ MlStore expectedMlStore = (MlStore) expected;
+ MlStore actualMlStore = JsonUtils.readValue(actual.toString(), MlStore.class);
+ assertEquals(expectedMlStore, actualMlStore);
} else {
assertCommonFieldChange(fieldName, expected, actual);
}