mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-08-15 12:37:18 +00:00
Prepare MlModel lineage (#1879)
This commit is contained in:
parent
9f48490f2e
commit
d303be847e
@ -0,0 +1,60 @@
|
||||
"""
|
||||
Mixin class containing Lineage specific methods
|
||||
|
||||
To be used by OpenMetadata class
|
||||
"""
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from metadata.generated.schema.api.lineage.addLineage import AddLineage
|
||||
from metadata.generated.schema.entity.data.mlmodel import MlModel
|
||||
from metadata.generated.schema.type.entityLineage import EntitiesEdge
|
||||
from metadata.generated.schema.type.entityReference import EntityReference
|
||||
from metadata.ingestion.ometa.client import REST
|
||||
from metadata.ingestion.ometa.mixins.lineageMixin import OMetaLineageMixin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OMetaMlModelMixin(OMetaLineageMixin):
|
||||
"""
|
||||
OpenMetadata API methods related to MlModel.
|
||||
|
||||
To be inherited by OpenMetadata
|
||||
"""
|
||||
|
||||
client: REST
|
||||
|
||||
def add_mlmodel_lineage(self, model: MlModel) -> Dict[str, Any]:
|
||||
"""
|
||||
Iterates over MlModel's Feature Sources and
|
||||
add the lineage information.
|
||||
:param model: MlModel containing EntityReferences
|
||||
:return: List of added lineage information
|
||||
"""
|
||||
|
||||
# Fetch all informed dataSource values
|
||||
refs = [
|
||||
source.dataSource
|
||||
for feature in model.mlFeatures
|
||||
if model.mlFeatures
|
||||
for source in feature.featureSources
|
||||
if feature.featureSources
|
||||
if source.dataSource
|
||||
]
|
||||
|
||||
# Iterate on the references to add lineage
|
||||
for entity_ref in refs:
|
||||
self.add_lineage(
|
||||
AddLineage(
|
||||
description="MlModel uses FeatureSource",
|
||||
edge=EntitiesEdge(
|
||||
fromEntity=EntityReference(id=model.id, type="mlmodel"),
|
||||
toEntity=entity_ref,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
mlmodel_lineage = self.get_lineage_by_id(MlModel, str(model.id.__root__))
|
||||
|
||||
return mlmodel_lineage
|
@ -39,6 +39,7 @@ from metadata.generated.schema.type.entityHistory import EntityVersionHistory
|
||||
from metadata.ingestion.ometa.auth_provider import AuthenticationProvider
|
||||
from metadata.ingestion.ometa.client import REST, APIError, ClientConfig
|
||||
from metadata.ingestion.ometa.mixins.lineageMixin import OMetaLineageMixin
|
||||
from metadata.ingestion.ometa.mixins.mlModelMixin import OMetaMlModelMixin
|
||||
from metadata.ingestion.ometa.mixins.tableMixin import OMetaTableMixin
|
||||
from metadata.ingestion.ometa.openmetadata_rest import (
|
||||
Auth0AuthenticationProvider,
|
||||
@ -75,7 +76,7 @@ class EntityList(Generic[T], BaseModel):
|
||||
after: str = None
|
||||
|
||||
|
||||
class OpenMetadata(OMetaLineageMixin, OMetaTableMixin, Generic[T, C]):
|
||||
class OpenMetadata(OMetaMlModelMixin, OMetaTableMixin, Generic[T, C]):
|
||||
"""
|
||||
Generic interface to the OpenMetadata API
|
||||
|
||||
|
@ -15,8 +15,16 @@ OpenMetadata high-level API Model test
|
||||
import uuid
|
||||
from unittest import TestCase
|
||||
|
||||
from metadata.generated.schema.api.data.createDatabase import (
|
||||
CreateDatabaseEntityRequest,
|
||||
)
|
||||
from metadata.generated.schema.api.data.createMlModel import CreateMlModelEntityRequest
|
||||
from metadata.generated.schema.api.data.createTable import CreateTableEntityRequest
|
||||
from metadata.generated.schema.api.services.createDatabaseService import (
|
||||
CreateDatabaseServiceEntityRequest,
|
||||
)
|
||||
from metadata.generated.schema.api.teams.createUser import CreateUserEntityRequest
|
||||
from metadata.generated.schema.entity.data.database import Database
|
||||
from metadata.generated.schema.entity.data.mlmodel import (
|
||||
FeatureSource,
|
||||
FeatureSourceDataType,
|
||||
@ -25,7 +33,13 @@ from metadata.generated.schema.entity.data.mlmodel import (
|
||||
MlHyperParameter,
|
||||
MlModel,
|
||||
)
|
||||
from metadata.generated.schema.entity.data.table import Column, DataType, Table
|
||||
from metadata.generated.schema.entity.services.databaseService import (
|
||||
DatabaseService,
|
||||
DatabaseServiceType,
|
||||
)
|
||||
from metadata.generated.schema.type.entityReference import EntityReference
|
||||
from metadata.generated.schema.type.jdbcConnection import JdbcInfo
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
|
||||
|
||||
@ -171,13 +185,42 @@ class OMetaModelTest(TestCase):
|
||||
None,
|
||||
)
|
||||
|
||||
def test_model_properties(self):
|
||||
def test_mlmodel_properties(self):
|
||||
"""
|
||||
Check that we can create models with MLFeatures and MLHyperParams
|
||||
|
||||
We can add lineage information
|
||||
"""
|
||||
|
||||
service = CreateDatabaseServiceEntityRequest(
|
||||
name="test-service-table",
|
||||
serviceType=DatabaseServiceType.MySQL,
|
||||
jdbc=JdbcInfo(driverClass="jdbc", connectionUrl="jdbc://localhost"),
|
||||
)
|
||||
service_entity = self.metadata.create_or_update(data=service)
|
||||
|
||||
create_db = CreateDatabaseEntityRequest(
|
||||
name="test-db",
|
||||
service=EntityReference(id=service_entity.id, type="databaseService"),
|
||||
)
|
||||
create_db_entity = self.metadata.create_or_update(data=create_db)
|
||||
|
||||
create_table1 = CreateTableEntityRequest(
|
||||
name="test",
|
||||
database=create_db_entity.id,
|
||||
columns=[Column(name="education", dataType=DataType.STRING)],
|
||||
)
|
||||
table1_entity = self.metadata.create_or_update(data=create_table1)
|
||||
|
||||
create_table2 = CreateTableEntityRequest(
|
||||
name="another_test",
|
||||
database=create_db_entity.id,
|
||||
columns=[Column(name="age", dataType=DataType.INT)],
|
||||
)
|
||||
table2_entity = self.metadata.create_or_update(data=create_table2)
|
||||
|
||||
model = CreateMlModelEntityRequest(
|
||||
name="test-model-properties",
|
||||
name="test-model-lineage",
|
||||
algorithm="algo",
|
||||
mlFeatures=[
|
||||
MlFeature(
|
||||
@ -187,7 +230,9 @@ class OMetaModelTest(TestCase):
|
||||
FeatureSource(
|
||||
name="age",
|
||||
dataType=FeatureSourceDataType.integer,
|
||||
fullyQualifiedName="my_service.my_db.my_table.age",
|
||||
dataSource=EntityReference(
|
||||
id=table2_entity.id, type="table"
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
@ -198,12 +243,19 @@ class OMetaModelTest(TestCase):
|
||||
FeatureSource(
|
||||
name="age",
|
||||
dataType=FeatureSourceDataType.integer,
|
||||
fullyQualifiedName="my_service.my_db.my_table.age",
|
||||
dataSource=EntityReference(
|
||||
id=table2_entity.id, type="table"
|
||||
),
|
||||
),
|
||||
FeatureSource(
|
||||
name="education",
|
||||
dataType=FeatureSourceDataType.string,
|
||||
fullyQualifiedName="my_api.education",
|
||||
dataSource=EntityReference(
|
||||
id=table1_entity.id, type="table"
|
||||
),
|
||||
),
|
||||
FeatureSource(
|
||||
name="city", dataType=FeatureSourceDataType.string
|
||||
),
|
||||
],
|
||||
featureAlgorithm="PCA",
|
||||
@ -219,3 +271,13 @@ class OMetaModelTest(TestCase):
|
||||
|
||||
self.assertIsNotNone(res.mlFeatures)
|
||||
self.assertIsNotNone(res.mlHyperParameters)
|
||||
|
||||
lineage = self.metadata.add_mlmodel_lineage(model=res)
|
||||
|
||||
nodes = {node["id"] for node in lineage["nodes"]}
|
||||
assert nodes == {str(table1_entity.id.__root__), str(table2_entity.id.__root__)}
|
||||
|
||||
self.metadata.delete(entity=Table, entity_id=table1_entity.id)
|
||||
self.metadata.delete(entity=Table, entity_id=table2_entity.id)
|
||||
self.metadata.delete(entity=Database, entity_id=create_db_entity.id)
|
||||
self.metadata.delete(entity=DatabaseService, entity_id=service_entity.id)
|
||||
|
Loading…
x
Reference in New Issue
Block a user