mirror of
https://github.com/datahub-project/datahub.git
synced 2025-07-07 17:23:11 +00:00
116 lines
2.9 KiB
Python
116 lines
2.9 KiB
Python
import uuid
|
|
from pathlib import Path
|
|
from typing import Any, Dict, TypeVar
|
|
|
|
import pytest
|
|
from mlflow import MlflowClient
|
|
|
|
from datahub.ingestion.run.pipeline import Pipeline
|
|
from tests.test_helpers import mce_helpers
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
@pytest.fixture
|
|
def tracking_uri(tmp_path: Path) -> str:
|
|
return str(tmp_path / "mlruns")
|
|
|
|
|
|
@pytest.fixture
|
|
def sink_file_path(tmp_path: Path) -> str:
|
|
return str(tmp_path / "mlflow_source_mcps.json")
|
|
|
|
|
|
@pytest.fixture
|
|
def pipeline_config(tracking_uri: str, sink_file_path: str) -> Dict[str, Any]:
|
|
source_type = "mlflow"
|
|
return {
|
|
"run_id": "mlflow-source-test",
|
|
"source": {
|
|
"type": source_type,
|
|
"config": {
|
|
"tracking_uri": tracking_uri,
|
|
},
|
|
},
|
|
"sink": {
|
|
"type": "file",
|
|
"config": {
|
|
"filename": sink_file_path,
|
|
},
|
|
},
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def generate_mlflow_data(tracking_uri: str, monkeypatch: pytest.MonkeyPatch) -> None:
|
|
test_uuid = "02660a3bee9941ed983667f678ce5611"
|
|
monkeypatch.setattr(uuid, "uuid4", lambda: uuid.UUID(test_uuid))
|
|
|
|
client = MlflowClient(tracking_uri=tracking_uri)
|
|
experiment_name = "test-experiment"
|
|
run_name = "test-run"
|
|
model_name = "test-model"
|
|
|
|
test_experiment_id = client.create_experiment(experiment_name)
|
|
test_run = client.create_run(
|
|
experiment_id=test_experiment_id,
|
|
run_name=run_name,
|
|
)
|
|
client.log_param(
|
|
run_id=test_run.info.run_id,
|
|
key="p",
|
|
value=1,
|
|
)
|
|
client.log_metric(
|
|
run_id=test_run.info.run_id,
|
|
key="m",
|
|
value=0.85,
|
|
)
|
|
client.create_registered_model(
|
|
name=model_name,
|
|
tags=dict(
|
|
model_id=1,
|
|
model_env="test",
|
|
),
|
|
description="This a test registered model",
|
|
)
|
|
client.create_model_version(
|
|
name=model_name,
|
|
source="dummy_dir/dummy_file",
|
|
run_id=test_run.info.run_id,
|
|
tags=dict(model_version_id=1),
|
|
)
|
|
client.transition_model_version_stage(
|
|
name=model_name,
|
|
version="1",
|
|
stage="Archived",
|
|
)
|
|
|
|
|
|
def test_ingestion(
|
|
pytestconfig,
|
|
mock_time,
|
|
sink_file_path,
|
|
pipeline_config,
|
|
generate_mlflow_data,
|
|
):
|
|
print(f"MCPs file path: {sink_file_path}")
|
|
golden_file_path = (
|
|
pytestconfig.rootpath / "tests/integration/mlflow/mlflow_mcps_golden.json"
|
|
)
|
|
ignore_paths = [
|
|
r"root\[\d+\]\['aspect'\]\['json'\]\['customProperties'\]\['artifacts_location'\]",
|
|
r"root\[\d+\]\['aspect'\]\['json'\]\['outputUrls'\]",
|
|
]
|
|
pipeline = Pipeline.create(pipeline_config)
|
|
pipeline.run()
|
|
pipeline.pretty_print_summary()
|
|
pipeline.raise_from_status()
|
|
|
|
mce_helpers.check_golden_file(
|
|
pytestconfig=pytestconfig,
|
|
output_path=sink_file_path,
|
|
golden_path=golden_file_path,
|
|
ignore_paths=ignore_paths,
|
|
)
|