mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-11-01 02:56:10 +00:00
parent
479a8de486
commit
0a921abf8b
@ -16,6 +16,10 @@
|
||||
"description": "Pipeline type",
|
||||
"$ref": "#/definitions/mlModelMetadataConfigType",
|
||||
"default": "MlModelMetadata"
|
||||
},
|
||||
"mlModelFilterPattern": {
|
||||
"description": "Regex to only fetch MlModels with names matching the pattern.",
|
||||
"$ref": "../type/filterPattern.json#/definitions/filterPattern"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
|
||||
@ -9,7 +9,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Base class for ingesting database services
|
||||
Base class for ingesting dashboard services
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
@ -13,12 +13,10 @@
|
||||
import ast
|
||||
import json
|
||||
import traceback
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterable, List, Optional
|
||||
from typing import Iterable, List, Optional, Tuple, cast
|
||||
|
||||
from mlflow.entities import RunData
|
||||
from mlflow.entities.model_registry import ModelVersion
|
||||
from mlflow.tracking import MlflowClient
|
||||
from mlflow.entities.model_registry import ModelVersion, RegisteredModel
|
||||
|
||||
from metadata.generated.schema.api.data.createMlModel import CreateMlModelRequest
|
||||
from metadata.generated.schema.entity.data.mlmodel import (
|
||||
@ -33,52 +31,19 @@ from metadata.generated.schema.entity.services.connections.metadata.openMetadata
|
||||
from metadata.generated.schema.entity.services.connections.mlmodel.mlflowConnection import (
|
||||
MlflowConnection,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.mlmodelService import MlModelService
|
||||
from metadata.generated.schema.metadataIngestion.workflow import (
|
||||
Source as WorkflowSource,
|
||||
)
|
||||
from metadata.generated.schema.type.entityReference import EntityReference
|
||||
from metadata.ingestion.api.source import InvalidSourceException, Source, SourceStatus
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
from metadata.utils.connections import get_connection
|
||||
from metadata.ingestion.api.source import InvalidSourceException
|
||||
from metadata.ingestion.source.mlmodel.mlmodel_service import MlModelServiceSource
|
||||
from metadata.utils.filters import filter_by_mlmodel
|
||||
from metadata.utils.logger import ingestion_logger
|
||||
|
||||
logger = ingestion_logger()
|
||||
|
||||
|
||||
@dataclass
|
||||
class MlFlowStatus(SourceStatus):
|
||||
"""
|
||||
ML Model specific Status
|
||||
"""
|
||||
|
||||
success: List[str] = field(default_factory=list)
|
||||
failures: List[str] = field(default_factory=list)
|
||||
warnings: List[str] = field(default_factory=list)
|
||||
|
||||
def scanned(self, record: str) -> None:
|
||||
"""
|
||||
Log successful ML Model scans
|
||||
"""
|
||||
self.success.append(record)
|
||||
logger.info("ML Model scanned: %s", record)
|
||||
|
||||
def failed(self, model_name: str, reason: str) -> None:
|
||||
"""
|
||||
Log failed ML Model scans
|
||||
"""
|
||||
self.failures.append(model_name)
|
||||
logger.error("ML Model failed: %s - %s", model_name, reason)
|
||||
|
||||
def warned(self, model_name: str, reason: str) -> None:
|
||||
"""
|
||||
Log Ml Model with warnings
|
||||
"""
|
||||
self.warnings.append(model_name)
|
||||
logger.warning("ML Model warning: %s - %s", model_name, reason)
|
||||
|
||||
|
||||
class MlflowSource(Source[CreateMlModelRequest]):
|
||||
class MlflowSource(MlModelServiceSource):
|
||||
"""
|
||||
Source implementation to ingest MLFlow data.
|
||||
|
||||
@ -86,25 +51,6 @@ class MlflowSource(Source[CreateMlModelRequest]):
|
||||
and prepare an iterator of CreateMlModelRequest
|
||||
"""
|
||||
|
||||
def __init__(self, config: WorkflowSource, metadata_config: OpenMetadataConnection):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.service_connection = self.config.serviceConnection.__root__.config
|
||||
|
||||
self.metadata = OpenMetadata(metadata_config)
|
||||
|
||||
self.connection = get_connection(self.service_connection)
|
||||
self.test_connection()
|
||||
self.client = self.connection.client
|
||||
|
||||
self.status = MlFlowStatus()
|
||||
self.service = self.metadata.get_service_or_create(
|
||||
entity=MlModelService, config=config
|
||||
)
|
||||
|
||||
def prepare(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def create(cls, config_dict, metadata_config: OpenMetadataConnection):
|
||||
config: WorkflowSource = WorkflowSource.parse_obj(config_dict)
|
||||
@ -116,15 +62,20 @@ class MlflowSource(Source[CreateMlModelRequest]):
|
||||
|
||||
return cls(config, metadata_config)
|
||||
|
||||
def next_record(self) -> Iterable[CreateMlModelRequest]:
|
||||
def get_mlmodels(self) -> Iterable[Tuple[RegisteredModel, ModelVersion]]:
|
||||
"""
|
||||
Fetch all registered models from MlFlow.
|
||||
List and filters models from the registry
|
||||
"""
|
||||
for model in cast(RegisteredModel, self.client.list_registered_models()):
|
||||
|
||||
We are setting the `algorithm` to a constant
|
||||
as there is not a non-trivial generic approach
|
||||
for retrieving the algorithm from the registry.
|
||||
"""
|
||||
for model in self.client.list_registered_models():
|
||||
if filter_by_mlmodel(
|
||||
self.source_config.mlModelFilterPattern, mlmodel_name=model.name
|
||||
):
|
||||
self.status.filter(
|
||||
f"{self.config.serviceName}.{model.name}",
|
||||
"MlModel name pattern not allowed",
|
||||
)
|
||||
continue
|
||||
|
||||
# Get the latest version
|
||||
latest_version: Optional[ModelVersion] = next(
|
||||
@ -139,21 +90,35 @@ class MlflowSource(Source[CreateMlModelRequest]):
|
||||
self.status.failed(model.name, reason="Invalid version")
|
||||
continue
|
||||
|
||||
run = self.client.get_run(latest_version.run_id)
|
||||
yield model, latest_version
|
||||
|
||||
self.status.scanned(model.name)
|
||||
def _get_algorithm(self) -> str:
|
||||
return "mlmodel"
|
||||
|
||||
yield CreateMlModelRequest(
|
||||
name=model.name,
|
||||
description=model.description,
|
||||
algorithm="mlflow", # Setting this to a constant
|
||||
mlHyperParameters=self._get_hyper_params(run.data),
|
||||
mlFeatures=self._get_ml_features(
|
||||
run.data, latest_version.run_id, model.name
|
||||
),
|
||||
mlStore=self._get_ml_store(latest_version),
|
||||
service=EntityReference(id=self.service.id, type="mlmodelService"),
|
||||
)
|
||||
def yield_mlmodel(
|
||||
self, model_and_version: Tuple[RegisteredModel, ModelVersion]
|
||||
) -> Iterable[CreateMlModelRequest]:
|
||||
"""
|
||||
Prepare the Request model
|
||||
"""
|
||||
model, latest_version = model_and_version
|
||||
self.status.scanned(model.name)
|
||||
|
||||
run = self.client.get_run(latest_version.run_id)
|
||||
|
||||
yield CreateMlModelRequest(
|
||||
name=model.name,
|
||||
description=model.description,
|
||||
algorithm=self._get_algorithm(), # Setting this to a constant
|
||||
mlHyperParameters=self._get_hyper_params(run.data),
|
||||
mlFeatures=self._get_ml_features(
|
||||
run.data, latest_version.run_id, model.name
|
||||
),
|
||||
mlStore=self._get_ml_store(latest_version),
|
||||
service=EntityReference(
|
||||
id=self.context.mlmodel_service.id, type="mlmodelService"
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_hyper_params(data: RunData) -> Optional[List[MlHyperParameter]]:
|
||||
@ -222,14 +187,3 @@ class MlflowSource(Source[CreateMlModelRequest]):
|
||||
self.status.warned(model_name, reason)
|
||||
|
||||
return None
|
||||
|
||||
def get_status(self) -> SourceStatus:
|
||||
return self.status
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Don't need to close the client
|
||||
"""
|
||||
|
||||
def test_connection(self) -> None:
|
||||
pass
|
||||
|
||||
@ -0,0 +1,208 @@
|
||||
# Copyright 2021 Collate
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Base class for ingesting mlmodel services
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
from metadata.generated.schema.api.data.createMlModel import CreateMlModelRequest
|
||||
from metadata.generated.schema.entity.data.mlmodel import (
|
||||
MlFeature,
|
||||
MlHyperParameter,
|
||||
MlModel,
|
||||
MlStore,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import (
|
||||
OpenMetadataConnection,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.mlmodelService import (
|
||||
MlModelConnection,
|
||||
MlModelService,
|
||||
)
|
||||
from metadata.generated.schema.metadataIngestion.mlmodelServiceMetadataPipeline import (
|
||||
MlModelServiceMetadataPipeline,
|
||||
)
|
||||
from metadata.generated.schema.metadataIngestion.workflow import (
|
||||
Source as WorkflowSource,
|
||||
)
|
||||
from metadata.ingestion.api.source import Source, SourceStatus
|
||||
from metadata.ingestion.api.topology_runner import TopologyRunnerMixin
|
||||
from metadata.ingestion.models.topology import (
|
||||
NodeStage,
|
||||
ServiceTopology,
|
||||
TopologyNode,
|
||||
create_source_context,
|
||||
)
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
from metadata.utils.connections import get_connection, test_connection
|
||||
from metadata.utils.logger import ingestion_logger
|
||||
|
||||
logger = ingestion_logger()
|
||||
|
||||
|
||||
class MlModelServiceTopology(ServiceTopology):
|
||||
"""
|
||||
Defines the hierarchy in MlModel Services.
|
||||
service -> MlModel
|
||||
|
||||
We could have a topology validator. We can only consume
|
||||
data that has been produced by any parent node.
|
||||
"""
|
||||
|
||||
root = TopologyNode(
|
||||
producer="get_services",
|
||||
stages=[
|
||||
NodeStage(
|
||||
type_=MlModelService,
|
||||
context="mlmodel_service",
|
||||
processor="yield_mlmodel_service",
|
||||
),
|
||||
],
|
||||
children=["mlmodel"],
|
||||
)
|
||||
mlmodel = TopologyNode(
|
||||
producer="get_mlmodels",
|
||||
stages=[
|
||||
NodeStage(
|
||||
type_=MlModel,
|
||||
context="mlmodels",
|
||||
processor="yield_mlmodel",
|
||||
consumer=["mlmodel_service"],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MlModelSourceStatus(SourceStatus):
|
||||
"""
|
||||
ML Model specific Status
|
||||
"""
|
||||
|
||||
success: List[str] = field(default_factory=list)
|
||||
failures: List[str] = field(default_factory=list)
|
||||
warnings: List[str] = field(default_factory=list)
|
||||
|
||||
def scanned(self, record: str) -> None:
|
||||
"""
|
||||
Log successful ML Model scans
|
||||
"""
|
||||
self.success.append(record)
|
||||
logger.info("ML Model scanned: %s", record)
|
||||
|
||||
def failed(self, model_name: str, reason: str) -> None:
|
||||
"""
|
||||
Log failed ML Model scans
|
||||
"""
|
||||
self.failures.append(model_name)
|
||||
logger.error("ML Model failed: %s - %s", model_name, reason)
|
||||
|
||||
def warned(self, model_name: str, reason: str) -> None:
|
||||
"""
|
||||
Log Ml Model with warnings
|
||||
"""
|
||||
self.warnings.append(model_name)
|
||||
logger.warning("ML Model warning: %s - %s", model_name, reason)
|
||||
|
||||
|
||||
class MlModelServiceSource(TopologyRunnerMixin, Source, ABC):
|
||||
"""
|
||||
Base class for MlModel services.
|
||||
It implements the topology and context
|
||||
"""
|
||||
|
||||
status: MlModelSourceStatus
|
||||
source_config: MlModelServiceMetadataPipeline
|
||||
config: WorkflowSource
|
||||
metadata: OpenMetadata
|
||||
# Big union of types we want to fetch dynamically
|
||||
service_connection: MlModelConnection.__fields__["config"].type_
|
||||
|
||||
topology = MlModelServiceTopology()
|
||||
context = create_source_context(topology)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: WorkflowSource,
|
||||
metadata_config: OpenMetadataConnection,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.metadata_config = metadata_config
|
||||
self.metadata = OpenMetadata(metadata_config)
|
||||
self.service_connection = self.config.serviceConnection.__root__.config
|
||||
self.source_config: MlModelServiceMetadataPipeline = (
|
||||
self.config.sourceConfig.config
|
||||
)
|
||||
self.connection = get_connection(self.service_connection)
|
||||
self.test_connection()
|
||||
self.status = MlModelSourceStatus()
|
||||
|
||||
self.client = self.connection.client
|
||||
|
||||
def get_services(self) -> Iterable[WorkflowSource]:
|
||||
yield self.config
|
||||
|
||||
def yield_mlmodel_service(self, config: WorkflowSource):
|
||||
yield self.metadata.get_create_service_from_source(
|
||||
entity=MlModelService, config=config
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_mlmodels(self, *args, **kwargs) -> Iterable[Any]:
|
||||
"""
|
||||
Method to list all models to process.
|
||||
Here is where filtering happens
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def yield_mlmodel(self, *args, **kwargs) -> Iterable[CreateMlModelRequest]:
|
||||
"""
|
||||
Method to return MlModel Entities
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _get_hyper_params(self, *args, **kwargs) -> Optional[List[MlHyperParameter]]:
|
||||
"""
|
||||
Get the Hyper Parameters from the MlModel
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _get_ml_store(self, *args, **kwargs) -> Optional[MlStore]:
|
||||
"""
|
||||
Get the Ml Store from the model version object
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _get_ml_features(self, *args, **kwargs) -> Optional[List[MlFeature]]:
|
||||
"""
|
||||
Pick up features
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _get_algorithm(self, *args, **kwargs) -> str:
|
||||
"""
|
||||
Return the algorithm for a given model
|
||||
"""
|
||||
|
||||
def get_status(self) -> SourceStatus:
|
||||
return self.status
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def test_connection(self) -> None:
|
||||
test_connection(self.connection)
|
||||
|
||||
def prepare(self):
|
||||
pass
|
||||
@ -190,3 +190,18 @@ def filter_by_pipeline(
|
||||
:return: True for filtering, False otherwise
|
||||
"""
|
||||
return _filter(pipeline_filter_pattern, pipeline_name)
|
||||
|
||||
|
||||
def filter_by_mlmodel(
|
||||
mlmodel_filter_pattern: Optional[FilterPattern], mlmodel_name: str
|
||||
) -> bool:
|
||||
"""
|
||||
Return True if the mlmodel needs to be filtered, False otherwise
|
||||
|
||||
Include takes precedence over exclude
|
||||
|
||||
:param mlmodel_filter_pattern: Model defining the mlmodel filtering logic
|
||||
:param mlmodel_name: mlmodel name
|
||||
:return: True for filtering, False otherwise
|
||||
"""
|
||||
return _filter(mlmodel_filter_pattern, mlmodel_name)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user