diff --git a/catalog-rest-service/src/main/java/org/openmetadata/catalog/resources/mlmodels/MLModelResource.java b/catalog-rest-service/src/main/java/org/openmetadata/catalog/resources/mlmodels/MLModelResource.java index 86c663fabdd..a534afa27ab 100644 --- a/catalog-rest-service/src/main/java/org/openmetadata/catalog/resources/mlmodels/MLModelResource.java +++ b/catalog-rest-service/src/main/java/org/openmetadata/catalog/resources/mlmodels/MLModelResource.java @@ -120,8 +120,8 @@ public class MLModelResource { @GET @Valid - @Operation(summary = "List Models", tags = "models", - description = "Get a list of models. Use `fields` parameter to get only necessary fields. " + + @Operation(summary = "List ML Models", tags = "mlModels", + description = "Get a list of ML Models. Use `fields` parameter to get only necessary fields. " + " Use cursor-based pagination to limit the number " + "entries in the list using `limit` and `before` or `after` query params.", responses = { @@ -162,8 +162,8 @@ public class MLModelResource { @GET @Path("/{id}") - @Operation(summary = "Get a model", tags = "models", - description = "Get a model by `id`.", + @Operation(summary = "Get an ML Model", tags = "mlModels", + description = "Get an ML Model by `id`.", responses = { @ApiResponse(responseCode = "200", description = "The model", content = @Content(mediaType = "application/json", @@ -182,8 +182,8 @@ public class MLModelResource { @GET @Path("/name/{fqn}") - @Operation(summary = "Get a model by name", tags = "models", - description = "Get a model by fully qualified name.", + @Operation(summary = "Get an ML Model by name", tags = "mlModels", + description = "Get an ML Model by fully qualified name.", responses = { @ApiResponse(responseCode = "200", description = "The model", content = @Content(mediaType = "application/json", @@ -202,8 +202,8 @@ public class MLModelResource { @POST - @Operation(summary = "Create a model", tags = "models", - description = "Create a new model.", + @Operation(summary = "Create an ML Model", tags = "mlModels", + description = "Create a new ML Model.", responses = { @ApiResponse(responseCode = "200", description = "The model", content = @Content(mediaType = "application/json", @@ -220,8 +220,8 @@ public class MLModelResource { @PATCH @Path("/{id}") - @Operation(summary = "Update a model", tags = "models", - description = "Update an existing model using JsonPatch.", + @Operation(summary = "Update an ML Model", tags = "mlModels", + description = "Update an existing ML Model using JsonPatch.", externalDocs = @ExternalDocumentation(description = "JsonPatch RFC", url = "https://tools.ietf.org/html/rfc6902")) @Consumes(MediaType.APPLICATION_JSON_PATCH_JSON) @@ -246,8 +246,8 @@ public class MLModelResource { } @PUT - @Operation(summary = "Create or update a model", tags = "models", - description = "Create a new model, if it does not exist or update an existing model.", + @Operation(summary = "Create or update an ML Model", tags = "mlModels", + description = "Create a new ML Model, if it does not exist or update an existing model.", responses = { @ApiResponse(responseCode = "200", description = "The model", content = @Content(mediaType = "application/json", @@ -265,7 +265,7 @@ public class MLModelResource { @PUT @Path("/{id}/followers") - @Operation(summary = "Add a follower", tags = "models", + @Operation(summary = "Add a follower", tags = "mlModels", description = "Add a user identified by `userId` as follower of this model", responses = { @ApiResponse(responseCode = "200", description = "OK"), @@ -284,7 +284,7 @@ public class MLModelResource { @DELETE @Path("/{id}/followers/{userId}") - @Operation(summary = "Remove a follower", tags = "models", + @Operation(summary = "Remove a follower", tags = "mlModels", description = "Remove the user identified `userId` as a follower of the model.") public Response deleteFollower(@Context UriInfo uriInfo, @Context SecurityContext securityContext, @@ -300,8 +300,8 @@ public class MLModelResource { @DELETE @Path("/{id}") - @Operation(summary = "Delete a Model", tags = "models", - description = "Delete a model by `id`.", + @Operation(summary = "Delete an ML Model", tags = "mlModels", + description = "Delete an ML Model by `id`.", responses = { @ApiResponse(responseCode = "200", description = "OK"), @ApiResponse(responseCode = "404", description = "model for instance {id} is not found") diff --git a/ingestion/tests/integration/ometa/test_ometa_model_api.py b/ingestion/tests/integration/ometa/test_ometa_model_api.py index acfa30d3a31..9cd12136919 100644 --- a/ingestion/tests/integration/ometa/test_ometa_model_api.py +++ b/ingestion/tests/integration/ometa/test_ometa_model_api.py @@ -4,9 +4,16 @@ OpenMetadata high-level API Model test import uuid from unittest import TestCase -from metadata.generated.schema.api.data.createModel import CreateModelEntityRequest +from metadata.generated.schema.api.data.createMLModel import CreateMLModelEntityRequest from metadata.generated.schema.api.teams.createUser import CreateUserEntityRequest -from metadata.generated.schema.entity.data.model import Model +from metadata.generated.schema.entity.data.mlmodel import ( + FeatureSource, + FeatureSourceDataType, + FeatureType, + MlFeature, + MlHyperParameter, + MLModel, +) from metadata.generated.schema.type.entityReference import EntityReference from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig @@ -28,13 +35,13 @@ class OMetaModelTest(TestCase): ) owner = EntityReference(id=user.id, type="user") - entity = Model( + entity = MLModel( id=uuid.uuid4(), name="test-model", algorithm="algo", fullyQualifiedName="test-model", ) - create = CreateModelEntityRequest(name="test-model", algorithm="algo") + create = CreateMLModelEntityRequest(name="test-model", algorithm="algo") def test_create(self): """ @@ -56,7 +63,7 @@ class OMetaModelTest(TestCase): updated = self.create.dict(exclude_unset=True) updated["owner"] = self.owner - updated_entity = CreateModelEntityRequest(**updated) + updated_entity = CreateMLModelEntityRequest(**updated) res = self.metadata.create_or_update(data=updated_entity) @@ -67,13 +74,13 @@ class OMetaModelTest(TestCase): # Getting without owner field does not return it by default res_none = self.metadata.get_by_name( - entity=Model, fqdn=self.entity.fullyQualifiedName + entity=MLModel, fqdn=self.entity.fullyQualifiedName ) self.assertIsNone(res_none.owner) # We can request specific fields to be added res_owner = self.metadata.get_by_name( - entity=Model, + entity=MLModel, fqdn=self.entity.fullyQualifiedName, fields=["owner", "followers"], ) @@ -87,7 +94,7 @@ class OMetaModelTest(TestCase): self.metadata.create_or_update(data=self.create) res = self.metadata.get_by_name( - entity=Model, fqdn=self.entity.fullyQualifiedName + entity=MLModel, fqdn=self.entity.fullyQualifiedName ) self.assertEqual(res.name, self.entity.name) @@ -100,10 +107,12 @@ class OMetaModelTest(TestCase): # First pick up by name res_name = self.metadata.get_by_name( - entity=Model, fqdn=self.entity.fullyQualifiedName + entity=MLModel, fqdn=self.entity.fullyQualifiedName ) # Then fetch by ID - res = self.metadata.get_by_id(entity=Model, entity_id=str(res_name.id.__root__)) + res = self.metadata.get_by_id( + entity=MLModel, entity_id=str(res_name.id.__root__) + ) self.assertEqual(res_name.id, res.id) @@ -114,7 +123,7 @@ class OMetaModelTest(TestCase): self.metadata.create_or_update(data=self.create) - res = self.metadata.list_entities(entity=Model) + res = self.metadata.list_entities(entity=MLModel) # Fetch our test model. We have already inserted it, so we should find it data = next( @@ -131,18 +140,18 @@ class OMetaModelTest(TestCase): # Find by name res_name = self.metadata.get_by_name( - entity=Model, fqdn=self.entity.fullyQualifiedName + entity=MLModel, fqdn=self.entity.fullyQualifiedName ) # Then fetch by ID res_id = self.metadata.get_by_id( - entity=Model, entity_id=str(res_name.id.__root__) + entity=MLModel, entity_id=str(res_name.id.__root__) ) # Delete - self.metadata.delete(entity=Model, entity_id=str(res_id.id.__root__)) + self.metadata.delete(entity=MLModel, entity_id=str(res_id.id.__root__)) # Then we should not find it - res = self.metadata.list_entities(entity=Model) + res = self.metadata.list_entities(entity=MLModel) assert not next( iter( @@ -152,3 +161,52 @@ class OMetaModelTest(TestCase): ), None, ) + + def test_model_properties(self): + """ + Check that we can create models with MLFeatures and MLHyperParams + """ + + model = CreateMLModelEntityRequest( + name="test-model-properties", + algorithm="algo", + mlFeatures=[ + MlFeature( + name="age", + dataType=FeatureType.numerical, + featureSources=[ + FeatureSource( + name="age", + dataType=FeatureSourceDataType.integer, + fullyQualifiedName="my_service.my_db.my_table.age", + ) + ], + ), + MlFeature( + name="persona", + dataType=FeatureType.categorical, + featureSources=[ + FeatureSource( + name="age", + dataType=FeatureSourceDataType.integer, + fullyQualifiedName="my_service.my_db.my_table.age", + ), + FeatureSource( + name="education", + dataType=FeatureSourceDataType.string, + fullyQualifiedName="my_api.education", + ), + ], + featureAlgorithm="PCA", + ), + ], + mlHyperParameters=[ + MlHyperParameter(name="regularisation", value="0.5"), + MlHyperParameter(name="random", value="hello"), + ], + ) + + res = self.metadata.create_or_update(data=model) + + self.assertIsNotNone(res.mlFeatures) + self.assertIsNotNone(res.mlHyperParameters)