mirror of
https://github.com/datahub-project/datahub.git
synced 2025-07-04 23:57:03 +00:00
446 lines
13 KiB
Python
446 lines
13 KiB
Python
import datetime
|
|
from pathlib import Path
|
|
from typing import Any, Union
|
|
|
|
import pytest
|
|
from mlflow import MlflowClient
|
|
from mlflow.entities import (
|
|
Dataset as MLflowDataset,
|
|
DatasetInput,
|
|
InputTag,
|
|
Run,
|
|
RunData,
|
|
RunInfo,
|
|
RunInputs,
|
|
)
|
|
from mlflow.entities.model_registry import RegisteredModel
|
|
from mlflow.entities.model_registry.model_version import ModelVersion
|
|
from mlflow.store.entities import PagedList
|
|
|
|
from datahub.ingestion.api.common import PipelineContext
|
|
from datahub.ingestion.source.mlflow import MLflowConfig, MLflowSource
|
|
from datahub.metadata.schema_classes import UpstreamLineageClass
|
|
|
|
|
|
@pytest.fixture
|
|
def tracking_uri(tmp_path: Path) -> str:
|
|
return str(tmp_path / "mlruns")
|
|
|
|
|
|
@pytest.fixture
|
|
def source(tracking_uri: str) -> MLflowSource:
|
|
return MLflowSource(
|
|
ctx=PipelineContext(run_id="mlflow-source-test"),
|
|
config=MLflowConfig(tracking_uri=tracking_uri),
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def registered_model(source: MLflowSource) -> RegisteredModel:
|
|
model_name = "abc"
|
|
return RegisteredModel(name=model_name)
|
|
|
|
|
|
@pytest.fixture
|
|
def model_version(
|
|
source: MLflowSource,
|
|
registered_model: RegisteredModel,
|
|
) -> ModelVersion:
|
|
version = "1"
|
|
return ModelVersion(
|
|
name=registered_model.name,
|
|
version=version,
|
|
creation_timestamp=datetime.datetime.now(),
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mlflow_dataset() -> MLflowDataset:
|
|
return MLflowDataset(
|
|
name="test_dataset",
|
|
digest="test_digest",
|
|
source="test_source",
|
|
source_type="snowflake",
|
|
schema='{"mlflow_colspec":[{"name":"col1", "type":"string"}]}',
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mlflow_local_dataset() -> MLflowDataset:
|
|
return MLflowDataset(
|
|
name="test_local_dataset",
|
|
digest="test_digest",
|
|
source="test_source",
|
|
source_type="local",
|
|
schema='{"mlflow_colspec":[{"name":"col1", "type":"string"}]}',
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mlflow_unsupported_dataset() -> MLflowDataset:
|
|
return MLflowDataset(
|
|
name="test_unsupported_dataset",
|
|
digest="test_digest",
|
|
source="test_source",
|
|
source_type="unsupported_platform",
|
|
schema='{"mlflow_colspec":[{"name":"col1", "type":"string"}]}',
|
|
)
|
|
|
|
|
|
def dummy_search_func(page_token: Union[None, str], **kwargs: Any) -> PagedList[str]:
|
|
dummy_pages = dict(
|
|
page_1=PagedList(items=["a", "b"], token="page_2"),
|
|
page_2=PagedList(items=["c", "d"], token="page_3"),
|
|
page_3=PagedList(items=["e"], token=None),
|
|
)
|
|
if page_token is None:
|
|
page_to_return = dummy_pages["page_1"]
|
|
else:
|
|
page_to_return = dummy_pages[page_token]
|
|
if kwargs.get("case", "") == "upper":
|
|
page_to_return = PagedList(
|
|
items=[e.upper() for e in page_to_return.to_list()],
|
|
token=page_to_return.token,
|
|
)
|
|
return page_to_return
|
|
|
|
|
|
def test_stages(source):
|
|
mlflow_registered_model_stages = {
|
|
"Production",
|
|
"Staging",
|
|
"Archived",
|
|
None,
|
|
}
|
|
workunits = source._get_tags_workunits()
|
|
names = [wu.metadata.aspect.name for wu in workunits]
|
|
|
|
assert len(names) == len(mlflow_registered_model_stages)
|
|
assert set(names) == {
|
|
"mlflow_" + str(stage).lower() for stage in mlflow_registered_model_stages
|
|
}
|
|
|
|
|
|
def test_config_model_name_separator(source, model_version):
|
|
name_version_sep = "+"
|
|
source.config.model_name_separator = name_version_sep
|
|
expected_model_name = (
|
|
f"{model_version.name}{name_version_sep}{model_version.version}"
|
|
)
|
|
expected_urn = f"urn:li:mlModel:(urn:li:dataPlatform:mlflow,{expected_model_name},{source.config.env})"
|
|
|
|
urn = source._make_ml_model_urn(model_version)
|
|
|
|
assert urn == expected_urn
|
|
|
|
|
|
def test_model_without_run(source, registered_model, model_version):
|
|
run = source._get_mlflow_run(model_version)
|
|
wu = source._get_ml_model_properties_workunit(
|
|
registered_model=registered_model,
|
|
model_version=model_version,
|
|
run=run,
|
|
)
|
|
aspect = wu.metadata.aspect
|
|
|
|
assert aspect.hyperParams is None
|
|
assert aspect.trainingMetrics is None
|
|
|
|
|
|
def test_traverse_mlflow_search_func(source):
|
|
expected_items = ["a", "b", "c", "d", "e"]
|
|
|
|
items = list(source._traverse_mlflow_search_func(dummy_search_func))
|
|
|
|
assert items == expected_items
|
|
|
|
|
|
def test_traverse_mlflow_search_func_with_kwargs(source):
|
|
expected_items = ["A", "B", "C", "D", "E"]
|
|
|
|
items = list(source._traverse_mlflow_search_func(dummy_search_func, case="upper"))
|
|
|
|
assert items == expected_items
|
|
|
|
|
|
def test_make_external_link_local(source, model_version):
|
|
expected_url = None
|
|
|
|
url = source._make_external_url(model_version)
|
|
|
|
assert url == expected_url
|
|
|
|
|
|
def test_make_external_link_remote(source, model_version):
|
|
tracking_uri_remote = "https://dummy-mlflow-tracking-server.org"
|
|
source.client = MlflowClient(tracking_uri=tracking_uri_remote)
|
|
expected_url = f"{tracking_uri_remote}/#/models/{model_version.name}/versions/{model_version.version}"
|
|
|
|
url = source._make_external_url(model_version)
|
|
|
|
assert url == expected_url
|
|
|
|
|
|
def test_make_external_link_remote_via_config(source, model_version):
|
|
custom_base_url = "https://custom-server.org"
|
|
source.config.base_external_url = custom_base_url
|
|
source.client = MlflowClient(
|
|
tracking_uri="https://dummy-mlflow-tracking-server.org"
|
|
)
|
|
expected_url = f"{custom_base_url}/#/models/{model_version.name}/versions/{model_version.version}"
|
|
|
|
url = source._make_external_url(model_version)
|
|
|
|
assert url == expected_url
|
|
|
|
|
|
def test_local_dataset_reference_creation(source, mlflow_local_dataset):
|
|
"""Test that local dataset reference is always created for local source type"""
|
|
run = Run(
|
|
RunInfo(
|
|
run_id="test_run",
|
|
run_uuid="test_uuid",
|
|
experiment_id="test_exp",
|
|
user_id="test_user",
|
|
status="FINISHED",
|
|
start_time=1234567890,
|
|
end_time=1234567891,
|
|
artifact_uri="s3://test-bucket/test",
|
|
lifecycle_stage="active",
|
|
),
|
|
RunData(metrics={}, params={}, tags=[]),
|
|
RunInputs(
|
|
dataset_inputs=[
|
|
DatasetInput(
|
|
dataset=mlflow_local_dataset,
|
|
tags=[InputTag(key="key", value="value")],
|
|
)
|
|
]
|
|
),
|
|
)
|
|
workunits = list(source._get_dataset_input_workunits(run))
|
|
|
|
# Should create local dataset reference
|
|
assert len(workunits) > 0
|
|
assert any(
|
|
"urn:li:dataset:(urn:li:dataPlatform:mlflow" in wu.id for wu in workunits
|
|
)
|
|
|
|
|
|
def test_materialization_disabled_with_supported_platform(source, mlflow_dataset):
|
|
"""
|
|
Test when materialize_dataset_inputs=False and platform is supported:
|
|
- Should create dataset reference
|
|
- Should not create hosted dataset
|
|
- Should try to link with existing upstream if exists
|
|
"""
|
|
source.config.materialize_dataset_inputs = False
|
|
|
|
# Mock the graph to simulate existing dataset
|
|
class MockGraph:
|
|
def get_aspect_v2(self, *args, **kwargs):
|
|
return {} # Return non-None to simulate existing dataset
|
|
|
|
source.ctx.graph = MockGraph()
|
|
|
|
run = Run(
|
|
RunInfo(
|
|
run_id="test_run",
|
|
run_uuid="test_uuid",
|
|
experiment_id="test_exp",
|
|
user_id="test_user",
|
|
status="FINISHED",
|
|
start_time=1234567890,
|
|
end_time=1234567891,
|
|
artifact_uri="s3://test-bucket/test",
|
|
lifecycle_stage="active",
|
|
),
|
|
RunData(metrics={}, params={}, tags=[]),
|
|
RunInputs(
|
|
dataset_inputs=[
|
|
DatasetInput(
|
|
dataset=mlflow_dataset,
|
|
tags=[InputTag(key="key", value="value")],
|
|
)
|
|
]
|
|
),
|
|
)
|
|
workunits = list(source._get_dataset_input_workunits(run))
|
|
|
|
# Should create dataset reference
|
|
assert len(workunits) > 0
|
|
# Should have upstream lineage
|
|
assert any(isinstance(wu.metadata.aspect, UpstreamLineageClass) for wu in workunits)
|
|
# Should not create hosted dataset (no snowflake platform URN)
|
|
assert not any(
|
|
"urn:li:dataset:(urn:li:dataPlatform:snowflake" in wu.id for wu in workunits
|
|
)
|
|
|
|
|
|
def test_materialization_disabled_with_unsupported_platform(
|
|
source, mlflow_unsupported_dataset
|
|
):
|
|
"""
|
|
Test when materialize_dataset_inputs=False and platform is not supported:
|
|
- Should create dataset reference
|
|
- Should not create hosted dataset
|
|
- Should not try to link upstream
|
|
"""
|
|
source.config.materialize_dataset_inputs = False
|
|
|
|
run = Run(
|
|
RunInfo(
|
|
run_id="test_run",
|
|
run_uuid="test_uuid",
|
|
experiment_id="test_exp",
|
|
user_id="test_user",
|
|
status="FINISHED",
|
|
start_time=1234567890,
|
|
end_time=1234567891,
|
|
artifact_uri="s3://test-bucket/test",
|
|
lifecycle_stage="active",
|
|
),
|
|
RunData(metrics={}, params={}, tags=[]),
|
|
RunInputs(
|
|
dataset_inputs=[
|
|
DatasetInput(
|
|
dataset=mlflow_unsupported_dataset,
|
|
tags=[InputTag(key="key", value="value")],
|
|
)
|
|
]
|
|
),
|
|
)
|
|
workunits = list(source._get_dataset_input_workunits(run))
|
|
|
|
# Should create dataset reference
|
|
assert len(workunits) > 0
|
|
# Should not have upstream lineage
|
|
assert not any(
|
|
isinstance(wu.metadata.aspect, UpstreamLineageClass) for wu in workunits
|
|
)
|
|
|
|
|
|
def test_materialization_enabled_with_supported_platform(source, mlflow_dataset):
|
|
"""
|
|
Test when materialize_dataset_inputs=True and platform is supported:
|
|
- Should create dataset reference
|
|
- Should create hosted dataset
|
|
"""
|
|
source.config.materialize_dataset_inputs = True
|
|
source.config.source_mapping_to_platform = {"snowflake": "snowflake"}
|
|
|
|
run = Run(
|
|
RunInfo(
|
|
run_id="test_run",
|
|
run_uuid="test_uuid",
|
|
experiment_id="test_exp",
|
|
user_id="test_user",
|
|
status="FINISHED",
|
|
start_time=1234567890,
|
|
end_time=1234567891,
|
|
artifact_uri="s3://test-bucket/test",
|
|
lifecycle_stage="active",
|
|
),
|
|
RunData(metrics={}, params={}, tags=[]),
|
|
RunInputs(
|
|
dataset_inputs=[
|
|
DatasetInput(
|
|
dataset=mlflow_dataset,
|
|
tags=[InputTag(key="key", value="value")],
|
|
)
|
|
]
|
|
),
|
|
)
|
|
workunits = list(source._get_dataset_input_workunits(run))
|
|
|
|
# Should create both dataset reference and hosted dataset
|
|
assert len(workunits) > 0
|
|
assert any(
|
|
"urn:li:dataset:(urn:li:dataPlatform:snowflake" in wu.id for wu in workunits
|
|
)
|
|
|
|
|
|
def test_materialization_enabled_with_unsupported_platform(
|
|
source, mlflow_unsupported_dataset
|
|
):
|
|
"""
|
|
Test when materialize_dataset_inputs=True and platform is not supported:
|
|
- Should report error about missing platform mapping
|
|
"""
|
|
source.config.materialize_dataset_inputs = True
|
|
source.config.source_mapping_to_platform = {}
|
|
|
|
run = Run(
|
|
RunInfo(
|
|
run_id="test_run",
|
|
run_uuid="test_uuid",
|
|
experiment_id="test_exp",
|
|
user_id="test_user",
|
|
status="FINISHED",
|
|
start_time=1234567890,
|
|
end_time=1234567891,
|
|
artifact_uri="s3://test-bucket/test",
|
|
lifecycle_stage="active",
|
|
),
|
|
RunData(metrics={}, params={}, tags=[]),
|
|
RunInputs(
|
|
dataset_inputs=[
|
|
DatasetInput(
|
|
dataset=mlflow_unsupported_dataset,
|
|
tags=[InputTag(key="key", value="value")],
|
|
)
|
|
]
|
|
),
|
|
)
|
|
workunits = list(source._get_dataset_input_workunits(run))
|
|
|
|
# Should not create any workunits due to error
|
|
assert len(workunits) == 0
|
|
# Should report error about missing platform mapping
|
|
assert any(
|
|
"No mapping dataPlatform found" in report.message
|
|
for report in source.report.failures
|
|
)
|
|
|
|
|
|
def test_materialization_enabled_with_custom_mapping(
|
|
source, mlflow_unsupported_dataset
|
|
):
|
|
"""
|
|
Test when materialize_dataset_inputs=True with custom platform mapping:
|
|
- Should create dataset reference
|
|
- Should create hosted dataset with mapped platform
|
|
"""
|
|
source.config.materialize_dataset_inputs = True
|
|
source.config.source_mapping_to_platform = {"unsupported_platform": "snowflake"}
|
|
|
|
run = Run(
|
|
RunInfo(
|
|
run_id="test_run",
|
|
run_uuid="test_uuid",
|
|
experiment_id="test_exp",
|
|
user_id="test_user",
|
|
status="FINISHED",
|
|
start_time=1234567890,
|
|
end_time=1234567891,
|
|
artifact_uri="s3://test-bucket/test",
|
|
lifecycle_stage="active",
|
|
),
|
|
RunData(metrics={}, params={}, tags=[]),
|
|
RunInputs(
|
|
dataset_inputs=[
|
|
DatasetInput(
|
|
dataset=mlflow_unsupported_dataset,
|
|
tags=[InputTag(key="key", value="value")],
|
|
)
|
|
]
|
|
),
|
|
)
|
|
workunits = list(source._get_dataset_input_workunits(run))
|
|
|
|
# Should create both dataset reference and hosted dataset
|
|
assert len(workunits) > 0
|
|
assert any(
|
|
"urn:li:dataset:(urn:li:dataPlatform:snowflake" in wu.id for wu in workunits
|
|
)
|