mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-08-15 12:37:18 +00:00
[Issue-1099] - Update MLModel Ingestion (#1179)
* Refactor MLModel * Prepare MLModel properties example * Update API tags to MLModels * Update API descriptions for MLModel * Rename tags to mlModels
This commit is contained in:
parent
69f9eeb718
commit
ffd7818978
@ -120,8 +120,8 @@ public class MLModelResource {
|
||||
|
||||
@GET
|
||||
@Valid
|
||||
@Operation(summary = "List Models", tags = "models",
|
||||
description = "Get a list of models. Use `fields` parameter to get only necessary fields. " +
|
||||
@Operation(summary = "List ML Models", tags = "mlModels",
|
||||
description = "Get a list of ML Models. Use `fields` parameter to get only necessary fields. " +
|
||||
" Use cursor-based pagination to limit the number " +
|
||||
"entries in the list using `limit` and `before` or `after` query params.",
|
||||
responses = {
|
||||
@ -162,8 +162,8 @@ public class MLModelResource {
|
||||
|
||||
@GET
|
||||
@Path("/{id}")
|
||||
@Operation(summary = "Get a model", tags = "models",
|
||||
description = "Get a model by `id`.",
|
||||
@Operation(summary = "Get an ML Model", tags = "mlModels",
|
||||
description = "Get an ML Model by `id`.",
|
||||
responses = {
|
||||
@ApiResponse(responseCode = "200", description = "The model",
|
||||
content = @Content(mediaType = "application/json",
|
||||
@ -182,8 +182,8 @@ public class MLModelResource {
|
||||
|
||||
@GET
|
||||
@Path("/name/{fqn}")
|
||||
@Operation(summary = "Get a model by name", tags = "models",
|
||||
description = "Get a model by fully qualified name.",
|
||||
@Operation(summary = "Get an ML Model by name", tags = "mlModels",
|
||||
description = "Get an ML Model by fully qualified name.",
|
||||
responses = {
|
||||
@ApiResponse(responseCode = "200", description = "The model",
|
||||
content = @Content(mediaType = "application/json",
|
||||
@ -202,8 +202,8 @@ public class MLModelResource {
|
||||
|
||||
|
||||
@POST
|
||||
@Operation(summary = "Create a model", tags = "models",
|
||||
description = "Create a new model.",
|
||||
@Operation(summary = "Create an ML Model", tags = "mlModels",
|
||||
description = "Create a new ML Model.",
|
||||
responses = {
|
||||
@ApiResponse(responseCode = "200", description = "The model",
|
||||
content = @Content(mediaType = "application/json",
|
||||
@ -220,8 +220,8 @@ public class MLModelResource {
|
||||
|
||||
@PATCH
|
||||
@Path("/{id}")
|
||||
@Operation(summary = "Update a model", tags = "models",
|
||||
description = "Update an existing model using JsonPatch.",
|
||||
@Operation(summary = "Update an ML Model", tags = "mlModels",
|
||||
description = "Update an existing ML Model using JsonPatch.",
|
||||
externalDocs = @ExternalDocumentation(description = "JsonPatch RFC",
|
||||
url = "https://tools.ietf.org/html/rfc6902"))
|
||||
@Consumes(MediaType.APPLICATION_JSON_PATCH_JSON)
|
||||
@ -246,8 +246,8 @@ public class MLModelResource {
|
||||
}
|
||||
|
||||
@PUT
|
||||
@Operation(summary = "Create or update a model", tags = "models",
|
||||
description = "Create a new model, if it does not exist or update an existing model.",
|
||||
@Operation(summary = "Create or update an ML Model", tags = "mlModels",
|
||||
description = "Create a new ML Model, if it does not exist or update an existing model.",
|
||||
responses = {
|
||||
@ApiResponse(responseCode = "200", description = "The model",
|
||||
content = @Content(mediaType = "application/json",
|
||||
@ -265,7 +265,7 @@ public class MLModelResource {
|
||||
|
||||
@PUT
|
||||
@Path("/{id}/followers")
|
||||
@Operation(summary = "Add a follower", tags = "models",
|
||||
@Operation(summary = "Add a follower", tags = "mlModels",
|
||||
description = "Add a user identified by `userId` as follower of this model",
|
||||
responses = {
|
||||
@ApiResponse(responseCode = "200", description = "OK"),
|
||||
@ -284,7 +284,7 @@ public class MLModelResource {
|
||||
|
||||
@DELETE
|
||||
@Path("/{id}/followers/{userId}")
|
||||
@Operation(summary = "Remove a follower", tags = "models",
|
||||
@Operation(summary = "Remove a follower", tags = "mlModels",
|
||||
description = "Remove the user identified `userId` as a follower of the model.")
|
||||
public Response deleteFollower(@Context UriInfo uriInfo,
|
||||
@Context SecurityContext securityContext,
|
||||
@ -300,8 +300,8 @@ public class MLModelResource {
|
||||
|
||||
@DELETE
|
||||
@Path("/{id}")
|
||||
@Operation(summary = "Delete a Model", tags = "models",
|
||||
description = "Delete a model by `id`.",
|
||||
@Operation(summary = "Delete an ML Model", tags = "mlModels",
|
||||
description = "Delete an ML Model by `id`.",
|
||||
responses = {
|
||||
@ApiResponse(responseCode = "200", description = "OK"),
|
||||
@ApiResponse(responseCode = "404", description = "model for instance {id} is not found")
|
||||
|
@ -4,9 +4,16 @@ OpenMetadata high-level API Model test
|
||||
import uuid
|
||||
from unittest import TestCase
|
||||
|
||||
from metadata.generated.schema.api.data.createModel import CreateModelEntityRequest
|
||||
from metadata.generated.schema.api.data.createMLModel import CreateMLModelEntityRequest
|
||||
from metadata.generated.schema.api.teams.createUser import CreateUserEntityRequest
|
||||
from metadata.generated.schema.entity.data.model import Model
|
||||
from metadata.generated.schema.entity.data.mlmodel import (
|
||||
FeatureSource,
|
||||
FeatureSourceDataType,
|
||||
FeatureType,
|
||||
MlFeature,
|
||||
MlHyperParameter,
|
||||
MLModel,
|
||||
)
|
||||
from metadata.generated.schema.type.entityReference import EntityReference
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
|
||||
@ -28,13 +35,13 @@ class OMetaModelTest(TestCase):
|
||||
)
|
||||
owner = EntityReference(id=user.id, type="user")
|
||||
|
||||
entity = Model(
|
||||
entity = MLModel(
|
||||
id=uuid.uuid4(),
|
||||
name="test-model",
|
||||
algorithm="algo",
|
||||
fullyQualifiedName="test-model",
|
||||
)
|
||||
create = CreateModelEntityRequest(name="test-model", algorithm="algo")
|
||||
create = CreateMLModelEntityRequest(name="test-model", algorithm="algo")
|
||||
|
||||
def test_create(self):
|
||||
"""
|
||||
@ -56,7 +63,7 @@ class OMetaModelTest(TestCase):
|
||||
|
||||
updated = self.create.dict(exclude_unset=True)
|
||||
updated["owner"] = self.owner
|
||||
updated_entity = CreateModelEntityRequest(**updated)
|
||||
updated_entity = CreateMLModelEntityRequest(**updated)
|
||||
|
||||
res = self.metadata.create_or_update(data=updated_entity)
|
||||
|
||||
@ -67,13 +74,13 @@ class OMetaModelTest(TestCase):
|
||||
|
||||
# Getting without owner field does not return it by default
|
||||
res_none = self.metadata.get_by_name(
|
||||
entity=Model, fqdn=self.entity.fullyQualifiedName
|
||||
entity=MLModel, fqdn=self.entity.fullyQualifiedName
|
||||
)
|
||||
self.assertIsNone(res_none.owner)
|
||||
|
||||
# We can request specific fields to be added
|
||||
res_owner = self.metadata.get_by_name(
|
||||
entity=Model,
|
||||
entity=MLModel,
|
||||
fqdn=self.entity.fullyQualifiedName,
|
||||
fields=["owner", "followers"],
|
||||
)
|
||||
@ -87,7 +94,7 @@ class OMetaModelTest(TestCase):
|
||||
self.metadata.create_or_update(data=self.create)
|
||||
|
||||
res = self.metadata.get_by_name(
|
||||
entity=Model, fqdn=self.entity.fullyQualifiedName
|
||||
entity=MLModel, fqdn=self.entity.fullyQualifiedName
|
||||
)
|
||||
self.assertEqual(res.name, self.entity.name)
|
||||
|
||||
@ -100,10 +107,12 @@ class OMetaModelTest(TestCase):
|
||||
|
||||
# First pick up by name
|
||||
res_name = self.metadata.get_by_name(
|
||||
entity=Model, fqdn=self.entity.fullyQualifiedName
|
||||
entity=MLModel, fqdn=self.entity.fullyQualifiedName
|
||||
)
|
||||
# Then fetch by ID
|
||||
res = self.metadata.get_by_id(entity=Model, entity_id=str(res_name.id.__root__))
|
||||
res = self.metadata.get_by_id(
|
||||
entity=MLModel, entity_id=str(res_name.id.__root__)
|
||||
)
|
||||
|
||||
self.assertEqual(res_name.id, res.id)
|
||||
|
||||
@ -114,7 +123,7 @@ class OMetaModelTest(TestCase):
|
||||
|
||||
self.metadata.create_or_update(data=self.create)
|
||||
|
||||
res = self.metadata.list_entities(entity=Model)
|
||||
res = self.metadata.list_entities(entity=MLModel)
|
||||
|
||||
# Fetch our test model. We have already inserted it, so we should find it
|
||||
data = next(
|
||||
@ -131,18 +140,18 @@ class OMetaModelTest(TestCase):
|
||||
|
||||
# Find by name
|
||||
res_name = self.metadata.get_by_name(
|
||||
entity=Model, fqdn=self.entity.fullyQualifiedName
|
||||
entity=MLModel, fqdn=self.entity.fullyQualifiedName
|
||||
)
|
||||
# Then fetch by ID
|
||||
res_id = self.metadata.get_by_id(
|
||||
entity=Model, entity_id=str(res_name.id.__root__)
|
||||
entity=MLModel, entity_id=str(res_name.id.__root__)
|
||||
)
|
||||
|
||||
# Delete
|
||||
self.metadata.delete(entity=Model, entity_id=str(res_id.id.__root__))
|
||||
self.metadata.delete(entity=MLModel, entity_id=str(res_id.id.__root__))
|
||||
|
||||
# Then we should not find it
|
||||
res = self.metadata.list_entities(entity=Model)
|
||||
res = self.metadata.list_entities(entity=MLModel)
|
||||
|
||||
assert not next(
|
||||
iter(
|
||||
@ -152,3 +161,52 @@ class OMetaModelTest(TestCase):
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
def test_model_properties(self):
|
||||
"""
|
||||
Check that we can create models with MLFeatures and MLHyperParams
|
||||
"""
|
||||
|
||||
model = CreateMLModelEntityRequest(
|
||||
name="test-model-properties",
|
||||
algorithm="algo",
|
||||
mlFeatures=[
|
||||
MlFeature(
|
||||
name="age",
|
||||
dataType=FeatureType.numerical,
|
||||
featureSources=[
|
||||
FeatureSource(
|
||||
name="age",
|
||||
dataType=FeatureSourceDataType.integer,
|
||||
fullyQualifiedName="my_service.my_db.my_table.age",
|
||||
)
|
||||
],
|
||||
),
|
||||
MlFeature(
|
||||
name="persona",
|
||||
dataType=FeatureType.categorical,
|
||||
featureSources=[
|
||||
FeatureSource(
|
||||
name="age",
|
||||
dataType=FeatureSourceDataType.integer,
|
||||
fullyQualifiedName="my_service.my_db.my_table.age",
|
||||
),
|
||||
FeatureSource(
|
||||
name="education",
|
||||
dataType=FeatureSourceDataType.string,
|
||||
fullyQualifiedName="my_api.education",
|
||||
),
|
||||
],
|
||||
featureAlgorithm="PCA",
|
||||
),
|
||||
],
|
||||
mlHyperParameters=[
|
||||
MlHyperParameter(name="regularisation", value="0.5"),
|
||||
MlHyperParameter(name="random", value="hello"),
|
||||
],
|
||||
)
|
||||
|
||||
res = self.metadata.create_or_update(data=model)
|
||||
|
||||
self.assertIsNotNone(res.mlFeatures)
|
||||
self.assertIsNotNone(res.mlHyperParameters)
|
||||
|
Loading…
x
Reference in New Issue
Block a user