mirror of
https://github.com/datahub-project/datahub.git
synced 2025-07-07 09:11:47 +00:00
253 lines
8.4 KiB
Python
253 lines
8.4 KiB
Python
![]() |
from __future__ import annotations
|
||
|
|
||
|
import re
|
||
|
from datetime import datetime, timezone
|
||
|
from unittest import mock
|
||
|
|
||
|
import pytest
|
||
|
|
||
|
from datahub.errors import ItemNotFoundError
|
||
|
from datahub.metadata.schema_classes import (
|
||
|
MLHyperParamClass,
|
||
|
MLMetricClass,
|
||
|
)
|
||
|
from datahub.metadata.urns import (
|
||
|
DataPlatformUrn,
|
||
|
DataProcessInstanceUrn,
|
||
|
MlModelGroupUrn,
|
||
|
MlModelUrn,
|
||
|
)
|
||
|
from datahub.sdk.mlmodel import MLModel
|
||
|
from datahub.utilities.urns.error import InvalidUrnError
|
||
|
|
||
|
|
||
|
def test_mlmodel() -> None:
|
||
|
"""Test MLModel functionality with all essential features."""
|
||
|
# Create model with basic properties
|
||
|
model = MLModel(
|
||
|
id="test_model",
|
||
|
platform="mlflow",
|
||
|
name="test_model",
|
||
|
)
|
||
|
|
||
|
# Test basic properties
|
||
|
assert model.urn == MlModelUrn("mlflow", "test_model")
|
||
|
assert model.name == "test_model"
|
||
|
assert model.platform == DataPlatformUrn("urn:li:dataPlatform:mlflow")
|
||
|
|
||
|
# Test description and URL
|
||
|
model.set_description("A test model")
|
||
|
assert model.description == "A test model"
|
||
|
model.set_external_url("https://example.com/model")
|
||
|
assert model.external_url == "https://example.com/model"
|
||
|
|
||
|
# Test dates
|
||
|
test_date = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||
|
model.set_created(test_date)
|
||
|
model.set_last_modified(test_date)
|
||
|
assert model.created == test_date
|
||
|
assert model.last_modified == test_date
|
||
|
|
||
|
# Test version and aliases
|
||
|
model.set_version("1.0.0")
|
||
|
assert model.version == "1.0.0"
|
||
|
model.add_version_alias("stable")
|
||
|
model.add_version_alias("production")
|
||
|
aliases = model.version_aliases
|
||
|
assert aliases is not None
|
||
|
assert "stable" in aliases
|
||
|
assert "production" in aliases
|
||
|
model.remove_version_alias("stable")
|
||
|
assert "stable" not in model.version_aliases
|
||
|
|
||
|
# Test metrics - both individual and bulk operations
|
||
|
model.set_training_metrics(
|
||
|
{
|
||
|
"precision": "0.92",
|
||
|
"recall": "0.88",
|
||
|
}
|
||
|
)
|
||
|
model.add_training_metrics([MLMetricClass(name="f1_score", value="0.90")])
|
||
|
model.add_training_metrics([MLMetricClass(name="accuracy", value="0.95")])
|
||
|
metrics = model.training_metrics
|
||
|
assert metrics is not None
|
||
|
assert len(metrics) == 4
|
||
|
metric_values = {
|
||
|
m.name: m.value for m in metrics if hasattr(m, "name") and hasattr(m, "value")
|
||
|
}
|
||
|
assert metric_values["precision"] == "0.92"
|
||
|
assert metric_values["accuracy"] == "0.95"
|
||
|
|
||
|
# Test hyperparameters
|
||
|
model.set_hyper_params({"learning_rate": "0.001", "num_layers": "3"})
|
||
|
model.add_hyper_params({"batch_size": "32"})
|
||
|
model.add_hyper_params([MLHyperParamClass(name="num_epochs", value="10")])
|
||
|
params = model.hyper_params
|
||
|
assert params is not None
|
||
|
assert len(params) == 4
|
||
|
param_values = {
|
||
|
p.name: p.value for p in params if hasattr(p, "name") and hasattr(p, "value")
|
||
|
}
|
||
|
assert param_values["learning_rate"] == "0.001"
|
||
|
assert param_values["num_epochs"] == "10"
|
||
|
|
||
|
# Test custom properties
|
||
|
model.set_custom_properties(
|
||
|
{
|
||
|
"framework": "pytorch",
|
||
|
"task": "classification",
|
||
|
}
|
||
|
)
|
||
|
assert model.custom_properties == {
|
||
|
"framework": "pytorch",
|
||
|
"task": "classification",
|
||
|
}
|
||
|
|
||
|
# Test relationships
|
||
|
# Model group
|
||
|
group_urn = MlModelGroupUrn("mlflow", "test_group")
|
||
|
model.set_model_group(group_urn)
|
||
|
assert model.model_group is not None
|
||
|
assert str(group_urn) == model.model_group
|
||
|
|
||
|
# Training and downstream jobs
|
||
|
job1 = DataProcessInstanceUrn("job1")
|
||
|
job2 = DataProcessInstanceUrn("job2")
|
||
|
|
||
|
# Add and remove jobs
|
||
|
model.add_training_job(job1)
|
||
|
assert model.training_jobs is not None
|
||
|
assert str(job1) in model.training_jobs
|
||
|
|
||
|
model.remove_training_job(job1)
|
||
|
assert model.training_jobs is not None
|
||
|
assert len(model.training_jobs) == 0
|
||
|
|
||
|
# Test bulk job operations
|
||
|
model.set_training_jobs([job1, job2])
|
||
|
model.set_downstream_jobs([job1, job2])
|
||
|
assert model.training_jobs is not None
|
||
|
assert model.downstream_jobs is not None
|
||
|
assert len(model.training_jobs) == 2
|
||
|
assert len(model.downstream_jobs) == 2
|
||
|
|
||
|
|
||
|
def test_mlmodel_complex_initialization() -> None:
|
||
|
"""Test MLModel with initialization of all properties at once."""
|
||
|
test_date = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||
|
model = MLModel(
|
||
|
id="complex_model",
|
||
|
platform="mlflow",
|
||
|
name="Complex Model",
|
||
|
description="A model with all properties",
|
||
|
external_url="https://example.com/model",
|
||
|
version="2.0.0",
|
||
|
created=test_date,
|
||
|
last_modified=test_date,
|
||
|
training_metrics=[
|
||
|
MLMetricClass(name="accuracy", value="0.95"),
|
||
|
MLMetricClass(name="loss", value="0.1"),
|
||
|
],
|
||
|
hyper_params=[
|
||
|
MLHyperParamClass(name="learning_rate", value="0.001"),
|
||
|
MLHyperParamClass(name="batch_size", value="32"),
|
||
|
],
|
||
|
custom_properties={
|
||
|
"framework": "pytorch",
|
||
|
"task": "classification",
|
||
|
},
|
||
|
model_group=MlModelGroupUrn("mlflow", "test_group"),
|
||
|
training_jobs=[DataProcessInstanceUrn("training_job")],
|
||
|
downstream_jobs=[DataProcessInstanceUrn("inference_job")],
|
||
|
)
|
||
|
|
||
|
# Verify properties
|
||
|
assert model.name == "Complex Model"
|
||
|
assert model.description == "A model with all properties"
|
||
|
assert model.version == "2.0.0"
|
||
|
assert model.created == test_date
|
||
|
|
||
|
# Verify collections
|
||
|
assert model.training_metrics is not None and len(model.training_metrics) == 2
|
||
|
assert model.hyper_params is not None and len(model.hyper_params) == 2
|
||
|
assert model.custom_properties is not None
|
||
|
assert model.training_jobs is not None and len(model.training_jobs) == 1
|
||
|
assert model.downstream_jobs is not None and len(model.downstream_jobs) == 1
|
||
|
|
||
|
|
||
|
def test_mlmodel_validation() -> None:
|
||
|
"""Test MLModel validation errors."""
|
||
|
# Test invalid platform
|
||
|
with pytest.raises(InvalidUrnError):
|
||
|
MLModel(id="test", platform="")
|
||
|
|
||
|
# Test invalid ID
|
||
|
with pytest.raises(InvalidUrnError):
|
||
|
MLModel(id="", platform="test_platform")
|
||
|
|
||
|
|
||
|
def test_client_get_mlmodel() -> None:
|
||
|
"""Test retrieving MLModels using client.entities.get()."""
|
||
|
# Set up mock
|
||
|
mock_client = mock.MagicMock()
|
||
|
mock_entities = mock.MagicMock()
|
||
|
mock_client.entities = mock_entities
|
||
|
|
||
|
# Basic retrieval
|
||
|
model_urn = MlModelUrn("mlflow", "test_model", "PROD")
|
||
|
expected_model = MLModel(
|
||
|
id="test_model",
|
||
|
platform="mlflow",
|
||
|
name="Test Model",
|
||
|
description="A test model",
|
||
|
)
|
||
|
mock_entities.get.return_value = expected_model
|
||
|
|
||
|
result = mock_client.entities.get(model_urn)
|
||
|
assert result == expected_model
|
||
|
mock_entities.get.assert_called_once_with(model_urn)
|
||
|
mock_entities.get.reset_mock()
|
||
|
|
||
|
# String URN
|
||
|
urn_str = "urn:li:mlModel:(urn:li:dataPlatform:mlflow,string_model,PROD)"
|
||
|
mock_entities.get.return_value = MLModel(id="string_model", platform="mlflow")
|
||
|
result = mock_client.entities.get(urn_str)
|
||
|
mock_entities.get.assert_called_once_with(urn_str)
|
||
|
mock_entities.get.reset_mock()
|
||
|
|
||
|
# Complex model with properties
|
||
|
complex_model = MLModel(
|
||
|
id="complex_model",
|
||
|
platform="mlflow",
|
||
|
name="Complex Model",
|
||
|
description="Complex test model",
|
||
|
training_metrics=[MLMetricClass(name="accuracy", value="0.95")],
|
||
|
hyper_params=[MLHyperParamClass(name="learning_rate", value="0.001")],
|
||
|
)
|
||
|
test_date = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||
|
complex_model.set_created(test_date)
|
||
|
complex_model.set_version("1.0.0")
|
||
|
|
||
|
# Set relationships
|
||
|
group_urn = MlModelGroupUrn("mlflow", "test_group")
|
||
|
complex_model.set_model_group(group_urn)
|
||
|
complex_model.set_training_jobs([DataProcessInstanceUrn("job1")])
|
||
|
|
||
|
model_urn = MlModelUrn("mlflow", "complex_model", "PROD")
|
||
|
mock_entities.get.return_value = complex_model
|
||
|
|
||
|
result = mock_client.entities.get(model_urn)
|
||
|
assert result.name == "Complex Model"
|
||
|
assert result.version == "1.0.0"
|
||
|
assert result.created == test_date
|
||
|
assert result.training_metrics is not None
|
||
|
assert result.model_group is not None
|
||
|
mock_entities.get.assert_called_once_with(model_urn)
|
||
|
mock_entities.get.reset_mock()
|
||
|
|
||
|
# Not found case
|
||
|
error_message = f"Entity {model_urn} not found"
|
||
|
mock_entities.get.side_effect = ItemNotFoundError(error_message)
|
||
|
with pytest.raises(ItemNotFoundError, match=re.escape(error_message)):
|
||
|
mock_client.entities.get(model_urn)
|