[issue-1973] - Python API from Sklearn to MlModel (#2119)

* Move staticmethods to utils

* Use functions from utils

* Convert sklearn to MlModel

* merge main
This commit is contained in:
Pere Miquel Brull 2022-01-10 09:36:08 +01:00 committed by GitHub
parent 178315d68a
commit d3b6c7cf27
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 264 additions and 45 deletions

View File

@ -107,6 +107,7 @@ plugins: Dict[str, Set[str]] = {
"salesforce": {"simple_salesforce~=1.11.4"},
"okta": {"okta~=2.3.0"},
"mlflow": {"mlflow-skinny~=1.22.0"},
"sklearn": {"scikit-learn==1.0.2"},
}
dev = {
"boto3==1.20.14",
@ -125,6 +126,9 @@ test = {
"pytest-cov",
"faker",
"coverage",
# sklearn integration
"scikit-learn==1.0.2",
"pandas==1.3.5",
}
build_options = {"includes": ["_cffi_backend"]}

View File

@ -10,6 +10,7 @@ from pydantic import BaseModel
from metadata.generated.schema.api.lineage.addLineage import AddLineage
from metadata.ingestion.ometa.client import REST, APIError
from metadata.ingestion.ometa.utils import get_entity_type
logger = logging.getLogger(__name__)
@ -97,7 +98,7 @@ class OMetaLineageMixin(Generic[T]):
:param up_depth: Upstream depth of lineage (default=1, min=0, max=3)"
:param down_depth: Downstream depth of lineage (default=1, min=0, max=3)
"""
entity_name = self.get_entity_type(entity)
entity_name = get_entity_type(entity)
search = (
f"?upstreamDepth={min(up_depth, 3)}&downstreamDepth={min(down_depth, 3)}"
)

View File

@ -4,13 +4,19 @@ Mixin class containing Lineage specific methods
To be used by OpenMetadata class
"""
import logging
from typing import Any, Dict
from typing import Any, Dict, Optional
from metadata.generated.schema.api.data.createMlModel import CreateMlModelEntityRequest
from metadata.generated.schema.api.lineage.addLineage import AddLineage
from metadata.generated.schema.entity.data.mlmodel import MlModel
from metadata.generated.schema.entity.data.mlmodel import (
MlFeature,
MlHyperParameter,
MlModel,
)
from metadata.generated.schema.type.entityLineage import EntitiesEdge
from metadata.ingestion.ometa.client import REST
from metadata.ingestion.ometa.mixins.lineage_mixin import OMetaLineageMixin
from metadata.ingestion.ometa.utils import format_name
logger = logging.getLogger(__name__)
@ -59,3 +65,48 @@ class OMetaMlModelMixin(OMetaLineageMixin):
mlmodel_lineage = self.get_lineage_by_id(MlModel, str(model.id.__root__))
return mlmodel_lineage
@staticmethod
def get_mlmodel_sklearn(
name: str, model, description: Optional[str] = None
) -> CreateMlModelEntityRequest:
"""
Get an MlModel Entity instance from a scikit-learn model.
Sklearn estimators all extend BaseEstimator.
:param name: MlModel name
:param model: sklearn estimator
:param description: MlModel description
:return: OpenMetadata CreateMlModelEntityRequest Entity
"""
try:
# pylint: disable=import-outside-toplevel
from sklearn.base import BaseEstimator
# pylint: enable=import-outside-toplevel
except ModuleNotFoundError as exc:
logger.error(
"Cannot import BaseEstimator, please install sklearn plugin: "
+ f"pip install openmetadata-ingestion[sklearn], {exc}"
)
raise exc
if not isinstance(model, BaseEstimator):
raise ValueError("Input model is not an instance of sklearn BaseEstimator")
return CreateMlModelEntityRequest(
name=name,
description=description,
algorithm=model.__class__.__name__,
mlFeatures=[
MlFeature(name=format_name(feature))
for feature in model.feature_names_in_
],
mlHyperParameters=[
MlHyperParameter(
name=key,
value=value,
)
for key, value in model.get_params().items()
],
)

View File

@ -13,6 +13,7 @@ from requests.models import Response
from metadata.generated.schema.type import basic
from metadata.generated.schema.type.entityHistory import EntityVersionHistory
from metadata.ingestion.ometa.client import REST
from metadata.ingestion.ometa.utils import uuid_to_str
T = TypeVar("T", bound=BaseModel)
logger = logging.getLogger(__name__)
@ -67,7 +68,7 @@ class OMetaVersionMixin(Generic[T]):
fields: List
List of fields to return
"""
entity_id = self.uuid_to_str(entity_id)
entity_id = uuid_to_str(entity_id)
version = self.version_to_str(version)
path = f"{entity_id}/versions/{version}"
@ -94,7 +95,7 @@ class OMetaVersionMixin(Generic[T]):
List
lists of available versions for a specific entity
"""
path = f"{self.uuid_to_str(entity_id)}/versions"
path = f"{uuid_to_str(entity_id)}/versions"
resp = self.client.get(f"{self.get_suffix(entity)}/{path}")

View File

@ -55,6 +55,7 @@ from metadata.ingestion.ometa.openmetadata_rest import (
NoOpAuthenticationProvider,
OktaAuthenticationProvider,
)
from metadata.ingestion.ometa.utils import get_entity_type, uuid_to_str
logger = logging.getLogger(__name__)
@ -253,27 +254,6 @@ class OpenMetadata(
f"Missing {entity} type when generating suffixes"
)
@staticmethod
def get_entity_type(
entity: Union[Type[T], str],
) -> str:
"""
Given an Entity T, return its type.
E.g., Table returns table, Dashboard returns dashboard...
Also allow to be the identity if we just receive a string
"""
if isinstance(entity, str):
return entity
class_name: str = entity.__name__.lower()
if "service" in class_name:
# Capitalize service, e.g., pipelineService
return class_name.replace("service", "Service")
return class_name
def get_module_path(self, entity: Type[T]) -> str:
"""
Based on the entity, return the module path
@ -355,20 +335,6 @@ class OpenMetadata(
resp = self.client.put(self.get_suffix(entity), data=data.json())
return entity_class(**resp)
@staticmethod
def uuid_to_str(entity_id: Union[str, basic.Uuid]) -> str:
"""
Given an entity_id, that can be a str or our pydantic
definition of Uuid, return a proper str to build
the endpoint path
:param entity_id: Entity ID to onvert to string
:return: str for the ID
"""
if isinstance(entity_id, basic.Uuid):
return str(entity_id.__root__)
return entity_id
def get_by_name(
self, entity: Type[T], fqdn: str, fields: Optional[List[str]] = None
) -> Optional[T]:
@ -388,7 +354,7 @@ class OpenMetadata(
Return entity by ID or None
"""
return self._get(entity=entity, path=self.uuid_to_str(entity_id), fields=fields)
return self._get(entity=entity, path=uuid_to_str(entity_id), fields=fields)
def _get(
self, entity: Type[T], path: str, fields: Optional[List[str]] = None
@ -405,7 +371,8 @@ class OpenMetadata(
return entity(**resp)
except APIError as err:
logger.error(
f"Creating new {entity.__class__.__name__} for {path}. Error {err.status_code}"
f"Creating new {entity.__class__.__name__} for {path}. "
+ f"Error {err.status_code}"
)
return None
@ -423,7 +390,7 @@ class OpenMetadata(
if instance:
return EntityReference(
id=instance.id,
type=self.get_entity_type(entity),
type=get_entity_type(entity),
name=instance.fullyQualifiedName,
description=instance.description,
href=instance.href,
@ -470,13 +437,13 @@ class OpenMetadata(
return [entity(**p) for p in resp["data"]]
def delete(self, entity: Type[T], entity_id: Union[str, basic.Uuid]) -> None:
self.client.delete(f"{self.get_suffix(entity)}/{self.uuid_to_str(entity_id)}")
self.client.delete(f"{self.get_suffix(entity)}/{uuid_to_str(entity_id)}")
def compute_percentile(self, entity: Union[Type[T], str], date: str) -> None:
"""
Compute an entity usage percentile
"""
entity_name = self.get_entity_type(entity)
entity_name = get_entity_type(entity)
resp = self.client.post(f"/usage/compute.percentile/{entity_name}/{date}")
logger.debug("published compute percentile {}".format(resp))

View File

@ -0,0 +1,68 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Helper functions to handle OpenMetadata Entities' properties
"""
import re
import string
from typing import Type, TypeVar, Union
from pydantic import BaseModel
from metadata.generated.schema.type import basic
T = TypeVar("T", bound=BaseModel)
def format_name(name: str) -> str:
"""
Given a name, replace all special characters by `_`
:param name: name to format
:return: formatted string
"""
subs = re.escape(string.punctuation + " ")
return re.sub(r"[" + subs + "]", "_", name)
def get_entity_type(
entity: Union[Type[T], str],
) -> str:
"""
Given an Entity T, return its type.
E.g., Table returns table, Dashboard returns dashboard...
Also allow to be the identity if we just receive a string
"""
if isinstance(entity, str):
return entity
class_name: str = entity.__name__.lower()
if "service" in class_name:
# Capitalize service, e.g., pipelineService
return class_name.replace("service", "Service")
return class_name
def uuid_to_str(entity_id: Union[str, basic.Uuid]) -> str:
"""
Given an entity_id, that can be a str or our pydantic
definition of Uuid, return a proper str to build
the endpoint path
:param entity_id: Entity ID to onvert to string
:return: str for the ID
"""
if isinstance(entity_id, basic.Uuid):
return str(entity_id.__root__)
return entity_id

View File

@ -0,0 +1,78 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
OpenMetadata MlModel mixin test
"""
from unittest import TestCase
import pandas as pd
import sklearn.datasets as datasets
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from metadata.generated.schema.api.data.createMlModel import CreateMlModelEntityRequest
from metadata.generated.schema.entity.data.mlmodel import MlFeature, MlModel
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
class OMetaModelMixinTest(TestCase):
"""
Test the MlModel integrations from MlModel Mixin
"""
server_config = MetadataServerConfig(api_endpoint="http://localhost:8585/api")
metadata = OpenMetadata(server_config)
iris = datasets.load_iris()
def test_get_sklearn(self):
"""
Check that we can ingest an SKlearn model
"""
df = pd.DataFrame(self.iris.data, columns=self.iris.feature_names)
y = self.iris.target
x_train, x_test, y_train, y_test = train_test_split(
df, y, test_size=0.25, random_state=70
)
dtree = DecisionTreeClassifier()
dtree.fit(x_train, y_train)
entity_create: CreateMlModelEntityRequest = self.metadata.get_mlmodel_sklearn(
name="test-sklearn",
model=dtree,
description="Creating a test sklearn model",
)
entity: MlModel = self.metadata.create_or_update(data=entity_create)
self.assertEqual(entity.name, entity_create.name)
self.assertEqual(entity.algorithm, "DecisionTreeClassifier")
self.assertEqual(
{feature.name.__root__ for feature in entity.mlFeatures},
{
"sepal_length__cm_",
"sepal_width__cm_",
"petal_length__cm_",
"petal_width__cm_",
},
)
hyper_param = next(
iter(
param for param in entity.mlHyperParameters if param.name == "criterion"
),
None,
)
self.assertIsNotNone(hyper_param)

View File

@ -0,0 +1,49 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
OpenMetadata utils tests
"""
from unittest import TestCase
from metadata.generated.schema.entity.data.mlmodel import MlModel
from metadata.generated.schema.type import basic
from metadata.ingestion.ometa.utils import format_name, get_entity_type, uuid_to_str
class OMetaUtilsTest(TestCase):
def test_format_name(self):
"""
Check we are properly formatting names
"""
self.assertEqual(format_name("random"), "random")
self.assertEqual(format_name("ran dom"), "ran_dom")
self.assertEqual(format_name("ran_(dom"), "ran__dom")
def test_get_entity_type(self):
"""
Check that we return a string or the class name
"""
self.assertEqual(get_entity_type("hello"), "hello")
self.assertEqual(get_entity_type(MlModel), "mlmodel")
def test_uuid_to_str(self):
"""
Return Uuid as str
"""
self.assertEqual(uuid_to_str("random"), "random")
self.assertEqual(
uuid_to_str(basic.Uuid(__root__="9fc58e81-7412-4023-a298-59f2494aab9d")),
"9fc58e81-7412-4023-a298-59f2494aab9d",
)