mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-10-27 08:44:49 +00:00
* Fix mlflow from UI * format
This commit is contained in:
parent
7a131b115f
commit
89a026b022
@ -19,13 +19,16 @@
|
||||
},
|
||||
{
|
||||
"$ref": "../../../entity/services/pipelineService.json#/definitions/pipelineConnection"
|
||||
},
|
||||
{
|
||||
"$ref": "../../../entity/services/mlmodelService.json#/definitions/mlModelConnection"
|
||||
}
|
||||
]
|
||||
},
|
||||
"connectionType": {
|
||||
"description": "Type of database service such as MySQL, BigQuery, Snowflake, Redshift, Postgres...",
|
||||
"type": "string",
|
||||
"enum": ["Database", "Dashboard", "Messaging", "Pipeline"],
|
||||
"enum": ["Database", "Dashboard", "Messaging", "Pipeline", "MlModel"],
|
||||
"javaEnums": [
|
||||
{
|
||||
"name": "Database"
|
||||
@ -38,6 +41,9 @@
|
||||
},
|
||||
{
|
||||
"name": "Pipeline"
|
||||
},
|
||||
{
|
||||
"name": "MlModel"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -122,7 +122,7 @@ plugins: Dict[str, Set[str]] = {
|
||||
"webhook-server": {},
|
||||
"salesforce": {"simple_salesforce~=1.11.4"},
|
||||
"okta": {"okta~=2.3.0"},
|
||||
"mlflow": {"mlflow-skinny~=1.22.0"},
|
||||
"mlflow": {"mlflow-skinny~=1.26.1"},
|
||||
"sklearn": {"scikit-learn==1.0.2"},
|
||||
"db2": {"ibm-db-sa==0.3.7"},
|
||||
"clickhouse": {"clickhouse-driver==0.2.3", "clickhouse-sqlalchemy==0.2.0"},
|
||||
|
||||
@ -326,7 +326,6 @@ class CommonDbSourceService(
|
||||
# Disable the DictConfigurator.configure method while importing LineageRunner
|
||||
configure = DictConfigurator.configure
|
||||
DictConfigurator.configure = lambda _: None
|
||||
from sqllineage.exceptions import SQLLineageException
|
||||
from sqllineage.runner import LineageRunner
|
||||
|
||||
# Reverting changes after import is done
|
||||
|
||||
@ -40,6 +40,7 @@ from metadata.generated.schema.metadataIngestion.workflow import (
|
||||
from metadata.generated.schema.type.entityReference import EntityReference
|
||||
from metadata.ingestion.api.source import InvalidSourceException, Source, SourceStatus
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
from metadata.utils.connections import get_connection
|
||||
from metadata.utils.logger import ingestion_logger
|
||||
|
||||
logger = ingestion_logger()
|
||||
@ -92,10 +93,10 @@ class MlflowSource(Source[CreateMlModelRequest]):
|
||||
|
||||
self.metadata = OpenMetadata(metadata_config)
|
||||
|
||||
self.client = MlflowClient(
|
||||
tracking_uri=self.service_connection.trackingUri,
|
||||
registry_uri=self.service_connection.registryUri,
|
||||
)
|
||||
self.connection = get_connection(self.service_connection)
|
||||
self.test_connection()
|
||||
self.client = self.connection.client
|
||||
|
||||
self.status = MlFlowStatus()
|
||||
self.service = self.metadata.get_service_or_create(
|
||||
entity=MlModelService, config=config
|
||||
|
||||
@ -109,3 +109,9 @@ class AirByteClient:
|
||||
class ModeClient:
|
||||
def __init__(self, client) -> None:
|
||||
self.client = client
|
||||
|
||||
|
||||
@dataclass
|
||||
class MlflowClientWrapper:
|
||||
def __init__(self, client) -> None:
|
||||
self.client = client
|
||||
|
||||
@ -80,6 +80,9 @@ from metadata.generated.schema.entity.services.connections.database.snowflakeCon
|
||||
from metadata.generated.schema.entity.services.connections.messaging.kafkaConnection import (
|
||||
KafkaConnection,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.mlmodel.mlflowConnection import (
|
||||
MlflowConnection,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.pipeline.airbyteConnection import (
|
||||
AirbyteConnection,
|
||||
)
|
||||
@ -103,6 +106,7 @@ from metadata.utils.connection_clients import (
|
||||
KafkaClient,
|
||||
LookerClient,
|
||||
MetabaseClient,
|
||||
MlflowClientWrapper,
|
||||
ModeClient,
|
||||
PowerBiClient,
|
||||
RedashClient,
|
||||
@ -754,6 +758,28 @@ def _(connection: ModeClient) -> None:
|
||||
)
|
||||
|
||||
|
||||
@get_connection.register
|
||||
def _(connection: MlflowConnection, verbose: bool = False):
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
return MlflowClientWrapper(
|
||||
MlflowClient(
|
||||
tracking_uri=connection.trackingUri,
|
||||
registry_uri=connection.registryUri,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@test_connection.register
|
||||
def _(connection: MlflowClientWrapper) -> None:
|
||||
try:
|
||||
connection.client.list_registered_models()
|
||||
except Exception as err:
|
||||
raise SourceConnectionException(
|
||||
f"Unknown error connecting with {connection} - {err}."
|
||||
)
|
||||
|
||||
|
||||
@get_connection.register
|
||||
def _(_: BackendConnection, verbose: bool = False):
|
||||
"""
|
||||
|
||||
@ -21,6 +21,7 @@ from airflow import DAG
|
||||
from metadata.generated.schema.entity.services.dashboardService import DashboardService
|
||||
from metadata.generated.schema.entity.services.databaseService import DatabaseService
|
||||
from metadata.generated.schema.entity.services.messagingService import MessagingService
|
||||
from metadata.generated.schema.entity.services.mlmodelService import MlModelService
|
||||
from metadata.generated.schema.entity.services.pipelineService import PipelineService
|
||||
from metadata.generated.schema.type import basic
|
||||
from metadata.ingestion.models.encoders import show_secrets_encoder
|
||||
@ -79,6 +80,10 @@ def build_source(ingestion_pipeline: IngestionPipeline) -> WorkflowSource:
|
||||
service: MessagingService = metadata.get_by_name(
|
||||
entity=MessagingService, fqn=ingestion_pipeline.service.name
|
||||
)
|
||||
elif service_type == "mlmodelService":
|
||||
service: MlModelService = metadata.get_by_name(
|
||||
entity=MlModelService, fqn=ingestion_pipeline.service.name
|
||||
)
|
||||
|
||||
if not service:
|
||||
raise ValueError(f"Could not get service from type {service_type}")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user