[Issue-1099] - Update MLModel Ingestion (#1179)

* Refactor MLModel

* Prepare MLModel properties example

* Update API tags to MLModels

* Update API descriptions for MLModel

* Rename tags to mlModels
This commit is contained in:
Pere Miquel Brull 2021-11-15 16:58:15 +01:00 committed by GitHub
parent 69f9eeb718
commit ffd7818978
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 89 additions and 31 deletions

View File

@ -120,8 +120,8 @@ public class MLModelResource {
@GET @GET
@Valid @Valid
@Operation(summary = "List Models", tags = "models", @Operation(summary = "List ML Models", tags = "mlModels",
description = "Get a list of models. Use `fields` parameter to get only necessary fields. " + description = "Get a list of ML Models. Use `fields` parameter to get only necessary fields. " +
" Use cursor-based pagination to limit the number " + " Use cursor-based pagination to limit the number " +
"entries in the list using `limit` and `before` or `after` query params.", "entries in the list using `limit` and `before` or `after` query params.",
responses = { responses = {
@ -162,8 +162,8 @@ public class MLModelResource {
@GET @GET
@Path("/{id}") @Path("/{id}")
@Operation(summary = "Get a model", tags = "models", @Operation(summary = "Get an ML Model", tags = "mlModels",
description = "Get a model by `id`.", description = "Get an ML Model by `id`.",
responses = { responses = {
@ApiResponse(responseCode = "200", description = "The model", @ApiResponse(responseCode = "200", description = "The model",
content = @Content(mediaType = "application/json", content = @Content(mediaType = "application/json",
@ -182,8 +182,8 @@ public class MLModelResource {
@GET @GET
@Path("/name/{fqn}") @Path("/name/{fqn}")
@Operation(summary = "Get a model by name", tags = "models", @Operation(summary = "Get an ML Model by name", tags = "mlModels",
description = "Get a model by fully qualified name.", description = "Get an ML Model by fully qualified name.",
responses = { responses = {
@ApiResponse(responseCode = "200", description = "The model", @ApiResponse(responseCode = "200", description = "The model",
content = @Content(mediaType = "application/json", content = @Content(mediaType = "application/json",
@ -202,8 +202,8 @@ public class MLModelResource {
@POST @POST
@Operation(summary = "Create a model", tags = "models", @Operation(summary = "Create an ML Model", tags = "mlModels",
description = "Create a new model.", description = "Create a new ML Model.",
responses = { responses = {
@ApiResponse(responseCode = "200", description = "The model", @ApiResponse(responseCode = "200", description = "The model",
content = @Content(mediaType = "application/json", content = @Content(mediaType = "application/json",
@ -220,8 +220,8 @@ public class MLModelResource {
@PATCH @PATCH
@Path("/{id}") @Path("/{id}")
@Operation(summary = "Update a model", tags = "models", @Operation(summary = "Update an ML Model", tags = "mlModels",
description = "Update an existing model using JsonPatch.", description = "Update an existing ML Model using JsonPatch.",
externalDocs = @ExternalDocumentation(description = "JsonPatch RFC", externalDocs = @ExternalDocumentation(description = "JsonPatch RFC",
url = "https://tools.ietf.org/html/rfc6902")) url = "https://tools.ietf.org/html/rfc6902"))
@Consumes(MediaType.APPLICATION_JSON_PATCH_JSON) @Consumes(MediaType.APPLICATION_JSON_PATCH_JSON)
@ -246,8 +246,8 @@ public class MLModelResource {
} }
@PUT @PUT
@Operation(summary = "Create or update a model", tags = "models", @Operation(summary = "Create or update an ML Model", tags = "mlModels",
description = "Create a new model, if it does not exist or update an existing model.", description = "Create a new ML Model, if it does not exist or update an existing model.",
responses = { responses = {
@ApiResponse(responseCode = "200", description = "The model", @ApiResponse(responseCode = "200", description = "The model",
content = @Content(mediaType = "application/json", content = @Content(mediaType = "application/json",
@ -265,7 +265,7 @@ public class MLModelResource {
@PUT @PUT
@Path("/{id}/followers") @Path("/{id}/followers")
@Operation(summary = "Add a follower", tags = "models", @Operation(summary = "Add a follower", tags = "mlModels",
description = "Add a user identified by `userId` as follower of this model", description = "Add a user identified by `userId` as follower of this model",
responses = { responses = {
@ApiResponse(responseCode = "200", description = "OK"), @ApiResponse(responseCode = "200", description = "OK"),
@ -284,7 +284,7 @@ public class MLModelResource {
@DELETE @DELETE
@Path("/{id}/followers/{userId}") @Path("/{id}/followers/{userId}")
@Operation(summary = "Remove a follower", tags = "models", @Operation(summary = "Remove a follower", tags = "mlModels",
description = "Remove the user identified `userId` as a follower of the model.") description = "Remove the user identified `userId` as a follower of the model.")
public Response deleteFollower(@Context UriInfo uriInfo, public Response deleteFollower(@Context UriInfo uriInfo,
@Context SecurityContext securityContext, @Context SecurityContext securityContext,
@ -300,8 +300,8 @@ public class MLModelResource {
@DELETE @DELETE
@Path("/{id}") @Path("/{id}")
@Operation(summary = "Delete a Model", tags = "models", @Operation(summary = "Delete an ML Model", tags = "mlModels",
description = "Delete a model by `id`.", description = "Delete an ML Model by `id`.",
responses = { responses = {
@ApiResponse(responseCode = "200", description = "OK"), @ApiResponse(responseCode = "200", description = "OK"),
@ApiResponse(responseCode = "404", description = "model for instance {id} is not found") @ApiResponse(responseCode = "404", description = "model for instance {id} is not found")

View File

@ -4,9 +4,16 @@ OpenMetadata high-level API Model test
import uuid import uuid
from unittest import TestCase from unittest import TestCase
from metadata.generated.schema.api.data.createModel import CreateModelEntityRequest from metadata.generated.schema.api.data.createMLModel import CreateMLModelEntityRequest
from metadata.generated.schema.api.teams.createUser import CreateUserEntityRequest from metadata.generated.schema.api.teams.createUser import CreateUserEntityRequest
from metadata.generated.schema.entity.data.model import Model from metadata.generated.schema.entity.data.mlmodel import (
FeatureSource,
FeatureSourceDataType,
FeatureType,
MlFeature,
MlHyperParameter,
MLModel,
)
from metadata.generated.schema.type.entityReference import EntityReference from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
@ -28,13 +35,13 @@ class OMetaModelTest(TestCase):
) )
owner = EntityReference(id=user.id, type="user") owner = EntityReference(id=user.id, type="user")
entity = Model( entity = MLModel(
id=uuid.uuid4(), id=uuid.uuid4(),
name="test-model", name="test-model",
algorithm="algo", algorithm="algo",
fullyQualifiedName="test-model", fullyQualifiedName="test-model",
) )
create = CreateModelEntityRequest(name="test-model", algorithm="algo") create = CreateMLModelEntityRequest(name="test-model", algorithm="algo")
def test_create(self): def test_create(self):
""" """
@ -56,7 +63,7 @@ class OMetaModelTest(TestCase):
updated = self.create.dict(exclude_unset=True) updated = self.create.dict(exclude_unset=True)
updated["owner"] = self.owner updated["owner"] = self.owner
updated_entity = CreateModelEntityRequest(**updated) updated_entity = CreateMLModelEntityRequest(**updated)
res = self.metadata.create_or_update(data=updated_entity) res = self.metadata.create_or_update(data=updated_entity)
@ -67,13 +74,13 @@ class OMetaModelTest(TestCase):
# Getting without owner field does not return it by default # Getting without owner field does not return it by default
res_none = self.metadata.get_by_name( res_none = self.metadata.get_by_name(
entity=Model, fqdn=self.entity.fullyQualifiedName entity=MLModel, fqdn=self.entity.fullyQualifiedName
) )
self.assertIsNone(res_none.owner) self.assertIsNone(res_none.owner)
# We can request specific fields to be added # We can request specific fields to be added
res_owner = self.metadata.get_by_name( res_owner = self.metadata.get_by_name(
entity=Model, entity=MLModel,
fqdn=self.entity.fullyQualifiedName, fqdn=self.entity.fullyQualifiedName,
fields=["owner", "followers"], fields=["owner", "followers"],
) )
@ -87,7 +94,7 @@ class OMetaModelTest(TestCase):
self.metadata.create_or_update(data=self.create) self.metadata.create_or_update(data=self.create)
res = self.metadata.get_by_name( res = self.metadata.get_by_name(
entity=Model, fqdn=self.entity.fullyQualifiedName entity=MLModel, fqdn=self.entity.fullyQualifiedName
) )
self.assertEqual(res.name, self.entity.name) self.assertEqual(res.name, self.entity.name)
@ -100,10 +107,12 @@ class OMetaModelTest(TestCase):
# First pick up by name # First pick up by name
res_name = self.metadata.get_by_name( res_name = self.metadata.get_by_name(
entity=Model, fqdn=self.entity.fullyQualifiedName entity=MLModel, fqdn=self.entity.fullyQualifiedName
) )
# Then fetch by ID # Then fetch by ID
res = self.metadata.get_by_id(entity=Model, entity_id=str(res_name.id.__root__)) res = self.metadata.get_by_id(
entity=MLModel, entity_id=str(res_name.id.__root__)
)
self.assertEqual(res_name.id, res.id) self.assertEqual(res_name.id, res.id)
@ -114,7 +123,7 @@ class OMetaModelTest(TestCase):
self.metadata.create_or_update(data=self.create) self.metadata.create_or_update(data=self.create)
res = self.metadata.list_entities(entity=Model) res = self.metadata.list_entities(entity=MLModel)
# Fetch our test model. We have already inserted it, so we should find it # Fetch our test model. We have already inserted it, so we should find it
data = next( data = next(
@ -131,18 +140,18 @@ class OMetaModelTest(TestCase):
# Find by name # Find by name
res_name = self.metadata.get_by_name( res_name = self.metadata.get_by_name(
entity=Model, fqdn=self.entity.fullyQualifiedName entity=MLModel, fqdn=self.entity.fullyQualifiedName
) )
# Then fetch by ID # Then fetch by ID
res_id = self.metadata.get_by_id( res_id = self.metadata.get_by_id(
entity=Model, entity_id=str(res_name.id.__root__) entity=MLModel, entity_id=str(res_name.id.__root__)
) )
# Delete # Delete
self.metadata.delete(entity=Model, entity_id=str(res_id.id.__root__)) self.metadata.delete(entity=MLModel, entity_id=str(res_id.id.__root__))
# Then we should not find it # Then we should not find it
res = self.metadata.list_entities(entity=Model) res = self.metadata.list_entities(entity=MLModel)
assert not next( assert not next(
iter( iter(
@ -152,3 +161,52 @@ class OMetaModelTest(TestCase):
), ),
None, None,
) )
def test_model_properties(self):
"""
Check that we can create models with MLFeatures and MLHyperParams
"""
model = CreateMLModelEntityRequest(
name="test-model-properties",
algorithm="algo",
mlFeatures=[
MlFeature(
name="age",
dataType=FeatureType.numerical,
featureSources=[
FeatureSource(
name="age",
dataType=FeatureSourceDataType.integer,
fullyQualifiedName="my_service.my_db.my_table.age",
)
],
),
MlFeature(
name="persona",
dataType=FeatureType.categorical,
featureSources=[
FeatureSource(
name="age",
dataType=FeatureSourceDataType.integer,
fullyQualifiedName="my_service.my_db.my_table.age",
),
FeatureSource(
name="education",
dataType=FeatureSourceDataType.string,
fullyQualifiedName="my_api.education",
),
],
featureAlgorithm="PCA",
),
],
mlHyperParameters=[
MlHyperParameter(name="regularisation", value="0.5"),
MlHyperParameter(name="random", value="hello"),
],
)
res = self.metadata.create_or_update(data=model)
self.assertIsNotNone(res.mlFeatures)
self.assertIsNotNone(res.mlHyperParameters)