[issue-1284] - Add MlStore and Server to MlModel (#1439)

* Add mlStore and server to MlModel entity

* Add mlStore and server to Create MlModel entity

* Add MlStore and Server to MlModel properties
This commit is contained in:
Pere Miquel Brull 2021-11-29 19:08:00 +01:00 committed by GitHub
parent a1e9f986fb
commit c2efcb1107
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 105 additions and 7 deletions

View File

@ -43,7 +43,7 @@ import java.util.UUID;
public class MlModelRepository extends EntityRepository<MlModel> { public class MlModelRepository extends EntityRepository<MlModel> {
private static final Logger LOG = LoggerFactory.getLogger(MlModelRepository.class); private static final Logger LOG = LoggerFactory.getLogger(MlModelRepository.class);
private static final Fields MODEL_UPDATE_FIELDS = new Fields(MlModelResource.FIELD_LIST, private static final Fields MODEL_UPDATE_FIELDS = new Fields(MlModelResource.FIELD_LIST,
"owner,algorithm,dashboard,mlHyperParameters,mlFeatures,tags"); "owner,algorithm,dashboard,mlHyperParameters,mlFeatures,mlStore,server,tags");
private static final Fields MODEL_PATCH_FIELDS = new Fields(MlModelResource.FIELD_LIST, private static final Fields MODEL_PATCH_FIELDS = new Fields(MlModelResource.FIELD_LIST,
"owner,algorithm,dashboard,mlHyperParameters,mlFeatures,tags"); "owner,algorithm,dashboard,mlHyperParameters,mlFeatures,tags");
private final CollectionDAO dao; private final CollectionDAO dao;
@ -77,6 +77,8 @@ public class MlModelRepository extends EntityRepository<MlModel> {
mlModel.setDashboard(fields.contains("dashboard") ? getDashboard(mlModel) : null); mlModel.setDashboard(fields.contains("dashboard") ? getDashboard(mlModel) : null);
mlModel.setMlFeatures(fields.contains("mlFeatures") ? mlModel.getMlFeatures(): null); mlModel.setMlFeatures(fields.contains("mlFeatures") ? mlModel.getMlFeatures(): null);
mlModel.setMlHyperParameters(fields.contains("mlHyperParameters") ? mlModel.getMlHyperParameters(): null); mlModel.setMlHyperParameters(fields.contains("mlHyperParameters") ? mlModel.getMlHyperParameters(): null);
mlModel.setMlStore(fields.contains("mlStore") ? mlModel.getMlStore(): null);
mlModel.setServer(fields.contains("server") ? mlModel.getServer(): null);
mlModel.setFollowers(fields.contains("followers") ? getFollowers(mlModel) : null); mlModel.setFollowers(fields.contains("followers") ? getFollowers(mlModel) : null);
mlModel.setTags(fields.contains("tags") ? getTags(mlModel.getFullyQualifiedName()) : null); mlModel.setTags(fields.contains("tags") ? getTags(mlModel.getFullyQualifiedName()) : null);
mlModel.setUsageSummary(fields.contains("usageSummary") ? EntityUtil.getLatestUsage(dao.usageDAO(), mlModel.setUsageSummary(fields.contains("usageSummary") ? EntityUtil.getLatestUsage(dao.usageDAO(),
@ -338,6 +340,8 @@ public class MlModelRepository extends EntityRepository<MlModel> {
updateDashboard(origMlModel, updatedMlModel); updateDashboard(origMlModel, updatedMlModel);
updateMlFeatures(origMlModel, updatedMlModel); updateMlFeatures(origMlModel, updatedMlModel);
updateMlHyperParameters(origMlModel, updatedMlModel); updateMlHyperParameters(origMlModel, updatedMlModel);
updateMlStore(origMlModel, updatedMlModel);
updateServer(origMlModel, updatedMlModel);
} }
private void updateAlgorithm(MlModel origModel, MlModel updatedModel) throws JsonProcessingException { private void updateAlgorithm(MlModel origModel, MlModel updatedModel) throws JsonProcessingException {
@ -352,6 +356,14 @@ public class MlModelRepository extends EntityRepository<MlModel> {
recordChange("mlHyperParameters", origModel.getMlHyperParameters(), updatedModel.getMlHyperParameters()); recordChange("mlHyperParameters", origModel.getMlHyperParameters(), updatedModel.getMlHyperParameters());
} }
private void updateMlStore(MlModel origModel, MlModel updatedModel) throws JsonProcessingException {
recordChange("mlStore", origModel.getMlStore(), updatedModel.getMlStore(), true);
}
private void updateServer(MlModel origModel, MlModel updatedModel) throws JsonProcessingException {
recordChange("server", origModel.getServer(), updatedModel.getServer());
}
private void updateDashboard(MlModel origModel, MlModel updatedModel) throws JsonProcessingException { private void updateDashboard(MlModel origModel, MlModel updatedModel) throws JsonProcessingException {
String modelId = updatedModel.getId().toString(); String modelId = updatedModel.getId().toString();

View File

@ -110,7 +110,8 @@ public class MlModelResource {
} }
} }
static final String FIELDS = "owner,dashboard,algorithm,mlFeatures,mlHyperParameters,followers,tags,usageSummary"; static final String FIELDS = "owner,dashboard,algorithm,mlFeatures,mlHyperParameters,mlStore,server," +
"followers,tags,usageSummary";
public static final List<String> FIELD_LIST = Arrays.asList(FIELDS.replaceAll(" ", "") public static final List<String> FIELD_LIST = Arrays.asList(FIELDS.replaceAll(" ", "")
.split(",")); .split(","));
@ -357,6 +358,8 @@ public class MlModelResource {
.withAlgorithm(create.getAlgorithm()) .withAlgorithm(create.getAlgorithm())
.withMlFeatures(create.getMlFeatures()) .withMlFeatures(create.getMlFeatures())
.withMlHyperParameters(create.getMlHyperParameters()) .withMlHyperParameters(create.getMlHyperParameters())
.withMlStore(create.getMlStore())
.withServer(create.getServer())
.withTags(create.getTags()) .withTags(create.getTags())
.withOwner(create.getOwner()) .withOwner(create.getOwner())
.withUpdatedBy(securityContext.getUserPrincipal().getName()) .withUpdatedBy(securityContext.getUserPrincipal().getName())

View File

@ -43,6 +43,14 @@
"description": "Performance Dashboard URL to track metric evolution", "description": "Performance Dashboard URL to track metric evolution",
"$ref" : "../../type/entityReference.json" "$ref" : "../../type/entityReference.json"
}, },
"mlStore": {
"description": "Location containing the ML Model. It can be a storage layer and/or a container repository.",
"$ref": "../../entity/data/mlmodel.json#/definitions/mlStore"
},
"server": {
"description": "Endpoint that makes the ML Model available, e.g,. a REST API serving the data or computing predictions.",
"$ref": "../../type/basic.json#/definitions/href"
},
"tags": { "tags": {
"description": "Tags for this ML Model", "description": "Tags for this ML Model",
"type": "array", "type": "array",

View File

@ -155,6 +155,22 @@
"type": "string" "type": "string"
} }
} }
},
"mlStore": {
"type": "object",
"javaType": "org.openmetadata.catalog.type.MlStore",
"description": "Location containing the ML Model. It can be a storage layer and/or a container repository.",
"additionalProperties": false,
"properties": {
"storage": {
"description": "Storage Layer containing the ML Model data.",
"$ref": "../../type/basic.json#/definitions/href"
},
"imageRepository": {
"description": "Container Repository with the ML Model image.",
"$ref": "../../type/basic.json#/definitions/href"
}
}
} }
}, },
"properties" : { "properties" : {
@ -206,6 +222,14 @@
"description": "Performance Dashboard URL to track metric evolution.", "description": "Performance Dashboard URL to track metric evolution.",
"$ref" : "../../type/entityReference.json" "$ref" : "../../type/entityReference.json"
}, },
"mlStore": {
"description": "Location containing the ML Model. It can be a storage layer and/or a container repository.",
"$ref": "#/definitions/mlStore"
},
"server": {
"description": "Endpoint that makes the ML Model available, e.g,. a REST API serving the data or computing predictions.",
"$ref": "../../type/basic.json#/definitions/href"
},
"href": { "href": {
"description": "Link to the resource corresponding to this entity.", "description": "Link to the resource corresponding to this entity.",
"$ref": "../../type/basic.json#/definitions/href" "$ref": "../../type/basic.json#/definitions/href"

View File

@ -36,6 +36,7 @@ import org.openmetadata.catalog.type.FeatureSourceDataType;
import org.openmetadata.catalog.type.FieldChange; import org.openmetadata.catalog.type.FieldChange;
import org.openmetadata.catalog.type.MlFeature; import org.openmetadata.catalog.type.MlFeature;
import org.openmetadata.catalog.type.MlFeatureDataType; import org.openmetadata.catalog.type.MlFeatureDataType;
import org.openmetadata.catalog.type.MlStore;
import org.openmetadata.catalog.type.MlFeatureSource; import org.openmetadata.catalog.type.MlFeatureSource;
import org.openmetadata.catalog.type.MlHyperParameter; import org.openmetadata.catalog.type.MlHyperParameter;
import org.openmetadata.catalog.entity.services.DashboardService; import org.openmetadata.catalog.entity.services.DashboardService;
@ -52,6 +53,7 @@ import org.openmetadata.catalog.util.JsonUtils;
import javax.ws.rs.client.WebTarget; import javax.ws.rs.client.WebTarget;
import javax.ws.rs.core.Response.Status; import javax.ws.rs.core.Response.Status;
import java.io.IOException; import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
@ -83,6 +85,12 @@ public class MlModelResourceTest extends EntityResourceTest<MlModel> {
public static String ALGORITHM = "regression"; public static String ALGORITHM = "regression";
public static Dashboard DASHBOARD; public static Dashboard DASHBOARD;
public static EntityReference DASHBOARD_REFERENCE; public static EntityReference DASHBOARD_REFERENCE;
public static URI SERVER = URI.create("http://localhost.com/mlModel");
public static MlStore ML_STORE = new MlStore()
.withStorage(URI.create("s3://my-bucket.com/mlModel"))
.withImageRepository(URI.create("https://12345.dkr.ecr.region.amazonaws.com"));
public static List<MlFeature> ML_FEATURES = Arrays.asList( public static List<MlFeature> ML_FEATURES = Arrays.asList(
new MlFeature() new MlFeature()
.withName("age") .withName("age")
@ -196,6 +204,18 @@ public class MlModelResourceTest extends EntityResourceTest<MlModel> {
createAndCheckEntity(create, adminAuthHeaders()); createAndCheckEntity(create, adminAuthHeaders());
} }
@Test
public void post_MlModelWitMlStore_200_ok(TestInfo test) throws IOException {
CreateMlModel create = create(test).withMlStore(ML_STORE);
createAndCheckEntity(create, adminAuthHeaders());
}
@Test
public void post_MlModelWitServer_200_ok(TestInfo test) throws IOException {
CreateMlModel create = create(test).withServer(SERVER);
createAndCheckEntity(create, adminAuthHeaders());
}
@Test @Test
public void post_MlModel_as_non_admin_401(TestInfo test) { public void post_MlModel_as_non_admin_401(TestInfo test) {
CreateMlModel create = create(test); CreateMlModel create = create(test);
@ -257,6 +277,30 @@ public class MlModelResourceTest extends EntityResourceTest<MlModel> {
); );
} }
@Test
public void put_MlModelAddServer_200(TestInfo test) throws IOException {
CreateMlModel request = create(test);
MlModel model = createAndCheckEntity(request, adminAuthHeaders());
ChangeDescription change = getChangeDescription(model.getVersion());
change.getFieldsAdded().add(new FieldChange().withName("server").withNewValue(SERVER));
updateAndCheckEntity(
request.withServer(SERVER), Status.OK, adminAuthHeaders(), MINOR_UPDATE, change
);
}
@Test
public void put_MlModelAddMlStore_200(TestInfo test) throws IOException {
CreateMlModel request = create(test);
MlModel model = createAndCheckEntity(request, adminAuthHeaders());
ChangeDescription change = getChangeDescription(model.getVersion());
change.getFieldsAdded().add(new FieldChange().withName("mlStore").withNewValue(ML_STORE));
updateAndCheckEntity(
request.withMlStore(ML_STORE), Status.OK, adminAuthHeaders(), MINOR_UPDATE, change
);
}
@Test @Test
public void get_nonExistentMlModel_404_notFound() { public void get_nonExistentMlModel_404_notFound() {
HttpResponseException exception = assertThrows(HttpResponseException.class, () -> HttpResponseException exception = assertThrows(HttpResponseException.class, () ->
@ -469,23 +513,30 @@ public class MlModelResourceTest extends EntityResourceTest<MlModel> {
if (expected == actual) { if (expected == actual) {
return; return;
} }
if (fieldName.contains("mlFeatures") && !fieldName.endsWith("tags") && !fieldName.endsWith("description")) { if (fieldName.contains("mlFeatures")) {
List<MlFeature> expectedFeatures = (List<MlFeature>) expected; List<MlFeature> expectedFeatures = (List<MlFeature>) expected;
List<MlFeature> actualFeatures = JsonUtils.readObjects(actual.toString(), MlFeature.class); List<MlFeature> actualFeatures = JsonUtils.readObjects(actual.toString(), MlFeature.class);
assertEquals(expectedFeatures, actualFeatures); assertEquals(expectedFeatures, actualFeatures);
} else if (fieldName.contains("mlHyperParameters") && !fieldName.endsWith("tags") } else if (fieldName.contains("mlHyperParameters")) {
&& !fieldName.endsWith("description")) {
List<MlHyperParameter> expectedConstraints = (List<MlHyperParameter>) expected; List<MlHyperParameter> expectedConstraints = (List<MlHyperParameter>) expected;
List<MlHyperParameter> actualConstraints = JsonUtils.readObjects(actual.toString(), MlHyperParameter.class); List<MlHyperParameter> actualConstraints = JsonUtils.readObjects(actual.toString(), MlHyperParameter.class);
assertEquals(expectedConstraints, actualConstraints); assertEquals(expectedConstraints, actualConstraints);
} else if (fieldName.endsWith("algorithm")) { } else if (fieldName.contains("algorithm")) {
String expectedAlgorithm = (String) expected; String expectedAlgorithm = (String) expected;
String actualAlgorithm = actual.toString(); String actualAlgorithm = actual.toString();
assertEquals(expectedAlgorithm, actualAlgorithm); assertEquals(expectedAlgorithm, actualAlgorithm);
} else if (fieldName.endsWith("dashboard")) { } else if (fieldName.contains("dashboard")) {
EntityReference expectedDashboard = (EntityReference) expected; EntityReference expectedDashboard = (EntityReference) expected;
EntityReference actualDashboard = JsonUtils.readValue(actual.toString(), EntityReference.class); EntityReference actualDashboard = JsonUtils.readValue(actual.toString(), EntityReference.class);
assertEquals(expectedDashboard, actualDashboard); assertEquals(expectedDashboard, actualDashboard);
} else if (fieldName.contains("server")) {
URI expectedServer = (URI) expected;
URI actualServer = URI.create(actual.toString());
assertEquals(expectedServer, actualServer);
} else if (fieldName.contains("mlStore")) {
MlStore expectedMlStore = (MlStore) expected;
MlStore actualMlStore = JsonUtils.readValue(actual.toString(), MlStore.class);
assertEquals(expectedMlStore, actualMlStore);
} else { } else {
assertCommonFieldChange(fieldName, expected, actual); assertCommonFieldChange(fieldName, expected, actual);
} }