116 lines
2.9 KiB
Python
Raw Permalink Normal View History

import uuid
from pathlib import Path
from typing import Any, Dict, TypeVar
import pytest
from mlflow import MlflowClient
from datahub.ingestion.run.pipeline import Pipeline
from tests.test_helpers import mce_helpers
T = TypeVar("T")
@pytest.fixture
def tracking_uri(tmp_path: Path) -> str:
return str(tmp_path / "mlruns")
@pytest.fixture
def sink_file_path(tmp_path: Path) -> str:
return str(tmp_path / "mlflow_source_mcps.json")
@pytest.fixture
def pipeline_config(tracking_uri: str, sink_file_path: str) -> Dict[str, Any]:
source_type = "mlflow"
return {
"run_id": "mlflow-source-test",
"source": {
"type": source_type,
"config": {
"tracking_uri": tracking_uri,
},
},
"sink": {
"type": "file",
"config": {
"filename": sink_file_path,
},
},
}
@pytest.fixture
def generate_mlflow_data(tracking_uri: str, monkeypatch: pytest.MonkeyPatch) -> None:
test_uuid = "02660a3bee9941ed983667f678ce5611"
monkeypatch.setattr(uuid, "uuid4", lambda: uuid.UUID(test_uuid))
client = MlflowClient(tracking_uri=tracking_uri)
experiment_name = "test-experiment"
run_name = "test-run"
model_name = "test-model"
test_experiment_id = client.create_experiment(experiment_name)
test_run = client.create_run(
experiment_id=test_experiment_id,
run_name=run_name,
)
client.log_param(
run_id=test_run.info.run_id,
key="p",
value=1,
)
client.log_metric(
run_id=test_run.info.run_id,
key="m",
value=0.85,
)
client.create_registered_model(
name=model_name,
tags=dict(
model_id=1,
model_env="test",
),
description="This a test registered model",
)
client.create_model_version(
name=model_name,
source="dummy_dir/dummy_file",
run_id=test_run.info.run_id,
tags=dict(model_version_id=1),
)
client.transition_model_version_stage(
name=model_name,
version="1",
stage="Archived",
)
def test_ingestion(
pytestconfig,
mock_time,
sink_file_path,
pipeline_config,
generate_mlflow_data,
):
print(f"MCPs file path: {sink_file_path}")
golden_file_path = (
pytestconfig.rootpath / "tests/integration/mlflow/mlflow_mcps_golden.json"
)
ignore_paths = [
r"root\[\d+\]\['aspect'\]\['json'\]\['customProperties'\]\['artifacts_location'\]",
r"root\[\d+\]\['aspect'\]\['json'\]\['outputUrls'\]",
]
pipeline = Pipeline.create(pipeline_config)
pipeline.run()
pipeline.pretty_print_summary()
pipeline.raise_from_status()
mce_helpers.check_golden_file(
pytestconfig=pytestconfig,
output_path=sink_file_path,
golden_path=golden_file_path,
ignore_paths=ignore_paths,
)