diff --git a/ingestion/src/metadata/ingestion/ometa/mixins/mlModelMixin.py b/ingestion/src/metadata/ingestion/ometa/mixins/mlModelMixin.py new file mode 100644 index 00000000000..ab18779153f --- /dev/null +++ b/ingestion/src/metadata/ingestion/ometa/mixins/mlModelMixin.py @@ -0,0 +1,60 @@ +""" +Mixin class containing Lineage specific methods + +To be used by OpenMetadata class +""" +import logging +from typing import Any, Dict + +from metadata.generated.schema.api.lineage.addLineage import AddLineage +from metadata.generated.schema.entity.data.mlmodel import MlModel +from metadata.generated.schema.type.entityLineage import EntitiesEdge +from metadata.generated.schema.type.entityReference import EntityReference +from metadata.ingestion.ometa.client import REST +from metadata.ingestion.ometa.mixins.lineageMixin import OMetaLineageMixin + +logger = logging.getLogger(__name__) + + +class OMetaMlModelMixin(OMetaLineageMixin): + """ + OpenMetadata API methods related to MlModel. + + To be inherited by OpenMetadata + """ + + client: REST + + def add_mlmodel_lineage(self, model: MlModel) -> Dict[str, Any]: + """ + Iterates over MlModel's Feature Sources and + add the lineage information. + :param model: MlModel containing EntityReferences + :return: List of added lineage information + """ + + # Fetch all informed dataSource values + refs = [ + source.dataSource + for feature in model.mlFeatures + if model.mlFeatures + for source in feature.featureSources + if feature.featureSources + if source.dataSource + ] + + # Iterate on the references to add lineage + for entity_ref in refs: + self.add_lineage( + AddLineage( + description="MlModel uses FeatureSource", + edge=EntitiesEdge( + fromEntity=EntityReference(id=model.id, type="mlmodel"), + toEntity=entity_ref, + ), + ) + ) + + mlmodel_lineage = self.get_lineage_by_id(MlModel, str(model.id.__root__)) + + return mlmodel_lineage diff --git a/ingestion/src/metadata/ingestion/ometa/ometa_api.py b/ingestion/src/metadata/ingestion/ometa/ometa_api.py index 85caeb10df0..6dc72dc076a 100644 --- a/ingestion/src/metadata/ingestion/ometa/ometa_api.py +++ b/ingestion/src/metadata/ingestion/ometa/ometa_api.py @@ -39,6 +39,7 @@ from metadata.generated.schema.type.entityHistory import EntityVersionHistory from metadata.ingestion.ometa.auth_provider import AuthenticationProvider from metadata.ingestion.ometa.client import REST, APIError, ClientConfig from metadata.ingestion.ometa.mixins.lineageMixin import OMetaLineageMixin +from metadata.ingestion.ometa.mixins.mlModelMixin import OMetaMlModelMixin from metadata.ingestion.ometa.mixins.tableMixin import OMetaTableMixin from metadata.ingestion.ometa.openmetadata_rest import ( Auth0AuthenticationProvider, @@ -75,7 +76,7 @@ class EntityList(Generic[T], BaseModel): after: str = None -class OpenMetadata(OMetaLineageMixin, OMetaTableMixin, Generic[T, C]): +class OpenMetadata(OMetaMlModelMixin, OMetaTableMixin, Generic[T, C]): """ Generic interface to the OpenMetadata API diff --git a/ingestion/tests/integration/ometa/test_ometa_model_api.py b/ingestion/tests/integration/ometa/test_ometa_model_api.py index 9d327c3dabb..cf559f38b5b 100644 --- a/ingestion/tests/integration/ometa/test_ometa_model_api.py +++ b/ingestion/tests/integration/ometa/test_ometa_model_api.py @@ -15,8 +15,16 @@ OpenMetadata high-level API Model test import uuid from unittest import TestCase +from metadata.generated.schema.api.data.createDatabase import ( + CreateDatabaseEntityRequest, +) from metadata.generated.schema.api.data.createMlModel import CreateMlModelEntityRequest +from metadata.generated.schema.api.data.createTable import CreateTableEntityRequest +from metadata.generated.schema.api.services.createDatabaseService import ( + CreateDatabaseServiceEntityRequest, +) from metadata.generated.schema.api.teams.createUser import CreateUserEntityRequest +from metadata.generated.schema.entity.data.database import Database from metadata.generated.schema.entity.data.mlmodel import ( FeatureSource, FeatureSourceDataType, @@ -25,7 +33,13 @@ from metadata.generated.schema.entity.data.mlmodel import ( MlHyperParameter, MlModel, ) +from metadata.generated.schema.entity.data.table import Column, DataType, Table +from metadata.generated.schema.entity.services.databaseService import ( + DatabaseService, + DatabaseServiceType, +) from metadata.generated.schema.type.entityReference import EntityReference +from metadata.generated.schema.type.jdbcConnection import JdbcInfo from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig @@ -171,13 +185,42 @@ class OMetaModelTest(TestCase): None, ) - def test_model_properties(self): + def test_mlmodel_properties(self): """ Check that we can create models with MLFeatures and MLHyperParams + + We can add lineage information """ + service = CreateDatabaseServiceEntityRequest( + name="test-service-table", + serviceType=DatabaseServiceType.MySQL, + jdbc=JdbcInfo(driverClass="jdbc", connectionUrl="jdbc://localhost"), + ) + service_entity = self.metadata.create_or_update(data=service) + + create_db = CreateDatabaseEntityRequest( + name="test-db", + service=EntityReference(id=service_entity.id, type="databaseService"), + ) + create_db_entity = self.metadata.create_or_update(data=create_db) + + create_table1 = CreateTableEntityRequest( + name="test", + database=create_db_entity.id, + columns=[Column(name="education", dataType=DataType.STRING)], + ) + table1_entity = self.metadata.create_or_update(data=create_table1) + + create_table2 = CreateTableEntityRequest( + name="another_test", + database=create_db_entity.id, + columns=[Column(name="age", dataType=DataType.INT)], + ) + table2_entity = self.metadata.create_or_update(data=create_table2) + model = CreateMlModelEntityRequest( - name="test-model-properties", + name="test-model-lineage", algorithm="algo", mlFeatures=[ MlFeature( @@ -187,7 +230,9 @@ class OMetaModelTest(TestCase): FeatureSource( name="age", dataType=FeatureSourceDataType.integer, - fullyQualifiedName="my_service.my_db.my_table.age", + dataSource=EntityReference( + id=table2_entity.id, type="table" + ), ) ], ), @@ -198,12 +243,19 @@ class OMetaModelTest(TestCase): FeatureSource( name="age", dataType=FeatureSourceDataType.integer, - fullyQualifiedName="my_service.my_db.my_table.age", + dataSource=EntityReference( + id=table2_entity.id, type="table" + ), ), FeatureSource( name="education", dataType=FeatureSourceDataType.string, - fullyQualifiedName="my_api.education", + dataSource=EntityReference( + id=table1_entity.id, type="table" + ), + ), + FeatureSource( + name="city", dataType=FeatureSourceDataType.string ), ], featureAlgorithm="PCA", @@ -219,3 +271,13 @@ class OMetaModelTest(TestCase): self.assertIsNotNone(res.mlFeatures) self.assertIsNotNone(res.mlHyperParameters) + + lineage = self.metadata.add_mlmodel_lineage(model=res) + + nodes = {node["id"] for node in lineage["nodes"]} + assert nodes == {str(table1_entity.id.__root__), str(table2_entity.id.__root__)} + + self.metadata.delete(entity=Table, entity_id=table1_entity.id) + self.metadata.delete(entity=Table, entity_id=table2_entity.id) + self.metadata.delete(entity=Database, entity_id=create_db_entity.id) + self.metadata.delete(entity=DatabaseService, entity_id=service_entity.id)