From 76f40e6f371d8827cc85777b36a517e11012f63e Mon Sep 17 00:00:00 2001 From: Pere Miquel Brull Date: Sat, 27 Nov 2021 21:31:55 +0100 Subject: [PATCH] [issue-1318] - Refactor MlModel tests and improve logic (#1420) * Prepare MlModel test utils * Refactor tests * Properly set FQN for MlModels * Fix FQN and validate inner properties * Change MlModel naming * Inherit variables from parent test * Simplify delete * Update tests interface * Add order * Add missing version endpoints * Add security context * Update mlmodel name * Update mlmodel methods * Update mlmodel tests * Rename MlModel * Prepare dashboard service in setup * Fix MlModel test setup * MlModel name to minus * Workaround for issue-1415 * Workaround for issue-1415 * Move superset ref back to MlModel * Add missing href in Dashboard entity reference * Handle update dashboards * Add full dashboard props in validate * Add update tests * Reformat * Reformat * Reformat * Reformat * Reformat --- .../java/org/openmetadata/catalog/Entity.java | 2 +- .../catalog/jdbi3/DashboardRepository.java | 2 +- .../catalog/jdbi3/MlModelRepository.java | 109 ++-- .../catalog/jdbi3/Relationship.java | 2 +- .../resources/mlmodels/MlModelResource.java | 98 ++- .../catalog/resources/EntityResourceTest.java | 19 + .../mlmodels/MlModelResourceTest.java | 577 +++++++----------- 7 files changed, 375 insertions(+), 434 deletions(-) diff --git a/catalog-rest-service/src/main/java/org/openmetadata/catalog/Entity.java b/catalog-rest-service/src/main/java/org/openmetadata/catalog/Entity.java index 4e76cda81bd..55c384bbeb1 100644 --- a/catalog-rest-service/src/main/java/org/openmetadata/catalog/Entity.java +++ b/catalog-rest-service/src/main/java/org/openmetadata/catalog/Entity.java @@ -60,7 +60,7 @@ public final class Entity { public static final String CHART = "chart"; public static final String REPORT = "report"; public static final String TOPIC = "topic"; - public static final String MLMODEL = "mlModel"; + public static final String MLMODEL = "mlmodel"; public static final String DBTMODEL = "dbtmodel"; public static final String BOTS = "bots"; public static final String LOCATION = "location"; diff --git a/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/DashboardRepository.java b/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/DashboardRepository.java index 3147a4bf040..cbbb8845639 100644 --- a/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/DashboardRepository.java +++ b/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/DashboardRepository.java @@ -287,7 +287,7 @@ public class DashboardRepository extends EntityRepository { @Override public EntityReference getEntityReference() { return new EntityReference().withId(getId()).withName(getFullyQualifiedName()).withDescription(getDescription()) - .withDisplayName(getDisplayName()).withType(Entity.DASHBOARD); + .withDisplayName(getDisplayName()).withType(Entity.DASHBOARD).withHref(getHref()); } @Override diff --git a/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/MlModelRepository.java b/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/MlModelRepository.java index e3619ca50f9..d4173b4ad5d 100644 --- a/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/MlModelRepository.java +++ b/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/MlModelRepository.java @@ -20,10 +20,11 @@ import com.fasterxml.jackson.core.JsonProcessingException; import org.jdbi.v3.sqlobject.transaction.Transaction; import org.openmetadata.catalog.Entity; import org.openmetadata.catalog.entity.data.MlModel; -import org.openmetadata.catalog.exception.EntityNotFoundException; import org.openmetadata.catalog.resources.mlmodels.MlModelResource; import org.openmetadata.catalog.type.ChangeDescription; import org.openmetadata.catalog.type.EntityReference; +import org.openmetadata.catalog.type.MlFeature; +import org.openmetadata.catalog.type.MlFeatureSource; import org.openmetadata.catalog.type.TagLabel; import org.openmetadata.catalog.util.EntityInterface; import org.openmetadata.catalog.util.EntityUtil; @@ -39,14 +40,12 @@ import java.util.Date; import java.util.List; import java.util.UUID; -import static org.openmetadata.catalog.exception.CatalogExceptionMessage.entityNotFound; - public class MlModelRepository extends EntityRepository { private static final Logger LOG = LoggerFactory.getLogger(MlModelRepository.class); private static final Fields MODEL_UPDATE_FIELDS = new Fields(MlModelResource.FIELD_LIST, - "owner,dashboard,mlHyperParameters,mlFeatures,tags"); + "owner,algorithm,dashboard,mlHyperParameters,mlFeatures,tags"); private static final Fields MODEL_PATCH_FIELDS = new Fields(MlModelResource.FIELD_LIST, - "owner,dashboard,mlHyperParameters,mlFeatures,tags"); + "owner,algorithm,dashboard,mlHyperParameters,mlFeatures,tags"); private final CollectionDAO dao; public MlModelRepository(CollectionDAO dao) { @@ -62,12 +61,7 @@ public class MlModelRepository extends EntityRepository { @Transaction public void delete(UUID id) { - if (dao.relationshipDAO().findToCount(id.toString(), Relationship.CONTAINS.ordinal(), Entity.MLMODEL) > 0) { - throw new IllegalArgumentException("Model is not empty"); - } - if (dao.mlModelDAO().delete(id) <= 0) { - throw EntityNotFoundException.byMessage(entityNotFound(Entity.MLMODEL, id)); - } + dao.mlModelDAO().delete(id); dao.relationshipDAO().deleteAll(id.toString()); } @@ -92,7 +86,9 @@ public class MlModelRepository extends EntityRepository { @Override public void restorePatchAttributes(MlModel original, MlModel updated) throws IOException, ParseException { - + // Patch can't make changes to following fields. Ignore the changes + updated.withFullyQualifiedName(original.getFullyQualifiedName()) + .withName(original.getName()).withId(original.getId()); } @Override @@ -104,16 +100,40 @@ public class MlModelRepository extends EntityRepository { return dao.tagDAO().getTags(fqn); } + private void setMlFeatureSourcesFQN(String parentFQN, List mlSources) { + mlSources.forEach(s -> { + String sourceFqn = parentFQN + "." + s.getName(); + s.setFullyQualifiedName(sourceFqn); + }); + } + + private void setMlFeatureFQN(String parentFQN, List mlFeatures) { + mlFeatures.forEach(f -> { + String featureFqn = parentFQN + "." + f.getName(); + f.setFullyQualifiedName(featureFqn); + if (f.getFeatureSources() != null) { + setMlFeatureSourcesFQN(featureFqn, f.getFeatureSources()); + } + }); + } + @Override - public void validate(MlModel model) throws IOException { - model.setFullyQualifiedName(getFQN(model)); - EntityUtil.populateOwner(dao.userDAO(), dao.teamDAO(), model.getOwner()); // Validate owner - if (model.getDashboard() != null) { - UUID dashboardId = model.getDashboard().getId(); - model.setDashboard(dao.dashboardDAO().findEntityReferenceById(dashboardId)); + public void validate(MlModel mlModel) throws IOException { + mlModel.setFullyQualifiedName(getFQN(mlModel)); + setMlFeatureFQN(mlModel.getFullyQualifiedName(), mlModel.getMlFeatures()); + + // Check if owner is valid and set the relationship + mlModel.setOwner(EntityUtil.populateOwner(dao.userDAO(), dao.teamDAO(), mlModel.getOwner())); + + setDashboard(mlModel, mlModel.getDashboard()); + if (mlModel.getDashboard() != null) { + // Add relationship from MlModel to Dashboard + String dashboardId = mlModel.getDashboard().getId().toString(); + dao.relationshipDAO().insert(dashboardId, mlModel.getId().toString(), Entity.MLMODEL, Entity.DASHBOARD, + Relationship.USES.ordinal()); } - model.setTags(EntityUtil.addDerivedTags(dao.tagDAO(), model.getTags())); + mlModel.setTags(EntityUtil.addDerivedTags(dao.tagDAO(), mlModel.getTags())); } @Override @@ -138,8 +158,18 @@ public class MlModelRepository extends EntityRepository { @Override public void storeRelationships(MlModel mlModel) throws IOException { - setOwner(mlModel, mlModel.getOwner()); + + EntityUtil.setOwner(dao.relationshipDAO(), mlModel.getId(), Entity.MLMODEL, mlModel.getOwner()); + setDashboard(mlModel, mlModel.getDashboard()); + + if (mlModel.getDashboard() != null) { + // Add relationship from MlModel to Dashboard + String dashboardId = mlModel.getDashboard().getId().toString(); + dao.relationshipDAO().insert(dashboardId, mlModel.getId().toString(), Entity.MLMODEL, Entity.DASHBOARD, + Relationship.USES.ordinal()); + } + applyTags(mlModel); } @@ -153,11 +183,6 @@ public class MlModelRepository extends EntityRepository { dao.userDAO(), dao.teamDAO()); } - public void setOwner(MlModel mlModel, EntityReference owner) { - EntityUtil.setOwner(dao.relationshipDAO(), mlModel.getId(), Entity.MLMODEL, owner); - mlModel.setOwner(owner); - } - private EntityReference getDashboard(MlModel mlModel) throws IOException { if (mlModel != null) { List ids = dao.relationshipDAO().findTo(mlModel.getId().toString(), Relationship.USES.ordinal()); @@ -193,10 +218,10 @@ public class MlModelRepository extends EntityRepository { return model == null ? null : EntityUtil.getFollowers(model.getId(), dao.relationshipDAO(), dao.userDAO()); } - static class MlModelEntityInterface implements EntityInterface { + public static class MlModelEntityInterface implements EntityInterface { private final MlModel entity; - MlModelEntityInterface(MlModel entity) { + public MlModelEntityInterface(MlModel entity) { this.entity = entity; } @@ -304,10 +329,12 @@ public class MlModelRepository extends EntityRepository { @Override public void entitySpecificUpdate() throws IOException { - updateAlgorithm(original.getEntity(), updated.getEntity()); - updateDashboard(original.getEntity(), updated.getEntity()); - updateMlFeatures(original.getEntity(), updated.getEntity()); - updateMlHyperParameters(original.getEntity(), updated.getEntity()); + MlModel origMlModel = original.getEntity(); + MlModel updatedMlModel = updated.getEntity(); + updateAlgorithm(origMlModel, updatedMlModel); + updateDashboard(origMlModel, updatedMlModel); + updateMlFeatures(origMlModel, updatedMlModel); + updateMlHyperParameters(origMlModel, updatedMlModel); } private void updateAlgorithm(MlModel origModel, MlModel updatedModel) throws JsonProcessingException { @@ -323,15 +350,21 @@ public class MlModelRepository extends EntityRepository { } private void updateDashboard(MlModel origModel, MlModel updatedModel) throws JsonProcessingException { - // Remove existing dashboards - removeDashboard(origModel); + String modelId = updatedModel.getId().toString(); - EntityReference origOwner = origModel.getDashboard(); - EntityReference updatedOwner = updatedModel.getDashboard(); - if (recordChange("owner", origOwner == null ? null : origOwner.getId(), - updatedOwner == null ? null : updatedOwner.getId())) { - setDashboard(updatedModel, updatedModel.getDashboard()); + // Remove the dashboard associated with the model, if any + if (origModel.getDashboard() != null) { + dao.relationshipDAO().deleteFrom(modelId, Relationship.USES.ordinal(), "dashboard"); } + + // Add relationship from model to dashboard + EntityReference updatedDashboard = updatedModel.getDashboard(); + if (updatedDashboard != null) { + dao.relationshipDAO().insert(modelId, updatedDashboard.getId().toString(), + Entity.MLMODEL, Entity.DASHBOARD, Relationship.USES.ordinal()); + } + recordChange("dashboard", origModel.getDashboard(), updatedDashboard, true); + } } } diff --git a/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/Relationship.java b/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/Relationship.java index 29f15ca4c61..32f46856c80 100644 --- a/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/Relationship.java +++ b/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/Relationship.java @@ -54,7 +54,7 @@ public enum Relationship { // {Dashboard|Pipeline|Query} --- uses ---> Table // {User} --- uses ---> {Table|Dashboard|Query} - // {Model} --- uses ---> {Dashboard} + // {MlModel} --- uses ---> {Dashboard} USES("uses"), // {User|Team|Org} --- owns ---> {Table|Dashboard|Query} diff --git a/catalog-rest-service/src/main/java/org/openmetadata/catalog/resources/mlmodels/MlModelResource.java b/catalog-rest-service/src/main/java/org/openmetadata/catalog/resources/mlmodels/MlModelResource.java index 96c2eadf23d..eaa1917c4ca 100644 --- a/catalog-rest-service/src/main/java/org/openmetadata/catalog/resources/mlmodels/MlModelResource.java +++ b/catalog-rest-service/src/main/java/org/openmetadata/catalog/resources/mlmodels/MlModelResource.java @@ -34,6 +34,7 @@ import org.openmetadata.catalog.jdbi3.MlModelRepository; import org.openmetadata.catalog.resources.Collection; import org.openmetadata.catalog.security.CatalogAuthorizer; import org.openmetadata.catalog.security.SecurityUtil; +import org.openmetadata.catalog.type.EntityHistory; import org.openmetadata.catalog.util.EntityUtil.Fields; import org.openmetadata.catalog.util.RestUtil; import org.openmetadata.catalog.util.RestUtil.PatchResponse; @@ -82,15 +83,10 @@ public class MlModelResource { private final MlModelRepository dao; private final CatalogAuthorizer authorizer; - public static List addHref(UriInfo uriInfo, List models) { - Optional.ofNullable(models).orElse(Collections.emptyList()).forEach(i -> addHref(uriInfo, i)); - return models; - } - public static MlModel addHref(UriInfo uriInfo, MlModel mlmodel) { mlmodel.setHref(RestUtil.getHref(uriInfo, COLLECTION_PATH, mlmodel.getId())); Entity.withHref(uriInfo, mlmodel.getOwner()); - Entity.withHref(uriInfo, mlmodel.getDashboard()); // Dashboard HREF + Entity.withHref(uriInfo, mlmodel.getDashboard()); Entity.withHref(uriInfo, mlmodel.getFollowers()); return mlmodel; } @@ -152,11 +148,11 @@ public class MlModelResource { ResultList mlmodels; if (before != null) { // Reverse paging - mlmodels = dao.listBefore(uriInfo, fields, null, limitParam, before); // Ask for one extra entry + mlmodels = dao.listBefore(uriInfo, fields, null, limitParam, before); } else { // Forward paging or first page mlmodels = dao.listAfter(uriInfo, fields, null, limitParam, after); } - addHref(uriInfo, mlmodels.getData()); + mlmodels.getData().forEach(m -> addHref(uriInfo, m)); return mlmodels; } @@ -196,8 +192,7 @@ public class MlModelResource { schema = @Schema(type = "string", example = FIELDS)) @QueryParam("fields") String fieldsParam) throws IOException, ParseException { Fields fields = new Fields(FIELD_LIST, fieldsParam); - MlModel mlmodel = dao.getByName(uriInfo, fqn, fields); - return addHref(uriInfo, mlmodel); + return addHref(uriInfo, dao.getByName(uriInfo, fqn, fields)); } @@ -205,12 +200,13 @@ public class MlModelResource { @Operation(summary = "Create an ML Model", tags = "mlModels", description = "Create a new ML Model.", responses = { - @ApiResponse(responseCode = "200", description = "The model", - content = @Content(mediaType = "application/json", - schema = @Schema(implementation = CreateMlModel.class))), - @ApiResponse(responseCode = "400", description = "Bad request") + @ApiResponse(responseCode = "200", description = "ML Model", + content = @Content(mediaType = "application/json", + schema = @Schema(implementation = CreateMlModel.class))), + @ApiResponse(responseCode = "400", description = "Bad request") }) - public Response create(@Context UriInfo uriInfo, @Context SecurityContext securityContext, + public Response create(@Context UriInfo uriInfo, + @Context SecurityContext securityContext, @Valid CreateMlModel create) throws IOException, ParseException { SecurityUtil.checkAdminOrBotRole(authorizer, securityContext); MlModel mlModel = getMlModel(securityContext, create); @@ -225,22 +221,22 @@ public class MlModelResource { externalDocs = @ExternalDocumentation(description = "JsonPatch RFC", url = "https://tools.ietf.org/html/rfc6902")) @Consumes(MediaType.APPLICATION_JSON_PATCH_JSON) - public Response updateDescription(@Context UriInfo uriInfo, - @Context SecurityContext securityContext, - @PathParam("id") String id, - @RequestBody(description = "JsonPatch with array of operations", - content = @Content(mediaType = MediaType.APPLICATION_JSON_PATCH_JSON, - examples = {@ExampleObject("[" + - "{op:remove, path:/a}," + - "{op:add, path: /b, value: val}" + - "]")})) - JsonPatch patch) throws IOException, ParseException { + public Response patch(@Context UriInfo uriInfo, + @Context SecurityContext securityContext, + @Parameter(description = "Id of the ML Model", schema = @Schema(type = "string")) + @PathParam("id") String id, + @RequestBody(description = "JsonPatch with array of operations", + content = @Content(mediaType = MediaType.APPLICATION_JSON_PATCH_JSON, + examples = {@ExampleObject("[" + + "{op:remove, path:/a}," + + "{op:add, path: /b, value: val}" + + "]")})) + JsonPatch patch) throws IOException, ParseException { Fields fields = new Fields(FIELD_LIST, FIELDS); MlModel mlModel = dao.get(uriInfo, id, fields); - SecurityUtil.checkAdminRoleOrPermissions(authorizer, securityContext, - dao.getOwnerReference(mlModel)); - PatchResponse response = - dao.patch(uriInfo, UUID.fromString(id), securityContext.getUserPrincipal().getName(), patch); + SecurityUtil.checkAdminRoleOrPermissions(authorizer, securityContext, dao.getOwnerReference(mlModel)); + PatchResponse response = dao.patch(uriInfo, UUID.fromString(id), + securityContext.getUserPrincipal().getName(), patch); addHref(uriInfo, response.getEntity()); return response.toResponse(); } @@ -258,6 +254,7 @@ public class MlModelResource { @Context SecurityContext securityContext, @Valid CreateMlModel create) throws IOException, ParseException { MlModel mlModel = getMlModel(securityContext, create); + SecurityUtil.checkAdminRoleOrPermissions(authorizer, securityContext, dao.getOwnerReference(mlModel)); PutResponse response = dao.createOrUpdate(uriInfo, mlModel); addHref(uriInfo, response.getEntity()); return response.toResponse(); @@ -298,6 +295,43 @@ public class MlModelResource { UUID.fromString(userId)).toResponse(); } + @GET + @Path("/{id}/versions") + @Operation(summary = "List Ml Model versions", tags = "mlModels", + description = "Get a list of all the versions of an Ml Model identified by `id`", + responses = {@ApiResponse(responseCode = "200", description = "List of Ml Model versions", + content = @Content(mediaType = "application/json", + schema = @Schema(implementation = EntityHistory.class))) + }) + public EntityHistory listVersions(@Context UriInfo uriInfo, + @Context SecurityContext securityContext, + @Parameter(description = "ML Model Id", schema = @Schema(type = "string")) + @PathParam("id") String id) + throws IOException, ParseException, GeneralSecurityException { + return dao.listVersions(id); + } + + @GET + @Path("/{id}/versions/{version}") + @Operation(summary = "Get a version of the ML Model", tags = "mlModels", + description = "Get a version of the ML Model by given `id`", + responses = { + @ApiResponse(responseCode = "200", description = "MlModel", + content = @Content(mediaType = "application/json", + schema = @Schema(implementation = MlModel.class))), + @ApiResponse(responseCode = "404", description = "ML Model for instance {id} and version {version} is " + + "not found") + }) + public MlModel getVersion(@Context UriInfo uriInfo, + @Context SecurityContext securityContext, + @Parameter(description = "ML Model Id", schema = @Schema(type = "string")) + @PathParam("id") String id, + @Parameter(description = "ML Model version number in the form `major`.`minor`", + schema = @Schema(type = "string", example = "0.1 or 1.1")) + @PathParam("version") String version) throws IOException, ParseException { + return dao.getVersion(id, version); + } + @DELETE @Path("/{id}") @Operation(summary = "Delete an ML Model", tags = "mlModels", @@ -306,7 +340,11 @@ public class MlModelResource { @ApiResponse(responseCode = "200", description = "OK"), @ApiResponse(responseCode = "404", description = "model for instance {id} is not found") }) - public Response delete(@Context UriInfo uriInfo, @PathParam("id") String id) { + public Response delete(@Context UriInfo uriInfo, + @Context SecurityContext securityContext, + @Parameter(description = "Id of the ML Model", schema = @Schema(type = "string")) + @PathParam("id") String id) { + SecurityUtil.checkAdminOrBotRole(authorizer, securityContext); dao.delete(UUID.fromString(id)); return Response.ok().build(); } diff --git a/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/EntityResourceTest.java b/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/EntityResourceTest.java index 59b56e3477b..a462fa407e1 100644 --- a/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/EntityResourceTest.java +++ b/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/EntityResourceTest.java @@ -61,6 +61,7 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Optional; import java.util.UUID; +import java.util.function.BiConsumer; import static javax.ws.rs.core.Response.Status.BAD_REQUEST; import static javax.ws.rs.core.Response.Status.CREATED; @@ -645,6 +646,7 @@ public abstract class EntityResourceTest extends CatalogApplicationTest { protected final T createAndCheckEntity(Object create, Map authHeaders) throws IOException { // Validate an entity that is created has all the information set in create request String updatedBy = TestUtils.getPrincipal(authHeaders); + // aqui si que tenim HREF T entity = createEntity(create, authHeaders); EntityInterface entityInterface = getEntityInterface(entity); @@ -1029,4 +1031,21 @@ public abstract class EntityResourceTest extends CatalogApplicationTest { list.getData().forEach(e -> LOG.info("{} {}", entityClass, getEntityInterface(e).getFullyQualifiedName())); LOG.info("before {} after {} ", list.getPaging().getBefore(), list.getPaging().getAfter()); } + + /** + * Given a list of properties of an Entity (e.g., List or List and + * a function that validate the elements of T, validate lists + */ + public

void assertListProperty(List

expected, List

actual, BiConsumer validate) + throws HttpResponseException { + if (expected == null && actual == null) { + return; + } + + assertNotNull(expected); + assertEquals(expected.size(), actual.size()); + for (int i = 0; i < expected.size(); i++) { + validate.accept(expected.get(i), actual.get(i)); + } + } } diff --git a/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/mlmodels/MlModelResourceTest.java b/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/mlmodels/MlModelResourceTest.java index fc3efb5be92..c9b578a7d61 100644 --- a/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/mlmodels/MlModelResourceTest.java +++ b/catalog-rest-service/src/test/java/org/openmetadata/catalog/resources/mlmodels/MlModelResourceTest.java @@ -18,72 +18,69 @@ package org.openmetadata.catalog.resources.mlmodels; import org.apache.http.client.HttpResponseException; import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.MethodOrderer; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInfo; -import org.openmetadata.catalog.CatalogApplicationTest; +import org.junit.jupiter.api.TestMethodOrder; import org.openmetadata.catalog.Entity; import org.openmetadata.catalog.api.data.CreateMlModel; import org.openmetadata.catalog.api.services.CreateDashboardService; import org.openmetadata.catalog.api.services.CreateDashboardService.DashboardServiceType; import org.openmetadata.catalog.entity.data.Dashboard; import org.openmetadata.catalog.entity.data.MlModel; +import org.openmetadata.catalog.exception.CatalogExceptionMessage; +import org.openmetadata.catalog.jdbi3.MlModelRepository; +import org.openmetadata.catalog.resources.EntityResourceTest; +import org.openmetadata.catalog.type.ChangeDescription; import org.openmetadata.catalog.type.FeatureSourceDataType; +import org.openmetadata.catalog.type.FieldChange; import org.openmetadata.catalog.type.MlFeature; import org.openmetadata.catalog.type.MlFeatureDataType; import org.openmetadata.catalog.type.MlFeatureSource; import org.openmetadata.catalog.type.MlHyperParameter; import org.openmetadata.catalog.entity.services.DashboardService; -import org.openmetadata.catalog.entity.teams.Team; -import org.openmetadata.catalog.entity.teams.User; -import org.openmetadata.catalog.exception.CatalogExceptionMessage; import org.openmetadata.catalog.jdbi3.DashboardRepository.DashboardEntityInterface; import org.openmetadata.catalog.jdbi3.DashboardServiceRepository.DashboardServiceEntityInterface; import org.openmetadata.catalog.resources.dashboards.DashboardResourceTest; import org.openmetadata.catalog.resources.mlmodels.MlModelResource.MlModelList; import org.openmetadata.catalog.resources.services.DashboardServiceResourceTest; -import org.openmetadata.catalog.resources.teams.TeamResourceTest; -import org.openmetadata.catalog.resources.teams.UserResourceTest; import org.openmetadata.catalog.type.EntityReference; -import org.openmetadata.catalog.type.TagLabel; +import org.openmetadata.catalog.util.EntityInterface; import org.openmetadata.catalog.util.TestUtils; -import org.openmetadata.catalog.util.TestUtils.UpdateType; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import org.openmetadata.catalog.util.JsonUtils; import javax.ws.rs.client.WebTarget; import javax.ws.rs.core.Response.Status; +import java.io.IOException; +import java.net.URISyntaxException; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.UUID; +import java.util.function.BiConsumer; import static javax.ws.rs.core.Response.Status.BAD_REQUEST; import static javax.ws.rs.core.Response.Status.CONFLICT; -import static javax.ws.rs.core.Response.Status.CREATED; import static javax.ws.rs.core.Response.Status.FORBIDDEN; import static javax.ws.rs.core.Response.Status.NOT_FOUND; -import static javax.ws.rs.core.Response.Status.OK; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.openmetadata.catalog.exception.CatalogExceptionMessage.ENTITY_ALREADY_EXISTS; import static org.openmetadata.catalog.exception.CatalogExceptionMessage.entityNotFound; import static org.openmetadata.catalog.util.TestUtils.UpdateType.MINOR_UPDATE; import static org.openmetadata.catalog.util.TestUtils.UpdateType.NO_CHANGE; import static org.openmetadata.catalog.util.TestUtils.adminAuthHeaders; -import static org.openmetadata.catalog.util.TestUtils.assertEntityPagination; import static org.openmetadata.catalog.util.TestUtils.assertResponse; import static org.openmetadata.catalog.util.TestUtils.authHeaders; -public class MlModelResourceTest extends CatalogApplicationTest { - private static final Logger LOG = LoggerFactory.getLogger(MlModelResourceTest.class); - public static User USER1; - public static EntityReference USER_OWNER1; - public static Team TEAM1; - public static EntityReference TEAM_OWNER1; - public static String ALGORITHM = "regression"; +@TestMethodOrder(MethodOrderer.OrderAnnotation.class) +public class MlModelResourceTest extends EntityResourceTest { + public static EntityReference SUPERSET_REFERENCE; + public static String ALGORITHM = "regression"; public static Dashboard DASHBOARD; public static EntityReference DASHBOARD_REFERENCE; public static List ML_FEATURES = Arrays.asList( @@ -95,7 +92,6 @@ public class MlModelResourceTest extends CatalogApplicationTest { new MlFeatureSource() .withName("age") .withDataType(FeatureSourceDataType.INTEGER) - .withFullyQualifiedName("my_service.my_db.my_table.age") ) ), new MlFeature() @@ -105,12 +101,10 @@ public class MlModelResourceTest extends CatalogApplicationTest { Arrays.asList( new MlFeatureSource() .withName("age") - .withDataType(FeatureSourceDataType.INTEGER) - .withFullyQualifiedName("my_service.my_db.my_table.age"), + .withDataType(FeatureSourceDataType.INTEGER), new MlFeatureSource() .withName("education") .withDataType(FeatureSourceDataType.STRING) - .withFullyQualifiedName("my_api.education") ) ) .withFeatureAlgorithm("PCA") @@ -120,14 +114,16 @@ public class MlModelResourceTest extends CatalogApplicationTest { new MlHyperParameter().withName("random").withValue("hello") ); + public MlModelResourceTest() { + super(Entity.MLMODEL, MlModel.class, MlModelList.class, "mlmodels", MlModelResource.FIELDS, true, + true, true); + } + @BeforeAll - public static void setup(TestInfo test) throws HttpResponseException { - USER1 = UserResourceTest.createUser(UserResourceTest.create(test), authHeaders("test@open-metadata.org")); - USER_OWNER1 = new EntityReference().withId(USER1.getId()).withType("user"); + public static void setup(TestInfo test) throws IOException, URISyntaxException { - TEAM1 = TeamResourceTest.createTeam(TeamResourceTest.create(test), adminAuthHeaders()); - TEAM_OWNER1 = new EntityReference().withId(TEAM1.getId()).withType("team"); + EntityResourceTest.setup(test); CreateDashboardService createService = new CreateDashboardService().withName("superset") .withServiceType(DashboardServiceType.Superset).withDashboardUrl(TestUtils.DASHBOARD_URL); @@ -141,358 +137,159 @@ public class MlModelResourceTest extends CatalogApplicationTest { DASHBOARD_REFERENCE = new DashboardEntityInterface(DASHBOARD).getEntityReference(); } + public static MlModel createMlModel(CreateMlModel create, + Map authHeaders) throws HttpResponseException { + return new MlModelResourceTest().createEntity(create, authHeaders); + } + @Test - public void post_modelWithLongName_400_badRequest(TestInfo test) { + public void post_MlModelWithLongName_400_badRequest(TestInfo test) { // Create model with mandatory name field empty CreateMlModel create = create(test).withName(TestUtils.LONG_ENTITY_NAME); - HttpResponseException exception = assertThrows(HttpResponseException.class, () -> - createModel(create, adminAuthHeaders())); - assertResponse(exception, BAD_REQUEST, "[name size must be between 1 and 64]"); + assertResponse(() -> createMlModel(create, adminAuthHeaders()), BAD_REQUEST, + "[name size must be between 1 and 64]"); } @Test - public void post_ModelWithoutName_400_badRequest(TestInfo test) { + public void post_MlModelWithoutName_400_badRequest(TestInfo test) { // Create Model with mandatory name field empty CreateMlModel create = create(test).withName(""); - HttpResponseException exception = assertThrows(HttpResponseException.class, () -> - createModel(create, adminAuthHeaders())); - assertResponse(exception, BAD_REQUEST, "[name size must be between 1 and 64]"); + assertResponse(() -> createMlModel(create, adminAuthHeaders()), BAD_REQUEST, + "[name size must be between 1 and 64]"); } @Test - public void post_ModelAlreadyExists_409_conflict(TestInfo test) throws HttpResponseException { + public void post_MlModelAlreadyExists_409_conflict(TestInfo test) throws HttpResponseException { CreateMlModel create = create(test); - createModel(create, adminAuthHeaders()); - HttpResponseException exception = assertThrows(HttpResponseException.class, () -> - createModel(create, adminAuthHeaders())); - assertResponse(exception, CONFLICT, CatalogExceptionMessage.ENTITY_ALREADY_EXISTS); + createMlModel(create, adminAuthHeaders()); + assertResponse(() -> createMlModel(create, adminAuthHeaders()), CONFLICT, ENTITY_ALREADY_EXISTS); } @Test - public void post_validModels_as_admin_200_OK(TestInfo test) throws HttpResponseException { + public void post_validMlModels_as_admin_200_OK(TestInfo test) throws IOException { // Create valid model CreateMlModel create = create(test); - createAndCheckModel(create, adminAuthHeaders()); + createAndCheckEntity(create, adminAuthHeaders()); create.withName(getModelName(test, 1)).withDescription("description"); - createAndCheckModel(create, adminAuthHeaders()); + createAndCheckEntity(create, adminAuthHeaders()); } @Test - public void post_ModelWithUserOwner_200_ok(TestInfo test) throws HttpResponseException { - createAndCheckModel(create(test).withOwner(USER_OWNER1), adminAuthHeaders()); + public void post_MlModelWithUserOwner_200_ok(TestInfo test) throws IOException { + createAndCheckEntity(create(test).withOwner(USER_OWNER1), adminAuthHeaders()); } @Test - public void post_ModelWithTeamOwner_200_ok(TestInfo test) throws HttpResponseException { - createAndCheckModel(create(test).withOwner(TEAM_OWNER1).withDisplayName("Model1"), adminAuthHeaders()); + public void post_MlModelWithTeamOwner_200_ok(TestInfo test) throws IOException { + createAndCheckEntity(create(test).withOwner(TEAM_OWNER1).withDisplayName("Model1"), adminAuthHeaders()); } @Test - public void post_ModelWithDashboard_200_ok(TestInfo test) throws HttpResponseException { - createAndCheckModel(create(test), DASHBOARD_REFERENCE, adminAuthHeaders()); + public void post_MlModelWithDashboard_200_ok(TestInfo test) throws IOException { + CreateMlModel create = create(test).withDashboard(DASHBOARD_REFERENCE); + createAndCheckEntity(create, adminAuthHeaders()); } @Test - public void post_Model_as_non_admin_401(TestInfo test) { + public void post_MlModel_as_non_admin_401(TestInfo test) { CreateMlModel create = create(test); - HttpResponseException exception = assertThrows(HttpResponseException.class, () -> - createModel(create, authHeaders("test@open-metadata.org"))); - assertResponse(exception, FORBIDDEN, "Principal: CatalogPrincipal{name='test'} is not admin"); + assertResponse(() -> createMlModel(create, authHeaders("test@open-metadata.org")), FORBIDDEN, + "Principal: CatalogPrincipal{name='test'} is not admin"); } @Test - public void post_ModelWithInvalidOwnerType_4xx(TestInfo test) { + public void post_MlModelWithInvalidOwnerType_4xx(TestInfo test) { EntityReference owner = new EntityReference().withId(TEAM1.getId()); /* No owner type is set */ - CreateMlModel create = create(test).withOwner(owner); + HttpResponseException exception = assertThrows(HttpResponseException.class, () -> - createModel(create, adminAuthHeaders())); + createEntity(create, adminAuthHeaders())); TestUtils.assertResponseContains(exception, BAD_REQUEST, "type must not be null"); } @Test - public void post_ModelWithNonExistentOwner_4xx(TestInfo test) { + public void post_MlModelWithNonExistentOwner_4xx(TestInfo test) { EntityReference owner = new EntityReference().withId(TestUtils.NON_EXISTENT_ENTITY).withType("user"); CreateMlModel create = create(test).withOwner(owner); - HttpResponseException exception = assertThrows(HttpResponseException.class, () -> - createModel(create, adminAuthHeaders())); - assertResponse(exception, NOT_FOUND, entityNotFound("User", TestUtils.NON_EXISTENT_ENTITY)); + + assertResponse(() -> createMlModel(create, adminAuthHeaders()), NOT_FOUND, + entityNotFound("User", TestUtils.NON_EXISTENT_ENTITY)); } @Test - public void get_ModelListWithInvalidLimitOffset_4xx() { - // Limit must be >= 1 and <= 1000,000 - HttpResponseException exception = assertThrows(HttpResponseException.class, () - -> listModels(null, -1, null, null, adminAuthHeaders())); - assertResponse(exception, BAD_REQUEST, "[query param limit must be greater than or equal to 1]"); - - exception = assertThrows(HttpResponseException.class, () - -> listModels(null, 0, null, null, adminAuthHeaders())); - assertResponse(exception, BAD_REQUEST, "[query param limit must be greater than or equal to 1]"); - - exception = assertThrows(HttpResponseException.class, () - -> listModels(null, 1000001, null, null, adminAuthHeaders())); - assertResponse(exception, BAD_REQUEST, "[query param limit must be less than or equal to 1000000]"); - } - - @Test - public void get_ModelListWithInvalidPaginationCursors_4xx() { - // Passing both before and after cursors is invalid - HttpResponseException exception = assertThrows(HttpResponseException.class, () - -> listModels(null, 1, "", "", adminAuthHeaders())); - assertResponse(exception, BAD_REQUEST, "Only one of before or after query parameter allowed"); - } - - @Test - public void get_ModelListWithValidLimitOffset_4xx(TestInfo test) throws HttpResponseException { - // Create a large number of Models - int maxModels = 40; - for (int i = 0; i < maxModels; i++) { - createModel(create(test, i), adminAuthHeaders()); - } - - // List all Models - MlModelList allModels = listModels(null, 1000000, null, - null, adminAuthHeaders()); - int totalRecords = allModels.getData().size(); - printModels(allModels); - - // List limit number Models at a time at various offsets and ensure right results are returned - for (int limit = 1; limit < maxModels; limit++) { - String after = null; - String before; - int pageCount = 0; - int indexInAllModels = 0; - MlModelList forwardPage; - MlModelList backwardPage; - do { // For each limit (or page size) - forward scroll till the end - LOG.info("Limit {} forward scrollCount {} afterCursor {}", limit, pageCount, after); - forwardPage = listModels(null, limit, null, after, adminAuthHeaders()); - printModels(forwardPage); - after = forwardPage.getPaging().getAfter(); - before = forwardPage.getPaging().getBefore(); - assertEntityPagination(allModels.getData(), forwardPage, limit, indexInAllModels); - - if (pageCount == 0) { // CASE 0 - First page is being returned. There is no before cursor - assertNull(before); - } else { - // Make sure scrolling back based on before cursor returns the correct result - backwardPage = listModels(null, limit, before, null, adminAuthHeaders()); - assertEntityPagination(allModels.getData(), backwardPage, limit, (indexInAllModels - limit)); - } - - indexInAllModels += forwardPage.getData().size(); - pageCount++; - } while (after != null); - - // We have now reached the last page - test backward scroll till the beginning - pageCount = 0; - indexInAllModels = totalRecords - limit - forwardPage.getData().size(); - do { - LOG.info("Limit {} backward scrollCount {} beforeCursor {}", limit, pageCount, before); - forwardPage = listModels(null, limit, before, null, adminAuthHeaders()); - printModels(forwardPage); - before = forwardPage.getPaging().getBefore(); - assertEntityPagination(allModels.getData(), forwardPage, limit, indexInAllModels); - pageCount++; - indexInAllModels -= forwardPage.getData().size(); - } while (before != null); - } - } - - private void printModels(MlModelList list) { - list.getData().forEach(Model -> LOG.info("DB {}", Model.getFullyQualifiedName())); - LOG.info("before {} after {} ", list.getPaging().getBefore(), list.getPaging().getAfter()); - } - - @Test - public void put_ModelUpdateWithNoChange_200(TestInfo test) throws HttpResponseException { + public void put_MlModelUpdateWithNoChange_200(TestInfo test) throws IOException { // Create a Model with POST CreateMlModel request = create(test).withOwner(USER_OWNER1); - MlModel model = createAndCheckModel(request, adminAuthHeaders()); + MlModel model = createAndCheckEntity(request, adminAuthHeaders()); + ChangeDescription change = getChangeDescription(model.getVersion()); // Update Model two times successfully with PUT requests - model = updateAndCheckModel(model, request, OK, adminAuthHeaders(), NO_CHANGE); - updateAndCheckModel(model, request, OK, adminAuthHeaders(), NO_CHANGE); + updateAndCheckEntity(request, Status.OK, adminAuthHeaders(), NO_CHANGE, change); } @Test - public void put_ModelCreate_200(TestInfo test) throws HttpResponseException { - // Create a new Model with PUT - CreateMlModel request = create(test).withOwner(USER_OWNER1); - updateAndCheckModel(null, request.withName(test.getDisplayName()).withDescription(null), CREATED, - adminAuthHeaders(), NO_CHANGE); + public void put_MlModelUpdateAlgorithm_200(TestInfo test) throws IOException { + CreateMlModel request = create(test); + MlModel model = createAndCheckEntity(request, adminAuthHeaders()); + ChangeDescription change = getChangeDescription(model.getVersion()); + change.getFieldsUpdated().add( + new FieldChange().withName("algorithm").withNewValue("SVM").withOldValue("regression") + ); + + updateAndCheckEntity(request.withAlgorithm("SVM"), Status.OK, adminAuthHeaders(), MINOR_UPDATE, change); } @Test - public void put_ModelCreate_as_owner_200(TestInfo test) throws HttpResponseException { - // Create a new Model with put - CreateMlModel request = create(test).withOwner(USER_OWNER1); - // Add model as admin - MlModel model = createAndCheckModel(request, adminAuthHeaders()); - // Update the table as Owner - updateAndCheckModel(model, request.withDescription("new"), OK, authHeaders(USER1.getEmail()), MINOR_UPDATE); + public void put_MlModelAddDashboard_200(TestInfo test) throws IOException { + CreateMlModel request = create(test); + MlModel model = createAndCheckEntity(request, adminAuthHeaders()); + ChangeDescription change = getChangeDescription(model.getVersion()); + change.getFieldsAdded().add(new FieldChange().withName("dashboard").withNewValue(DASHBOARD_REFERENCE)); + + updateAndCheckEntity( + request.withDashboard(DASHBOARD_REFERENCE), Status.OK, adminAuthHeaders(), MINOR_UPDATE, change + ); } @Test - public void put_ModelNullDescriptionUpdate_200(TestInfo test) throws HttpResponseException { - CreateMlModel request = create(test).withDescription(null); - MlModel model = createAndCheckModel(request, adminAuthHeaders()); - - // Update null description with a new description - MlModel db = updateAndCheckModel(model, request.withDisplayName("model1"). - withDescription("newDescription"), OK, adminAuthHeaders(), MINOR_UPDATE); - assertEquals("model1", db.getDisplayName()); // Move this check to validate method - } - - @Test - public void put_ModelEmptyDescriptionUpdate_200(TestInfo test) throws HttpResponseException { - // Create table with empty description - CreateMlModel request = create(test).withDescription(""); - MlModel model = createAndCheckModel(request, adminAuthHeaders()); - - // Update empty description with a new description - updateAndCheckModel(model, request.withDescription("newDescription"), OK, adminAuthHeaders(), MINOR_UPDATE); - } - - @Test - public void put_ModelNonEmptyDescriptionUpdate_200(TestInfo test) throws HttpResponseException { - CreateMlModel request = create(test).withDescription("description"); - createAndCheckModel(request, adminAuthHeaders()); - - // Updating description is ignored when backend already has description - MlModel db = updateModel(request.withDescription("newDescription"), OK, adminAuthHeaders()); - assertEquals("description", db.getDescription()); - } - - @Test - public void put_ModelUpdateOwner_200(TestInfo test) throws HttpResponseException { - CreateMlModel request = create(test).withDescription(""); - MlModel model = createAndCheckModel(request, adminAuthHeaders()); - - // Change ownership from USER_OWNER1 to TEAM_OWNER1 - model = updateAndCheckModel(model, request.withOwner(TEAM_OWNER1), OK, adminAuthHeaders(), MINOR_UPDATE); - - // Remove ownership - model = updateAndCheckModel(model, request.withOwner(null), OK, adminAuthHeaders(), MINOR_UPDATE); - assertNull(model.getOwner()); - } - - @Test - public void put_ModelUpdateAlgorithm_200(TestInfo test) throws HttpResponseException { - CreateMlModel request = create(test).withDescription(""); - MlModel model = createAndCheckModel(request, adminAuthHeaders()); - updateAndCheckModel(model, request.withAlgorithm("SVM"), OK, adminAuthHeaders(), MINOR_UPDATE); - } - - @Test - public void put_ModelUpdateDashboard_200(TestInfo test) throws HttpResponseException { - CreateMlModel request = create(test).withDescription(""); - MlModel model = createAndCheckModel(request, adminAuthHeaders()); - updateAndCheckModel(model, request.withDashboard(DASHBOARD_REFERENCE), OK, adminAuthHeaders(), MINOR_UPDATE); - } - - @Test - public void get_nonExistentModel_404_notFound() { + public void get_nonExistentMlModel_404_notFound() { HttpResponseException exception = assertThrows(HttpResponseException.class, () -> getModel(TestUtils.NON_EXISTENT_ENTITY, adminAuthHeaders())); + // TODO: issue-1415 assertResponse(exception, NOT_FOUND, - entityNotFound(Entity.MLMODEL, TestUtils.NON_EXISTENT_ENTITY)); + entityNotFound("mlModel", TestUtils.NON_EXISTENT_ENTITY)); } @Test - public void get_ModelWithDifferentFields_200_OK(TestInfo test) throws HttpResponseException { + public void get_MlModelWithDifferentFields_200_OK(TestInfo test) throws IOException { + // aqui no tenim HREF al dashboard CreateMlModel create = create(test).withDescription("description") .withOwner(USER_OWNER1).withDashboard(DASHBOARD_REFERENCE); - MlModel model = createAndCheckModel(create, adminAuthHeaders()); + MlModel model = createAndCheckEntity(create, adminAuthHeaders()); validateGetWithDifferentFields(model, false); } @Test - public void get_ModelByNameWithDifferentFields_200_OK(TestInfo test) throws HttpResponseException { + public void get_MlModelByNameWithDifferentFields_200_OK(TestInfo test) throws IOException { CreateMlModel create = create(test).withDescription("description") .withOwner(USER_OWNER1).withDashboard(DASHBOARD_REFERENCE); - MlModel model = createAndCheckModel(create, adminAuthHeaders()); + MlModel model = createAndCheckEntity(create, adminAuthHeaders()); validateGetWithDifferentFields(model, true); } @Test - public void delete_emptyModel_200_ok(TestInfo test) throws HttpResponseException { - MlModel model = createModel(create(test), adminAuthHeaders()); + public void delete_MlModel_200_ok(TestInfo test) throws HttpResponseException { + MlModel model = createMlModel(create(test), adminAuthHeaders()); deleteModel(model.getId(), adminAuthHeaders()); } - @Test - public void delete_nonEmptyModel_4xx() { - // TODO - } - @Test public void delete_nonExistentModel_404() { HttpResponseException exception = assertThrows(HttpResponseException.class, () -> deleteModel(TestUtils.NON_EXISTENT_ENTITY, adminAuthHeaders())); - assertResponse(exception, NOT_FOUND, entityNotFound(Entity.MLMODEL, TestUtils.NON_EXISTENT_ENTITY)); - } - - public static MlModel createAndCheckModel(CreateMlModel create, - Map authHeaders) throws HttpResponseException { - String updatedBy = TestUtils.getPrincipal(authHeaders); - MlModel model = createModel(create, authHeaders); - validateModel(model, create.getDisplayName(), create.getDescription(), create.getOwner(), updatedBy); - return getAndValidate(model.getId(), create, authHeaders, updatedBy); - } - - public static MlModel createAndCheckModel(CreateMlModel create, EntityReference dashboard, - Map authHeaders) throws HttpResponseException { - String updatedBy = TestUtils.getPrincipal(authHeaders); - create.withDashboard(dashboard); - MlModel model = createModel(create, authHeaders); - assertEquals(0.1, model.getVersion()); - validateModel(model, create.getDescription(), create.getOwner(), create.getTags(), updatedBy); - return getAndValidate(model.getId(), create, authHeaders, updatedBy); - } - - public static MlModel updateAndCheckModel(MlModel before, CreateMlModel create, Status status, - Map authHeaders, UpdateType updateType) - throws HttpResponseException { - String updatedBy = TestUtils.getPrincipal(authHeaders); - MlModel updatedModel = updateModel(create, status, authHeaders); - validateModel(updatedModel, create.getDescription(), create.getOwner(), updatedBy); - if (before == null) { - assertEquals(0.1, updatedModel.getVersion()); // First version created - } else { - TestUtils.validateUpdate(before.getVersion(), updatedModel.getVersion(), updateType); - } - - return getAndValidate(updatedModel.getId(), create, authHeaders, updatedBy); - } - - // Make sure in GET operations the returned Model has all the required information passed during creation - public static MlModel getAndValidate(UUID modelId, - CreateMlModel create, - Map authHeaders, - String expectedUpdatedBy) throws HttpResponseException { - // GET the newly created Model by ID and validate - MlModel model = getModel(modelId, "owner", authHeaders); - validateModel(model, create.getDescription(), create.getOwner(), expectedUpdatedBy); - - // GET the newly created Model by name and validate - String fqn = model.getFullyQualifiedName(); - model = getModelByName(fqn, "owner", authHeaders); - return validateModel(model, create.getDescription(), create.getOwner(), expectedUpdatedBy); - } - - public static MlModel updateModel(CreateMlModel create, - Status status, - Map authHeaders) throws HttpResponseException { - return TestUtils.put(getResource("mlmodels"), - create, MlModel.class, status, authHeaders); - } - - public static MlModel createModel(CreateMlModel create, - Map authHeaders) throws HttpResponseException { - return TestUtils.post(getResource("mlmodels"), create, MlModel.class, authHeaders); + // TODO: issue-1415 + assertResponse(exception, NOT_FOUND, entityNotFound("mlModel", TestUtils.NON_EXISTENT_ENTITY)); } /** Validate returned fields GET .../models/{id}?fields="..." or GET .../models/name/{fqn}?fields="..." */ @@ -532,54 +329,6 @@ public class MlModelResourceTest extends CatalogApplicationTest { TestUtils.validateEntityReference(model.getDashboard()); } - private static MlModel validateModel(MlModel model, String expectedDisplayName, - String expectedDescription, - EntityReference expectedOwner, - String expectedUpdatedBy) { - MlModel newModel = validateModel(model, expectedDescription, expectedOwner, expectedUpdatedBy); - assertEquals(expectedDisplayName, newModel.getDisplayName()); - return newModel; - } - private static MlModel validateModel(MlModel model, String expectedDescription, - EntityReference expectedOwner, String expectedUpdatedBy) { - assertNotNull(model.getId()); - assertNotNull(model.getHref()); - assertNotNull(model.getAlgorithm()); - assertEquals(expectedDescription, model.getDescription()); - assertEquals(expectedUpdatedBy, model.getUpdatedBy()); - - // Validate owner - if (expectedOwner != null) { - TestUtils.validateEntityReference(model.getOwner()); - assertEquals(expectedOwner.getId(), model.getOwner().getId()); - assertEquals(expectedOwner.getType(), model.getOwner().getType()); - assertNotNull(model.getOwner().getHref()); - } - - return model; - } - - private static MlModel validateModel(MlModel model, String expectedDescription, - EntityReference expectedOwner, - List expectedTags, - String expectedUpdatedBy) throws HttpResponseException { - assertNotNull(model.getId()); - assertNotNull(model.getHref()); - assertEquals(expectedDescription, model.getDescription()); - assertEquals(expectedUpdatedBy, model.getUpdatedBy()); - - // Validate owner - if (expectedOwner != null) { - TestUtils.validateEntityReference(model.getOwner()); - assertEquals(expectedOwner.getId(), model.getOwner().getId()); - assertEquals(expectedOwner.getType(), model.getOwner().getType()); - assertNotNull(model.getOwner().getHref()); - } - - TestUtils.validateTags(expectedTags, model.getTags()); - return model; - } - public static void getModel(UUID id, Map authHeaders) throws HttpResponseException { getModel(id, null, authHeaders); } @@ -598,27 +347,13 @@ public class MlModelResourceTest extends CatalogApplicationTest { return TestUtils.get(target, MlModel.class, authHeaders); } - public static MlModelList listModels(String fields, Integer limitParam, - String before, String after, Map authHeaders) - throws HttpResponseException { - WebTarget target = getResource("mlmodels"); - target = fields != null ? target.queryParam("fields", fields): target; - target = limitParam != null ? target.queryParam("limit", limitParam): target; - target = before != null ? target.queryParam("before", before) : target; - target = after != null ? target.queryParam("after", after) : target; - return TestUtils.get(target, MlModelList.class, authHeaders); - } - private void deleteModel(UUID id, Map authHeaders) throws HttpResponseException { TestUtils.delete(getResource("mlmodels/" + id), authHeaders); - // Ensure deleted Model does not exist + // Check to make sure database does not exist HttpResponseException exception = assertThrows(HttpResponseException.class, () -> getModel(id, authHeaders)); - assertResponse(exception, NOT_FOUND, entityNotFound(Entity.MLMODEL, id)); - } - - public static String getModelName(TestInfo test) { - return String.format("mlmodel_%s", test.getDisplayName()); + // TODO: issue-1415 instead of mlModel, use Entity.MLMODEL + assertResponse(exception, NOT_FOUND, CatalogExceptionMessage.entityNotFound("mlModel", id)); } public static String getModelName(TestInfo test, int index) { @@ -626,8 +361,7 @@ public class MlModelResourceTest extends CatalogApplicationTest { } public static CreateMlModel create(TestInfo test) { - return new CreateMlModel().withName(getModelName(test)).withAlgorithm(ALGORITHM) - .withMlFeatures(ML_FEATURES).withMlHyperParameters(ML_HYPERPARAMS); + return create(test, 0); } public static CreateMlModel create(TestInfo test, int index) { @@ -635,4 +369,121 @@ public class MlModelResourceTest extends CatalogApplicationTest { .withMlFeatures(ML_FEATURES).withMlHyperParameters(ML_HYPERPARAMS); } + @Override + public Object createRequest(TestInfo test, int index, String description, String displayName, EntityReference owner) { + return create(test, index).withDescription(description).withDisplayName(displayName).withOwner(owner); + } + + @Override + public void validateUpdatedEntity(MlModel mlModel, Object request, Map authHeaders) + throws HttpResponseException { + validateCreatedEntity(mlModel, request, authHeaders); + } + + @Override + public void compareEntities(MlModel expected, MlModel updated, Map authHeaders) + throws HttpResponseException { + validateCommonEntityFields(getEntityInterface(updated), expected.getDescription(), + TestUtils.getPrincipal(authHeaders), expected.getOwner()); + + // Entity specific validations + assertEquals(expected.getAlgorithm(), updated.getAlgorithm()); + assertEquals(expected.getDashboard(), updated.getDashboard()); + assertListProperty(expected.getMlFeatures(), updated.getMlFeatures(), assertMlFeature); + assertListProperty(expected.getMlHyperParameters(), updated.getMlHyperParameters(), assertMlHyperParam); + + // assertListProperty on MlFeatures already validates size, so we can directly iterate on sources + validateMlFeatureSources(expected.getMlFeatures(), updated.getMlFeatures()); + + TestUtils.validateTags(expected.getTags(), updated.getTags()); + TestUtils.validateEntityReference(updated.getFollowers()); + } + + @Override + public EntityInterface getEntityInterface(MlModel entity) { + return new MlModelRepository.MlModelEntityInterface(entity); + } + + BiConsumer assertMlFeature = (MlFeature expected, MlFeature actual) -> { + assertNotNull(actual.getFullyQualifiedName()); + assertEquals(actual.getName(), expected.getName()); + assertEquals(actual.getDescription(), expected.getDescription()); + assertEquals(actual.getFeatureAlgorithm(), expected.getFeatureAlgorithm()); + assertEquals(actual.getDataType(), expected.getDataType()); + }; + + BiConsumer assertMlHyperParam = + (MlHyperParameter expected, MlHyperParameter actual) -> { + assertEquals(actual.getName(), expected.getName()); + assertEquals(actual.getDescription(), expected.getDescription()); + assertEquals(actual.getValue(), expected.getValue()); + }; + + BiConsumer assertMlFeatureSource = + (MlFeatureSource expected, MlFeatureSource actual) -> { + assertNotNull(actual.getFullyQualifiedName()); + assertEquals(actual.getName(), expected.getName()); + assertEquals(actual.getDescription(), expected.getDescription()); + assertEquals(actual.getDataType(), expected.getDataType()); + }; + + private void validateMlFeatureSources(List expected, List actual) + throws HttpResponseException { + if (expected == null && actual == null) { + return; + } + + for (int i = 0; i < expected.size(); i++) { + assertListProperty(expected.get(i).getFeatureSources(), actual.get(i).getFeatureSources(), assertMlFeatureSource); + } + + } + + @Override + public void validateCreatedEntity(MlModel createdEntity, Object request, Map authHeaders) + throws HttpResponseException { + CreateMlModel createRequest = (CreateMlModel) request; + validateCommonEntityFields(getEntityInterface(createdEntity), createRequest.getDescription(), + TestUtils.getPrincipal(authHeaders), createRequest.getOwner()); + + // Entity specific validations + assertEquals(createRequest.getAlgorithm(), createdEntity.getAlgorithm()); + assertEquals(createRequest.getDashboard(), createdEntity.getDashboard()); + assertListProperty(createRequest.getMlFeatures(), createdEntity.getMlFeatures(), assertMlFeature); + assertListProperty(createRequest.getMlHyperParameters(), createdEntity.getMlHyperParameters(), assertMlHyperParam); + + // assertListProperty on MlFeatures already validates size, so we can directly iterate on sources + validateMlFeatureSources(createRequest.getMlFeatures(), createdEntity.getMlFeatures()); + + TestUtils.validateTags(createRequest.getTags(), createdEntity.getTags()); + TestUtils.validateEntityReference(createdEntity.getFollowers()); + } + + @Override + public void assertFieldChange(String fieldName, Object expected, Object actual) throws IOException { + if (expected == actual) { + return; + } + if (fieldName.contains("mlFeatures") && !fieldName.endsWith("tags") && !fieldName.endsWith("description")) { + List expectedFeatures = (List) expected; + List actualFeatures = JsonUtils.readObjects(actual.toString(), MlFeature.class); + assertEquals(expectedFeatures, actualFeatures); + } else if (fieldName.contains("mlHyperParameters") && !fieldName.endsWith("tags") + && !fieldName.endsWith("description")) { + List expectedConstraints = (List) expected; + List actualConstraints = JsonUtils.readObjects(actual.toString(), MlHyperParameter.class); + assertEquals(expectedConstraints, actualConstraints); + } else if (fieldName.endsWith("algorithm")) { + String expectedAlgorithm = (String) expected; + String actualAlgorithm = actual.toString(); + assertEquals(expectedAlgorithm, actualAlgorithm); + } else if (fieldName.endsWith("dashboard")) { + EntityReference expectedDashboard = (EntityReference) expected; + EntityReference actualDashboard = JsonUtils.readValue(actual.toString(), EntityReference.class); + assertEquals(expectedDashboard, actualDashboard); + } else { + assertCommonFieldChange(fieldName, expected, actual); + } + } + }