[issue-1318] - Refactor MlModel tests and improve logic (#1420)

* Prepare MlModel test utils

* Refactor tests

* Properly set FQN for MlModels

* Fix FQN and validate inner properties

* Change MlModel naming

* Inherit variables from parent test

* Simplify delete

* Update tests interface

* Add order

* Add missing version endpoints

* Add security context

* Update mlmodel name

* Update mlmodel methods

* Update mlmodel tests

* Rename MlModel

* Prepare dashboard service in setup

* Fix MlModel test setup

* MlModel name to minus

* Workaround for issue-1415

* Workaround for issue-1415

* Move superset ref back to MlModel

* Add missing href in Dashboard entity reference

* Handle update dashboards

* Add full dashboard props in validate

* Add update tests

* Reformat

* Reformat

* Reformat

* Reformat

* Reformat
This commit is contained in:
Pere Miquel Brull 2021-11-27 21:31:55 +01:00 committed by GitHub
parent edcbb04e3a
commit 76f40e6f37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 375 additions and 434 deletions

View File

@ -60,7 +60,7 @@ public final class Entity {
public static final String CHART = "chart"; public static final String CHART = "chart";
public static final String REPORT = "report"; public static final String REPORT = "report";
public static final String TOPIC = "topic"; public static final String TOPIC = "topic";
public static final String MLMODEL = "mlModel"; public static final String MLMODEL = "mlmodel";
public static final String DBTMODEL = "dbtmodel"; public static final String DBTMODEL = "dbtmodel";
public static final String BOTS = "bots"; public static final String BOTS = "bots";
public static final String LOCATION = "location"; public static final String LOCATION = "location";

View File

@ -287,7 +287,7 @@ public class DashboardRepository extends EntityRepository<Dashboard> {
@Override @Override
public EntityReference getEntityReference() { public EntityReference getEntityReference() {
return new EntityReference().withId(getId()).withName(getFullyQualifiedName()).withDescription(getDescription()) return new EntityReference().withId(getId()).withName(getFullyQualifiedName()).withDescription(getDescription())
.withDisplayName(getDisplayName()).withType(Entity.DASHBOARD); .withDisplayName(getDisplayName()).withType(Entity.DASHBOARD).withHref(getHref());
} }
@Override @Override

View File

@ -20,10 +20,11 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import org.jdbi.v3.sqlobject.transaction.Transaction; import org.jdbi.v3.sqlobject.transaction.Transaction;
import org.openmetadata.catalog.Entity; import org.openmetadata.catalog.Entity;
import org.openmetadata.catalog.entity.data.MlModel; import org.openmetadata.catalog.entity.data.MlModel;
import org.openmetadata.catalog.exception.EntityNotFoundException;
import org.openmetadata.catalog.resources.mlmodels.MlModelResource; import org.openmetadata.catalog.resources.mlmodels.MlModelResource;
import org.openmetadata.catalog.type.ChangeDescription; import org.openmetadata.catalog.type.ChangeDescription;
import org.openmetadata.catalog.type.EntityReference; import org.openmetadata.catalog.type.EntityReference;
import org.openmetadata.catalog.type.MlFeature;
import org.openmetadata.catalog.type.MlFeatureSource;
import org.openmetadata.catalog.type.TagLabel; import org.openmetadata.catalog.type.TagLabel;
import org.openmetadata.catalog.util.EntityInterface; import org.openmetadata.catalog.util.EntityInterface;
import org.openmetadata.catalog.util.EntityUtil; import org.openmetadata.catalog.util.EntityUtil;
@ -39,14 +40,12 @@ import java.util.Date;
import java.util.List; import java.util.List;
import java.util.UUID; import java.util.UUID;
import static org.openmetadata.catalog.exception.CatalogExceptionMessage.entityNotFound;
public class MlModelRepository extends EntityRepository<MlModel> { public class MlModelRepository extends EntityRepository<MlModel> {
private static final Logger LOG = LoggerFactory.getLogger(MlModelRepository.class); private static final Logger LOG = LoggerFactory.getLogger(MlModelRepository.class);
private static final Fields MODEL_UPDATE_FIELDS = new Fields(MlModelResource.FIELD_LIST, private static final Fields MODEL_UPDATE_FIELDS = new Fields(MlModelResource.FIELD_LIST,
"owner,dashboard,mlHyperParameters,mlFeatures,tags"); "owner,algorithm,dashboard,mlHyperParameters,mlFeatures,tags");
private static final Fields MODEL_PATCH_FIELDS = new Fields(MlModelResource.FIELD_LIST, private static final Fields MODEL_PATCH_FIELDS = new Fields(MlModelResource.FIELD_LIST,
"owner,dashboard,mlHyperParameters,mlFeatures,tags"); "owner,algorithm,dashboard,mlHyperParameters,mlFeatures,tags");
private final CollectionDAO dao; private final CollectionDAO dao;
public MlModelRepository(CollectionDAO dao) { public MlModelRepository(CollectionDAO dao) {
@ -62,12 +61,7 @@ public class MlModelRepository extends EntityRepository<MlModel> {
@Transaction @Transaction
public void delete(UUID id) { public void delete(UUID id) {
if (dao.relationshipDAO().findToCount(id.toString(), Relationship.CONTAINS.ordinal(), Entity.MLMODEL) > 0) { dao.mlModelDAO().delete(id);
throw new IllegalArgumentException("Model is not empty");
}
if (dao.mlModelDAO().delete(id) <= 0) {
throw EntityNotFoundException.byMessage(entityNotFound(Entity.MLMODEL, id));
}
dao.relationshipDAO().deleteAll(id.toString()); dao.relationshipDAO().deleteAll(id.toString());
} }
@ -92,7 +86,9 @@ public class MlModelRepository extends EntityRepository<MlModel> {
@Override @Override
public void restorePatchAttributes(MlModel original, MlModel updated) throws IOException, ParseException { public void restorePatchAttributes(MlModel original, MlModel updated) throws IOException, ParseException {
// Patch can't make changes to following fields. Ignore the changes
updated.withFullyQualifiedName(original.getFullyQualifiedName())
.withName(original.getName()).withId(original.getId());
} }
@Override @Override
@ -104,16 +100,40 @@ public class MlModelRepository extends EntityRepository<MlModel> {
return dao.tagDAO().getTags(fqn); return dao.tagDAO().getTags(fqn);
} }
private void setMlFeatureSourcesFQN(String parentFQN, List<MlFeatureSource> mlSources) {
mlSources.forEach(s -> {
String sourceFqn = parentFQN + "." + s.getName();
s.setFullyQualifiedName(sourceFqn);
});
}
private void setMlFeatureFQN(String parentFQN, List<MlFeature> mlFeatures) {
mlFeatures.forEach(f -> {
String featureFqn = parentFQN + "." + f.getName();
f.setFullyQualifiedName(featureFqn);
if (f.getFeatureSources() != null) {
setMlFeatureSourcesFQN(featureFqn, f.getFeatureSources());
}
});
}
@Override @Override
public void validate(MlModel model) throws IOException { public void validate(MlModel mlModel) throws IOException {
model.setFullyQualifiedName(getFQN(model)); mlModel.setFullyQualifiedName(getFQN(mlModel));
EntityUtil.populateOwner(dao.userDAO(), dao.teamDAO(), model.getOwner()); // Validate owner setMlFeatureFQN(mlModel.getFullyQualifiedName(), mlModel.getMlFeatures());
if (model.getDashboard() != null) {
UUID dashboardId = model.getDashboard().getId(); // Check if owner is valid and set the relationship
model.setDashboard(dao.dashboardDAO().findEntityReferenceById(dashboardId)); mlModel.setOwner(EntityUtil.populateOwner(dao.userDAO(), dao.teamDAO(), mlModel.getOwner()));
setDashboard(mlModel, mlModel.getDashboard());
if (mlModel.getDashboard() != null) {
// Add relationship from MlModel to Dashboard
String dashboardId = mlModel.getDashboard().getId().toString();
dao.relationshipDAO().insert(dashboardId, mlModel.getId().toString(), Entity.MLMODEL, Entity.DASHBOARD,
Relationship.USES.ordinal());
} }
model.setTags(EntityUtil.addDerivedTags(dao.tagDAO(), model.getTags())); mlModel.setTags(EntityUtil.addDerivedTags(dao.tagDAO(), mlModel.getTags()));
} }
@Override @Override
@ -138,8 +158,18 @@ public class MlModelRepository extends EntityRepository<MlModel> {
@Override @Override
public void storeRelationships(MlModel mlModel) throws IOException { public void storeRelationships(MlModel mlModel) throws IOException {
setOwner(mlModel, mlModel.getOwner());
EntityUtil.setOwner(dao.relationshipDAO(), mlModel.getId(), Entity.MLMODEL, mlModel.getOwner());
setDashboard(mlModel, mlModel.getDashboard()); setDashboard(mlModel, mlModel.getDashboard());
if (mlModel.getDashboard() != null) {
// Add relationship from MlModel to Dashboard
String dashboardId = mlModel.getDashboard().getId().toString();
dao.relationshipDAO().insert(dashboardId, mlModel.getId().toString(), Entity.MLMODEL, Entity.DASHBOARD,
Relationship.USES.ordinal());
}
applyTags(mlModel); applyTags(mlModel);
} }
@ -153,11 +183,6 @@ public class MlModelRepository extends EntityRepository<MlModel> {
dao.userDAO(), dao.teamDAO()); dao.userDAO(), dao.teamDAO());
} }
public void setOwner(MlModel mlModel, EntityReference owner) {
EntityUtil.setOwner(dao.relationshipDAO(), mlModel.getId(), Entity.MLMODEL, owner);
mlModel.setOwner(owner);
}
private EntityReference getDashboard(MlModel mlModel) throws IOException { private EntityReference getDashboard(MlModel mlModel) throws IOException {
if (mlModel != null) { if (mlModel != null) {
List<EntityReference> ids = dao.relationshipDAO().findTo(mlModel.getId().toString(), Relationship.USES.ordinal()); List<EntityReference> ids = dao.relationshipDAO().findTo(mlModel.getId().toString(), Relationship.USES.ordinal());
@ -193,10 +218,10 @@ public class MlModelRepository extends EntityRepository<MlModel> {
return model == null ? null : EntityUtil.getFollowers(model.getId(), dao.relationshipDAO(), dao.userDAO()); return model == null ? null : EntityUtil.getFollowers(model.getId(), dao.relationshipDAO(), dao.userDAO());
} }
static class MlModelEntityInterface implements EntityInterface<MlModel> { public static class MlModelEntityInterface implements EntityInterface<MlModel> {
private final MlModel entity; private final MlModel entity;
MlModelEntityInterface(MlModel entity) { public MlModelEntityInterface(MlModel entity) {
this.entity = entity; this.entity = entity;
} }
@ -304,10 +329,12 @@ public class MlModelRepository extends EntityRepository<MlModel> {
@Override @Override
public void entitySpecificUpdate() throws IOException { public void entitySpecificUpdate() throws IOException {
updateAlgorithm(original.getEntity(), updated.getEntity()); MlModel origMlModel = original.getEntity();
updateDashboard(original.getEntity(), updated.getEntity()); MlModel updatedMlModel = updated.getEntity();
updateMlFeatures(original.getEntity(), updated.getEntity()); updateAlgorithm(origMlModel, updatedMlModel);
updateMlHyperParameters(original.getEntity(), updated.getEntity()); updateDashboard(origMlModel, updatedMlModel);
updateMlFeatures(origMlModel, updatedMlModel);
updateMlHyperParameters(origMlModel, updatedMlModel);
} }
private void updateAlgorithm(MlModel origModel, MlModel updatedModel) throws JsonProcessingException { private void updateAlgorithm(MlModel origModel, MlModel updatedModel) throws JsonProcessingException {
@ -323,15 +350,21 @@ public class MlModelRepository extends EntityRepository<MlModel> {
} }
private void updateDashboard(MlModel origModel, MlModel updatedModel) throws JsonProcessingException { private void updateDashboard(MlModel origModel, MlModel updatedModel) throws JsonProcessingException {
// Remove existing dashboards String modelId = updatedModel.getId().toString();
removeDashboard(origModel);
// Remove the dashboard associated with the model, if any
if (origModel.getDashboard() != null) {
dao.relationshipDAO().deleteFrom(modelId, Relationship.USES.ordinal(), "dashboard");
}
// Add relationship from model to dashboard
EntityReference updatedDashboard = updatedModel.getDashboard();
if (updatedDashboard != null) {
dao.relationshipDAO().insert(modelId, updatedDashboard.getId().toString(),
Entity.MLMODEL, Entity.DASHBOARD, Relationship.USES.ordinal());
}
recordChange("dashboard", origModel.getDashboard(), updatedDashboard, true);
EntityReference origOwner = origModel.getDashboard();
EntityReference updatedOwner = updatedModel.getDashboard();
if (recordChange("owner", origOwner == null ? null : origOwner.getId(),
updatedOwner == null ? null : updatedOwner.getId())) {
setDashboard(updatedModel, updatedModel.getDashboard());
}
} }
} }
} }

View File

@ -54,7 +54,7 @@ public enum Relationship {
// {Dashboard|Pipeline|Query} --- uses ---> Table // {Dashboard|Pipeline|Query} --- uses ---> Table
// {User} --- uses ---> {Table|Dashboard|Query} // {User} --- uses ---> {Table|Dashboard|Query}
// {Model} --- uses ---> {Dashboard} // {MlModel} --- uses ---> {Dashboard}
USES("uses"), USES("uses"),
// {User|Team|Org} --- owns ---> {Table|Dashboard|Query} // {User|Team|Org} --- owns ---> {Table|Dashboard|Query}

View File

@ -34,6 +34,7 @@ import org.openmetadata.catalog.jdbi3.MlModelRepository;
import org.openmetadata.catalog.resources.Collection; import org.openmetadata.catalog.resources.Collection;
import org.openmetadata.catalog.security.CatalogAuthorizer; import org.openmetadata.catalog.security.CatalogAuthorizer;
import org.openmetadata.catalog.security.SecurityUtil; import org.openmetadata.catalog.security.SecurityUtil;
import org.openmetadata.catalog.type.EntityHistory;
import org.openmetadata.catalog.util.EntityUtil.Fields; import org.openmetadata.catalog.util.EntityUtil.Fields;
import org.openmetadata.catalog.util.RestUtil; import org.openmetadata.catalog.util.RestUtil;
import org.openmetadata.catalog.util.RestUtil.PatchResponse; import org.openmetadata.catalog.util.RestUtil.PatchResponse;
@ -82,15 +83,10 @@ public class MlModelResource {
private final MlModelRepository dao; private final MlModelRepository dao;
private final CatalogAuthorizer authorizer; private final CatalogAuthorizer authorizer;
public static List<MlModel> addHref(UriInfo uriInfo, List<MlModel> models) {
Optional.ofNullable(models).orElse(Collections.emptyList()).forEach(i -> addHref(uriInfo, i));
return models;
}
public static MlModel addHref(UriInfo uriInfo, MlModel mlmodel) { public static MlModel addHref(UriInfo uriInfo, MlModel mlmodel) {
mlmodel.setHref(RestUtil.getHref(uriInfo, COLLECTION_PATH, mlmodel.getId())); mlmodel.setHref(RestUtil.getHref(uriInfo, COLLECTION_PATH, mlmodel.getId()));
Entity.withHref(uriInfo, mlmodel.getOwner()); Entity.withHref(uriInfo, mlmodel.getOwner());
Entity.withHref(uriInfo, mlmodel.getDashboard()); // Dashboard HREF Entity.withHref(uriInfo, mlmodel.getDashboard());
Entity.withHref(uriInfo, mlmodel.getFollowers()); Entity.withHref(uriInfo, mlmodel.getFollowers());
return mlmodel; return mlmodel;
} }
@ -152,11 +148,11 @@ public class MlModelResource {
ResultList<MlModel> mlmodels; ResultList<MlModel> mlmodels;
if (before != null) { // Reverse paging if (before != null) { // Reverse paging
mlmodels = dao.listBefore(uriInfo, fields, null, limitParam, before); // Ask for one extra entry mlmodels = dao.listBefore(uriInfo, fields, null, limitParam, before);
} else { // Forward paging or first page } else { // Forward paging or first page
mlmodels = dao.listAfter(uriInfo, fields, null, limitParam, after); mlmodels = dao.listAfter(uriInfo, fields, null, limitParam, after);
} }
addHref(uriInfo, mlmodels.getData()); mlmodels.getData().forEach(m -> addHref(uriInfo, m));
return mlmodels; return mlmodels;
} }
@ -196,8 +192,7 @@ public class MlModelResource {
schema = @Schema(type = "string", example = FIELDS)) schema = @Schema(type = "string", example = FIELDS))
@QueryParam("fields") String fieldsParam) throws IOException, ParseException { @QueryParam("fields") String fieldsParam) throws IOException, ParseException {
Fields fields = new Fields(FIELD_LIST, fieldsParam); Fields fields = new Fields(FIELD_LIST, fieldsParam);
MlModel mlmodel = dao.getByName(uriInfo, fqn, fields); return addHref(uriInfo, dao.getByName(uriInfo, fqn, fields));
return addHref(uriInfo, mlmodel);
} }
@ -205,12 +200,13 @@ public class MlModelResource {
@Operation(summary = "Create an ML Model", tags = "mlModels", @Operation(summary = "Create an ML Model", tags = "mlModels",
description = "Create a new ML Model.", description = "Create a new ML Model.",
responses = { responses = {
@ApiResponse(responseCode = "200", description = "The model", @ApiResponse(responseCode = "200", description = "ML Model",
content = @Content(mediaType = "application/json", content = @Content(mediaType = "application/json",
schema = @Schema(implementation = CreateMlModel.class))), schema = @Schema(implementation = CreateMlModel.class))),
@ApiResponse(responseCode = "400", description = "Bad request") @ApiResponse(responseCode = "400", description = "Bad request")
}) })
public Response create(@Context UriInfo uriInfo, @Context SecurityContext securityContext, public Response create(@Context UriInfo uriInfo,
@Context SecurityContext securityContext,
@Valid CreateMlModel create) throws IOException, ParseException { @Valid CreateMlModel create) throws IOException, ParseException {
SecurityUtil.checkAdminOrBotRole(authorizer, securityContext); SecurityUtil.checkAdminOrBotRole(authorizer, securityContext);
MlModel mlModel = getMlModel(securityContext, create); MlModel mlModel = getMlModel(securityContext, create);
@ -225,8 +221,9 @@ public class MlModelResource {
externalDocs = @ExternalDocumentation(description = "JsonPatch RFC", externalDocs = @ExternalDocumentation(description = "JsonPatch RFC",
url = "https://tools.ietf.org/html/rfc6902")) url = "https://tools.ietf.org/html/rfc6902"))
@Consumes(MediaType.APPLICATION_JSON_PATCH_JSON) @Consumes(MediaType.APPLICATION_JSON_PATCH_JSON)
public Response updateDescription(@Context UriInfo uriInfo, public Response patch(@Context UriInfo uriInfo,
@Context SecurityContext securityContext, @Context SecurityContext securityContext,
@Parameter(description = "Id of the ML Model", schema = @Schema(type = "string"))
@PathParam("id") String id, @PathParam("id") String id,
@RequestBody(description = "JsonPatch with array of operations", @RequestBody(description = "JsonPatch with array of operations",
content = @Content(mediaType = MediaType.APPLICATION_JSON_PATCH_JSON, content = @Content(mediaType = MediaType.APPLICATION_JSON_PATCH_JSON,
@ -237,10 +234,9 @@ public class MlModelResource {
JsonPatch patch) throws IOException, ParseException { JsonPatch patch) throws IOException, ParseException {
Fields fields = new Fields(FIELD_LIST, FIELDS); Fields fields = new Fields(FIELD_LIST, FIELDS);
MlModel mlModel = dao.get(uriInfo, id, fields); MlModel mlModel = dao.get(uriInfo, id, fields);
SecurityUtil.checkAdminRoleOrPermissions(authorizer, securityContext, SecurityUtil.checkAdminRoleOrPermissions(authorizer, securityContext, dao.getOwnerReference(mlModel));
dao.getOwnerReference(mlModel)); PatchResponse<MlModel> response = dao.patch(uriInfo, UUID.fromString(id),
PatchResponse<MlModel> response = securityContext.getUserPrincipal().getName(), patch);
dao.patch(uriInfo, UUID.fromString(id), securityContext.getUserPrincipal().getName(), patch);
addHref(uriInfo, response.getEntity()); addHref(uriInfo, response.getEntity());
return response.toResponse(); return response.toResponse();
} }
@ -258,6 +254,7 @@ public class MlModelResource {
@Context SecurityContext securityContext, @Context SecurityContext securityContext,
@Valid CreateMlModel create) throws IOException, ParseException { @Valid CreateMlModel create) throws IOException, ParseException {
MlModel mlModel = getMlModel(securityContext, create); MlModel mlModel = getMlModel(securityContext, create);
SecurityUtil.checkAdminRoleOrPermissions(authorizer, securityContext, dao.getOwnerReference(mlModel));
PutResponse<MlModel> response = dao.createOrUpdate(uriInfo, mlModel); PutResponse<MlModel> response = dao.createOrUpdate(uriInfo, mlModel);
addHref(uriInfo, response.getEntity()); addHref(uriInfo, response.getEntity());
return response.toResponse(); return response.toResponse();
@ -298,6 +295,43 @@ public class MlModelResource {
UUID.fromString(userId)).toResponse(); UUID.fromString(userId)).toResponse();
} }
@GET
@Path("/{id}/versions")
@Operation(summary = "List Ml Model versions", tags = "mlModels",
description = "Get a list of all the versions of an Ml Model identified by `id`",
responses = {@ApiResponse(responseCode = "200", description = "List of Ml Model versions",
content = @Content(mediaType = "application/json",
schema = @Schema(implementation = EntityHistory.class)))
})
public EntityHistory listVersions(@Context UriInfo uriInfo,
@Context SecurityContext securityContext,
@Parameter(description = "ML Model Id", schema = @Schema(type = "string"))
@PathParam("id") String id)
throws IOException, ParseException, GeneralSecurityException {
return dao.listVersions(id);
}
@GET
@Path("/{id}/versions/{version}")
@Operation(summary = "Get a version of the ML Model", tags = "mlModels",
description = "Get a version of the ML Model by given `id`",
responses = {
@ApiResponse(responseCode = "200", description = "MlModel",
content = @Content(mediaType = "application/json",
schema = @Schema(implementation = MlModel.class))),
@ApiResponse(responseCode = "404", description = "ML Model for instance {id} and version {version} is " +
"not found")
})
public MlModel getVersion(@Context UriInfo uriInfo,
@Context SecurityContext securityContext,
@Parameter(description = "ML Model Id", schema = @Schema(type = "string"))
@PathParam("id") String id,
@Parameter(description = "ML Model version number in the form `major`.`minor`",
schema = @Schema(type = "string", example = "0.1 or 1.1"))
@PathParam("version") String version) throws IOException, ParseException {
return dao.getVersion(id, version);
}
@DELETE @DELETE
@Path("/{id}") @Path("/{id}")
@Operation(summary = "Delete an ML Model", tags = "mlModels", @Operation(summary = "Delete an ML Model", tags = "mlModels",
@ -306,7 +340,11 @@ public class MlModelResource {
@ApiResponse(responseCode = "200", description = "OK"), @ApiResponse(responseCode = "200", description = "OK"),
@ApiResponse(responseCode = "404", description = "model for instance {id} is not found") @ApiResponse(responseCode = "404", description = "model for instance {id} is not found")
}) })
public Response delete(@Context UriInfo uriInfo, @PathParam("id") String id) { public Response delete(@Context UriInfo uriInfo,
@Context SecurityContext securityContext,
@Parameter(description = "Id of the ML Model", schema = @Schema(type = "string"))
@PathParam("id") String id) {
SecurityUtil.checkAdminOrBotRole(authorizer, securityContext);
dao.delete(UUID.fromString(id)); dao.delete(UUID.fromString(id));
return Response.ok().build(); return Response.ok().build();
} }

View File

@ -61,6 +61,7 @@ import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.function.BiConsumer;
import static javax.ws.rs.core.Response.Status.BAD_REQUEST; import static javax.ws.rs.core.Response.Status.BAD_REQUEST;
import static javax.ws.rs.core.Response.Status.CREATED; import static javax.ws.rs.core.Response.Status.CREATED;
@ -645,6 +646,7 @@ public abstract class EntityResourceTest<T> extends CatalogApplicationTest {
protected final T createAndCheckEntity(Object create, Map<String, String> authHeaders) throws IOException { protected final T createAndCheckEntity(Object create, Map<String, String> authHeaders) throws IOException {
// Validate an entity that is created has all the information set in create request // Validate an entity that is created has all the information set in create request
String updatedBy = TestUtils.getPrincipal(authHeaders); String updatedBy = TestUtils.getPrincipal(authHeaders);
// aqui si que tenim HREF
T entity = createEntity(create, authHeaders); T entity = createEntity(create, authHeaders);
EntityInterface<T> entityInterface = getEntityInterface(entity); EntityInterface<T> entityInterface = getEntityInterface(entity);
@ -1029,4 +1031,21 @@ public abstract class EntityResourceTest<T> extends CatalogApplicationTest {
list.getData().forEach(e -> LOG.info("{} {}", entityClass, getEntityInterface(e).getFullyQualifiedName())); list.getData().forEach(e -> LOG.info("{} {}", entityClass, getEntityInterface(e).getFullyQualifiedName()));
LOG.info("before {} after {} ", list.getPaging().getBefore(), list.getPaging().getAfter()); LOG.info("before {} after {} ", list.getPaging().getBefore(), list.getPaging().getAfter());
} }
/**
* Given a list of properties of an Entity (e.g., List<Column> or List<MlFeature> and
* a function that validate the elements of T, validate lists
*/
public <P> void assertListProperty(List<P> expected, List<P> actual, BiConsumer<P, P> validate)
throws HttpResponseException {
if (expected == null && actual == null) {
return;
}
assertNotNull(expected);
assertEquals(expected.size(), actual.size());
for (int i = 0; i < expected.size(); i++) {
validate.accept(expected.get(i), actual.get(i));
}
}
} }

View File

@ -18,72 +18,69 @@ package org.openmetadata.catalog.resources.mlmodels;
import org.apache.http.client.HttpResponseException; import org.apache.http.client.HttpResponseException;
import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.MethodOrderer;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInfo; import org.junit.jupiter.api.TestInfo;
import org.openmetadata.catalog.CatalogApplicationTest; import org.junit.jupiter.api.TestMethodOrder;
import org.openmetadata.catalog.Entity; import org.openmetadata.catalog.Entity;
import org.openmetadata.catalog.api.data.CreateMlModel; import org.openmetadata.catalog.api.data.CreateMlModel;
import org.openmetadata.catalog.api.services.CreateDashboardService; import org.openmetadata.catalog.api.services.CreateDashboardService;
import org.openmetadata.catalog.api.services.CreateDashboardService.DashboardServiceType; import org.openmetadata.catalog.api.services.CreateDashboardService.DashboardServiceType;
import org.openmetadata.catalog.entity.data.Dashboard; import org.openmetadata.catalog.entity.data.Dashboard;
import org.openmetadata.catalog.entity.data.MlModel; import org.openmetadata.catalog.entity.data.MlModel;
import org.openmetadata.catalog.exception.CatalogExceptionMessage;
import org.openmetadata.catalog.jdbi3.MlModelRepository;
import org.openmetadata.catalog.resources.EntityResourceTest;
import org.openmetadata.catalog.type.ChangeDescription;
import org.openmetadata.catalog.type.FeatureSourceDataType; import org.openmetadata.catalog.type.FeatureSourceDataType;
import org.openmetadata.catalog.type.FieldChange;
import org.openmetadata.catalog.type.MlFeature; import org.openmetadata.catalog.type.MlFeature;
import org.openmetadata.catalog.type.MlFeatureDataType; import org.openmetadata.catalog.type.MlFeatureDataType;
import org.openmetadata.catalog.type.MlFeatureSource; import org.openmetadata.catalog.type.MlFeatureSource;
import org.openmetadata.catalog.type.MlHyperParameter; import org.openmetadata.catalog.type.MlHyperParameter;
import org.openmetadata.catalog.entity.services.DashboardService; import org.openmetadata.catalog.entity.services.DashboardService;
import org.openmetadata.catalog.entity.teams.Team;
import org.openmetadata.catalog.entity.teams.User;
import org.openmetadata.catalog.exception.CatalogExceptionMessage;
import org.openmetadata.catalog.jdbi3.DashboardRepository.DashboardEntityInterface; import org.openmetadata.catalog.jdbi3.DashboardRepository.DashboardEntityInterface;
import org.openmetadata.catalog.jdbi3.DashboardServiceRepository.DashboardServiceEntityInterface; import org.openmetadata.catalog.jdbi3.DashboardServiceRepository.DashboardServiceEntityInterface;
import org.openmetadata.catalog.resources.dashboards.DashboardResourceTest; import org.openmetadata.catalog.resources.dashboards.DashboardResourceTest;
import org.openmetadata.catalog.resources.mlmodels.MlModelResource.MlModelList; import org.openmetadata.catalog.resources.mlmodels.MlModelResource.MlModelList;
import org.openmetadata.catalog.resources.services.DashboardServiceResourceTest; import org.openmetadata.catalog.resources.services.DashboardServiceResourceTest;
import org.openmetadata.catalog.resources.teams.TeamResourceTest;
import org.openmetadata.catalog.resources.teams.UserResourceTest;
import org.openmetadata.catalog.type.EntityReference; import org.openmetadata.catalog.type.EntityReference;
import org.openmetadata.catalog.type.TagLabel; import org.openmetadata.catalog.util.EntityInterface;
import org.openmetadata.catalog.util.TestUtils; import org.openmetadata.catalog.util.TestUtils;
import org.openmetadata.catalog.util.TestUtils.UpdateType; import org.openmetadata.catalog.util.JsonUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.ws.rs.client.WebTarget; import javax.ws.rs.client.WebTarget;
import javax.ws.rs.core.Response.Status; import javax.ws.rs.core.Response.Status;
import java.io.IOException;
import java.net.URISyntaxException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.UUID; import java.util.UUID;
import java.util.function.BiConsumer;
import static javax.ws.rs.core.Response.Status.BAD_REQUEST; import static javax.ws.rs.core.Response.Status.BAD_REQUEST;
import static javax.ws.rs.core.Response.Status.CONFLICT; import static javax.ws.rs.core.Response.Status.CONFLICT;
import static javax.ws.rs.core.Response.Status.CREATED;
import static javax.ws.rs.core.Response.Status.FORBIDDEN; import static javax.ws.rs.core.Response.Status.FORBIDDEN;
import static javax.ws.rs.core.Response.Status.NOT_FOUND; import static javax.ws.rs.core.Response.Status.NOT_FOUND;
import static javax.ws.rs.core.Response.Status.OK;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.openmetadata.catalog.exception.CatalogExceptionMessage.ENTITY_ALREADY_EXISTS;
import static org.openmetadata.catalog.exception.CatalogExceptionMessage.entityNotFound; import static org.openmetadata.catalog.exception.CatalogExceptionMessage.entityNotFound;
import static org.openmetadata.catalog.util.TestUtils.UpdateType.MINOR_UPDATE; import static org.openmetadata.catalog.util.TestUtils.UpdateType.MINOR_UPDATE;
import static org.openmetadata.catalog.util.TestUtils.UpdateType.NO_CHANGE; import static org.openmetadata.catalog.util.TestUtils.UpdateType.NO_CHANGE;
import static org.openmetadata.catalog.util.TestUtils.adminAuthHeaders; import static org.openmetadata.catalog.util.TestUtils.adminAuthHeaders;
import static org.openmetadata.catalog.util.TestUtils.assertEntityPagination;
import static org.openmetadata.catalog.util.TestUtils.assertResponse; import static org.openmetadata.catalog.util.TestUtils.assertResponse;
import static org.openmetadata.catalog.util.TestUtils.authHeaders; import static org.openmetadata.catalog.util.TestUtils.authHeaders;
public class MlModelResourceTest extends CatalogApplicationTest { @TestMethodOrder(MethodOrderer.OrderAnnotation.class)
private static final Logger LOG = LoggerFactory.getLogger(MlModelResourceTest.class); public class MlModelResourceTest extends EntityResourceTest<MlModel> {
public static User USER1;
public static EntityReference USER_OWNER1;
public static Team TEAM1;
public static EntityReference TEAM_OWNER1;
public static String ALGORITHM = "regression";
public static EntityReference SUPERSET_REFERENCE; public static EntityReference SUPERSET_REFERENCE;
public static String ALGORITHM = "regression";
public static Dashboard DASHBOARD; public static Dashboard DASHBOARD;
public static EntityReference DASHBOARD_REFERENCE; public static EntityReference DASHBOARD_REFERENCE;
public static List<MlFeature> ML_FEATURES = Arrays.asList( public static List<MlFeature> ML_FEATURES = Arrays.asList(
@ -95,7 +92,6 @@ public class MlModelResourceTest extends CatalogApplicationTest {
new MlFeatureSource() new MlFeatureSource()
.withName("age") .withName("age")
.withDataType(FeatureSourceDataType.INTEGER) .withDataType(FeatureSourceDataType.INTEGER)
.withFullyQualifiedName("my_service.my_db.my_table.age")
) )
), ),
new MlFeature() new MlFeature()
@ -105,12 +101,10 @@ public class MlModelResourceTest extends CatalogApplicationTest {
Arrays.asList( Arrays.asList(
new MlFeatureSource() new MlFeatureSource()
.withName("age") .withName("age")
.withDataType(FeatureSourceDataType.INTEGER) .withDataType(FeatureSourceDataType.INTEGER),
.withFullyQualifiedName("my_service.my_db.my_table.age"),
new MlFeatureSource() new MlFeatureSource()
.withName("education") .withName("education")
.withDataType(FeatureSourceDataType.STRING) .withDataType(FeatureSourceDataType.STRING)
.withFullyQualifiedName("my_api.education")
) )
) )
.withFeatureAlgorithm("PCA") .withFeatureAlgorithm("PCA")
@ -120,14 +114,16 @@ public class MlModelResourceTest extends CatalogApplicationTest {
new MlHyperParameter().withName("random").withValue("hello") new MlHyperParameter().withName("random").withValue("hello")
); );
public MlModelResourceTest() {
super(Entity.MLMODEL, MlModel.class, MlModelList.class, "mlmodels", MlModelResource.FIELDS, true,
true, true);
}
@BeforeAll @BeforeAll
public static void setup(TestInfo test) throws HttpResponseException { public static void setup(TestInfo test) throws IOException, URISyntaxException {
USER1 = UserResourceTest.createUser(UserResourceTest.create(test), authHeaders("test@open-metadata.org"));
USER_OWNER1 = new EntityReference().withId(USER1.getId()).withType("user");
TEAM1 = TeamResourceTest.createTeam(TeamResourceTest.create(test), adminAuthHeaders()); EntityResourceTest.setup(test);
TEAM_OWNER1 = new EntityReference().withId(TEAM1.getId()).withType("team");
CreateDashboardService createService = new CreateDashboardService().withName("superset") CreateDashboardService createService = new CreateDashboardService().withName("superset")
.withServiceType(DashboardServiceType.Superset).withDashboardUrl(TestUtils.DASHBOARD_URL); .withServiceType(DashboardServiceType.Superset).withDashboardUrl(TestUtils.DASHBOARD_URL);
@ -141,358 +137,159 @@ public class MlModelResourceTest extends CatalogApplicationTest {
DASHBOARD_REFERENCE = new DashboardEntityInterface(DASHBOARD).getEntityReference(); DASHBOARD_REFERENCE = new DashboardEntityInterface(DASHBOARD).getEntityReference();
} }
public static MlModel createMlModel(CreateMlModel create,
Map<String, String> authHeaders) throws HttpResponseException {
return new MlModelResourceTest().createEntity(create, authHeaders);
}
@Test @Test
public void post_modelWithLongName_400_badRequest(TestInfo test) { public void post_MlModelWithLongName_400_badRequest(TestInfo test) {
// Create model with mandatory name field empty // Create model with mandatory name field empty
CreateMlModel create = create(test).withName(TestUtils.LONG_ENTITY_NAME); CreateMlModel create = create(test).withName(TestUtils.LONG_ENTITY_NAME);
HttpResponseException exception = assertThrows(HttpResponseException.class, () -> assertResponse(() -> createMlModel(create, adminAuthHeaders()), BAD_REQUEST,
createModel(create, adminAuthHeaders())); "[name size must be between 1 and 64]");
assertResponse(exception, BAD_REQUEST, "[name size must be between 1 and 64]");
} }
@Test @Test
public void post_ModelWithoutName_400_badRequest(TestInfo test) { public void post_MlModelWithoutName_400_badRequest(TestInfo test) {
// Create Model with mandatory name field empty // Create Model with mandatory name field empty
CreateMlModel create = create(test).withName(""); CreateMlModel create = create(test).withName("");
HttpResponseException exception = assertThrows(HttpResponseException.class, () -> assertResponse(() -> createMlModel(create, adminAuthHeaders()), BAD_REQUEST,
createModel(create, adminAuthHeaders())); "[name size must be between 1 and 64]");
assertResponse(exception, BAD_REQUEST, "[name size must be between 1 and 64]");
} }
@Test @Test
public void post_ModelAlreadyExists_409_conflict(TestInfo test) throws HttpResponseException { public void post_MlModelAlreadyExists_409_conflict(TestInfo test) throws HttpResponseException {
CreateMlModel create = create(test); CreateMlModel create = create(test);
createModel(create, adminAuthHeaders()); createMlModel(create, adminAuthHeaders());
HttpResponseException exception = assertThrows(HttpResponseException.class, () -> assertResponse(() -> createMlModel(create, adminAuthHeaders()), CONFLICT, ENTITY_ALREADY_EXISTS);
createModel(create, adminAuthHeaders()));
assertResponse(exception, CONFLICT, CatalogExceptionMessage.ENTITY_ALREADY_EXISTS);
} }
@Test @Test
public void post_validModels_as_admin_200_OK(TestInfo test) throws HttpResponseException { public void post_validMlModels_as_admin_200_OK(TestInfo test) throws IOException {
// Create valid model // Create valid model
CreateMlModel create = create(test); CreateMlModel create = create(test);
createAndCheckModel(create, adminAuthHeaders()); createAndCheckEntity(create, adminAuthHeaders());
create.withName(getModelName(test, 1)).withDescription("description"); create.withName(getModelName(test, 1)).withDescription("description");
createAndCheckModel(create, adminAuthHeaders()); createAndCheckEntity(create, adminAuthHeaders());
} }
@Test @Test
public void post_ModelWithUserOwner_200_ok(TestInfo test) throws HttpResponseException { public void post_MlModelWithUserOwner_200_ok(TestInfo test) throws IOException {
createAndCheckModel(create(test).withOwner(USER_OWNER1), adminAuthHeaders()); createAndCheckEntity(create(test).withOwner(USER_OWNER1), adminAuthHeaders());
} }
@Test @Test
public void post_ModelWithTeamOwner_200_ok(TestInfo test) throws HttpResponseException { public void post_MlModelWithTeamOwner_200_ok(TestInfo test) throws IOException {
createAndCheckModel(create(test).withOwner(TEAM_OWNER1).withDisplayName("Model1"), adminAuthHeaders()); createAndCheckEntity(create(test).withOwner(TEAM_OWNER1).withDisplayName("Model1"), adminAuthHeaders());
} }
@Test @Test
public void post_ModelWithDashboard_200_ok(TestInfo test) throws HttpResponseException { public void post_MlModelWithDashboard_200_ok(TestInfo test) throws IOException {
createAndCheckModel(create(test), DASHBOARD_REFERENCE, adminAuthHeaders()); CreateMlModel create = create(test).withDashboard(DASHBOARD_REFERENCE);
createAndCheckEntity(create, adminAuthHeaders());
} }
@Test @Test
public void post_Model_as_non_admin_401(TestInfo test) { public void post_MlModel_as_non_admin_401(TestInfo test) {
CreateMlModel create = create(test); CreateMlModel create = create(test);
HttpResponseException exception = assertThrows(HttpResponseException.class, () -> assertResponse(() -> createMlModel(create, authHeaders("test@open-metadata.org")), FORBIDDEN,
createModel(create, authHeaders("test@open-metadata.org"))); "Principal: CatalogPrincipal{name='test'} is not admin");
assertResponse(exception, FORBIDDEN, "Principal: CatalogPrincipal{name='test'} is not admin");
} }
@Test @Test
public void post_ModelWithInvalidOwnerType_4xx(TestInfo test) { public void post_MlModelWithInvalidOwnerType_4xx(TestInfo test) {
EntityReference owner = new EntityReference().withId(TEAM1.getId()); /* No owner type is set */ EntityReference owner = new EntityReference().withId(TEAM1.getId()); /* No owner type is set */
CreateMlModel create = create(test).withOwner(owner); CreateMlModel create = create(test).withOwner(owner);
HttpResponseException exception = assertThrows(HttpResponseException.class, () -> HttpResponseException exception = assertThrows(HttpResponseException.class, () ->
createModel(create, adminAuthHeaders())); createEntity(create, adminAuthHeaders()));
TestUtils.assertResponseContains(exception, BAD_REQUEST, "type must not be null"); TestUtils.assertResponseContains(exception, BAD_REQUEST, "type must not be null");
} }
@Test @Test
public void post_ModelWithNonExistentOwner_4xx(TestInfo test) { public void post_MlModelWithNonExistentOwner_4xx(TestInfo test) {
EntityReference owner = new EntityReference().withId(TestUtils.NON_EXISTENT_ENTITY).withType("user"); EntityReference owner = new EntityReference().withId(TestUtils.NON_EXISTENT_ENTITY).withType("user");
CreateMlModel create = create(test).withOwner(owner); CreateMlModel create = create(test).withOwner(owner);
HttpResponseException exception = assertThrows(HttpResponseException.class, () ->
createModel(create, adminAuthHeaders())); assertResponse(() -> createMlModel(create, adminAuthHeaders()), NOT_FOUND,
assertResponse(exception, NOT_FOUND, entityNotFound("User", TestUtils.NON_EXISTENT_ENTITY)); entityNotFound("User", TestUtils.NON_EXISTENT_ENTITY));
} }
@Test @Test
public void get_ModelListWithInvalidLimitOffset_4xx() { public void put_MlModelUpdateWithNoChange_200(TestInfo test) throws IOException {
// Limit must be >= 1 and <= 1000,000
HttpResponseException exception = assertThrows(HttpResponseException.class, ()
-> listModels(null, -1, null, null, adminAuthHeaders()));
assertResponse(exception, BAD_REQUEST, "[query param limit must be greater than or equal to 1]");
exception = assertThrows(HttpResponseException.class, ()
-> listModels(null, 0, null, null, adminAuthHeaders()));
assertResponse(exception, BAD_REQUEST, "[query param limit must be greater than or equal to 1]");
exception = assertThrows(HttpResponseException.class, ()
-> listModels(null, 1000001, null, null, adminAuthHeaders()));
assertResponse(exception, BAD_REQUEST, "[query param limit must be less than or equal to 1000000]");
}
@Test
public void get_ModelListWithInvalidPaginationCursors_4xx() {
// Passing both before and after cursors is invalid
HttpResponseException exception = assertThrows(HttpResponseException.class, ()
-> listModels(null, 1, "", "", adminAuthHeaders()));
assertResponse(exception, BAD_REQUEST, "Only one of before or after query parameter allowed");
}
@Test
public void get_ModelListWithValidLimitOffset_4xx(TestInfo test) throws HttpResponseException {
// Create a large number of Models
int maxModels = 40;
for (int i = 0; i < maxModels; i++) {
createModel(create(test, i), adminAuthHeaders());
}
// List all Models
MlModelList allModels = listModels(null, 1000000, null,
null, adminAuthHeaders());
int totalRecords = allModels.getData().size();
printModels(allModels);
// List limit number Models at a time at various offsets and ensure right results are returned
for (int limit = 1; limit < maxModels; limit++) {
String after = null;
String before;
int pageCount = 0;
int indexInAllModels = 0;
MlModelList forwardPage;
MlModelList backwardPage;
do { // For each limit (or page size) - forward scroll till the end
LOG.info("Limit {} forward scrollCount {} afterCursor {}", limit, pageCount, after);
forwardPage = listModels(null, limit, null, after, adminAuthHeaders());
printModels(forwardPage);
after = forwardPage.getPaging().getAfter();
before = forwardPage.getPaging().getBefore();
assertEntityPagination(allModels.getData(), forwardPage, limit, indexInAllModels);
if (pageCount == 0) { // CASE 0 - First page is being returned. There is no before cursor
assertNull(before);
} else {
// Make sure scrolling back based on before cursor returns the correct result
backwardPage = listModels(null, limit, before, null, adminAuthHeaders());
assertEntityPagination(allModels.getData(), backwardPage, limit, (indexInAllModels - limit));
}
indexInAllModels += forwardPage.getData().size();
pageCount++;
} while (after != null);
// We have now reached the last page - test backward scroll till the beginning
pageCount = 0;
indexInAllModels = totalRecords - limit - forwardPage.getData().size();
do {
LOG.info("Limit {} backward scrollCount {} beforeCursor {}", limit, pageCount, before);
forwardPage = listModels(null, limit, before, null, adminAuthHeaders());
printModels(forwardPage);
before = forwardPage.getPaging().getBefore();
assertEntityPagination(allModels.getData(), forwardPage, limit, indexInAllModels);
pageCount++;
indexInAllModels -= forwardPage.getData().size();
} while (before != null);
}
}
private void printModels(MlModelList list) {
list.getData().forEach(Model -> LOG.info("DB {}", Model.getFullyQualifiedName()));
LOG.info("before {} after {} ", list.getPaging().getBefore(), list.getPaging().getAfter());
}
@Test
public void put_ModelUpdateWithNoChange_200(TestInfo test) throws HttpResponseException {
// Create a Model with POST // Create a Model with POST
CreateMlModel request = create(test).withOwner(USER_OWNER1); CreateMlModel request = create(test).withOwner(USER_OWNER1);
MlModel model = createAndCheckModel(request, adminAuthHeaders()); MlModel model = createAndCheckEntity(request, adminAuthHeaders());
ChangeDescription change = getChangeDescription(model.getVersion());
// Update Model two times successfully with PUT requests // Update Model two times successfully with PUT requests
model = updateAndCheckModel(model, request, OK, adminAuthHeaders(), NO_CHANGE); updateAndCheckEntity(request, Status.OK, adminAuthHeaders(), NO_CHANGE, change);
updateAndCheckModel(model, request, OK, adminAuthHeaders(), NO_CHANGE);
} }
@Test @Test
public void put_ModelCreate_200(TestInfo test) throws HttpResponseException { public void put_MlModelUpdateAlgorithm_200(TestInfo test) throws IOException {
// Create a new Model with PUT CreateMlModel request = create(test);
CreateMlModel request = create(test).withOwner(USER_OWNER1); MlModel model = createAndCheckEntity(request, adminAuthHeaders());
updateAndCheckModel(null, request.withName(test.getDisplayName()).withDescription(null), CREATED, ChangeDescription change = getChangeDescription(model.getVersion());
adminAuthHeaders(), NO_CHANGE); change.getFieldsUpdated().add(
new FieldChange().withName("algorithm").withNewValue("SVM").withOldValue("regression")
);
updateAndCheckEntity(request.withAlgorithm("SVM"), Status.OK, adminAuthHeaders(), MINOR_UPDATE, change);
} }
@Test @Test
public void put_ModelCreate_as_owner_200(TestInfo test) throws HttpResponseException { public void put_MlModelAddDashboard_200(TestInfo test) throws IOException {
// Create a new Model with put CreateMlModel request = create(test);
CreateMlModel request = create(test).withOwner(USER_OWNER1); MlModel model = createAndCheckEntity(request, adminAuthHeaders());
// Add model as admin ChangeDescription change = getChangeDescription(model.getVersion());
MlModel model = createAndCheckModel(request, adminAuthHeaders()); change.getFieldsAdded().add(new FieldChange().withName("dashboard").withNewValue(DASHBOARD_REFERENCE));
// Update the table as Owner
updateAndCheckModel(model, request.withDescription("new"), OK, authHeaders(USER1.getEmail()), MINOR_UPDATE); updateAndCheckEntity(
request.withDashboard(DASHBOARD_REFERENCE), Status.OK, adminAuthHeaders(), MINOR_UPDATE, change
);
} }
@Test @Test
public void put_ModelNullDescriptionUpdate_200(TestInfo test) throws HttpResponseException { public void get_nonExistentMlModel_404_notFound() {
CreateMlModel request = create(test).withDescription(null);
MlModel model = createAndCheckModel(request, adminAuthHeaders());
// Update null description with a new description
MlModel db = updateAndCheckModel(model, request.withDisplayName("model1").
withDescription("newDescription"), OK, adminAuthHeaders(), MINOR_UPDATE);
assertEquals("model1", db.getDisplayName()); // Move this check to validate method
}
@Test
public void put_ModelEmptyDescriptionUpdate_200(TestInfo test) throws HttpResponseException {
// Create table with empty description
CreateMlModel request = create(test).withDescription("");
MlModel model = createAndCheckModel(request, adminAuthHeaders());
// Update empty description with a new description
updateAndCheckModel(model, request.withDescription("newDescription"), OK, adminAuthHeaders(), MINOR_UPDATE);
}
@Test
public void put_ModelNonEmptyDescriptionUpdate_200(TestInfo test) throws HttpResponseException {
CreateMlModel request = create(test).withDescription("description");
createAndCheckModel(request, adminAuthHeaders());
// Updating description is ignored when backend already has description
MlModel db = updateModel(request.withDescription("newDescription"), OK, adminAuthHeaders());
assertEquals("description", db.getDescription());
}
@Test
public void put_ModelUpdateOwner_200(TestInfo test) throws HttpResponseException {
CreateMlModel request = create(test).withDescription("");
MlModel model = createAndCheckModel(request, adminAuthHeaders());
// Change ownership from USER_OWNER1 to TEAM_OWNER1
model = updateAndCheckModel(model, request.withOwner(TEAM_OWNER1), OK, adminAuthHeaders(), MINOR_UPDATE);
// Remove ownership
model = updateAndCheckModel(model, request.withOwner(null), OK, adminAuthHeaders(), MINOR_UPDATE);
assertNull(model.getOwner());
}
@Test
public void put_ModelUpdateAlgorithm_200(TestInfo test) throws HttpResponseException {
CreateMlModel request = create(test).withDescription("");
MlModel model = createAndCheckModel(request, adminAuthHeaders());
updateAndCheckModel(model, request.withAlgorithm("SVM"), OK, adminAuthHeaders(), MINOR_UPDATE);
}
@Test
public void put_ModelUpdateDashboard_200(TestInfo test) throws HttpResponseException {
CreateMlModel request = create(test).withDescription("");
MlModel model = createAndCheckModel(request, adminAuthHeaders());
updateAndCheckModel(model, request.withDashboard(DASHBOARD_REFERENCE), OK, adminAuthHeaders(), MINOR_UPDATE);
}
@Test
public void get_nonExistentModel_404_notFound() {
HttpResponseException exception = assertThrows(HttpResponseException.class, () -> HttpResponseException exception = assertThrows(HttpResponseException.class, () ->
getModel(TestUtils.NON_EXISTENT_ENTITY, adminAuthHeaders())); getModel(TestUtils.NON_EXISTENT_ENTITY, adminAuthHeaders()));
// TODO: issue-1415
assertResponse(exception, NOT_FOUND, assertResponse(exception, NOT_FOUND,
entityNotFound(Entity.MLMODEL, TestUtils.NON_EXISTENT_ENTITY)); entityNotFound("mlModel", TestUtils.NON_EXISTENT_ENTITY));
} }
@Test @Test
public void get_ModelWithDifferentFields_200_OK(TestInfo test) throws HttpResponseException { public void get_MlModelWithDifferentFields_200_OK(TestInfo test) throws IOException {
// aqui no tenim HREF al dashboard
CreateMlModel create = create(test).withDescription("description") CreateMlModel create = create(test).withDescription("description")
.withOwner(USER_OWNER1).withDashboard(DASHBOARD_REFERENCE); .withOwner(USER_OWNER1).withDashboard(DASHBOARD_REFERENCE);
MlModel model = createAndCheckModel(create, adminAuthHeaders()); MlModel model = createAndCheckEntity(create, adminAuthHeaders());
validateGetWithDifferentFields(model, false); validateGetWithDifferentFields(model, false);
} }
@Test @Test
public void get_ModelByNameWithDifferentFields_200_OK(TestInfo test) throws HttpResponseException { public void get_MlModelByNameWithDifferentFields_200_OK(TestInfo test) throws IOException {
CreateMlModel create = create(test).withDescription("description") CreateMlModel create = create(test).withDescription("description")
.withOwner(USER_OWNER1).withDashboard(DASHBOARD_REFERENCE); .withOwner(USER_OWNER1).withDashboard(DASHBOARD_REFERENCE);
MlModel model = createAndCheckModel(create, adminAuthHeaders()); MlModel model = createAndCheckEntity(create, adminAuthHeaders());
validateGetWithDifferentFields(model, true); validateGetWithDifferentFields(model, true);
} }
@Test @Test
public void delete_emptyModel_200_ok(TestInfo test) throws HttpResponseException { public void delete_MlModel_200_ok(TestInfo test) throws HttpResponseException {
MlModel model = createModel(create(test), adminAuthHeaders()); MlModel model = createMlModel(create(test), adminAuthHeaders());
deleteModel(model.getId(), adminAuthHeaders()); deleteModel(model.getId(), adminAuthHeaders());
} }
@Test
public void delete_nonEmptyModel_4xx() {
// TODO
}
@Test @Test
public void delete_nonExistentModel_404() { public void delete_nonExistentModel_404() {
HttpResponseException exception = assertThrows(HttpResponseException.class, () -> HttpResponseException exception = assertThrows(HttpResponseException.class, () ->
deleteModel(TestUtils.NON_EXISTENT_ENTITY, adminAuthHeaders())); deleteModel(TestUtils.NON_EXISTENT_ENTITY, adminAuthHeaders()));
assertResponse(exception, NOT_FOUND, entityNotFound(Entity.MLMODEL, TestUtils.NON_EXISTENT_ENTITY)); // TODO: issue-1415
} assertResponse(exception, NOT_FOUND, entityNotFound("mlModel", TestUtils.NON_EXISTENT_ENTITY));
public static MlModel createAndCheckModel(CreateMlModel create,
Map<String, String> authHeaders) throws HttpResponseException {
String updatedBy = TestUtils.getPrincipal(authHeaders);
MlModel model = createModel(create, authHeaders);
validateModel(model, create.getDisplayName(), create.getDescription(), create.getOwner(), updatedBy);
return getAndValidate(model.getId(), create, authHeaders, updatedBy);
}
public static MlModel createAndCheckModel(CreateMlModel create, EntityReference dashboard,
Map<String, String> authHeaders) throws HttpResponseException {
String updatedBy = TestUtils.getPrincipal(authHeaders);
create.withDashboard(dashboard);
MlModel model = createModel(create, authHeaders);
assertEquals(0.1, model.getVersion());
validateModel(model, create.getDescription(), create.getOwner(), create.getTags(), updatedBy);
return getAndValidate(model.getId(), create, authHeaders, updatedBy);
}
public static MlModel updateAndCheckModel(MlModel before, CreateMlModel create, Status status,
Map<String, String> authHeaders, UpdateType updateType)
throws HttpResponseException {
String updatedBy = TestUtils.getPrincipal(authHeaders);
MlModel updatedModel = updateModel(create, status, authHeaders);
validateModel(updatedModel, create.getDescription(), create.getOwner(), updatedBy);
if (before == null) {
assertEquals(0.1, updatedModel.getVersion()); // First version created
} else {
TestUtils.validateUpdate(before.getVersion(), updatedModel.getVersion(), updateType);
}
return getAndValidate(updatedModel.getId(), create, authHeaders, updatedBy);
}
// Make sure in GET operations the returned Model has all the required information passed during creation
public static MlModel getAndValidate(UUID modelId,
CreateMlModel create,
Map<String, String> authHeaders,
String expectedUpdatedBy) throws HttpResponseException {
// GET the newly created Model by ID and validate
MlModel model = getModel(modelId, "owner", authHeaders);
validateModel(model, create.getDescription(), create.getOwner(), expectedUpdatedBy);
// GET the newly created Model by name and validate
String fqn = model.getFullyQualifiedName();
model = getModelByName(fqn, "owner", authHeaders);
return validateModel(model, create.getDescription(), create.getOwner(), expectedUpdatedBy);
}
public static MlModel updateModel(CreateMlModel create,
Status status,
Map<String, String> authHeaders) throws HttpResponseException {
return TestUtils.put(getResource("mlmodels"),
create, MlModel.class, status, authHeaders);
}
public static MlModel createModel(CreateMlModel create,
Map<String, String> authHeaders) throws HttpResponseException {
return TestUtils.post(getResource("mlmodels"), create, MlModel.class, authHeaders);
} }
/** Validate returned fields GET .../models/{id}?fields="..." or GET .../models/name/{fqn}?fields="..." */ /** Validate returned fields GET .../models/{id}?fields="..." or GET .../models/name/{fqn}?fields="..." */
@ -532,54 +329,6 @@ public class MlModelResourceTest extends CatalogApplicationTest {
TestUtils.validateEntityReference(model.getDashboard()); TestUtils.validateEntityReference(model.getDashboard());
} }
private static MlModel validateModel(MlModel model, String expectedDisplayName,
String expectedDescription,
EntityReference expectedOwner,
String expectedUpdatedBy) {
MlModel newModel = validateModel(model, expectedDescription, expectedOwner, expectedUpdatedBy);
assertEquals(expectedDisplayName, newModel.getDisplayName());
return newModel;
}
private static MlModel validateModel(MlModel model, String expectedDescription,
EntityReference expectedOwner, String expectedUpdatedBy) {
assertNotNull(model.getId());
assertNotNull(model.getHref());
assertNotNull(model.getAlgorithm());
assertEquals(expectedDescription, model.getDescription());
assertEquals(expectedUpdatedBy, model.getUpdatedBy());
// Validate owner
if (expectedOwner != null) {
TestUtils.validateEntityReference(model.getOwner());
assertEquals(expectedOwner.getId(), model.getOwner().getId());
assertEquals(expectedOwner.getType(), model.getOwner().getType());
assertNotNull(model.getOwner().getHref());
}
return model;
}
private static MlModel validateModel(MlModel model, String expectedDescription,
EntityReference expectedOwner,
List<TagLabel> expectedTags,
String expectedUpdatedBy) throws HttpResponseException {
assertNotNull(model.getId());
assertNotNull(model.getHref());
assertEquals(expectedDescription, model.getDescription());
assertEquals(expectedUpdatedBy, model.getUpdatedBy());
// Validate owner
if (expectedOwner != null) {
TestUtils.validateEntityReference(model.getOwner());
assertEquals(expectedOwner.getId(), model.getOwner().getId());
assertEquals(expectedOwner.getType(), model.getOwner().getType());
assertNotNull(model.getOwner().getHref());
}
TestUtils.validateTags(expectedTags, model.getTags());
return model;
}
public static void getModel(UUID id, Map<String, String> authHeaders) throws HttpResponseException { public static void getModel(UUID id, Map<String, String> authHeaders) throws HttpResponseException {
getModel(id, null, authHeaders); getModel(id, null, authHeaders);
} }
@ -598,27 +347,13 @@ public class MlModelResourceTest extends CatalogApplicationTest {
return TestUtils.get(target, MlModel.class, authHeaders); return TestUtils.get(target, MlModel.class, authHeaders);
} }
public static MlModelList listModels(String fields, Integer limitParam,
String before, String after, Map<String, String> authHeaders)
throws HttpResponseException {
WebTarget target = getResource("mlmodels");
target = fields != null ? target.queryParam("fields", fields): target;
target = limitParam != null ? target.queryParam("limit", limitParam): target;
target = before != null ? target.queryParam("before", before) : target;
target = after != null ? target.queryParam("after", after) : target;
return TestUtils.get(target, MlModelList.class, authHeaders);
}
private void deleteModel(UUID id, Map<String, String> authHeaders) throws HttpResponseException { private void deleteModel(UUID id, Map<String, String> authHeaders) throws HttpResponseException {
TestUtils.delete(getResource("mlmodels/" + id), authHeaders); TestUtils.delete(getResource("mlmodels/" + id), authHeaders);
// Ensure deleted Model does not exist // Check to make sure database does not exist
HttpResponseException exception = assertThrows(HttpResponseException.class, () -> getModel(id, authHeaders)); HttpResponseException exception = assertThrows(HttpResponseException.class, () -> getModel(id, authHeaders));
assertResponse(exception, NOT_FOUND, entityNotFound(Entity.MLMODEL, id)); // TODO: issue-1415 instead of mlModel, use Entity.MLMODEL
} assertResponse(exception, NOT_FOUND, CatalogExceptionMessage.entityNotFound("mlModel", id));
public static String getModelName(TestInfo test) {
return String.format("mlmodel_%s", test.getDisplayName());
} }
public static String getModelName(TestInfo test, int index) { public static String getModelName(TestInfo test, int index) {
@ -626,8 +361,7 @@ public class MlModelResourceTest extends CatalogApplicationTest {
} }
public static CreateMlModel create(TestInfo test) { public static CreateMlModel create(TestInfo test) {
return new CreateMlModel().withName(getModelName(test)).withAlgorithm(ALGORITHM) return create(test, 0);
.withMlFeatures(ML_FEATURES).withMlHyperParameters(ML_HYPERPARAMS);
} }
public static CreateMlModel create(TestInfo test, int index) { public static CreateMlModel create(TestInfo test, int index) {
@ -635,4 +369,121 @@ public class MlModelResourceTest extends CatalogApplicationTest {
.withMlFeatures(ML_FEATURES).withMlHyperParameters(ML_HYPERPARAMS); .withMlFeatures(ML_FEATURES).withMlHyperParameters(ML_HYPERPARAMS);
} }
@Override
public Object createRequest(TestInfo test, int index, String description, String displayName, EntityReference owner) {
return create(test, index).withDescription(description).withDisplayName(displayName).withOwner(owner);
}
@Override
public void validateUpdatedEntity(MlModel mlModel, Object request, Map<String, String> authHeaders)
throws HttpResponseException {
validateCreatedEntity(mlModel, request, authHeaders);
}
@Override
public void compareEntities(MlModel expected, MlModel updated, Map<String, String> authHeaders)
throws HttpResponseException {
validateCommonEntityFields(getEntityInterface(updated), expected.getDescription(),
TestUtils.getPrincipal(authHeaders), expected.getOwner());
// Entity specific validations
assertEquals(expected.getAlgorithm(), updated.getAlgorithm());
assertEquals(expected.getDashboard(), updated.getDashboard());
assertListProperty(expected.getMlFeatures(), updated.getMlFeatures(), assertMlFeature);
assertListProperty(expected.getMlHyperParameters(), updated.getMlHyperParameters(), assertMlHyperParam);
// assertListProperty on MlFeatures already validates size, so we can directly iterate on sources
validateMlFeatureSources(expected.getMlFeatures(), updated.getMlFeatures());
TestUtils.validateTags(expected.getTags(), updated.getTags());
TestUtils.validateEntityReference(updated.getFollowers());
}
@Override
public EntityInterface<MlModel> getEntityInterface(MlModel entity) {
return new MlModelRepository.MlModelEntityInterface(entity);
}
BiConsumer<MlFeature, MlFeature> assertMlFeature = (MlFeature expected, MlFeature actual) -> {
assertNotNull(actual.getFullyQualifiedName());
assertEquals(actual.getName(), expected.getName());
assertEquals(actual.getDescription(), expected.getDescription());
assertEquals(actual.getFeatureAlgorithm(), expected.getFeatureAlgorithm());
assertEquals(actual.getDataType(), expected.getDataType());
};
BiConsumer<MlHyperParameter, MlHyperParameter> assertMlHyperParam =
(MlHyperParameter expected, MlHyperParameter actual) -> {
assertEquals(actual.getName(), expected.getName());
assertEquals(actual.getDescription(), expected.getDescription());
assertEquals(actual.getValue(), expected.getValue());
};
BiConsumer<MlFeatureSource, MlFeatureSource> assertMlFeatureSource =
(MlFeatureSource expected, MlFeatureSource actual) -> {
assertNotNull(actual.getFullyQualifiedName());
assertEquals(actual.getName(), expected.getName());
assertEquals(actual.getDescription(), expected.getDescription());
assertEquals(actual.getDataType(), expected.getDataType());
};
private void validateMlFeatureSources(List<MlFeature> expected, List<MlFeature> actual)
throws HttpResponseException {
if (expected == null && actual == null) {
return;
}
for (int i = 0; i < expected.size(); i++) {
assertListProperty(expected.get(i).getFeatureSources(), actual.get(i).getFeatureSources(), assertMlFeatureSource);
}
}
@Override
public void validateCreatedEntity(MlModel createdEntity, Object request, Map<String, String> authHeaders)
throws HttpResponseException {
CreateMlModel createRequest = (CreateMlModel) request;
validateCommonEntityFields(getEntityInterface(createdEntity), createRequest.getDescription(),
TestUtils.getPrincipal(authHeaders), createRequest.getOwner());
// Entity specific validations
assertEquals(createRequest.getAlgorithm(), createdEntity.getAlgorithm());
assertEquals(createRequest.getDashboard(), createdEntity.getDashboard());
assertListProperty(createRequest.getMlFeatures(), createdEntity.getMlFeatures(), assertMlFeature);
assertListProperty(createRequest.getMlHyperParameters(), createdEntity.getMlHyperParameters(), assertMlHyperParam);
// assertListProperty on MlFeatures already validates size, so we can directly iterate on sources
validateMlFeatureSources(createRequest.getMlFeatures(), createdEntity.getMlFeatures());
TestUtils.validateTags(createRequest.getTags(), createdEntity.getTags());
TestUtils.validateEntityReference(createdEntity.getFollowers());
}
@Override
public void assertFieldChange(String fieldName, Object expected, Object actual) throws IOException {
if (expected == actual) {
return;
}
if (fieldName.contains("mlFeatures") && !fieldName.endsWith("tags") && !fieldName.endsWith("description")) {
List<MlFeature> expectedFeatures = (List<MlFeature>) expected;
List<MlFeature> actualFeatures = JsonUtils.readObjects(actual.toString(), MlFeature.class);
assertEquals(expectedFeatures, actualFeatures);
} else if (fieldName.contains("mlHyperParameters") && !fieldName.endsWith("tags")
&& !fieldName.endsWith("description")) {
List<MlHyperParameter> expectedConstraints = (List<MlHyperParameter>) expected;
List<MlHyperParameter> actualConstraints = JsonUtils.readObjects(actual.toString(), MlHyperParameter.class);
assertEquals(expectedConstraints, actualConstraints);
} else if (fieldName.endsWith("algorithm")) {
String expectedAlgorithm = (String) expected;
String actualAlgorithm = actual.toString();
assertEquals(expectedAlgorithm, actualAlgorithm);
} else if (fieldName.endsWith("dashboard")) {
EntityReference expectedDashboard = (EntityReference) expected;
EntityReference actualDashboard = JsonUtils.readValue(actual.toString(), EntityReference.class);
assertEquals(expectedDashboard, actualDashboard);
} else {
assertCommonFieldChange(fieldName, expected, actual);
}
}
} }