From d3b6c7cf270cb033243a251b6435f616886c2197 Mon Sep 17 00:00:00 2001 From: Pere Miquel Brull Date: Mon, 10 Jan 2022 09:36:08 +0100 Subject: [PATCH] [issue-1973] - Python API from Sklearn to MlModel (#2119) * Move staticmethods to utils * Use functions from utils * Convert sklearn to MlModel * merge main --- ingestion/setup.py | 4 + .../ingestion/ometa/mixins/lineage_mixin.py | 3 +- .../ingestion/ometa/mixins/mlmodel_mixin.py | 55 ++++++++++++- .../ingestion/ometa/mixins/version_mixin.py | 5 +- .../src/metadata/ingestion/ometa/ometa_api.py | 47 ++--------- .../src/metadata/ingestion/ometa/utils.py | 68 ++++++++++++++++ ingestion/tests/unit/test_ometa_mlmodel.py | 78 +++++++++++++++++++ ingestion/tests/unit/test_ometa_utils.py | 49 ++++++++++++ 8 files changed, 264 insertions(+), 45 deletions(-) create mode 100644 ingestion/src/metadata/ingestion/ometa/utils.py create mode 100644 ingestion/tests/unit/test_ometa_mlmodel.py create mode 100644 ingestion/tests/unit/test_ometa_utils.py diff --git a/ingestion/setup.py b/ingestion/setup.py index f571103f95c..150419b3a89 100644 --- a/ingestion/setup.py +++ b/ingestion/setup.py @@ -107,6 +107,7 @@ plugins: Dict[str, Set[str]] = { "salesforce": {"simple_salesforce~=1.11.4"}, "okta": {"okta~=2.3.0"}, "mlflow": {"mlflow-skinny~=1.22.0"}, + "sklearn": {"scikit-learn==1.0.2"}, } dev = { "boto3==1.20.14", @@ -125,6 +126,9 @@ test = { "pytest-cov", "faker", "coverage", + # sklearn integration + "scikit-learn==1.0.2", + "pandas==1.3.5", } build_options = {"includes": ["_cffi_backend"]} diff --git a/ingestion/src/metadata/ingestion/ometa/mixins/lineage_mixin.py b/ingestion/src/metadata/ingestion/ometa/mixins/lineage_mixin.py index 1629936de04..254d6c6cde9 100644 --- a/ingestion/src/metadata/ingestion/ometa/mixins/lineage_mixin.py +++ b/ingestion/src/metadata/ingestion/ometa/mixins/lineage_mixin.py @@ -10,6 +10,7 @@ from pydantic import BaseModel from metadata.generated.schema.api.lineage.addLineage import AddLineage from metadata.ingestion.ometa.client import REST, APIError +from metadata.ingestion.ometa.utils import get_entity_type logger = logging.getLogger(__name__) @@ -97,7 +98,7 @@ class OMetaLineageMixin(Generic[T]): :param up_depth: Upstream depth of lineage (default=1, min=0, max=3)" :param down_depth: Downstream depth of lineage (default=1, min=0, max=3) """ - entity_name = self.get_entity_type(entity) + entity_name = get_entity_type(entity) search = ( f"?upstreamDepth={min(up_depth, 3)}&downstreamDepth={min(down_depth, 3)}" ) diff --git a/ingestion/src/metadata/ingestion/ometa/mixins/mlmodel_mixin.py b/ingestion/src/metadata/ingestion/ometa/mixins/mlmodel_mixin.py index 04d1383476c..62622c72584 100644 --- a/ingestion/src/metadata/ingestion/ometa/mixins/mlmodel_mixin.py +++ b/ingestion/src/metadata/ingestion/ometa/mixins/mlmodel_mixin.py @@ -4,13 +4,19 @@ Mixin class containing Lineage specific methods To be used by OpenMetadata class """ import logging -from typing import Any, Dict +from typing import Any, Dict, Optional +from metadata.generated.schema.api.data.createMlModel import CreateMlModelEntityRequest from metadata.generated.schema.api.lineage.addLineage import AddLineage -from metadata.generated.schema.entity.data.mlmodel import MlModel +from metadata.generated.schema.entity.data.mlmodel import ( + MlFeature, + MlHyperParameter, + MlModel, +) from metadata.generated.schema.type.entityLineage import EntitiesEdge from metadata.ingestion.ometa.client import REST from metadata.ingestion.ometa.mixins.lineage_mixin import OMetaLineageMixin +from metadata.ingestion.ometa.utils import format_name logger = logging.getLogger(__name__) @@ -59,3 +65,48 @@ class OMetaMlModelMixin(OMetaLineageMixin): mlmodel_lineage = self.get_lineage_by_id(MlModel, str(model.id.__root__)) return mlmodel_lineage + + @staticmethod + def get_mlmodel_sklearn( + name: str, model, description: Optional[str] = None + ) -> CreateMlModelEntityRequest: + """ + Get an MlModel Entity instance from a scikit-learn model. + + Sklearn estimators all extend BaseEstimator. + :param name: MlModel name + :param model: sklearn estimator + :param description: MlModel description + :return: OpenMetadata CreateMlModelEntityRequest Entity + """ + try: + # pylint: disable=import-outside-toplevel + from sklearn.base import BaseEstimator + + # pylint: enable=import-outside-toplevel + except ModuleNotFoundError as exc: + logger.error( + "Cannot import BaseEstimator, please install sklearn plugin: " + + f"pip install openmetadata-ingestion[sklearn], {exc}" + ) + raise exc + + if not isinstance(model, BaseEstimator): + raise ValueError("Input model is not an instance of sklearn BaseEstimator") + + return CreateMlModelEntityRequest( + name=name, + description=description, + algorithm=model.__class__.__name__, + mlFeatures=[ + MlFeature(name=format_name(feature)) + for feature in model.feature_names_in_ + ], + mlHyperParameters=[ + MlHyperParameter( + name=key, + value=value, + ) + for key, value in model.get_params().items() + ], + ) diff --git a/ingestion/src/metadata/ingestion/ometa/mixins/version_mixin.py b/ingestion/src/metadata/ingestion/ometa/mixins/version_mixin.py index e5fd32f074e..cb88e406678 100644 --- a/ingestion/src/metadata/ingestion/ometa/mixins/version_mixin.py +++ b/ingestion/src/metadata/ingestion/ometa/mixins/version_mixin.py @@ -13,6 +13,7 @@ from requests.models import Response from metadata.generated.schema.type import basic from metadata.generated.schema.type.entityHistory import EntityVersionHistory from metadata.ingestion.ometa.client import REST +from metadata.ingestion.ometa.utils import uuid_to_str T = TypeVar("T", bound=BaseModel) logger = logging.getLogger(__name__) @@ -67,7 +68,7 @@ class OMetaVersionMixin(Generic[T]): fields: List List of fields to return """ - entity_id = self.uuid_to_str(entity_id) + entity_id = uuid_to_str(entity_id) version = self.version_to_str(version) path = f"{entity_id}/versions/{version}" @@ -94,7 +95,7 @@ class OMetaVersionMixin(Generic[T]): List lists of available versions for a specific entity """ - path = f"{self.uuid_to_str(entity_id)}/versions" + path = f"{uuid_to_str(entity_id)}/versions" resp = self.client.get(f"{self.get_suffix(entity)}/{path}") diff --git a/ingestion/src/metadata/ingestion/ometa/ometa_api.py b/ingestion/src/metadata/ingestion/ometa/ometa_api.py index 097c2ae54b7..8c65d35bc56 100644 --- a/ingestion/src/metadata/ingestion/ometa/ometa_api.py +++ b/ingestion/src/metadata/ingestion/ometa/ometa_api.py @@ -55,6 +55,7 @@ from metadata.ingestion.ometa.openmetadata_rest import ( NoOpAuthenticationProvider, OktaAuthenticationProvider, ) +from metadata.ingestion.ometa.utils import get_entity_type, uuid_to_str logger = logging.getLogger(__name__) @@ -253,27 +254,6 @@ class OpenMetadata( f"Missing {entity} type when generating suffixes" ) - @staticmethod - def get_entity_type( - entity: Union[Type[T], str], - ) -> str: - """ - Given an Entity T, return its type. - E.g., Table returns table, Dashboard returns dashboard... - - Also allow to be the identity if we just receive a string - """ - if isinstance(entity, str): - return entity - - class_name: str = entity.__name__.lower() - - if "service" in class_name: - # Capitalize service, e.g., pipelineService - return class_name.replace("service", "Service") - - return class_name - def get_module_path(self, entity: Type[T]) -> str: """ Based on the entity, return the module path @@ -355,20 +335,6 @@ class OpenMetadata( resp = self.client.put(self.get_suffix(entity), data=data.json()) return entity_class(**resp) - @staticmethod - def uuid_to_str(entity_id: Union[str, basic.Uuid]) -> str: - """ - Given an entity_id, that can be a str or our pydantic - definition of Uuid, return a proper str to build - the endpoint path - :param entity_id: Entity ID to onvert to string - :return: str for the ID - """ - if isinstance(entity_id, basic.Uuid): - return str(entity_id.__root__) - - return entity_id - def get_by_name( self, entity: Type[T], fqdn: str, fields: Optional[List[str]] = None ) -> Optional[T]: @@ -388,7 +354,7 @@ class OpenMetadata( Return entity by ID or None """ - return self._get(entity=entity, path=self.uuid_to_str(entity_id), fields=fields) + return self._get(entity=entity, path=uuid_to_str(entity_id), fields=fields) def _get( self, entity: Type[T], path: str, fields: Optional[List[str]] = None @@ -405,7 +371,8 @@ class OpenMetadata( return entity(**resp) except APIError as err: logger.error( - f"Creating new {entity.__class__.__name__} for {path}. Error {err.status_code}" + f"Creating new {entity.__class__.__name__} for {path}. " + + f"Error {err.status_code}" ) return None @@ -423,7 +390,7 @@ class OpenMetadata( if instance: return EntityReference( id=instance.id, - type=self.get_entity_type(entity), + type=get_entity_type(entity), name=instance.fullyQualifiedName, description=instance.description, href=instance.href, @@ -470,13 +437,13 @@ class OpenMetadata( return [entity(**p) for p in resp["data"]] def delete(self, entity: Type[T], entity_id: Union[str, basic.Uuid]) -> None: - self.client.delete(f"{self.get_suffix(entity)}/{self.uuid_to_str(entity_id)}") + self.client.delete(f"{self.get_suffix(entity)}/{uuid_to_str(entity_id)}") def compute_percentile(self, entity: Union[Type[T], str], date: str) -> None: """ Compute an entity usage percentile """ - entity_name = self.get_entity_type(entity) + entity_name = get_entity_type(entity) resp = self.client.post(f"/usage/compute.percentile/{entity_name}/{date}") logger.debug("published compute percentile {}".format(resp)) diff --git a/ingestion/src/metadata/ingestion/ometa/utils.py b/ingestion/src/metadata/ingestion/ometa/utils.py new file mode 100644 index 00000000000..a7f5371bfa6 --- /dev/null +++ b/ingestion/src/metadata/ingestion/ometa/utils.py @@ -0,0 +1,68 @@ +# Copyright 2021 Collate +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Helper functions to handle OpenMetadata Entities' properties +""" + +import re +import string +from typing import Type, TypeVar, Union + +from pydantic import BaseModel + +from metadata.generated.schema.type import basic + +T = TypeVar("T", bound=BaseModel) + + +def format_name(name: str) -> str: + """ + Given a name, replace all special characters by `_` + :param name: name to format + :return: formatted string + """ + subs = re.escape(string.punctuation + " ") + return re.sub(r"[" + subs + "]", "_", name) + + +def get_entity_type( + entity: Union[Type[T], str], +) -> str: + """ + Given an Entity T, return its type. + E.g., Table returns table, Dashboard returns dashboard... + + Also allow to be the identity if we just receive a string + """ + if isinstance(entity, str): + return entity + + class_name: str = entity.__name__.lower() + + if "service" in class_name: + # Capitalize service, e.g., pipelineService + return class_name.replace("service", "Service") + + return class_name + + +def uuid_to_str(entity_id: Union[str, basic.Uuid]) -> str: + """ + Given an entity_id, that can be a str or our pydantic + definition of Uuid, return a proper str to build + the endpoint path + :param entity_id: Entity ID to onvert to string + :return: str for the ID + """ + if isinstance(entity_id, basic.Uuid): + return str(entity_id.__root__) + + return entity_id diff --git a/ingestion/tests/unit/test_ometa_mlmodel.py b/ingestion/tests/unit/test_ometa_mlmodel.py new file mode 100644 index 00000000000..f07f67b509d --- /dev/null +++ b/ingestion/tests/unit/test_ometa_mlmodel.py @@ -0,0 +1,78 @@ +# Copyright 2021 Collate +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +OpenMetadata MlModel mixin test +""" +from unittest import TestCase + +import pandas as pd +import sklearn.datasets as datasets +from sklearn.model_selection import train_test_split +from sklearn.tree import DecisionTreeClassifier + +from metadata.generated.schema.api.data.createMlModel import CreateMlModelEntityRequest +from metadata.generated.schema.entity.data.mlmodel import MlFeature, MlModel +from metadata.ingestion.ometa.ometa_api import OpenMetadata +from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig + + +class OMetaModelMixinTest(TestCase): + """ + Test the MlModel integrations from MlModel Mixin + """ + + server_config = MetadataServerConfig(api_endpoint="http://localhost:8585/api") + metadata = OpenMetadata(server_config) + + iris = datasets.load_iris() + + def test_get_sklearn(self): + """ + Check that we can ingest an SKlearn model + """ + df = pd.DataFrame(self.iris.data, columns=self.iris.feature_names) + y = self.iris.target + + x_train, x_test, y_train, y_test = train_test_split( + df, y, test_size=0.25, random_state=70 + ) + + dtree = DecisionTreeClassifier() + dtree.fit(x_train, y_train) + + entity_create: CreateMlModelEntityRequest = self.metadata.get_mlmodel_sklearn( + name="test-sklearn", + model=dtree, + description="Creating a test sklearn model", + ) + + entity: MlModel = self.metadata.create_or_update(data=entity_create) + + self.assertEqual(entity.name, entity_create.name) + self.assertEqual(entity.algorithm, "DecisionTreeClassifier") + self.assertEqual( + {feature.name.__root__ for feature in entity.mlFeatures}, + { + "sepal_length__cm_", + "sepal_width__cm_", + "petal_length__cm_", + "petal_width__cm_", + }, + ) + + hyper_param = next( + iter( + param for param in entity.mlHyperParameters if param.name == "criterion" + ), + None, + ) + self.assertIsNotNone(hyper_param) diff --git a/ingestion/tests/unit/test_ometa_utils.py b/ingestion/tests/unit/test_ometa_utils.py new file mode 100644 index 00000000000..e7857140ab1 --- /dev/null +++ b/ingestion/tests/unit/test_ometa_utils.py @@ -0,0 +1,49 @@ +# Copyright 2021 Collate +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +OpenMetadata utils tests +""" +from unittest import TestCase + +from metadata.generated.schema.entity.data.mlmodel import MlModel +from metadata.generated.schema.type import basic +from metadata.ingestion.ometa.utils import format_name, get_entity_type, uuid_to_str + + +class OMetaUtilsTest(TestCase): + def test_format_name(self): + """ + Check we are properly formatting names + """ + + self.assertEqual(format_name("random"), "random") + self.assertEqual(format_name("ran dom"), "ran_dom") + self.assertEqual(format_name("ran_(dom"), "ran__dom") + + def test_get_entity_type(self): + """ + Check that we return a string or the class name + """ + + self.assertEqual(get_entity_type("hello"), "hello") + self.assertEqual(get_entity_type(MlModel), "mlmodel") + + def test_uuid_to_str(self): + """ + Return Uuid as str + """ + + self.assertEqual(uuid_to_str("random"), "random") + self.assertEqual( + uuid_to_str(basic.Uuid(__root__="9fc58e81-7412-4023-a298-59f2494aab9d")), + "9fc58e81-7412-4023-a298-59f2494aab9d", + )