From 89a026b022e34e461228cedcb460baaba501a13d Mon Sep 17 00:00:00 2001 From: Pere Miquel Brull Date: Tue, 28 Jun 2022 14:58:38 +0200 Subject: [PATCH] Fix #5688 - MlFlow deployment fails from UI (#5691) * Fix mlflow from UI * format --- .../testServiceConnection.json | 8 +++++- ingestion/setup.py | 2 +- .../source/database/common_db_source.py | 1 - .../ingestion/source/mlmodel/mlflow.py | 9 ++++--- .../src/metadata/utils/connection_clients.py | 6 +++++ ingestion/src/metadata/utils/connections.py | 26 +++++++++++++++++++ .../workflows/ingestion/common.py | 5 ++++ 7 files changed, 50 insertions(+), 7 deletions(-) diff --git a/catalog-rest-service/src/main/resources/json/schema/api/services/ingestionPipelines/testServiceConnection.json b/catalog-rest-service/src/main/resources/json/schema/api/services/ingestionPipelines/testServiceConnection.json index 1b8c2548aaf..ec44d26cf0c 100644 --- a/catalog-rest-service/src/main/resources/json/schema/api/services/ingestionPipelines/testServiceConnection.json +++ b/catalog-rest-service/src/main/resources/json/schema/api/services/ingestionPipelines/testServiceConnection.json @@ -19,13 +19,16 @@ }, { "$ref": "../../../entity/services/pipelineService.json#/definitions/pipelineConnection" + }, + { + "$ref": "../../../entity/services/mlmodelService.json#/definitions/mlModelConnection" } ] }, "connectionType": { "description": "Type of database service such as MySQL, BigQuery, Snowflake, Redshift, Postgres...", "type": "string", - "enum": ["Database", "Dashboard", "Messaging", "Pipeline"], + "enum": ["Database", "Dashboard", "Messaging", "Pipeline", "MlModel"], "javaEnums": [ { "name": "Database" @@ -38,6 +41,9 @@ }, { "name": "Pipeline" + }, + { + "name": "MlModel" } ] } diff --git a/ingestion/setup.py b/ingestion/setup.py index b8adcae19d0..3f63a9e06da 100644 --- a/ingestion/setup.py +++ b/ingestion/setup.py @@ -122,7 +122,7 @@ plugins: Dict[str, Set[str]] = { "webhook-server": {}, "salesforce": {"simple_salesforce~=1.11.4"}, "okta": {"okta~=2.3.0"}, - "mlflow": {"mlflow-skinny~=1.22.0"}, + "mlflow": {"mlflow-skinny~=1.26.1"}, "sklearn": {"scikit-learn==1.0.2"}, "db2": {"ibm-db-sa==0.3.7"}, "clickhouse": {"clickhouse-driver==0.2.3", "clickhouse-sqlalchemy==0.2.0"}, diff --git a/ingestion/src/metadata/ingestion/source/database/common_db_source.py b/ingestion/src/metadata/ingestion/source/database/common_db_source.py index 7c9c59e1095..e43eae905dc 100644 --- a/ingestion/src/metadata/ingestion/source/database/common_db_source.py +++ b/ingestion/src/metadata/ingestion/source/database/common_db_source.py @@ -326,7 +326,6 @@ class CommonDbSourceService( # Disable the DictConfigurator.configure method while importing LineageRunner configure = DictConfigurator.configure DictConfigurator.configure = lambda _: None - from sqllineage.exceptions import SQLLineageException from sqllineage.runner import LineageRunner # Reverting changes after import is done diff --git a/ingestion/src/metadata/ingestion/source/mlmodel/mlflow.py b/ingestion/src/metadata/ingestion/source/mlmodel/mlflow.py index 2abe16024a1..241dabb5f01 100644 --- a/ingestion/src/metadata/ingestion/source/mlmodel/mlflow.py +++ b/ingestion/src/metadata/ingestion/source/mlmodel/mlflow.py @@ -40,6 +40,7 @@ from metadata.generated.schema.metadataIngestion.workflow import ( from metadata.generated.schema.type.entityReference import EntityReference from metadata.ingestion.api.source import InvalidSourceException, Source, SourceStatus from metadata.ingestion.ometa.ometa_api import OpenMetadata +from metadata.utils.connections import get_connection from metadata.utils.logger import ingestion_logger logger = ingestion_logger() @@ -92,10 +93,10 @@ class MlflowSource(Source[CreateMlModelRequest]): self.metadata = OpenMetadata(metadata_config) - self.client = MlflowClient( - tracking_uri=self.service_connection.trackingUri, - registry_uri=self.service_connection.registryUri, - ) + self.connection = get_connection(self.service_connection) + self.test_connection() + self.client = self.connection.client + self.status = MlFlowStatus() self.service = self.metadata.get_service_or_create( entity=MlModelService, config=config diff --git a/ingestion/src/metadata/utils/connection_clients.py b/ingestion/src/metadata/utils/connection_clients.py index bc902890563..7a2abed58a3 100644 --- a/ingestion/src/metadata/utils/connection_clients.py +++ b/ingestion/src/metadata/utils/connection_clients.py @@ -109,3 +109,9 @@ class AirByteClient: class ModeClient: def __init__(self, client) -> None: self.client = client + + +@dataclass +class MlflowClientWrapper: + def __init__(self, client) -> None: + self.client = client diff --git a/ingestion/src/metadata/utils/connections.py b/ingestion/src/metadata/utils/connections.py index 6346d3a81b3..46fc1526955 100644 --- a/ingestion/src/metadata/utils/connections.py +++ b/ingestion/src/metadata/utils/connections.py @@ -80,6 +80,9 @@ from metadata.generated.schema.entity.services.connections.database.snowflakeCon from metadata.generated.schema.entity.services.connections.messaging.kafkaConnection import ( KafkaConnection, ) +from metadata.generated.schema.entity.services.connections.mlmodel.mlflowConnection import ( + MlflowConnection, +) from metadata.generated.schema.entity.services.connections.pipeline.airbyteConnection import ( AirbyteConnection, ) @@ -103,6 +106,7 @@ from metadata.utils.connection_clients import ( KafkaClient, LookerClient, MetabaseClient, + MlflowClientWrapper, ModeClient, PowerBiClient, RedashClient, @@ -754,6 +758,28 @@ def _(connection: ModeClient) -> None: ) +@get_connection.register +def _(connection: MlflowConnection, verbose: bool = False): + from mlflow.tracking import MlflowClient + + return MlflowClientWrapper( + MlflowClient( + tracking_uri=connection.trackingUri, + registry_uri=connection.registryUri, + ) + ) + + +@test_connection.register +def _(connection: MlflowClientWrapper) -> None: + try: + connection.client.list_registered_models() + except Exception as err: + raise SourceConnectionException( + f"Unknown error connecting with {connection} - {err}." + ) + + @get_connection.register def _(_: BackendConnection, verbose: bool = False): """ diff --git a/openmetadata-airflow-apis/src/openmetadata/workflows/ingestion/common.py b/openmetadata-airflow-apis/src/openmetadata/workflows/ingestion/common.py index 54a9f3994cd..2b8f78bb7b2 100644 --- a/openmetadata-airflow-apis/src/openmetadata/workflows/ingestion/common.py +++ b/openmetadata-airflow-apis/src/openmetadata/workflows/ingestion/common.py @@ -21,6 +21,7 @@ from airflow import DAG from metadata.generated.schema.entity.services.dashboardService import DashboardService from metadata.generated.schema.entity.services.databaseService import DatabaseService from metadata.generated.schema.entity.services.messagingService import MessagingService +from metadata.generated.schema.entity.services.mlmodelService import MlModelService from metadata.generated.schema.entity.services.pipelineService import PipelineService from metadata.generated.schema.type import basic from metadata.ingestion.models.encoders import show_secrets_encoder @@ -79,6 +80,10 @@ def build_source(ingestion_pipeline: IngestionPipeline) -> WorkflowSource: service: MessagingService = metadata.get_by_name( entity=MessagingService, fqn=ingestion_pipeline.service.name ) + elif service_type == "mlmodelService": + service: MlModelService = metadata.get_by_name( + entity=MlModelService, fqn=ingestion_pipeline.service.name + ) if not service: raise ValueError(f"Could not get service from type {service_type}")