diff --git a/ingestion/examples/sample_data/models/models.json b/ingestion/examples/sample_data/models/models.json index e7d4e1c0b25..06c42db8f38 100644 --- a/ingestion/examples/sample_data/models/models.json +++ b/ingestion/examples/sample_data/models/models.json @@ -4,7 +4,55 @@ "displayName": "ETA Predictions", "description": "ETA Predictions Model", "algorithm": "Neural Network", - "dashboard": "sample_superset.eta_predictions_performance" + "dashboard": "sample_superset.eta_predictions_performance", + "mlStore": { + "storage": "s3://path-to-pickle", + "imageRepository": "https://docker.hub.com/image" + }, + "server": "http://my-server.ai", + "mlFeatures": [ + { + "name": "sales", + "dataType": "numerical", + "description": "Sales amount", + "featureSources": [ + { + "name": "gross_sales", + "dataType": "integer", + "dataSource": "sample_data.ecommerce_db.shopify.fact_sale" + } + ] + }, + { + "name": "persona", + "dataType": "categorical", + "description": "type of buyer", + "featureAlgorithm": "PCA", + "featureSources": [ + { + "name": "membership", + "dataType": "string", + "dataSource": "sample_data.ecommerce_db.shopify.raw_customer" + }, + { + "name": "platform", + "dataType": "string", + "dataSource": "sample_data.ecommerce_db.shopify.raw_customer" + } + ] + } + ], + "mlHyperParameters": [ + { + "name": "regularisation", + "value": "0.5" + }, + { + "name": "random", + "value": "hello" + } + ], + "target": "ETA_time" }, { "name": "forecast_sales", diff --git a/ingestion/src/metadata/ingestion/source/sample_data.py b/ingestion/src/metadata/ingestion/source/sample_data.py index d9238cc3e2f..472b2373d15 100644 --- a/ingestion/src/metadata/ingestion/source/sample_data.py +++ b/ingestion/src/metadata/ingestion/source/sample_data.py @@ -32,6 +32,12 @@ from metadata.generated.schema.api.tests.createTableTest import CreateTableTestR from metadata.generated.schema.entity.data.database import Database from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema from metadata.generated.schema.entity.data.location import Location, LocationType +from metadata.generated.schema.entity.data.mlmodel import ( + FeatureSource, + MlFeature, + MlHyperParameter, + MlStore, +) from metadata.generated.schema.entity.data.pipeline import Pipeline, PipelineStatus from metadata.generated.schema.entity.data.table import Table from metadata.generated.schema.entity.data.topic import Topic @@ -545,6 +551,36 @@ class SampleDataSource(Source[Entity]): pipeline_status=PipelineStatus(**status), ) + def get_ml_feature_sources(self, feature: dict) -> List[FeatureSource]: + """ + Build FeatureSources from sample data + """ + return [ + FeatureSource( + name=source["name"], + dataType=source["dataType"], + dataSource=self.metadata.get_entity_reference( + entity=Table, fqdn=source["dataSource"] + ), + ) + for source in feature.get("featureSources", []) + ] + + def get_ml_features(self, model: dict) -> List[MlFeature]: + """ + Build MlFeatures from sample data + """ + return [ + MlFeature( + name=feature["name"], + dataType=feature["dataType"], + description=feature.get("description"), + featureAlgorithm=feature.get("featureAlgorithm"), + featureSources=self.get_ml_feature_sources(feature), + ) + for feature in model.get("mlFeatures", []) + ] + def ingest_mlmodels(self) -> Iterable[CreateMlModelRequest]: """ Convert sample model data into a Model Entity @@ -571,6 +607,19 @@ class SampleDataSource(Source[Entity]): description=model["description"], algorithm=model["algorithm"], dashboard=EntityReference(id=dashboard_id, type="dashboard"), + mlStore=MlStore( + storage=model["mlStore"]["storage"], + imageRepository=model["mlStore"]["imageRepository"], + ) + if model.get("mlStore") + else None, + server=model.get("server"), + target=model.get("target"), + mlFeatures=self.get_ml_features(model), + mlHyperParameters=[ + MlHyperParameter(name=param["name"], value=param["value"]) + for param in model.get("mlHyperParameters", []) + ], ) yield model_ev except Exception as err: