From d3b6c7cf270cb033243a251b6435f616886c2197 Mon Sep 17 00:00:00 2001
From: Pere Miquel Brull
Date: Mon, 10 Jan 2022 09:36:08 +0100
Subject: [PATCH] [issue-1973] - Python API from Sklearn to MlModel (#2119)
* Move staticmethods to utils
* Use functions from utils
* Convert sklearn to MlModel
* merge main
---
ingestion/setup.py | 4 +
.../ingestion/ometa/mixins/lineage_mixin.py | 3 +-
.../ingestion/ometa/mixins/mlmodel_mixin.py | 55 ++++++++++++-
.../ingestion/ometa/mixins/version_mixin.py | 5 +-
.../src/metadata/ingestion/ometa/ometa_api.py | 47 ++---------
.../src/metadata/ingestion/ometa/utils.py | 68 ++++++++++++++++
ingestion/tests/unit/test_ometa_mlmodel.py | 78 +++++++++++++++++++
ingestion/tests/unit/test_ometa_utils.py | 49 ++++++++++++
8 files changed, 264 insertions(+), 45 deletions(-)
create mode 100644 ingestion/src/metadata/ingestion/ometa/utils.py
create mode 100644 ingestion/tests/unit/test_ometa_mlmodel.py
create mode 100644 ingestion/tests/unit/test_ometa_utils.py
diff --git a/ingestion/setup.py b/ingestion/setup.py
index f571103f95c..150419b3a89 100644
--- a/ingestion/setup.py
+++ b/ingestion/setup.py
@@ -107,6 +107,7 @@ plugins: Dict[str, Set[str]] = {
"salesforce": {"simple_salesforce~=1.11.4"},
"okta": {"okta~=2.3.0"},
"mlflow": {"mlflow-skinny~=1.22.0"},
+ "sklearn": {"scikit-learn==1.0.2"},
}
dev = {
"boto3==1.20.14",
@@ -125,6 +126,9 @@ test = {
"pytest-cov",
"faker",
"coverage",
+ # sklearn integration
+ "scikit-learn==1.0.2",
+ "pandas==1.3.5",
}
build_options = {"includes": ["_cffi_backend"]}
diff --git a/ingestion/src/metadata/ingestion/ometa/mixins/lineage_mixin.py b/ingestion/src/metadata/ingestion/ometa/mixins/lineage_mixin.py
index 1629936de04..254d6c6cde9 100644
--- a/ingestion/src/metadata/ingestion/ometa/mixins/lineage_mixin.py
+++ b/ingestion/src/metadata/ingestion/ometa/mixins/lineage_mixin.py
@@ -10,6 +10,7 @@ from pydantic import BaseModel
from metadata.generated.schema.api.lineage.addLineage import AddLineage
from metadata.ingestion.ometa.client import REST, APIError
+from metadata.ingestion.ometa.utils import get_entity_type
logger = logging.getLogger(__name__)
@@ -97,7 +98,7 @@ class OMetaLineageMixin(Generic[T]):
:param up_depth: Upstream depth of lineage (default=1, min=0, max=3)"
:param down_depth: Downstream depth of lineage (default=1, min=0, max=3)
"""
- entity_name = self.get_entity_type(entity)
+ entity_name = get_entity_type(entity)
search = (
f"?upstreamDepth={min(up_depth, 3)}&downstreamDepth={min(down_depth, 3)}"
)
diff --git a/ingestion/src/metadata/ingestion/ometa/mixins/mlmodel_mixin.py b/ingestion/src/metadata/ingestion/ometa/mixins/mlmodel_mixin.py
index 04d1383476c..62622c72584 100644
--- a/ingestion/src/metadata/ingestion/ometa/mixins/mlmodel_mixin.py
+++ b/ingestion/src/metadata/ingestion/ometa/mixins/mlmodel_mixin.py
@@ -4,13 +4,19 @@ Mixin class containing Lineage specific methods
To be used by OpenMetadata class
"""
import logging
-from typing import Any, Dict
+from typing import Any, Dict, Optional
+from metadata.generated.schema.api.data.createMlModel import CreateMlModelEntityRequest
from metadata.generated.schema.api.lineage.addLineage import AddLineage
-from metadata.generated.schema.entity.data.mlmodel import MlModel
+from metadata.generated.schema.entity.data.mlmodel import (
+ MlFeature,
+ MlHyperParameter,
+ MlModel,
+)
from metadata.generated.schema.type.entityLineage import EntitiesEdge
from metadata.ingestion.ometa.client import REST
from metadata.ingestion.ometa.mixins.lineage_mixin import OMetaLineageMixin
+from metadata.ingestion.ometa.utils import format_name
logger = logging.getLogger(__name__)
@@ -59,3 +65,48 @@ class OMetaMlModelMixin(OMetaLineageMixin):
mlmodel_lineage = self.get_lineage_by_id(MlModel, str(model.id.__root__))
return mlmodel_lineage
+
+ @staticmethod
+ def get_mlmodel_sklearn(
+ name: str, model, description: Optional[str] = None
+ ) -> CreateMlModelEntityRequest:
+ """
+ Get an MlModel Entity instance from a scikit-learn model.
+
+ Sklearn estimators all extend BaseEstimator.
+ :param name: MlModel name
+ :param model: sklearn estimator
+ :param description: MlModel description
+ :return: OpenMetadata CreateMlModelEntityRequest Entity
+ """
+ try:
+ # pylint: disable=import-outside-toplevel
+ from sklearn.base import BaseEstimator
+
+ # pylint: enable=import-outside-toplevel
+ except ModuleNotFoundError as exc:
+ logger.error(
+ "Cannot import BaseEstimator, please install sklearn plugin: "
+ + f"pip install openmetadata-ingestion[sklearn], {exc}"
+ )
+ raise exc
+
+ if not isinstance(model, BaseEstimator):
+ raise ValueError("Input model is not an instance of sklearn BaseEstimator")
+
+ return CreateMlModelEntityRequest(
+ name=name,
+ description=description,
+ algorithm=model.__class__.__name__,
+ mlFeatures=[
+ MlFeature(name=format_name(feature))
+ for feature in model.feature_names_in_
+ ],
+ mlHyperParameters=[
+ MlHyperParameter(
+ name=key,
+ value=value,
+ )
+ for key, value in model.get_params().items()
+ ],
+ )
diff --git a/ingestion/src/metadata/ingestion/ometa/mixins/version_mixin.py b/ingestion/src/metadata/ingestion/ometa/mixins/version_mixin.py
index e5fd32f074e..cb88e406678 100644
--- a/ingestion/src/metadata/ingestion/ometa/mixins/version_mixin.py
+++ b/ingestion/src/metadata/ingestion/ometa/mixins/version_mixin.py
@@ -13,6 +13,7 @@ from requests.models import Response
from metadata.generated.schema.type import basic
from metadata.generated.schema.type.entityHistory import EntityVersionHistory
from metadata.ingestion.ometa.client import REST
+from metadata.ingestion.ometa.utils import uuid_to_str
T = TypeVar("T", bound=BaseModel)
logger = logging.getLogger(__name__)
@@ -67,7 +68,7 @@ class OMetaVersionMixin(Generic[T]):
fields: List
List of fields to return
"""
- entity_id = self.uuid_to_str(entity_id)
+ entity_id = uuid_to_str(entity_id)
version = self.version_to_str(version)
path = f"{entity_id}/versions/{version}"
@@ -94,7 +95,7 @@ class OMetaVersionMixin(Generic[T]):
List
lists of available versions for a specific entity
"""
- path = f"{self.uuid_to_str(entity_id)}/versions"
+ path = f"{uuid_to_str(entity_id)}/versions"
resp = self.client.get(f"{self.get_suffix(entity)}/{path}")
diff --git a/ingestion/src/metadata/ingestion/ometa/ometa_api.py b/ingestion/src/metadata/ingestion/ometa/ometa_api.py
index 097c2ae54b7..8c65d35bc56 100644
--- a/ingestion/src/metadata/ingestion/ometa/ometa_api.py
+++ b/ingestion/src/metadata/ingestion/ometa/ometa_api.py
@@ -55,6 +55,7 @@ from metadata.ingestion.ometa.openmetadata_rest import (
NoOpAuthenticationProvider,
OktaAuthenticationProvider,
)
+from metadata.ingestion.ometa.utils import get_entity_type, uuid_to_str
logger = logging.getLogger(__name__)
@@ -253,27 +254,6 @@ class OpenMetadata(
f"Missing {entity} type when generating suffixes"
)
- @staticmethod
- def get_entity_type(
- entity: Union[Type[T], str],
- ) -> str:
- """
- Given an Entity T, return its type.
- E.g., Table returns table, Dashboard returns dashboard...
-
- Also allow to be the identity if we just receive a string
- """
- if isinstance(entity, str):
- return entity
-
- class_name: str = entity.__name__.lower()
-
- if "service" in class_name:
- # Capitalize service, e.g., pipelineService
- return class_name.replace("service", "Service")
-
- return class_name
-
def get_module_path(self, entity: Type[T]) -> str:
"""
Based on the entity, return the module path
@@ -355,20 +335,6 @@ class OpenMetadata(
resp = self.client.put(self.get_suffix(entity), data=data.json())
return entity_class(**resp)
- @staticmethod
- def uuid_to_str(entity_id: Union[str, basic.Uuid]) -> str:
- """
- Given an entity_id, that can be a str or our pydantic
- definition of Uuid, return a proper str to build
- the endpoint path
- :param entity_id: Entity ID to onvert to string
- :return: str for the ID
- """
- if isinstance(entity_id, basic.Uuid):
- return str(entity_id.__root__)
-
- return entity_id
-
def get_by_name(
self, entity: Type[T], fqdn: str, fields: Optional[List[str]] = None
) -> Optional[T]:
@@ -388,7 +354,7 @@ class OpenMetadata(
Return entity by ID or None
"""
- return self._get(entity=entity, path=self.uuid_to_str(entity_id), fields=fields)
+ return self._get(entity=entity, path=uuid_to_str(entity_id), fields=fields)
def _get(
self, entity: Type[T], path: str, fields: Optional[List[str]] = None
@@ -405,7 +371,8 @@ class OpenMetadata(
return entity(**resp)
except APIError as err:
logger.error(
- f"Creating new {entity.__class__.__name__} for {path}. Error {err.status_code}"
+ f"Creating new {entity.__class__.__name__} for {path}. "
+ + f"Error {err.status_code}"
)
return None
@@ -423,7 +390,7 @@ class OpenMetadata(
if instance:
return EntityReference(
id=instance.id,
- type=self.get_entity_type(entity),
+ type=get_entity_type(entity),
name=instance.fullyQualifiedName,
description=instance.description,
href=instance.href,
@@ -470,13 +437,13 @@ class OpenMetadata(
return [entity(**p) for p in resp["data"]]
def delete(self, entity: Type[T], entity_id: Union[str, basic.Uuid]) -> None:
- self.client.delete(f"{self.get_suffix(entity)}/{self.uuid_to_str(entity_id)}")
+ self.client.delete(f"{self.get_suffix(entity)}/{uuid_to_str(entity_id)}")
def compute_percentile(self, entity: Union[Type[T], str], date: str) -> None:
"""
Compute an entity usage percentile
"""
- entity_name = self.get_entity_type(entity)
+ entity_name = get_entity_type(entity)
resp = self.client.post(f"/usage/compute.percentile/{entity_name}/{date}")
logger.debug("published compute percentile {}".format(resp))
diff --git a/ingestion/src/metadata/ingestion/ometa/utils.py b/ingestion/src/metadata/ingestion/ometa/utils.py
new file mode 100644
index 00000000000..a7f5371bfa6
--- /dev/null
+++ b/ingestion/src/metadata/ingestion/ometa/utils.py
@@ -0,0 +1,68 @@
+# Copyright 2021 Collate
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Helper functions to handle OpenMetadata Entities' properties
+"""
+
+import re
+import string
+from typing import Type, TypeVar, Union
+
+from pydantic import BaseModel
+
+from metadata.generated.schema.type import basic
+
+T = TypeVar("T", bound=BaseModel)
+
+
+def format_name(name: str) -> str:
+ """
+ Given a name, replace all special characters by `_`
+ :param name: name to format
+ :return: formatted string
+ """
+ subs = re.escape(string.punctuation + " ")
+ return re.sub(r"[" + subs + "]", "_", name)
+
+
+def get_entity_type(
+ entity: Union[Type[T], str],
+) -> str:
+ """
+ Given an Entity T, return its type.
+ E.g., Table returns table, Dashboard returns dashboard...
+
+ Also allow to be the identity if we just receive a string
+ """
+ if isinstance(entity, str):
+ return entity
+
+ class_name: str = entity.__name__.lower()
+
+ if "service" in class_name:
+ # Capitalize service, e.g., pipelineService
+ return class_name.replace("service", "Service")
+
+ return class_name
+
+
+def uuid_to_str(entity_id: Union[str, basic.Uuid]) -> str:
+ """
+ Given an entity_id, that can be a str or our pydantic
+ definition of Uuid, return a proper str to build
+ the endpoint path
+ :param entity_id: Entity ID to onvert to string
+ :return: str for the ID
+ """
+ if isinstance(entity_id, basic.Uuid):
+ return str(entity_id.__root__)
+
+ return entity_id
diff --git a/ingestion/tests/unit/test_ometa_mlmodel.py b/ingestion/tests/unit/test_ometa_mlmodel.py
new file mode 100644
index 00000000000..f07f67b509d
--- /dev/null
+++ b/ingestion/tests/unit/test_ometa_mlmodel.py
@@ -0,0 +1,78 @@
+# Copyright 2021 Collate
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+OpenMetadata MlModel mixin test
+"""
+from unittest import TestCase
+
+import pandas as pd
+import sklearn.datasets as datasets
+from sklearn.model_selection import train_test_split
+from sklearn.tree import DecisionTreeClassifier
+
+from metadata.generated.schema.api.data.createMlModel import CreateMlModelEntityRequest
+from metadata.generated.schema.entity.data.mlmodel import MlFeature, MlModel
+from metadata.ingestion.ometa.ometa_api import OpenMetadata
+from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
+
+
+class OMetaModelMixinTest(TestCase):
+ """
+ Test the MlModel integrations from MlModel Mixin
+ """
+
+ server_config = MetadataServerConfig(api_endpoint="http://localhost:8585/api")
+ metadata = OpenMetadata(server_config)
+
+ iris = datasets.load_iris()
+
+ def test_get_sklearn(self):
+ """
+ Check that we can ingest an SKlearn model
+ """
+ df = pd.DataFrame(self.iris.data, columns=self.iris.feature_names)
+ y = self.iris.target
+
+ x_train, x_test, y_train, y_test = train_test_split(
+ df, y, test_size=0.25, random_state=70
+ )
+
+ dtree = DecisionTreeClassifier()
+ dtree.fit(x_train, y_train)
+
+ entity_create: CreateMlModelEntityRequest = self.metadata.get_mlmodel_sklearn(
+ name="test-sklearn",
+ model=dtree,
+ description="Creating a test sklearn model",
+ )
+
+ entity: MlModel = self.metadata.create_or_update(data=entity_create)
+
+ self.assertEqual(entity.name, entity_create.name)
+ self.assertEqual(entity.algorithm, "DecisionTreeClassifier")
+ self.assertEqual(
+ {feature.name.__root__ for feature in entity.mlFeatures},
+ {
+ "sepal_length__cm_",
+ "sepal_width__cm_",
+ "petal_length__cm_",
+ "petal_width__cm_",
+ },
+ )
+
+ hyper_param = next(
+ iter(
+ param for param in entity.mlHyperParameters if param.name == "criterion"
+ ),
+ None,
+ )
+ self.assertIsNotNone(hyper_param)
diff --git a/ingestion/tests/unit/test_ometa_utils.py b/ingestion/tests/unit/test_ometa_utils.py
new file mode 100644
index 00000000000..e7857140ab1
--- /dev/null
+++ b/ingestion/tests/unit/test_ometa_utils.py
@@ -0,0 +1,49 @@
+# Copyright 2021 Collate
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+OpenMetadata utils tests
+"""
+from unittest import TestCase
+
+from metadata.generated.schema.entity.data.mlmodel import MlModel
+from metadata.generated.schema.type import basic
+from metadata.ingestion.ometa.utils import format_name, get_entity_type, uuid_to_str
+
+
+class OMetaUtilsTest(TestCase):
+ def test_format_name(self):
+ """
+ Check we are properly formatting names
+ """
+
+ self.assertEqual(format_name("random"), "random")
+ self.assertEqual(format_name("ran dom"), "ran_dom")
+ self.assertEqual(format_name("ran_(dom"), "ran__dom")
+
+ def test_get_entity_type(self):
+ """
+ Check that we return a string or the class name
+ """
+
+ self.assertEqual(get_entity_type("hello"), "hello")
+ self.assertEqual(get_entity_type(MlModel), "mlmodel")
+
+ def test_uuid_to_str(self):
+ """
+ Return Uuid as str
+ """
+
+ self.assertEqual(uuid_to_str("random"), "random")
+ self.assertEqual(
+ uuid_to_str(basic.Uuid(__root__="9fc58e81-7412-4023-a298-59f2494aab9d")),
+ "9fc58e81-7412-4023-a298-59f2494aab9d",
+ )