253 lines
8.4 KiB
Python
Raw Permalink Normal View History

from __future__ import annotations
import re
from datetime import datetime, timezone
from unittest import mock
import pytest
from datahub.errors import ItemNotFoundError
from datahub.metadata.schema_classes import (
MLHyperParamClass,
MLMetricClass,
)
from datahub.metadata.urns import (
DataPlatformUrn,
DataProcessInstanceUrn,
MlModelGroupUrn,
MlModelUrn,
)
from datahub.sdk.mlmodel import MLModel
from datahub.utilities.urns.error import InvalidUrnError
def test_mlmodel() -> None:
"""Test MLModel functionality with all essential features."""
# Create model with basic properties
model = MLModel(
id="test_model",
platform="mlflow",
name="test_model",
)
# Test basic properties
assert model.urn == MlModelUrn("mlflow", "test_model")
assert model.name == "test_model"
assert model.platform == DataPlatformUrn("urn:li:dataPlatform:mlflow")
# Test description and URL
model.set_description("A test model")
assert model.description == "A test model"
model.set_external_url("https://example.com/model")
assert model.external_url == "https://example.com/model"
# Test dates
test_date = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
model.set_created(test_date)
model.set_last_modified(test_date)
assert model.created == test_date
assert model.last_modified == test_date
# Test version and aliases
model.set_version("1.0.0")
assert model.version == "1.0.0"
model.add_version_alias("stable")
model.add_version_alias("production")
aliases = model.version_aliases
assert aliases is not None
assert "stable" in aliases
assert "production" in aliases
model.remove_version_alias("stable")
assert "stable" not in model.version_aliases
# Test metrics - both individual and bulk operations
model.set_training_metrics(
{
"precision": "0.92",
"recall": "0.88",
}
)
model.add_training_metrics([MLMetricClass(name="f1_score", value="0.90")])
model.add_training_metrics([MLMetricClass(name="accuracy", value="0.95")])
metrics = model.training_metrics
assert metrics is not None
assert len(metrics) == 4
metric_values = {
m.name: m.value for m in metrics if hasattr(m, "name") and hasattr(m, "value")
}
assert metric_values["precision"] == "0.92"
assert metric_values["accuracy"] == "0.95"
# Test hyperparameters
model.set_hyper_params({"learning_rate": "0.001", "num_layers": "3"})
model.add_hyper_params({"batch_size": "32"})
model.add_hyper_params([MLHyperParamClass(name="num_epochs", value="10")])
params = model.hyper_params
assert params is not None
assert len(params) == 4
param_values = {
p.name: p.value for p in params if hasattr(p, "name") and hasattr(p, "value")
}
assert param_values["learning_rate"] == "0.001"
assert param_values["num_epochs"] == "10"
# Test custom properties
model.set_custom_properties(
{
"framework": "pytorch",
"task": "classification",
}
)
assert model.custom_properties == {
"framework": "pytorch",
"task": "classification",
}
# Test relationships
# Model group
group_urn = MlModelGroupUrn("mlflow", "test_group")
model.set_model_group(group_urn)
assert model.model_group is not None
assert str(group_urn) == model.model_group
# Training and downstream jobs
job1 = DataProcessInstanceUrn("job1")
job2 = DataProcessInstanceUrn("job2")
# Add and remove jobs
model.add_training_job(job1)
assert model.training_jobs is not None
assert str(job1) in model.training_jobs
model.remove_training_job(job1)
assert model.training_jobs is not None
assert len(model.training_jobs) == 0
# Test bulk job operations
model.set_training_jobs([job1, job2])
model.set_downstream_jobs([job1, job2])
assert model.training_jobs is not None
assert model.downstream_jobs is not None
assert len(model.training_jobs) == 2
assert len(model.downstream_jobs) == 2
def test_mlmodel_complex_initialization() -> None:
"""Test MLModel with initialization of all properties at once."""
test_date = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
model = MLModel(
id="complex_model",
platform="mlflow",
name="Complex Model",
description="A model with all properties",
external_url="https://example.com/model",
version="2.0.0",
created=test_date,
last_modified=test_date,
training_metrics=[
MLMetricClass(name="accuracy", value="0.95"),
MLMetricClass(name="loss", value="0.1"),
],
hyper_params=[
MLHyperParamClass(name="learning_rate", value="0.001"),
MLHyperParamClass(name="batch_size", value="32"),
],
custom_properties={
"framework": "pytorch",
"task": "classification",
},
model_group=MlModelGroupUrn("mlflow", "test_group"),
training_jobs=[DataProcessInstanceUrn("training_job")],
downstream_jobs=[DataProcessInstanceUrn("inference_job")],
)
# Verify properties
assert model.name == "Complex Model"
assert model.description == "A model with all properties"
assert model.version == "2.0.0"
assert model.created == test_date
# Verify collections
assert model.training_metrics is not None and len(model.training_metrics) == 2
assert model.hyper_params is not None and len(model.hyper_params) == 2
assert model.custom_properties is not None
assert model.training_jobs is not None and len(model.training_jobs) == 1
assert model.downstream_jobs is not None and len(model.downstream_jobs) == 1
def test_mlmodel_validation() -> None:
"""Test MLModel validation errors."""
# Test invalid platform
with pytest.raises(InvalidUrnError):
MLModel(id="test", platform="")
# Test invalid ID
with pytest.raises(InvalidUrnError):
MLModel(id="", platform="test_platform")
def test_client_get_mlmodel() -> None:
"""Test retrieving MLModels using client.entities.get()."""
# Set up mock
mock_client = mock.MagicMock()
mock_entities = mock.MagicMock()
mock_client.entities = mock_entities
# Basic retrieval
model_urn = MlModelUrn("mlflow", "test_model", "PROD")
expected_model = MLModel(
id="test_model",
platform="mlflow",
name="Test Model",
description="A test model",
)
mock_entities.get.return_value = expected_model
result = mock_client.entities.get(model_urn)
assert result == expected_model
mock_entities.get.assert_called_once_with(model_urn)
mock_entities.get.reset_mock()
# String URN
urn_str = "urn:li:mlModel:(urn:li:dataPlatform:mlflow,string_model,PROD)"
mock_entities.get.return_value = MLModel(id="string_model", platform="mlflow")
result = mock_client.entities.get(urn_str)
mock_entities.get.assert_called_once_with(urn_str)
mock_entities.get.reset_mock()
# Complex model with properties
complex_model = MLModel(
id="complex_model",
platform="mlflow",
name="Complex Model",
description="Complex test model",
training_metrics=[MLMetricClass(name="accuracy", value="0.95")],
hyper_params=[MLHyperParamClass(name="learning_rate", value="0.001")],
)
test_date = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
complex_model.set_created(test_date)
complex_model.set_version("1.0.0")
# Set relationships
group_urn = MlModelGroupUrn("mlflow", "test_group")
complex_model.set_model_group(group_urn)
complex_model.set_training_jobs([DataProcessInstanceUrn("job1")])
model_urn = MlModelUrn("mlflow", "complex_model", "PROD")
mock_entities.get.return_value = complex_model
result = mock_client.entities.get(model_urn)
assert result.name == "Complex Model"
assert result.version == "1.0.0"
assert result.created == test_date
assert result.training_metrics is not None
assert result.model_group is not None
mock_entities.get.assert_called_once_with(model_urn)
mock_entities.get.reset_mock()
# Not found case
error_message = f"Entity {model_urn} not found"
mock_entities.get.side_effect = ItemNotFoundError(error_message)
with pytest.raises(ItemNotFoundError, match=re.escape(error_message)):
mock_client.entities.get(model_urn)