Update mlmodel sample data (#5060)

This commit is contained in:
Pere Miquel Brull 2022-05-20 08:45:26 +02:00 committed by GitHub
parent fbf9b8609a
commit 9effaa8037
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 98 additions and 1 deletions

View File

@ -4,7 +4,55 @@
"displayName": "ETA Predictions",
"description": "ETA Predictions Model",
"algorithm": "Neural Network",
"dashboard": "sample_superset.eta_predictions_performance"
"dashboard": "sample_superset.eta_predictions_performance",
"mlStore": {
"storage": "s3://path-to-pickle",
"imageRepository": "https://docker.hub.com/image"
},
"server": "http://my-server.ai",
"mlFeatures": [
{
"name": "sales",
"dataType": "numerical",
"description": "Sales amount",
"featureSources": [
{
"name": "gross_sales",
"dataType": "integer",
"dataSource": "sample_data.ecommerce_db.shopify.fact_sale"
}
]
},
{
"name": "persona",
"dataType": "categorical",
"description": "type of buyer",
"featureAlgorithm": "PCA",
"featureSources": [
{
"name": "membership",
"dataType": "string",
"dataSource": "sample_data.ecommerce_db.shopify.raw_customer"
},
{
"name": "platform",
"dataType": "string",
"dataSource": "sample_data.ecommerce_db.shopify.raw_customer"
}
]
}
],
"mlHyperParameters": [
{
"name": "regularisation",
"value": "0.5"
},
{
"name": "random",
"value": "hello"
}
],
"target": "ETA_time"
},
{
"name": "forecast_sales",

View File

@ -32,6 +32,12 @@ from metadata.generated.schema.api.tests.createTableTest import CreateTableTestR
from metadata.generated.schema.entity.data.database import Database
from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema
from metadata.generated.schema.entity.data.location import Location, LocationType
from metadata.generated.schema.entity.data.mlmodel import (
FeatureSource,
MlFeature,
MlHyperParameter,
MlStore,
)
from metadata.generated.schema.entity.data.pipeline import Pipeline, PipelineStatus
from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.entity.data.topic import Topic
@ -545,6 +551,36 @@ class SampleDataSource(Source[Entity]):
pipeline_status=PipelineStatus(**status),
)
def get_ml_feature_sources(self, feature: dict) -> List[FeatureSource]:
"""
Build FeatureSources from sample data
"""
return [
FeatureSource(
name=source["name"],
dataType=source["dataType"],
dataSource=self.metadata.get_entity_reference(
entity=Table, fqdn=source["dataSource"]
),
)
for source in feature.get("featureSources", [])
]
def get_ml_features(self, model: dict) -> List[MlFeature]:
"""
Build MlFeatures from sample data
"""
return [
MlFeature(
name=feature["name"],
dataType=feature["dataType"],
description=feature.get("description"),
featureAlgorithm=feature.get("featureAlgorithm"),
featureSources=self.get_ml_feature_sources(feature),
)
for feature in model.get("mlFeatures", [])
]
def ingest_mlmodels(self) -> Iterable[CreateMlModelRequest]:
"""
Convert sample model data into a Model Entity
@ -571,6 +607,19 @@ class SampleDataSource(Source[Entity]):
description=model["description"],
algorithm=model["algorithm"],
dashboard=EntityReference(id=dashboard_id, type="dashboard"),
mlStore=MlStore(
storage=model["mlStore"]["storage"],
imageRepository=model["mlStore"]["imageRepository"],
)
if model.get("mlStore")
else None,
server=model.get("server"),
target=model.get("target"),
mlFeatures=self.get_ml_features(model),
mlHyperParameters=[
MlHyperParameter(name=param["name"], value=param["value"])
for param in model.get("mlHyperParameters", [])
],
)
yield model_ev
except Exception as err: