diff --git a/catalog-rest-service/src/main/resources/json/schema/metadataIngestion/mlmodelServiceMetadataPipeline.json b/catalog-rest-service/src/main/resources/json/schema/metadataIngestion/mlmodelServiceMetadataPipeline.json index 9d5371fbb76..8781c847db1 100644 --- a/catalog-rest-service/src/main/resources/json/schema/metadataIngestion/mlmodelServiceMetadataPipeline.json +++ b/catalog-rest-service/src/main/resources/json/schema/metadataIngestion/mlmodelServiceMetadataPipeline.json @@ -16,6 +16,10 @@ "description": "Pipeline type", "$ref": "#/definitions/mlModelMetadataConfigType", "default": "MlModelMetadata" + }, + "mlModelFilterPattern": { + "description": "Regex to only fetch MlModels with names matching the pattern.", + "$ref": "../type/filterPattern.json#/definitions/filterPattern" } }, "additionalProperties": false diff --git a/ingestion/src/metadata/ingestion/source/dashboard/dashboard_service.py b/ingestion/src/metadata/ingestion/source/dashboard/dashboard_service.py index 5ddb0d39687..6f210b2a7e2 100644 --- a/ingestion/src/metadata/ingestion/source/dashboard/dashboard_service.py +++ b/ingestion/src/metadata/ingestion/source/dashboard/dashboard_service.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Base class for ingesting database services +Base class for ingesting dashboard services """ from abc import ABC, abstractmethod from dataclasses import dataclass diff --git a/ingestion/src/metadata/ingestion/source/mlmodel/mlflow.py b/ingestion/src/metadata/ingestion/source/mlmodel/mlflow.py index 241dabb5f01..705703d9026 100644 --- a/ingestion/src/metadata/ingestion/source/mlmodel/mlflow.py +++ b/ingestion/src/metadata/ingestion/source/mlmodel/mlflow.py @@ -13,12 +13,10 @@ import ast import json import traceback -from dataclasses import dataclass, field -from typing import Iterable, List, Optional +from typing import Iterable, List, Optional, Tuple, cast from mlflow.entities import RunData -from mlflow.entities.model_registry import ModelVersion -from mlflow.tracking import MlflowClient +from mlflow.entities.model_registry import ModelVersion, RegisteredModel from metadata.generated.schema.api.data.createMlModel import CreateMlModelRequest from metadata.generated.schema.entity.data.mlmodel import ( @@ -33,52 +31,19 @@ from metadata.generated.schema.entity.services.connections.metadata.openMetadata from metadata.generated.schema.entity.services.connections.mlmodel.mlflowConnection import ( MlflowConnection, ) -from metadata.generated.schema.entity.services.mlmodelService import MlModelService from metadata.generated.schema.metadataIngestion.workflow import ( Source as WorkflowSource, ) from metadata.generated.schema.type.entityReference import EntityReference -from metadata.ingestion.api.source import InvalidSourceException, Source, SourceStatus -from metadata.ingestion.ometa.ometa_api import OpenMetadata -from metadata.utils.connections import get_connection +from metadata.ingestion.api.source import InvalidSourceException +from metadata.ingestion.source.mlmodel.mlmodel_service import MlModelServiceSource +from metadata.utils.filters import filter_by_mlmodel from metadata.utils.logger import ingestion_logger logger = ingestion_logger() -@dataclass -class MlFlowStatus(SourceStatus): - """ - ML Model specific Status - """ - - success: List[str] = field(default_factory=list) - failures: List[str] = field(default_factory=list) - warnings: List[str] = field(default_factory=list) - - def scanned(self, record: str) -> None: - """ - Log successful ML Model scans - """ - self.success.append(record) - logger.info("ML Model scanned: %s", record) - - def failed(self, model_name: str, reason: str) -> None: - """ - Log failed ML Model scans - """ - self.failures.append(model_name) - logger.error("ML Model failed: %s - %s", model_name, reason) - - def warned(self, model_name: str, reason: str) -> None: - """ - Log Ml Model with warnings - """ - self.warnings.append(model_name) - logger.warning("ML Model warning: %s - %s", model_name, reason) - - -class MlflowSource(Source[CreateMlModelRequest]): +class MlflowSource(MlModelServiceSource): """ Source implementation to ingest MLFlow data. @@ -86,25 +51,6 @@ class MlflowSource(Source[CreateMlModelRequest]): and prepare an iterator of CreateMlModelRequest """ - def __init__(self, config: WorkflowSource, metadata_config: OpenMetadataConnection): - super().__init__() - self.config = config - self.service_connection = self.config.serviceConnection.__root__.config - - self.metadata = OpenMetadata(metadata_config) - - self.connection = get_connection(self.service_connection) - self.test_connection() - self.client = self.connection.client - - self.status = MlFlowStatus() - self.service = self.metadata.get_service_or_create( - entity=MlModelService, config=config - ) - - def prepare(self): - pass - @classmethod def create(cls, config_dict, metadata_config: OpenMetadataConnection): config: WorkflowSource = WorkflowSource.parse_obj(config_dict) @@ -116,15 +62,20 @@ class MlflowSource(Source[CreateMlModelRequest]): return cls(config, metadata_config) - def next_record(self) -> Iterable[CreateMlModelRequest]: + def get_mlmodels(self) -> Iterable[Tuple[RegisteredModel, ModelVersion]]: """ - Fetch all registered models from MlFlow. + List and filters models from the registry + """ + for model in cast(RegisteredModel, self.client.list_registered_models()): - We are setting the `algorithm` to a constant - as there is not a non-trivial generic approach - for retrieving the algorithm from the registry. - """ - for model in self.client.list_registered_models(): + if filter_by_mlmodel( + self.source_config.mlModelFilterPattern, mlmodel_name=model.name + ): + self.status.filter( + f"{self.config.serviceName}.{model.name}", + "MlModel name pattern not allowed", + ) + continue # Get the latest version latest_version: Optional[ModelVersion] = next( @@ -139,21 +90,35 @@ class MlflowSource(Source[CreateMlModelRequest]): self.status.failed(model.name, reason="Invalid version") continue - run = self.client.get_run(latest_version.run_id) + yield model, latest_version - self.status.scanned(model.name) + def _get_algorithm(self) -> str: + return "mlmodel" - yield CreateMlModelRequest( - name=model.name, - description=model.description, - algorithm="mlflow", # Setting this to a constant - mlHyperParameters=self._get_hyper_params(run.data), - mlFeatures=self._get_ml_features( - run.data, latest_version.run_id, model.name - ), - mlStore=self._get_ml_store(latest_version), - service=EntityReference(id=self.service.id, type="mlmodelService"), - ) + def yield_mlmodel( + self, model_and_version: Tuple[RegisteredModel, ModelVersion] + ) -> Iterable[CreateMlModelRequest]: + """ + Prepare the Request model + """ + model, latest_version = model_and_version + self.status.scanned(model.name) + + run = self.client.get_run(latest_version.run_id) + + yield CreateMlModelRequest( + name=model.name, + description=model.description, + algorithm=self._get_algorithm(), # Setting this to a constant + mlHyperParameters=self._get_hyper_params(run.data), + mlFeatures=self._get_ml_features( + run.data, latest_version.run_id, model.name + ), + mlStore=self._get_ml_store(latest_version), + service=EntityReference( + id=self.context.mlmodel_service.id, type="mlmodelService" + ), + ) @staticmethod def _get_hyper_params(data: RunData) -> Optional[List[MlHyperParameter]]: @@ -222,14 +187,3 @@ class MlflowSource(Source[CreateMlModelRequest]): self.status.warned(model_name, reason) return None - - def get_status(self) -> SourceStatus: - return self.status - - def close(self) -> None: - """ - Don't need to close the client - """ - - def test_connection(self) -> None: - pass diff --git a/ingestion/src/metadata/ingestion/source/mlmodel/mlmodel_service.py b/ingestion/src/metadata/ingestion/source/mlmodel/mlmodel_service.py new file mode 100644 index 00000000000..063e0bdfdc6 --- /dev/null +++ b/ingestion/src/metadata/ingestion/source/mlmodel/mlmodel_service.py @@ -0,0 +1,208 @@ +# Copyright 2021 Collate +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Base class for ingesting mlmodel services +""" +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Iterable, List, Optional + +from metadata.generated.schema.api.data.createMlModel import CreateMlModelRequest +from metadata.generated.schema.entity.data.mlmodel import ( + MlFeature, + MlHyperParameter, + MlModel, + MlStore, +) +from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import ( + OpenMetadataConnection, +) +from metadata.generated.schema.entity.services.mlmodelService import ( + MlModelConnection, + MlModelService, +) +from metadata.generated.schema.metadataIngestion.mlmodelServiceMetadataPipeline import ( + MlModelServiceMetadataPipeline, +) +from metadata.generated.schema.metadataIngestion.workflow import ( + Source as WorkflowSource, +) +from metadata.ingestion.api.source import Source, SourceStatus +from metadata.ingestion.api.topology_runner import TopologyRunnerMixin +from metadata.ingestion.models.topology import ( + NodeStage, + ServiceTopology, + TopologyNode, + create_source_context, +) +from metadata.ingestion.ometa.ometa_api import OpenMetadata +from metadata.utils.connections import get_connection, test_connection +from metadata.utils.logger import ingestion_logger + +logger = ingestion_logger() + + +class MlModelServiceTopology(ServiceTopology): + """ + Defines the hierarchy in MlModel Services. + service -> MlModel + + We could have a topology validator. We can only consume + data that has been produced by any parent node. + """ + + root = TopologyNode( + producer="get_services", + stages=[ + NodeStage( + type_=MlModelService, + context="mlmodel_service", + processor="yield_mlmodel_service", + ), + ], + children=["mlmodel"], + ) + mlmodel = TopologyNode( + producer="get_mlmodels", + stages=[ + NodeStage( + type_=MlModel, + context="mlmodels", + processor="yield_mlmodel", + consumer=["mlmodel_service"], + ), + ], + ) + + +@dataclass +class MlModelSourceStatus(SourceStatus): + """ + ML Model specific Status + """ + + success: List[str] = field(default_factory=list) + failures: List[str] = field(default_factory=list) + warnings: List[str] = field(default_factory=list) + + def scanned(self, record: str) -> None: + """ + Log successful ML Model scans + """ + self.success.append(record) + logger.info("ML Model scanned: %s", record) + + def failed(self, model_name: str, reason: str) -> None: + """ + Log failed ML Model scans + """ + self.failures.append(model_name) + logger.error("ML Model failed: %s - %s", model_name, reason) + + def warned(self, model_name: str, reason: str) -> None: + """ + Log Ml Model with warnings + """ + self.warnings.append(model_name) + logger.warning("ML Model warning: %s - %s", model_name, reason) + + +class MlModelServiceSource(TopologyRunnerMixin, Source, ABC): + """ + Base class for MlModel services. + It implements the topology and context + """ + + status: MlModelSourceStatus + source_config: MlModelServiceMetadataPipeline + config: WorkflowSource + metadata: OpenMetadata + # Big union of types we want to fetch dynamically + service_connection: MlModelConnection.__fields__["config"].type_ + + topology = MlModelServiceTopology() + context = create_source_context(topology) + + def __init__( + self, + config: WorkflowSource, + metadata_config: OpenMetadataConnection, + ): + super().__init__() + self.config = config + self.metadata_config = metadata_config + self.metadata = OpenMetadata(metadata_config) + self.service_connection = self.config.serviceConnection.__root__.config + self.source_config: MlModelServiceMetadataPipeline = ( + self.config.sourceConfig.config + ) + self.connection = get_connection(self.service_connection) + self.test_connection() + self.status = MlModelSourceStatus() + + self.client = self.connection.client + + def get_services(self) -> Iterable[WorkflowSource]: + yield self.config + + def yield_mlmodel_service(self, config: WorkflowSource): + yield self.metadata.get_create_service_from_source( + entity=MlModelService, config=config + ) + + @abstractmethod + def get_mlmodels(self, *args, **kwargs) -> Iterable[Any]: + """ + Method to list all models to process. + Here is where filtering happens + """ + + @abstractmethod + def yield_mlmodel(self, *args, **kwargs) -> Iterable[CreateMlModelRequest]: + """ + Method to return MlModel Entities + """ + + @abstractmethod + def _get_hyper_params(self, *args, **kwargs) -> Optional[List[MlHyperParameter]]: + """ + Get the Hyper Parameters from the MlModel + """ + + @abstractmethod + def _get_ml_store(self, *args, **kwargs) -> Optional[MlStore]: + """ + Get the Ml Store from the model version object + """ + + @abstractmethod + def _get_ml_features(self, *args, **kwargs) -> Optional[List[MlFeature]]: + """ + Pick up features + """ + + @abstractmethod + def _get_algorithm(self, *args, **kwargs) -> str: + """ + Return the algorithm for a given model + """ + + def get_status(self) -> SourceStatus: + return self.status + + def close(self): + pass + + def test_connection(self) -> None: + test_connection(self.connection) + + def prepare(self): + pass diff --git a/ingestion/src/metadata/utils/filters.py b/ingestion/src/metadata/utils/filters.py index e56527c7450..5a791035382 100644 --- a/ingestion/src/metadata/utils/filters.py +++ b/ingestion/src/metadata/utils/filters.py @@ -190,3 +190,18 @@ def filter_by_pipeline( :return: True for filtering, False otherwise """ return _filter(pipeline_filter_pattern, pipeline_name) + + +def filter_by_mlmodel( + mlmodel_filter_pattern: Optional[FilterPattern], mlmodel_name: str +) -> bool: + """ + Return True if the mlmodel needs to be filtered, False otherwise + + Include takes precedence over exclude + + :param mlmodel_filter_pattern: Model defining the mlmodel filtering logic + :param mlmodel_name: mlmodel name + :return: True for filtering, False otherwise + """ + return _filter(mlmodel_filter_pattern, mlmodel_name)