mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-08-15 12:37:18 +00:00
[Issue-1099] - Update MLModel Ingestion (#1179)
* Refactor MLModel * Prepare MLModel properties example * Update API tags to MLModels * Update API descriptions for MLModel * Rename tags to mlModels
This commit is contained in:
parent
69f9eeb718
commit
ffd7818978
@ -120,8 +120,8 @@ public class MLModelResource {
|
|||||||
|
|
||||||
@GET
|
@GET
|
||||||
@Valid
|
@Valid
|
||||||
@Operation(summary = "List Models", tags = "models",
|
@Operation(summary = "List ML Models", tags = "mlModels",
|
||||||
description = "Get a list of models. Use `fields` parameter to get only necessary fields. " +
|
description = "Get a list of ML Models. Use `fields` parameter to get only necessary fields. " +
|
||||||
" Use cursor-based pagination to limit the number " +
|
" Use cursor-based pagination to limit the number " +
|
||||||
"entries in the list using `limit` and `before` or `after` query params.",
|
"entries in the list using `limit` and `before` or `after` query params.",
|
||||||
responses = {
|
responses = {
|
||||||
@ -162,8 +162,8 @@ public class MLModelResource {
|
|||||||
|
|
||||||
@GET
|
@GET
|
||||||
@Path("/{id}")
|
@Path("/{id}")
|
||||||
@Operation(summary = "Get a model", tags = "models",
|
@Operation(summary = "Get an ML Model", tags = "mlModels",
|
||||||
description = "Get a model by `id`.",
|
description = "Get an ML Model by `id`.",
|
||||||
responses = {
|
responses = {
|
||||||
@ApiResponse(responseCode = "200", description = "The model",
|
@ApiResponse(responseCode = "200", description = "The model",
|
||||||
content = @Content(mediaType = "application/json",
|
content = @Content(mediaType = "application/json",
|
||||||
@ -182,8 +182,8 @@ public class MLModelResource {
|
|||||||
|
|
||||||
@GET
|
@GET
|
||||||
@Path("/name/{fqn}")
|
@Path("/name/{fqn}")
|
||||||
@Operation(summary = "Get a model by name", tags = "models",
|
@Operation(summary = "Get an ML Model by name", tags = "mlModels",
|
||||||
description = "Get a model by fully qualified name.",
|
description = "Get an ML Model by fully qualified name.",
|
||||||
responses = {
|
responses = {
|
||||||
@ApiResponse(responseCode = "200", description = "The model",
|
@ApiResponse(responseCode = "200", description = "The model",
|
||||||
content = @Content(mediaType = "application/json",
|
content = @Content(mediaType = "application/json",
|
||||||
@ -202,8 +202,8 @@ public class MLModelResource {
|
|||||||
|
|
||||||
|
|
||||||
@POST
|
@POST
|
||||||
@Operation(summary = "Create a model", tags = "models",
|
@Operation(summary = "Create an ML Model", tags = "mlModels",
|
||||||
description = "Create a new model.",
|
description = "Create a new ML Model.",
|
||||||
responses = {
|
responses = {
|
||||||
@ApiResponse(responseCode = "200", description = "The model",
|
@ApiResponse(responseCode = "200", description = "The model",
|
||||||
content = @Content(mediaType = "application/json",
|
content = @Content(mediaType = "application/json",
|
||||||
@ -220,8 +220,8 @@ public class MLModelResource {
|
|||||||
|
|
||||||
@PATCH
|
@PATCH
|
||||||
@Path("/{id}")
|
@Path("/{id}")
|
||||||
@Operation(summary = "Update a model", tags = "models",
|
@Operation(summary = "Update an ML Model", tags = "mlModels",
|
||||||
description = "Update an existing model using JsonPatch.",
|
description = "Update an existing ML Model using JsonPatch.",
|
||||||
externalDocs = @ExternalDocumentation(description = "JsonPatch RFC",
|
externalDocs = @ExternalDocumentation(description = "JsonPatch RFC",
|
||||||
url = "https://tools.ietf.org/html/rfc6902"))
|
url = "https://tools.ietf.org/html/rfc6902"))
|
||||||
@Consumes(MediaType.APPLICATION_JSON_PATCH_JSON)
|
@Consumes(MediaType.APPLICATION_JSON_PATCH_JSON)
|
||||||
@ -246,8 +246,8 @@ public class MLModelResource {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@PUT
|
@PUT
|
||||||
@Operation(summary = "Create or update a model", tags = "models",
|
@Operation(summary = "Create or update an ML Model", tags = "mlModels",
|
||||||
description = "Create a new model, if it does not exist or update an existing model.",
|
description = "Create a new ML Model, if it does not exist or update an existing model.",
|
||||||
responses = {
|
responses = {
|
||||||
@ApiResponse(responseCode = "200", description = "The model",
|
@ApiResponse(responseCode = "200", description = "The model",
|
||||||
content = @Content(mediaType = "application/json",
|
content = @Content(mediaType = "application/json",
|
||||||
@ -265,7 +265,7 @@ public class MLModelResource {
|
|||||||
|
|
||||||
@PUT
|
@PUT
|
||||||
@Path("/{id}/followers")
|
@Path("/{id}/followers")
|
||||||
@Operation(summary = "Add a follower", tags = "models",
|
@Operation(summary = "Add a follower", tags = "mlModels",
|
||||||
description = "Add a user identified by `userId` as follower of this model",
|
description = "Add a user identified by `userId` as follower of this model",
|
||||||
responses = {
|
responses = {
|
||||||
@ApiResponse(responseCode = "200", description = "OK"),
|
@ApiResponse(responseCode = "200", description = "OK"),
|
||||||
@ -284,7 +284,7 @@ public class MLModelResource {
|
|||||||
|
|
||||||
@DELETE
|
@DELETE
|
||||||
@Path("/{id}/followers/{userId}")
|
@Path("/{id}/followers/{userId}")
|
||||||
@Operation(summary = "Remove a follower", tags = "models",
|
@Operation(summary = "Remove a follower", tags = "mlModels",
|
||||||
description = "Remove the user identified `userId` as a follower of the model.")
|
description = "Remove the user identified `userId` as a follower of the model.")
|
||||||
public Response deleteFollower(@Context UriInfo uriInfo,
|
public Response deleteFollower(@Context UriInfo uriInfo,
|
||||||
@Context SecurityContext securityContext,
|
@Context SecurityContext securityContext,
|
||||||
@ -300,8 +300,8 @@ public class MLModelResource {
|
|||||||
|
|
||||||
@DELETE
|
@DELETE
|
||||||
@Path("/{id}")
|
@Path("/{id}")
|
||||||
@Operation(summary = "Delete a Model", tags = "models",
|
@Operation(summary = "Delete an ML Model", tags = "mlModels",
|
||||||
description = "Delete a model by `id`.",
|
description = "Delete an ML Model by `id`.",
|
||||||
responses = {
|
responses = {
|
||||||
@ApiResponse(responseCode = "200", description = "OK"),
|
@ApiResponse(responseCode = "200", description = "OK"),
|
||||||
@ApiResponse(responseCode = "404", description = "model for instance {id} is not found")
|
@ApiResponse(responseCode = "404", description = "model for instance {id} is not found")
|
||||||
|
@ -4,9 +4,16 @@ OpenMetadata high-level API Model test
|
|||||||
import uuid
|
import uuid
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
|
||||||
from metadata.generated.schema.api.data.createModel import CreateModelEntityRequest
|
from metadata.generated.schema.api.data.createMLModel import CreateMLModelEntityRequest
|
||||||
from metadata.generated.schema.api.teams.createUser import CreateUserEntityRequest
|
from metadata.generated.schema.api.teams.createUser import CreateUserEntityRequest
|
||||||
from metadata.generated.schema.entity.data.model import Model
|
from metadata.generated.schema.entity.data.mlmodel import (
|
||||||
|
FeatureSource,
|
||||||
|
FeatureSourceDataType,
|
||||||
|
FeatureType,
|
||||||
|
MlFeature,
|
||||||
|
MlHyperParameter,
|
||||||
|
MLModel,
|
||||||
|
)
|
||||||
from metadata.generated.schema.type.entityReference import EntityReference
|
from metadata.generated.schema.type.entityReference import EntityReference
|
||||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||||
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
|
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
|
||||||
@ -28,13 +35,13 @@ class OMetaModelTest(TestCase):
|
|||||||
)
|
)
|
||||||
owner = EntityReference(id=user.id, type="user")
|
owner = EntityReference(id=user.id, type="user")
|
||||||
|
|
||||||
entity = Model(
|
entity = MLModel(
|
||||||
id=uuid.uuid4(),
|
id=uuid.uuid4(),
|
||||||
name="test-model",
|
name="test-model",
|
||||||
algorithm="algo",
|
algorithm="algo",
|
||||||
fullyQualifiedName="test-model",
|
fullyQualifiedName="test-model",
|
||||||
)
|
)
|
||||||
create = CreateModelEntityRequest(name="test-model", algorithm="algo")
|
create = CreateMLModelEntityRequest(name="test-model", algorithm="algo")
|
||||||
|
|
||||||
def test_create(self):
|
def test_create(self):
|
||||||
"""
|
"""
|
||||||
@ -56,7 +63,7 @@ class OMetaModelTest(TestCase):
|
|||||||
|
|
||||||
updated = self.create.dict(exclude_unset=True)
|
updated = self.create.dict(exclude_unset=True)
|
||||||
updated["owner"] = self.owner
|
updated["owner"] = self.owner
|
||||||
updated_entity = CreateModelEntityRequest(**updated)
|
updated_entity = CreateMLModelEntityRequest(**updated)
|
||||||
|
|
||||||
res = self.metadata.create_or_update(data=updated_entity)
|
res = self.metadata.create_or_update(data=updated_entity)
|
||||||
|
|
||||||
@ -67,13 +74,13 @@ class OMetaModelTest(TestCase):
|
|||||||
|
|
||||||
# Getting without owner field does not return it by default
|
# Getting without owner field does not return it by default
|
||||||
res_none = self.metadata.get_by_name(
|
res_none = self.metadata.get_by_name(
|
||||||
entity=Model, fqdn=self.entity.fullyQualifiedName
|
entity=MLModel, fqdn=self.entity.fullyQualifiedName
|
||||||
)
|
)
|
||||||
self.assertIsNone(res_none.owner)
|
self.assertIsNone(res_none.owner)
|
||||||
|
|
||||||
# We can request specific fields to be added
|
# We can request specific fields to be added
|
||||||
res_owner = self.metadata.get_by_name(
|
res_owner = self.metadata.get_by_name(
|
||||||
entity=Model,
|
entity=MLModel,
|
||||||
fqdn=self.entity.fullyQualifiedName,
|
fqdn=self.entity.fullyQualifiedName,
|
||||||
fields=["owner", "followers"],
|
fields=["owner", "followers"],
|
||||||
)
|
)
|
||||||
@ -87,7 +94,7 @@ class OMetaModelTest(TestCase):
|
|||||||
self.metadata.create_or_update(data=self.create)
|
self.metadata.create_or_update(data=self.create)
|
||||||
|
|
||||||
res = self.metadata.get_by_name(
|
res = self.metadata.get_by_name(
|
||||||
entity=Model, fqdn=self.entity.fullyQualifiedName
|
entity=MLModel, fqdn=self.entity.fullyQualifiedName
|
||||||
)
|
)
|
||||||
self.assertEqual(res.name, self.entity.name)
|
self.assertEqual(res.name, self.entity.name)
|
||||||
|
|
||||||
@ -100,10 +107,12 @@ class OMetaModelTest(TestCase):
|
|||||||
|
|
||||||
# First pick up by name
|
# First pick up by name
|
||||||
res_name = self.metadata.get_by_name(
|
res_name = self.metadata.get_by_name(
|
||||||
entity=Model, fqdn=self.entity.fullyQualifiedName
|
entity=MLModel, fqdn=self.entity.fullyQualifiedName
|
||||||
)
|
)
|
||||||
# Then fetch by ID
|
# Then fetch by ID
|
||||||
res = self.metadata.get_by_id(entity=Model, entity_id=str(res_name.id.__root__))
|
res = self.metadata.get_by_id(
|
||||||
|
entity=MLModel, entity_id=str(res_name.id.__root__)
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(res_name.id, res.id)
|
self.assertEqual(res_name.id, res.id)
|
||||||
|
|
||||||
@ -114,7 +123,7 @@ class OMetaModelTest(TestCase):
|
|||||||
|
|
||||||
self.metadata.create_or_update(data=self.create)
|
self.metadata.create_or_update(data=self.create)
|
||||||
|
|
||||||
res = self.metadata.list_entities(entity=Model)
|
res = self.metadata.list_entities(entity=MLModel)
|
||||||
|
|
||||||
# Fetch our test model. We have already inserted it, so we should find it
|
# Fetch our test model. We have already inserted it, so we should find it
|
||||||
data = next(
|
data = next(
|
||||||
@ -131,18 +140,18 @@ class OMetaModelTest(TestCase):
|
|||||||
|
|
||||||
# Find by name
|
# Find by name
|
||||||
res_name = self.metadata.get_by_name(
|
res_name = self.metadata.get_by_name(
|
||||||
entity=Model, fqdn=self.entity.fullyQualifiedName
|
entity=MLModel, fqdn=self.entity.fullyQualifiedName
|
||||||
)
|
)
|
||||||
# Then fetch by ID
|
# Then fetch by ID
|
||||||
res_id = self.metadata.get_by_id(
|
res_id = self.metadata.get_by_id(
|
||||||
entity=Model, entity_id=str(res_name.id.__root__)
|
entity=MLModel, entity_id=str(res_name.id.__root__)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Delete
|
# Delete
|
||||||
self.metadata.delete(entity=Model, entity_id=str(res_id.id.__root__))
|
self.metadata.delete(entity=MLModel, entity_id=str(res_id.id.__root__))
|
||||||
|
|
||||||
# Then we should not find it
|
# Then we should not find it
|
||||||
res = self.metadata.list_entities(entity=Model)
|
res = self.metadata.list_entities(entity=MLModel)
|
||||||
|
|
||||||
assert not next(
|
assert not next(
|
||||||
iter(
|
iter(
|
||||||
@ -152,3 +161,52 @@ class OMetaModelTest(TestCase):
|
|||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_model_properties(self):
|
||||||
|
"""
|
||||||
|
Check that we can create models with MLFeatures and MLHyperParams
|
||||||
|
"""
|
||||||
|
|
||||||
|
model = CreateMLModelEntityRequest(
|
||||||
|
name="test-model-properties",
|
||||||
|
algorithm="algo",
|
||||||
|
mlFeatures=[
|
||||||
|
MlFeature(
|
||||||
|
name="age",
|
||||||
|
dataType=FeatureType.numerical,
|
||||||
|
featureSources=[
|
||||||
|
FeatureSource(
|
||||||
|
name="age",
|
||||||
|
dataType=FeatureSourceDataType.integer,
|
||||||
|
fullyQualifiedName="my_service.my_db.my_table.age",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
MlFeature(
|
||||||
|
name="persona",
|
||||||
|
dataType=FeatureType.categorical,
|
||||||
|
featureSources=[
|
||||||
|
FeatureSource(
|
||||||
|
name="age",
|
||||||
|
dataType=FeatureSourceDataType.integer,
|
||||||
|
fullyQualifiedName="my_service.my_db.my_table.age",
|
||||||
|
),
|
||||||
|
FeatureSource(
|
||||||
|
name="education",
|
||||||
|
dataType=FeatureSourceDataType.string,
|
||||||
|
fullyQualifiedName="my_api.education",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
featureAlgorithm="PCA",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
mlHyperParameters=[
|
||||||
|
MlHyperParameter(name="regularisation", value="0.5"),
|
||||||
|
MlHyperParameter(name="random", value="hello"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
res = self.metadata.create_or_update(data=model)
|
||||||
|
|
||||||
|
self.assertIsNotNone(res.mlFeatures)
|
||||||
|
self.assertIsNotNone(res.mlHyperParameters)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user