Prepare MlModel lineage (#1879)

This commit is contained in:
Pere Miquel Brull 2021-12-22 02:44:29 +01:00 committed by GitHub
parent 9f48490f2e
commit d303be847e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 129 additions and 6 deletions

View File

@ -0,0 +1,60 @@
"""
Mixin class containing Lineage specific methods
To be used by OpenMetadata class
"""
import logging
from typing import Any, Dict
from metadata.generated.schema.api.lineage.addLineage import AddLineage
from metadata.generated.schema.entity.data.mlmodel import MlModel
from metadata.generated.schema.type.entityLineage import EntitiesEdge
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.ometa.client import REST
from metadata.ingestion.ometa.mixins.lineageMixin import OMetaLineageMixin
logger = logging.getLogger(__name__)
class OMetaMlModelMixin(OMetaLineageMixin):
"""
OpenMetadata API methods related to MlModel.
To be inherited by OpenMetadata
"""
client: REST
def add_mlmodel_lineage(self, model: MlModel) -> Dict[str, Any]:
"""
Iterates over MlModel's Feature Sources and
add the lineage information.
:param model: MlModel containing EntityReferences
:return: List of added lineage information
"""
# Fetch all informed dataSource values
refs = [
source.dataSource
for feature in model.mlFeatures
if model.mlFeatures
for source in feature.featureSources
if feature.featureSources
if source.dataSource
]
# Iterate on the references to add lineage
for entity_ref in refs:
self.add_lineage(
AddLineage(
description="MlModel uses FeatureSource",
edge=EntitiesEdge(
fromEntity=EntityReference(id=model.id, type="mlmodel"),
toEntity=entity_ref,
),
)
)
mlmodel_lineage = self.get_lineage_by_id(MlModel, str(model.id.__root__))
return mlmodel_lineage

View File

@ -39,6 +39,7 @@ from metadata.generated.schema.type.entityHistory import EntityVersionHistory
from metadata.ingestion.ometa.auth_provider import AuthenticationProvider
from metadata.ingestion.ometa.client import REST, APIError, ClientConfig
from metadata.ingestion.ometa.mixins.lineageMixin import OMetaLineageMixin
from metadata.ingestion.ometa.mixins.mlModelMixin import OMetaMlModelMixin
from metadata.ingestion.ometa.mixins.tableMixin import OMetaTableMixin
from metadata.ingestion.ometa.openmetadata_rest import (
Auth0AuthenticationProvider,
@ -75,7 +76,7 @@ class EntityList(Generic[T], BaseModel):
after: str = None
class OpenMetadata(OMetaLineageMixin, OMetaTableMixin, Generic[T, C]):
class OpenMetadata(OMetaMlModelMixin, OMetaTableMixin, Generic[T, C]):
"""
Generic interface to the OpenMetadata API

View File

@ -15,8 +15,16 @@ OpenMetadata high-level API Model test
import uuid
from unittest import TestCase
from metadata.generated.schema.api.data.createDatabase import (
CreateDatabaseEntityRequest,
)
from metadata.generated.schema.api.data.createMlModel import CreateMlModelEntityRequest
from metadata.generated.schema.api.data.createTable import CreateTableEntityRequest
from metadata.generated.schema.api.services.createDatabaseService import (
CreateDatabaseServiceEntityRequest,
)
from metadata.generated.schema.api.teams.createUser import CreateUserEntityRequest
from metadata.generated.schema.entity.data.database import Database
from metadata.generated.schema.entity.data.mlmodel import (
FeatureSource,
FeatureSourceDataType,
@ -25,7 +33,13 @@ from metadata.generated.schema.entity.data.mlmodel import (
MlHyperParameter,
MlModel,
)
from metadata.generated.schema.entity.data.table import Column, DataType, Table
from metadata.generated.schema.entity.services.databaseService import (
DatabaseService,
DatabaseServiceType,
)
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.generated.schema.type.jdbcConnection import JdbcInfo
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
@ -171,13 +185,42 @@ class OMetaModelTest(TestCase):
None,
)
def test_model_properties(self):
def test_mlmodel_properties(self):
"""
Check that we can create models with MLFeatures and MLHyperParams
We can add lineage information
"""
service = CreateDatabaseServiceEntityRequest(
name="test-service-table",
serviceType=DatabaseServiceType.MySQL,
jdbc=JdbcInfo(driverClass="jdbc", connectionUrl="jdbc://localhost"),
)
service_entity = self.metadata.create_or_update(data=service)
create_db = CreateDatabaseEntityRequest(
name="test-db",
service=EntityReference(id=service_entity.id, type="databaseService"),
)
create_db_entity = self.metadata.create_or_update(data=create_db)
create_table1 = CreateTableEntityRequest(
name="test",
database=create_db_entity.id,
columns=[Column(name="education", dataType=DataType.STRING)],
)
table1_entity = self.metadata.create_or_update(data=create_table1)
create_table2 = CreateTableEntityRequest(
name="another_test",
database=create_db_entity.id,
columns=[Column(name="age", dataType=DataType.INT)],
)
table2_entity = self.metadata.create_or_update(data=create_table2)
model = CreateMlModelEntityRequest(
name="test-model-properties",
name="test-model-lineage",
algorithm="algo",
mlFeatures=[
MlFeature(
@ -187,7 +230,9 @@ class OMetaModelTest(TestCase):
FeatureSource(
name="age",
dataType=FeatureSourceDataType.integer,
fullyQualifiedName="my_service.my_db.my_table.age",
dataSource=EntityReference(
id=table2_entity.id, type="table"
),
)
],
),
@ -198,12 +243,19 @@ class OMetaModelTest(TestCase):
FeatureSource(
name="age",
dataType=FeatureSourceDataType.integer,
fullyQualifiedName="my_service.my_db.my_table.age",
dataSource=EntityReference(
id=table2_entity.id, type="table"
),
),
FeatureSource(
name="education",
dataType=FeatureSourceDataType.string,
fullyQualifiedName="my_api.education",
dataSource=EntityReference(
id=table1_entity.id, type="table"
),
),
FeatureSource(
name="city", dataType=FeatureSourceDataType.string
),
],
featureAlgorithm="PCA",
@ -219,3 +271,13 @@ class OMetaModelTest(TestCase):
self.assertIsNotNone(res.mlFeatures)
self.assertIsNotNone(res.mlHyperParameters)
lineage = self.metadata.add_mlmodel_lineage(model=res)
nodes = {node["id"] for node in lineage["nodes"]}
assert nodes == {str(table1_entity.id.__root__), str(table2_entity.id.__root__)}
self.metadata.delete(entity=Table, entity_id=table1_entity.id)
self.metadata.delete(entity=Table, entity_id=table2_entity.id)
self.metadata.delete(entity=Database, entity_id=create_db_entity.id)
self.metadata.delete(entity=DatabaseService, entity_id=service_entity.id)