Fix #5688 - MlFlow deployment fails from UI (#5691)

* Fix mlflow from UI

* format
This commit is contained in:
Pere Miquel Brull 2022-06-28 14:58:38 +02:00 committed by GitHub
parent 7a131b115f
commit 89a026b022
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 50 additions and 7 deletions

View File

@ -19,13 +19,16 @@
},
{
"$ref": "../../../entity/services/pipelineService.json#/definitions/pipelineConnection"
},
{
"$ref": "../../../entity/services/mlmodelService.json#/definitions/mlModelConnection"
}
]
},
"connectionType": {
"description": "Type of database service such as MySQL, BigQuery, Snowflake, Redshift, Postgres...",
"type": "string",
"enum": ["Database", "Dashboard", "Messaging", "Pipeline"],
"enum": ["Database", "Dashboard", "Messaging", "Pipeline", "MlModel"],
"javaEnums": [
{
"name": "Database"
@ -38,6 +41,9 @@
},
{
"name": "Pipeline"
},
{
"name": "MlModel"
}
]
}

View File

@ -122,7 +122,7 @@ plugins: Dict[str, Set[str]] = {
"webhook-server": {},
"salesforce": {"simple_salesforce~=1.11.4"},
"okta": {"okta~=2.3.0"},
"mlflow": {"mlflow-skinny~=1.22.0"},
"mlflow": {"mlflow-skinny~=1.26.1"},
"sklearn": {"scikit-learn==1.0.2"},
"db2": {"ibm-db-sa==0.3.7"},
"clickhouse": {"clickhouse-driver==0.2.3", "clickhouse-sqlalchemy==0.2.0"},

View File

@ -326,7 +326,6 @@ class CommonDbSourceService(
# Disable the DictConfigurator.configure method while importing LineageRunner
configure = DictConfigurator.configure
DictConfigurator.configure = lambda _: None
from sqllineage.exceptions import SQLLineageException
from sqllineage.runner import LineageRunner
# Reverting changes after import is done

View File

@ -40,6 +40,7 @@ from metadata.generated.schema.metadataIngestion.workflow import (
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.api.source import InvalidSourceException, Source, SourceStatus
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.utils.connections import get_connection
from metadata.utils.logger import ingestion_logger
logger = ingestion_logger()
@ -92,10 +93,10 @@ class MlflowSource(Source[CreateMlModelRequest]):
self.metadata = OpenMetadata(metadata_config)
self.client = MlflowClient(
tracking_uri=self.service_connection.trackingUri,
registry_uri=self.service_connection.registryUri,
)
self.connection = get_connection(self.service_connection)
self.test_connection()
self.client = self.connection.client
self.status = MlFlowStatus()
self.service = self.metadata.get_service_or_create(
entity=MlModelService, config=config

View File

@ -109,3 +109,9 @@ class AirByteClient:
class ModeClient:
def __init__(self, client) -> None:
self.client = client
@dataclass
class MlflowClientWrapper:
def __init__(self, client) -> None:
self.client = client

View File

@ -80,6 +80,9 @@ from metadata.generated.schema.entity.services.connections.database.snowflakeCon
from metadata.generated.schema.entity.services.connections.messaging.kafkaConnection import (
KafkaConnection,
)
from metadata.generated.schema.entity.services.connections.mlmodel.mlflowConnection import (
MlflowConnection,
)
from metadata.generated.schema.entity.services.connections.pipeline.airbyteConnection import (
AirbyteConnection,
)
@ -103,6 +106,7 @@ from metadata.utils.connection_clients import (
KafkaClient,
LookerClient,
MetabaseClient,
MlflowClientWrapper,
ModeClient,
PowerBiClient,
RedashClient,
@ -754,6 +758,28 @@ def _(connection: ModeClient) -> None:
)
@get_connection.register
def _(connection: MlflowConnection, verbose: bool = False):
from mlflow.tracking import MlflowClient
return MlflowClientWrapper(
MlflowClient(
tracking_uri=connection.trackingUri,
registry_uri=connection.registryUri,
)
)
@test_connection.register
def _(connection: MlflowClientWrapper) -> None:
try:
connection.client.list_registered_models()
except Exception as err:
raise SourceConnectionException(
f"Unknown error connecting with {connection} - {err}."
)
@get_connection.register
def _(_: BackendConnection, verbose: bool = False):
"""

View File

@ -21,6 +21,7 @@ from airflow import DAG
from metadata.generated.schema.entity.services.dashboardService import DashboardService
from metadata.generated.schema.entity.services.databaseService import DatabaseService
from metadata.generated.schema.entity.services.messagingService import MessagingService
from metadata.generated.schema.entity.services.mlmodelService import MlModelService
from metadata.generated.schema.entity.services.pipelineService import PipelineService
from metadata.generated.schema.type import basic
from metadata.ingestion.models.encoders import show_secrets_encoder
@ -79,6 +80,10 @@ def build_source(ingestion_pipeline: IngestionPipeline) -> WorkflowSource:
service: MessagingService = metadata.get_by_name(
entity=MessagingService, fqn=ingestion_pipeline.service.name
)
elif service_type == "mlmodelService":
service: MlModelService = metadata.get_by_name(
entity=MlModelService, fqn=ingestion_pipeline.service.name
)
if not service:
raise ValueError(f"Could not get service from type {service_type}")