diff --git a/ingestion/src/metadata/ingestion/ometa/mixins/mlmodel_mixin.py b/ingestion/src/metadata/ingestion/ometa/mixins/mlmodel_mixin.py index ecb77999039..f67dfd8c48d 100644 --- a/ingestion/src/metadata/ingestion/ometa/mixins/mlmodel_mixin.py +++ b/ingestion/src/metadata/ingestion/ometa/mixins/mlmodel_mixin.py @@ -55,11 +55,14 @@ class OMetaMlModelMixin(OMetaLineageMixin): client: REST - def add_mlmodel_lineage(self, model: MlModel) -> Dict[str, Any]: + def add_mlmodel_lineage( + self, model: MlModel, description: Optional[str] = None + ) -> Dict[str, Any]: """ Iterates over MlModel's Feature Sources and add the lineage information. :param model: MlModel containing EntityReferences + :param description: Lineage description :return: List of added lineage information """ @@ -77,8 +80,8 @@ class OMetaMlModelMixin(OMetaLineageMixin): for entity_ref in refs: self.add_lineage( AddLineageRequest( - description="MlModel uses FeatureSource", edge=EntitiesEdge( + description=description, fromEntity=entity_ref, toEntity=self.get_entity_reference( entity=MlModel, fqn=model.fullyQualifiedName diff --git a/ingestion/tests/integration/ometa/test_ometa_model_api.py b/ingestion/tests/integration/ometa/test_ometa_mlmodel_api.py similarity index 94% rename from ingestion/tests/integration/ometa/test_ometa_model_api.py rename to ingestion/tests/integration/ometa/test_ometa_mlmodel_api.py index 3814807a556..687a9dbd617 100644 --- a/ingestion/tests/integration/ometa/test_ometa_model_api.py +++ b/ingestion/tests/integration/ometa/test_ometa_mlmodel_api.py @@ -62,6 +62,7 @@ from metadata.generated.schema.entity.services.mlmodelService import ( from metadata.generated.schema.security.client.openMetadataJWTClientConfig import ( OpenMetadataJWTClientConfig, ) +from metadata.generated.schema.type.entityLineage import EntitiesEdge from metadata.generated.schema.type.entityReference import EntityReference from metadata.ingestion.ometa.ometa_api import OpenMetadata @@ -373,6 +374,24 @@ class OMetaModelTest(TestCase): nodes = {node["id"] for node in lineage["nodes"]} assert nodes == {str(table1_entity.id.__root__), str(table2_entity.id.__root__)} + # If we delete the lineage, the `add_mlmodel_lineage` will take care of it too + for edge in lineage.get("upstreamEdges") or []: + self.metadata.delete_lineage_edge( + edge=EntitiesEdge( + fromEntity=EntityReference(id=edge["fromEntity"], type="table"), + toEntity=EntityReference(id=edge["toEntity"], type="mlmodel"), + ) + ) + + self.metadata.add_mlmodel_lineage(model=res) + + lineage = self.metadata.get_lineage_by_id( + entity=MlModel, entity_id=str(res.id.__root__) + ) + + nodes = {node["id"] for node in lineage["nodes"]} + assert nodes == {str(table1_entity.id.__root__), str(table2_entity.id.__root__)} + self.metadata.delete( entity=DatabaseService, entity_id=service_entity.id,