Fix #6005 - Create ML Model Topology (#6080)

Fix #6005 - Create ML Model Topology  (#6080)
This commit is contained in:
Pere Miquel Brull 2022-07-14 15:07:53 +02:00 committed by GitHub
parent 479a8de486
commit 0a921abf8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 273 additions and 92 deletions

View File

@ -16,6 +16,10 @@
"description": "Pipeline type",
"$ref": "#/definitions/mlModelMetadataConfigType",
"default": "MlModelMetadata"
},
"mlModelFilterPattern": {
"description": "Regex to only fetch MlModels with names matching the pattern.",
"$ref": "../type/filterPattern.json#/definitions/filterPattern"
}
},
"additionalProperties": false

View File

@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Base class for ingesting database services
Base class for ingesting dashboard services
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass

View File

@ -13,12 +13,10 @@
import ast
import json
import traceback
from dataclasses import dataclass, field
from typing import Iterable, List, Optional
from typing import Iterable, List, Optional, Tuple, cast
from mlflow.entities import RunData
from mlflow.entities.model_registry import ModelVersion
from mlflow.tracking import MlflowClient
from mlflow.entities.model_registry import ModelVersion, RegisteredModel
from metadata.generated.schema.api.data.createMlModel import CreateMlModelRequest
from metadata.generated.schema.entity.data.mlmodel import (
@ -33,52 +31,19 @@ from metadata.generated.schema.entity.services.connections.metadata.openMetadata
from metadata.generated.schema.entity.services.connections.mlmodel.mlflowConnection import (
MlflowConnection,
)
from metadata.generated.schema.entity.services.mlmodelService import MlModelService
from metadata.generated.schema.metadataIngestion.workflow import (
Source as WorkflowSource,
)
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.api.source import InvalidSourceException, Source, SourceStatus
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.utils.connections import get_connection
from metadata.ingestion.api.source import InvalidSourceException
from metadata.ingestion.source.mlmodel.mlmodel_service import MlModelServiceSource
from metadata.utils.filters import filter_by_mlmodel
from metadata.utils.logger import ingestion_logger
logger = ingestion_logger()
@dataclass
class MlFlowStatus(SourceStatus):
"""
ML Model specific Status
"""
success: List[str] = field(default_factory=list)
failures: List[str] = field(default_factory=list)
warnings: List[str] = field(default_factory=list)
def scanned(self, record: str) -> None:
"""
Log successful ML Model scans
"""
self.success.append(record)
logger.info("ML Model scanned: %s", record)
def failed(self, model_name: str, reason: str) -> None:
"""
Log failed ML Model scans
"""
self.failures.append(model_name)
logger.error("ML Model failed: %s - %s", model_name, reason)
def warned(self, model_name: str, reason: str) -> None:
"""
Log Ml Model with warnings
"""
self.warnings.append(model_name)
logger.warning("ML Model warning: %s - %s", model_name, reason)
class MlflowSource(Source[CreateMlModelRequest]):
class MlflowSource(MlModelServiceSource):
"""
Source implementation to ingest MLFlow data.
@ -86,25 +51,6 @@ class MlflowSource(Source[CreateMlModelRequest]):
and prepare an iterator of CreateMlModelRequest
"""
def __init__(self, config: WorkflowSource, metadata_config: OpenMetadataConnection):
super().__init__()
self.config = config
self.service_connection = self.config.serviceConnection.__root__.config
self.metadata = OpenMetadata(metadata_config)
self.connection = get_connection(self.service_connection)
self.test_connection()
self.client = self.connection.client
self.status = MlFlowStatus()
self.service = self.metadata.get_service_or_create(
entity=MlModelService, config=config
)
def prepare(self):
pass
@classmethod
def create(cls, config_dict, metadata_config: OpenMetadataConnection):
config: WorkflowSource = WorkflowSource.parse_obj(config_dict)
@ -116,15 +62,20 @@ class MlflowSource(Source[CreateMlModelRequest]):
return cls(config, metadata_config)
def next_record(self) -> Iterable[CreateMlModelRequest]:
def get_mlmodels(self) -> Iterable[Tuple[RegisteredModel, ModelVersion]]:
"""
Fetch all registered models from MlFlow.
List and filters models from the registry
"""
for model in cast(RegisteredModel, self.client.list_registered_models()):
We are setting the `algorithm` to a constant
as there is not a non-trivial generic approach
for retrieving the algorithm from the registry.
"""
for model in self.client.list_registered_models():
if filter_by_mlmodel(
self.source_config.mlModelFilterPattern, mlmodel_name=model.name
):
self.status.filter(
f"{self.config.serviceName}.{model.name}",
"MlModel name pattern not allowed",
)
continue
# Get the latest version
latest_version: Optional[ModelVersion] = next(
@ -139,21 +90,35 @@ class MlflowSource(Source[CreateMlModelRequest]):
self.status.failed(model.name, reason="Invalid version")
continue
run = self.client.get_run(latest_version.run_id)
yield model, latest_version
self.status.scanned(model.name)
def _get_algorithm(self) -> str:
return "mlmodel"
yield CreateMlModelRequest(
name=model.name,
description=model.description,
algorithm="mlflow", # Setting this to a constant
mlHyperParameters=self._get_hyper_params(run.data),
mlFeatures=self._get_ml_features(
run.data, latest_version.run_id, model.name
),
mlStore=self._get_ml_store(latest_version),
service=EntityReference(id=self.service.id, type="mlmodelService"),
)
def yield_mlmodel(
self, model_and_version: Tuple[RegisteredModel, ModelVersion]
) -> Iterable[CreateMlModelRequest]:
"""
Prepare the Request model
"""
model, latest_version = model_and_version
self.status.scanned(model.name)
run = self.client.get_run(latest_version.run_id)
yield CreateMlModelRequest(
name=model.name,
description=model.description,
algorithm=self._get_algorithm(), # Setting this to a constant
mlHyperParameters=self._get_hyper_params(run.data),
mlFeatures=self._get_ml_features(
run.data, latest_version.run_id, model.name
),
mlStore=self._get_ml_store(latest_version),
service=EntityReference(
id=self.context.mlmodel_service.id, type="mlmodelService"
),
)
@staticmethod
def _get_hyper_params(data: RunData) -> Optional[List[MlHyperParameter]]:
@ -222,14 +187,3 @@ class MlflowSource(Source[CreateMlModelRequest]):
self.status.warned(model_name, reason)
return None
def get_status(self) -> SourceStatus:
return self.status
def close(self) -> None:
"""
Don't need to close the client
"""
def test_connection(self) -> None:
pass

View File

@ -0,0 +1,208 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Base class for ingesting mlmodel services
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Iterable, List, Optional
from metadata.generated.schema.api.data.createMlModel import CreateMlModelRequest
from metadata.generated.schema.entity.data.mlmodel import (
MlFeature,
MlHyperParameter,
MlModel,
MlStore,
)
from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import (
OpenMetadataConnection,
)
from metadata.generated.schema.entity.services.mlmodelService import (
MlModelConnection,
MlModelService,
)
from metadata.generated.schema.metadataIngestion.mlmodelServiceMetadataPipeline import (
MlModelServiceMetadataPipeline,
)
from metadata.generated.schema.metadataIngestion.workflow import (
Source as WorkflowSource,
)
from metadata.ingestion.api.source import Source, SourceStatus
from metadata.ingestion.api.topology_runner import TopologyRunnerMixin
from metadata.ingestion.models.topology import (
NodeStage,
ServiceTopology,
TopologyNode,
create_source_context,
)
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.utils.connections import get_connection, test_connection
from metadata.utils.logger import ingestion_logger
logger = ingestion_logger()
class MlModelServiceTopology(ServiceTopology):
"""
Defines the hierarchy in MlModel Services.
service -> MlModel
We could have a topology validator. We can only consume
data that has been produced by any parent node.
"""
root = TopologyNode(
producer="get_services",
stages=[
NodeStage(
type_=MlModelService,
context="mlmodel_service",
processor="yield_mlmodel_service",
),
],
children=["mlmodel"],
)
mlmodel = TopologyNode(
producer="get_mlmodels",
stages=[
NodeStage(
type_=MlModel,
context="mlmodels",
processor="yield_mlmodel",
consumer=["mlmodel_service"],
),
],
)
@dataclass
class MlModelSourceStatus(SourceStatus):
"""
ML Model specific Status
"""
success: List[str] = field(default_factory=list)
failures: List[str] = field(default_factory=list)
warnings: List[str] = field(default_factory=list)
def scanned(self, record: str) -> None:
"""
Log successful ML Model scans
"""
self.success.append(record)
logger.info("ML Model scanned: %s", record)
def failed(self, model_name: str, reason: str) -> None:
"""
Log failed ML Model scans
"""
self.failures.append(model_name)
logger.error("ML Model failed: %s - %s", model_name, reason)
def warned(self, model_name: str, reason: str) -> None:
"""
Log Ml Model with warnings
"""
self.warnings.append(model_name)
logger.warning("ML Model warning: %s - %s", model_name, reason)
class MlModelServiceSource(TopologyRunnerMixin, Source, ABC):
"""
Base class for MlModel services.
It implements the topology and context
"""
status: MlModelSourceStatus
source_config: MlModelServiceMetadataPipeline
config: WorkflowSource
metadata: OpenMetadata
# Big union of types we want to fetch dynamically
service_connection: MlModelConnection.__fields__["config"].type_
topology = MlModelServiceTopology()
context = create_source_context(topology)
def __init__(
self,
config: WorkflowSource,
metadata_config: OpenMetadataConnection,
):
super().__init__()
self.config = config
self.metadata_config = metadata_config
self.metadata = OpenMetadata(metadata_config)
self.service_connection = self.config.serviceConnection.__root__.config
self.source_config: MlModelServiceMetadataPipeline = (
self.config.sourceConfig.config
)
self.connection = get_connection(self.service_connection)
self.test_connection()
self.status = MlModelSourceStatus()
self.client = self.connection.client
def get_services(self) -> Iterable[WorkflowSource]:
yield self.config
def yield_mlmodel_service(self, config: WorkflowSource):
yield self.metadata.get_create_service_from_source(
entity=MlModelService, config=config
)
@abstractmethod
def get_mlmodels(self, *args, **kwargs) -> Iterable[Any]:
"""
Method to list all models to process.
Here is where filtering happens
"""
@abstractmethod
def yield_mlmodel(self, *args, **kwargs) -> Iterable[CreateMlModelRequest]:
"""
Method to return MlModel Entities
"""
@abstractmethod
def _get_hyper_params(self, *args, **kwargs) -> Optional[List[MlHyperParameter]]:
"""
Get the Hyper Parameters from the MlModel
"""
@abstractmethod
def _get_ml_store(self, *args, **kwargs) -> Optional[MlStore]:
"""
Get the Ml Store from the model version object
"""
@abstractmethod
def _get_ml_features(self, *args, **kwargs) -> Optional[List[MlFeature]]:
"""
Pick up features
"""
@abstractmethod
def _get_algorithm(self, *args, **kwargs) -> str:
"""
Return the algorithm for a given model
"""
def get_status(self) -> SourceStatus:
return self.status
def close(self):
pass
def test_connection(self) -> None:
test_connection(self.connection)
def prepare(self):
pass

View File

@ -190,3 +190,18 @@ def filter_by_pipeline(
:return: True for filtering, False otherwise
"""
return _filter(pipeline_filter_pattern, pipeline_name)
def filter_by_mlmodel(
mlmodel_filter_pattern: Optional[FilterPattern], mlmodel_name: str
) -> bool:
"""
Return True if the mlmodel needs to be filtered, False otherwise
Include takes precedence over exclude
:param mlmodel_filter_pattern: Model defining the mlmodel filtering logic
:param mlmodel_name: mlmodel name
:return: True for filtering, False otherwise
"""
return _filter(mlmodel_filter_pattern, mlmodel_name)