mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-12-05 03:54:23 +00:00
[issue-1973] - Python API from Sklearn to MlModel (#2119)
* Move staticmethods to utils * Use functions from utils * Convert sklearn to MlModel * merge main
This commit is contained in:
parent
178315d68a
commit
d3b6c7cf27
@ -107,6 +107,7 @@ plugins: Dict[str, Set[str]] = {
|
||||
"salesforce": {"simple_salesforce~=1.11.4"},
|
||||
"okta": {"okta~=2.3.0"},
|
||||
"mlflow": {"mlflow-skinny~=1.22.0"},
|
||||
"sklearn": {"scikit-learn==1.0.2"},
|
||||
}
|
||||
dev = {
|
||||
"boto3==1.20.14",
|
||||
@ -125,6 +126,9 @@ test = {
|
||||
"pytest-cov",
|
||||
"faker",
|
||||
"coverage",
|
||||
# sklearn integration
|
||||
"scikit-learn==1.0.2",
|
||||
"pandas==1.3.5",
|
||||
}
|
||||
|
||||
build_options = {"includes": ["_cffi_backend"]}
|
||||
|
||||
@ -10,6 +10,7 @@ from pydantic import BaseModel
|
||||
|
||||
from metadata.generated.schema.api.lineage.addLineage import AddLineage
|
||||
from metadata.ingestion.ometa.client import REST, APIError
|
||||
from metadata.ingestion.ometa.utils import get_entity_type
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -97,7 +98,7 @@ class OMetaLineageMixin(Generic[T]):
|
||||
:param up_depth: Upstream depth of lineage (default=1, min=0, max=3)"
|
||||
:param down_depth: Downstream depth of lineage (default=1, min=0, max=3)
|
||||
"""
|
||||
entity_name = self.get_entity_type(entity)
|
||||
entity_name = get_entity_type(entity)
|
||||
search = (
|
||||
f"?upstreamDepth={min(up_depth, 3)}&downstreamDepth={min(down_depth, 3)}"
|
||||
)
|
||||
|
||||
@ -4,13 +4,19 @@ Mixin class containing Lineage specific methods
|
||||
To be used by OpenMetadata class
|
||||
"""
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from metadata.generated.schema.api.data.createMlModel import CreateMlModelEntityRequest
|
||||
from metadata.generated.schema.api.lineage.addLineage import AddLineage
|
||||
from metadata.generated.schema.entity.data.mlmodel import MlModel
|
||||
from metadata.generated.schema.entity.data.mlmodel import (
|
||||
MlFeature,
|
||||
MlHyperParameter,
|
||||
MlModel,
|
||||
)
|
||||
from metadata.generated.schema.type.entityLineage import EntitiesEdge
|
||||
from metadata.ingestion.ometa.client import REST
|
||||
from metadata.ingestion.ometa.mixins.lineage_mixin import OMetaLineageMixin
|
||||
from metadata.ingestion.ometa.utils import format_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -59,3 +65,48 @@ class OMetaMlModelMixin(OMetaLineageMixin):
|
||||
mlmodel_lineage = self.get_lineage_by_id(MlModel, str(model.id.__root__))
|
||||
|
||||
return mlmodel_lineage
|
||||
|
||||
@staticmethod
|
||||
def get_mlmodel_sklearn(
|
||||
name: str, model, description: Optional[str] = None
|
||||
) -> CreateMlModelEntityRequest:
|
||||
"""
|
||||
Get an MlModel Entity instance from a scikit-learn model.
|
||||
|
||||
Sklearn estimators all extend BaseEstimator.
|
||||
:param name: MlModel name
|
||||
:param model: sklearn estimator
|
||||
:param description: MlModel description
|
||||
:return: OpenMetadata CreateMlModelEntityRequest Entity
|
||||
"""
|
||||
try:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from sklearn.base import BaseEstimator
|
||||
|
||||
# pylint: enable=import-outside-toplevel
|
||||
except ModuleNotFoundError as exc:
|
||||
logger.error(
|
||||
"Cannot import BaseEstimator, please install sklearn plugin: "
|
||||
+ f"pip install openmetadata-ingestion[sklearn], {exc}"
|
||||
)
|
||||
raise exc
|
||||
|
||||
if not isinstance(model, BaseEstimator):
|
||||
raise ValueError("Input model is not an instance of sklearn BaseEstimator")
|
||||
|
||||
return CreateMlModelEntityRequest(
|
||||
name=name,
|
||||
description=description,
|
||||
algorithm=model.__class__.__name__,
|
||||
mlFeatures=[
|
||||
MlFeature(name=format_name(feature))
|
||||
for feature in model.feature_names_in_
|
||||
],
|
||||
mlHyperParameters=[
|
||||
MlHyperParameter(
|
||||
name=key,
|
||||
value=value,
|
||||
)
|
||||
for key, value in model.get_params().items()
|
||||
],
|
||||
)
|
||||
|
||||
@ -13,6 +13,7 @@ from requests.models import Response
|
||||
from metadata.generated.schema.type import basic
|
||||
from metadata.generated.schema.type.entityHistory import EntityVersionHistory
|
||||
from metadata.ingestion.ometa.client import REST
|
||||
from metadata.ingestion.ometa.utils import uuid_to_str
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -67,7 +68,7 @@ class OMetaVersionMixin(Generic[T]):
|
||||
fields: List
|
||||
List of fields to return
|
||||
"""
|
||||
entity_id = self.uuid_to_str(entity_id)
|
||||
entity_id = uuid_to_str(entity_id)
|
||||
version = self.version_to_str(version)
|
||||
|
||||
path = f"{entity_id}/versions/{version}"
|
||||
@ -94,7 +95,7 @@ class OMetaVersionMixin(Generic[T]):
|
||||
List
|
||||
lists of available versions for a specific entity
|
||||
"""
|
||||
path = f"{self.uuid_to_str(entity_id)}/versions"
|
||||
path = f"{uuid_to_str(entity_id)}/versions"
|
||||
|
||||
resp = self.client.get(f"{self.get_suffix(entity)}/{path}")
|
||||
|
||||
|
||||
@ -55,6 +55,7 @@ from metadata.ingestion.ometa.openmetadata_rest import (
|
||||
NoOpAuthenticationProvider,
|
||||
OktaAuthenticationProvider,
|
||||
)
|
||||
from metadata.ingestion.ometa.utils import get_entity_type, uuid_to_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -253,27 +254,6 @@ class OpenMetadata(
|
||||
f"Missing {entity} type when generating suffixes"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_entity_type(
|
||||
entity: Union[Type[T], str],
|
||||
) -> str:
|
||||
"""
|
||||
Given an Entity T, return its type.
|
||||
E.g., Table returns table, Dashboard returns dashboard...
|
||||
|
||||
Also allow to be the identity if we just receive a string
|
||||
"""
|
||||
if isinstance(entity, str):
|
||||
return entity
|
||||
|
||||
class_name: str = entity.__name__.lower()
|
||||
|
||||
if "service" in class_name:
|
||||
# Capitalize service, e.g., pipelineService
|
||||
return class_name.replace("service", "Service")
|
||||
|
||||
return class_name
|
||||
|
||||
def get_module_path(self, entity: Type[T]) -> str:
|
||||
"""
|
||||
Based on the entity, return the module path
|
||||
@ -355,20 +335,6 @@ class OpenMetadata(
|
||||
resp = self.client.put(self.get_suffix(entity), data=data.json())
|
||||
return entity_class(**resp)
|
||||
|
||||
@staticmethod
|
||||
def uuid_to_str(entity_id: Union[str, basic.Uuid]) -> str:
|
||||
"""
|
||||
Given an entity_id, that can be a str or our pydantic
|
||||
definition of Uuid, return a proper str to build
|
||||
the endpoint path
|
||||
:param entity_id: Entity ID to onvert to string
|
||||
:return: str for the ID
|
||||
"""
|
||||
if isinstance(entity_id, basic.Uuid):
|
||||
return str(entity_id.__root__)
|
||||
|
||||
return entity_id
|
||||
|
||||
def get_by_name(
|
||||
self, entity: Type[T], fqdn: str, fields: Optional[List[str]] = None
|
||||
) -> Optional[T]:
|
||||
@ -388,7 +354,7 @@ class OpenMetadata(
|
||||
Return entity by ID or None
|
||||
"""
|
||||
|
||||
return self._get(entity=entity, path=self.uuid_to_str(entity_id), fields=fields)
|
||||
return self._get(entity=entity, path=uuid_to_str(entity_id), fields=fields)
|
||||
|
||||
def _get(
|
||||
self, entity: Type[T], path: str, fields: Optional[List[str]] = None
|
||||
@ -405,7 +371,8 @@ class OpenMetadata(
|
||||
return entity(**resp)
|
||||
except APIError as err:
|
||||
logger.error(
|
||||
f"Creating new {entity.__class__.__name__} for {path}. Error {err.status_code}"
|
||||
f"Creating new {entity.__class__.__name__} for {path}. "
|
||||
+ f"Error {err.status_code}"
|
||||
)
|
||||
return None
|
||||
|
||||
@ -423,7 +390,7 @@ class OpenMetadata(
|
||||
if instance:
|
||||
return EntityReference(
|
||||
id=instance.id,
|
||||
type=self.get_entity_type(entity),
|
||||
type=get_entity_type(entity),
|
||||
name=instance.fullyQualifiedName,
|
||||
description=instance.description,
|
||||
href=instance.href,
|
||||
@ -470,13 +437,13 @@ class OpenMetadata(
|
||||
return [entity(**p) for p in resp["data"]]
|
||||
|
||||
def delete(self, entity: Type[T], entity_id: Union[str, basic.Uuid]) -> None:
|
||||
self.client.delete(f"{self.get_suffix(entity)}/{self.uuid_to_str(entity_id)}")
|
||||
self.client.delete(f"{self.get_suffix(entity)}/{uuid_to_str(entity_id)}")
|
||||
|
||||
def compute_percentile(self, entity: Union[Type[T], str], date: str) -> None:
|
||||
"""
|
||||
Compute an entity usage percentile
|
||||
"""
|
||||
entity_name = self.get_entity_type(entity)
|
||||
entity_name = get_entity_type(entity)
|
||||
resp = self.client.post(f"/usage/compute.percentile/{entity_name}/{date}")
|
||||
logger.debug("published compute percentile {}".format(resp))
|
||||
|
||||
|
||||
68
ingestion/src/metadata/ingestion/ometa/utils.py
Normal file
68
ingestion/src/metadata/ingestion/ometa/utils.py
Normal file
@ -0,0 +1,68 @@
|
||||
# Copyright 2021 Collate
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Helper functions to handle OpenMetadata Entities' properties
|
||||
"""
|
||||
|
||||
import re
|
||||
import string
|
||||
from typing import Type, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metadata.generated.schema.type import basic
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
def format_name(name: str) -> str:
|
||||
"""
|
||||
Given a name, replace all special characters by `_`
|
||||
:param name: name to format
|
||||
:return: formatted string
|
||||
"""
|
||||
subs = re.escape(string.punctuation + " ")
|
||||
return re.sub(r"[" + subs + "]", "_", name)
|
||||
|
||||
|
||||
def get_entity_type(
|
||||
entity: Union[Type[T], str],
|
||||
) -> str:
|
||||
"""
|
||||
Given an Entity T, return its type.
|
||||
E.g., Table returns table, Dashboard returns dashboard...
|
||||
|
||||
Also allow to be the identity if we just receive a string
|
||||
"""
|
||||
if isinstance(entity, str):
|
||||
return entity
|
||||
|
||||
class_name: str = entity.__name__.lower()
|
||||
|
||||
if "service" in class_name:
|
||||
# Capitalize service, e.g., pipelineService
|
||||
return class_name.replace("service", "Service")
|
||||
|
||||
return class_name
|
||||
|
||||
|
||||
def uuid_to_str(entity_id: Union[str, basic.Uuid]) -> str:
|
||||
"""
|
||||
Given an entity_id, that can be a str or our pydantic
|
||||
definition of Uuid, return a proper str to build
|
||||
the endpoint path
|
||||
:param entity_id: Entity ID to onvert to string
|
||||
:return: str for the ID
|
||||
"""
|
||||
if isinstance(entity_id, basic.Uuid):
|
||||
return str(entity_id.__root__)
|
||||
|
||||
return entity_id
|
||||
78
ingestion/tests/unit/test_ometa_mlmodel.py
Normal file
78
ingestion/tests/unit/test_ometa_mlmodel.py
Normal file
@ -0,0 +1,78 @@
|
||||
# Copyright 2021 Collate
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
OpenMetadata MlModel mixin test
|
||||
"""
|
||||
from unittest import TestCase
|
||||
|
||||
import pandas as pd
|
||||
import sklearn.datasets as datasets
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
|
||||
from metadata.generated.schema.api.data.createMlModel import CreateMlModelEntityRequest
|
||||
from metadata.generated.schema.entity.data.mlmodel import MlFeature, MlModel
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
|
||||
|
||||
|
||||
class OMetaModelMixinTest(TestCase):
|
||||
"""
|
||||
Test the MlModel integrations from MlModel Mixin
|
||||
"""
|
||||
|
||||
server_config = MetadataServerConfig(api_endpoint="http://localhost:8585/api")
|
||||
metadata = OpenMetadata(server_config)
|
||||
|
||||
iris = datasets.load_iris()
|
||||
|
||||
def test_get_sklearn(self):
|
||||
"""
|
||||
Check that we can ingest an SKlearn model
|
||||
"""
|
||||
df = pd.DataFrame(self.iris.data, columns=self.iris.feature_names)
|
||||
y = self.iris.target
|
||||
|
||||
x_train, x_test, y_train, y_test = train_test_split(
|
||||
df, y, test_size=0.25, random_state=70
|
||||
)
|
||||
|
||||
dtree = DecisionTreeClassifier()
|
||||
dtree.fit(x_train, y_train)
|
||||
|
||||
entity_create: CreateMlModelEntityRequest = self.metadata.get_mlmodel_sklearn(
|
||||
name="test-sklearn",
|
||||
model=dtree,
|
||||
description="Creating a test sklearn model",
|
||||
)
|
||||
|
||||
entity: MlModel = self.metadata.create_or_update(data=entity_create)
|
||||
|
||||
self.assertEqual(entity.name, entity_create.name)
|
||||
self.assertEqual(entity.algorithm, "DecisionTreeClassifier")
|
||||
self.assertEqual(
|
||||
{feature.name.__root__ for feature in entity.mlFeatures},
|
||||
{
|
||||
"sepal_length__cm_",
|
||||
"sepal_width__cm_",
|
||||
"petal_length__cm_",
|
||||
"petal_width__cm_",
|
||||
},
|
||||
)
|
||||
|
||||
hyper_param = next(
|
||||
iter(
|
||||
param for param in entity.mlHyperParameters if param.name == "criterion"
|
||||
),
|
||||
None,
|
||||
)
|
||||
self.assertIsNotNone(hyper_param)
|
||||
49
ingestion/tests/unit/test_ometa_utils.py
Normal file
49
ingestion/tests/unit/test_ometa_utils.py
Normal file
@ -0,0 +1,49 @@
|
||||
# Copyright 2021 Collate
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
OpenMetadata utils tests
|
||||
"""
|
||||
from unittest import TestCase
|
||||
|
||||
from metadata.generated.schema.entity.data.mlmodel import MlModel
|
||||
from metadata.generated.schema.type import basic
|
||||
from metadata.ingestion.ometa.utils import format_name, get_entity_type, uuid_to_str
|
||||
|
||||
|
||||
class OMetaUtilsTest(TestCase):
|
||||
def test_format_name(self):
|
||||
"""
|
||||
Check we are properly formatting names
|
||||
"""
|
||||
|
||||
self.assertEqual(format_name("random"), "random")
|
||||
self.assertEqual(format_name("ran dom"), "ran_dom")
|
||||
self.assertEqual(format_name("ran_(dom"), "ran__dom")
|
||||
|
||||
def test_get_entity_type(self):
|
||||
"""
|
||||
Check that we return a string or the class name
|
||||
"""
|
||||
|
||||
self.assertEqual(get_entity_type("hello"), "hello")
|
||||
self.assertEqual(get_entity_type(MlModel), "mlmodel")
|
||||
|
||||
def test_uuid_to_str(self):
|
||||
"""
|
||||
Return Uuid as str
|
||||
"""
|
||||
|
||||
self.assertEqual(uuid_to_str("random"), "random")
|
||||
self.assertEqual(
|
||||
uuid_to_str(basic.Uuid(__root__="9fc58e81-7412-4023-a298-59f2494aab9d")),
|
||||
"9fc58e81-7412-4023-a298-59f2494aab9d",
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user