From d303be847ea3f15b6455e0744e8b8d60181b65d1 Mon Sep 17 00:00:00 2001
From: Pere Miquel Brull
Date: Wed, 22 Dec 2021 02:44:29 +0100
Subject: [PATCH] Prepare MlModel lineage (#1879)
---
.../ingestion/ometa/mixins/mlModelMixin.py | 60 ++++++++++++++++
.../src/metadata/ingestion/ometa/ometa_api.py | 3 +-
.../integration/ometa/test_ometa_model_api.py | 72 +++++++++++++++++--
3 files changed, 129 insertions(+), 6 deletions(-)
create mode 100644 ingestion/src/metadata/ingestion/ometa/mixins/mlModelMixin.py
diff --git a/ingestion/src/metadata/ingestion/ometa/mixins/mlModelMixin.py b/ingestion/src/metadata/ingestion/ometa/mixins/mlModelMixin.py
new file mode 100644
index 00000000000..ab18779153f
--- /dev/null
+++ b/ingestion/src/metadata/ingestion/ometa/mixins/mlModelMixin.py
@@ -0,0 +1,60 @@
+"""
+Mixin class containing Lineage specific methods
+
+To be used by OpenMetadata class
+"""
+import logging
+from typing import Any, Dict
+
+from metadata.generated.schema.api.lineage.addLineage import AddLineage
+from metadata.generated.schema.entity.data.mlmodel import MlModel
+from metadata.generated.schema.type.entityLineage import EntitiesEdge
+from metadata.generated.schema.type.entityReference import EntityReference
+from metadata.ingestion.ometa.client import REST
+from metadata.ingestion.ometa.mixins.lineageMixin import OMetaLineageMixin
+
+logger = logging.getLogger(__name__)
+
+
+class OMetaMlModelMixin(OMetaLineageMixin):
+ """
+ OpenMetadata API methods related to MlModel.
+
+ To be inherited by OpenMetadata
+ """
+
+ client: REST
+
+ def add_mlmodel_lineage(self, model: MlModel) -> Dict[str, Any]:
+ """
+ Iterates over MlModel's Feature Sources and
+ add the lineage information.
+ :param model: MlModel containing EntityReferences
+ :return: List of added lineage information
+ """
+
+ # Fetch all informed dataSource values
+ refs = [
+ source.dataSource
+ for feature in model.mlFeatures
+ if model.mlFeatures
+ for source in feature.featureSources
+ if feature.featureSources
+ if source.dataSource
+ ]
+
+ # Iterate on the references to add lineage
+ for entity_ref in refs:
+ self.add_lineage(
+ AddLineage(
+ description="MlModel uses FeatureSource",
+ edge=EntitiesEdge(
+ fromEntity=EntityReference(id=model.id, type="mlmodel"),
+ toEntity=entity_ref,
+ ),
+ )
+ )
+
+ mlmodel_lineage = self.get_lineage_by_id(MlModel, str(model.id.__root__))
+
+ return mlmodel_lineage
diff --git a/ingestion/src/metadata/ingestion/ometa/ometa_api.py b/ingestion/src/metadata/ingestion/ometa/ometa_api.py
index 85caeb10df0..6dc72dc076a 100644
--- a/ingestion/src/metadata/ingestion/ometa/ometa_api.py
+++ b/ingestion/src/metadata/ingestion/ometa/ometa_api.py
@@ -39,6 +39,7 @@ from metadata.generated.schema.type.entityHistory import EntityVersionHistory
from metadata.ingestion.ometa.auth_provider import AuthenticationProvider
from metadata.ingestion.ometa.client import REST, APIError, ClientConfig
from metadata.ingestion.ometa.mixins.lineageMixin import OMetaLineageMixin
+from metadata.ingestion.ometa.mixins.mlModelMixin import OMetaMlModelMixin
from metadata.ingestion.ometa.mixins.tableMixin import OMetaTableMixin
from metadata.ingestion.ometa.openmetadata_rest import (
Auth0AuthenticationProvider,
@@ -75,7 +76,7 @@ class EntityList(Generic[T], BaseModel):
after: str = None
-class OpenMetadata(OMetaLineageMixin, OMetaTableMixin, Generic[T, C]):
+class OpenMetadata(OMetaMlModelMixin, OMetaTableMixin, Generic[T, C]):
"""
Generic interface to the OpenMetadata API
diff --git a/ingestion/tests/integration/ometa/test_ometa_model_api.py b/ingestion/tests/integration/ometa/test_ometa_model_api.py
index 9d327c3dabb..cf559f38b5b 100644
--- a/ingestion/tests/integration/ometa/test_ometa_model_api.py
+++ b/ingestion/tests/integration/ometa/test_ometa_model_api.py
@@ -15,8 +15,16 @@ OpenMetadata high-level API Model test
import uuid
from unittest import TestCase
+from metadata.generated.schema.api.data.createDatabase import (
+ CreateDatabaseEntityRequest,
+)
from metadata.generated.schema.api.data.createMlModel import CreateMlModelEntityRequest
+from metadata.generated.schema.api.data.createTable import CreateTableEntityRequest
+from metadata.generated.schema.api.services.createDatabaseService import (
+ CreateDatabaseServiceEntityRequest,
+)
from metadata.generated.schema.api.teams.createUser import CreateUserEntityRequest
+from metadata.generated.schema.entity.data.database import Database
from metadata.generated.schema.entity.data.mlmodel import (
FeatureSource,
FeatureSourceDataType,
@@ -25,7 +33,13 @@ from metadata.generated.schema.entity.data.mlmodel import (
MlHyperParameter,
MlModel,
)
+from metadata.generated.schema.entity.data.table import Column, DataType, Table
+from metadata.generated.schema.entity.services.databaseService import (
+ DatabaseService,
+ DatabaseServiceType,
+)
from metadata.generated.schema.type.entityReference import EntityReference
+from metadata.generated.schema.type.jdbcConnection import JdbcInfo
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
@@ -171,13 +185,42 @@ class OMetaModelTest(TestCase):
None,
)
- def test_model_properties(self):
+ def test_mlmodel_properties(self):
"""
Check that we can create models with MLFeatures and MLHyperParams
+
+ We can add lineage information
"""
+ service = CreateDatabaseServiceEntityRequest(
+ name="test-service-table",
+ serviceType=DatabaseServiceType.MySQL,
+ jdbc=JdbcInfo(driverClass="jdbc", connectionUrl="jdbc://localhost"),
+ )
+ service_entity = self.metadata.create_or_update(data=service)
+
+ create_db = CreateDatabaseEntityRequest(
+ name="test-db",
+ service=EntityReference(id=service_entity.id, type="databaseService"),
+ )
+ create_db_entity = self.metadata.create_or_update(data=create_db)
+
+ create_table1 = CreateTableEntityRequest(
+ name="test",
+ database=create_db_entity.id,
+ columns=[Column(name="education", dataType=DataType.STRING)],
+ )
+ table1_entity = self.metadata.create_or_update(data=create_table1)
+
+ create_table2 = CreateTableEntityRequest(
+ name="another_test",
+ database=create_db_entity.id,
+ columns=[Column(name="age", dataType=DataType.INT)],
+ )
+ table2_entity = self.metadata.create_or_update(data=create_table2)
+
model = CreateMlModelEntityRequest(
- name="test-model-properties",
+ name="test-model-lineage",
algorithm="algo",
mlFeatures=[
MlFeature(
@@ -187,7 +230,9 @@ class OMetaModelTest(TestCase):
FeatureSource(
name="age",
dataType=FeatureSourceDataType.integer,
- fullyQualifiedName="my_service.my_db.my_table.age",
+ dataSource=EntityReference(
+ id=table2_entity.id, type="table"
+ ),
)
],
),
@@ -198,12 +243,19 @@ class OMetaModelTest(TestCase):
FeatureSource(
name="age",
dataType=FeatureSourceDataType.integer,
- fullyQualifiedName="my_service.my_db.my_table.age",
+ dataSource=EntityReference(
+ id=table2_entity.id, type="table"
+ ),
),
FeatureSource(
name="education",
dataType=FeatureSourceDataType.string,
- fullyQualifiedName="my_api.education",
+ dataSource=EntityReference(
+ id=table1_entity.id, type="table"
+ ),
+ ),
+ FeatureSource(
+ name="city", dataType=FeatureSourceDataType.string
),
],
featureAlgorithm="PCA",
@@ -219,3 +271,13 @@ class OMetaModelTest(TestCase):
self.assertIsNotNone(res.mlFeatures)
self.assertIsNotNone(res.mlHyperParameters)
+
+ lineage = self.metadata.add_mlmodel_lineage(model=res)
+
+ nodes = {node["id"] for node in lineage["nodes"]}
+ assert nodes == {str(table1_entity.id.__root__), str(table2_entity.id.__root__)}
+
+ self.metadata.delete(entity=Table, entity_id=table1_entity.id)
+ self.metadata.delete(entity=Table, entity_id=table2_entity.id)
+ self.metadata.delete(entity=Database, entity_id=create_db_entity.id)
+ self.metadata.delete(entity=DatabaseService, entity_id=service_entity.id)