datahub/metadata-ingestion/tests/unit/test_unity_catalog_source.py

332 lines
13 KiB
Python

from unittest.mock import patch
import pytest
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.report import EntityFilterReport
from datahub.ingestion.source.unity.config import UnityCatalogSourceConfig
from datahub.ingestion.source.unity.source import UnityCatalogSource
class TestUnityCatalogSource:
@pytest.fixture
def minimal_config(self):
"""Create a minimal config for testing."""
return UnityCatalogSourceConfig.parse_obj(
{
"token": "test_token",
"workspace_url": "https://test.databricks.com",
"warehouse_id": "test_warehouse",
"include_hive_metastore": False,
}
)
@pytest.fixture
def config_with_page_size(self):
"""Create a config with custom page size."""
return UnityCatalogSourceConfig.parse_obj(
{
"token": "test_token",
"workspace_url": "https://test.databricks.com",
"warehouse_id": "test_warehouse",
"include_hive_metastore": False,
"databricks_api_page_size": 75,
}
)
@pytest.fixture
def config_with_ml_model_settings(self):
"""Create a config with ML model settings."""
return UnityCatalogSourceConfig.parse_obj(
{
"token": "test_token",
"workspace_url": "https://test.databricks.com",
"warehouse_id": "test_warehouse",
"include_hive_metastore": False,
"include_ml_model_aliases": True,
"ml_model_max_results": 500,
"databricks_api_page_size": 100,
}
)
@patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy")
@patch("datahub.ingestion.source.unity.source.HiveMetastoreProxy")
def test_source_constructor_passes_default_page_size_to_proxy(
self, mock_hive_proxy, mock_unity_proxy, minimal_config
):
"""Test that UnityCatalogSource passes default databricks_api_page_size to proxy."""
# Create a mock context
ctx = PipelineContext(run_id="test_run")
source = UnityCatalogSource.create(minimal_config, ctx)
# Verify proxy was created with correct parameters including page size
mock_unity_proxy.assert_called_once_with(
minimal_config.workspace_url,
minimal_config.token,
minimal_config.warehouse_id,
report=source.report,
hive_metastore_proxy=source.hive_metastore_proxy,
lineage_data_source=minimal_config.lineage_data_source,
databricks_api_page_size=0, # Default value
)
@patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy")
@patch("datahub.ingestion.source.unity.source.HiveMetastoreProxy")
def test_source_constructor_passes_custom_page_size_to_proxy(
self, mock_hive_proxy, mock_unity_proxy, config_with_page_size
):
"""Test that UnityCatalogSource passes custom databricks_api_page_size to proxy."""
ctx = PipelineContext(run_id="test_run")
source = UnityCatalogSource.create(config_with_page_size, ctx)
# Verify proxy was created with correct parameters including custom page size
mock_unity_proxy.assert_called_once_with(
config_with_page_size.workspace_url,
config_with_page_size.token,
config_with_page_size.warehouse_id,
report=source.report,
hive_metastore_proxy=source.hive_metastore_proxy,
lineage_data_source=config_with_page_size.lineage_data_source,
databricks_api_page_size=75, # Custom value
)
@patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy")
@patch("datahub.ingestion.source.unity.source.HiveMetastoreProxy")
def test_source_config_page_size_available_to_source(
self, mock_hive_proxy, mock_unity_proxy, config_with_page_size
):
"""Test that UnityCatalogSource has access to databricks_api_page_size config."""
ctx = PipelineContext(run_id="test_run")
source = UnityCatalogSource.create(config_with_page_size, ctx)
# Verify the source has access to the configuration value
assert source.config.databricks_api_page_size == 75
@patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy")
@patch("datahub.ingestion.source.unity.source.HiveMetastoreProxy")
def test_source_with_hive_metastore_disabled(
self, mock_hive_proxy, mock_unity_proxy
):
"""Test that UnityCatalogSource works with hive metastore disabled."""
config = UnityCatalogSourceConfig.parse_obj(
{
"token": "test_token",
"workspace_url": "https://test.databricks.com",
"warehouse_id": "test_warehouse",
"include_hive_metastore": False,
"databricks_api_page_size": 200,
}
)
ctx = PipelineContext(run_id="test_run")
source = UnityCatalogSource.create(config, ctx)
# Verify proxy was created with correct page size even when hive metastore is disabled
mock_unity_proxy.assert_called_once_with(
config.workspace_url,
config.token,
config.warehouse_id,
report=source.report,
hive_metastore_proxy=None, # Should be None when disabled
lineage_data_source=config.lineage_data_source,
databricks_api_page_size=200,
)
def test_test_connection_with_page_size_config(self):
"""Test that test_connection properly handles databricks_api_page_size."""
config_dict = {
"token": "test_token",
"workspace_url": "https://test.databricks.com",
"warehouse_id": "test_warehouse",
"databricks_api_page_size": 300,
}
with patch(
"datahub.ingestion.source.unity.source.UnityCatalogConnectionTest"
) as mock_connection_test:
mock_connection_test.return_value.get_connection_test.return_value = (
"test_report"
)
result = UnityCatalogSource.test_connection(config_dict)
# Verify connection test was created with correct config
assert result == "test_report"
mock_connection_test.assert_called_once()
# Get the config that was passed to UnityCatalogConnectionTest
connection_test_config = mock_connection_test.call_args[0][0]
assert connection_test_config.databricks_api_page_size == 300
@patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy")
@patch("datahub.ingestion.source.unity.source.HiveMetastoreProxy")
def test_source_report_includes_ml_model_stats(
self, mock_hive_proxy, mock_unity_proxy
):
"""Test that source report properly tracks ML model statistics."""
from datahub.ingestion.api.common import PipelineContext
# Setup mocks
mock_unity_instance = mock_unity_proxy.return_value
mock_unity_instance.catalogs.return_value = []
mock_unity_instance.check_basic_connectivity.return_value = True
config = UnityCatalogSourceConfig.parse_obj(
{
"token": "test_token",
"workspace_url": "https://test.databricks.com",
"warehouse_id": "test_warehouse",
"include_hive_metastore": False,
"databricks_api_page_size": 200,
}
)
ctx = PipelineContext(run_id="test_run")
source = UnityCatalogSource.create(config, ctx)
# Verify report has proper ML model tracking attributes
assert hasattr(source.report, "ml_models")
assert hasattr(source.report, "ml_model_versions")
# Verify they are EntityFilterReport objects
assert isinstance(source.report.ml_models, EntityFilterReport)
assert isinstance(source.report.ml_model_versions, EntityFilterReport)
def test_test_connection_with_ml_model_configs(self):
"""Test that test_connection properly handles ML model configs."""
config_dict = {
"token": "test_token",
"workspace_url": "https://test.databricks.com",
"warehouse_id": "test_warehouse",
"include_ml_model_aliases": True,
"ml_model_max_results": 750,
"databricks_api_page_size": 200,
}
with patch(
"datahub.ingestion.source.unity.source.UnityCatalogConnectionTest"
) as mock_connection_test:
mock_connection_test.return_value.get_connection_test.return_value = (
"test_report"
)
result = UnityCatalogSource.test_connection(config_dict)
# Verify connection test was created with correct config
assert result == "test_report"
mock_connection_test.assert_called_once()
# Get the config that was passed to UnityCatalogConnectionTest
connection_test_config = mock_connection_test.call_args[0][0]
assert connection_test_config.include_ml_model_aliases is True
assert connection_test_config.ml_model_max_results == 750
assert connection_test_config.databricks_api_page_size == 200
@patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy")
@patch("datahub.ingestion.source.unity.source.HiveMetastoreProxy")
def test_process_ml_model_generates_workunits(
self, mock_hive_proxy, mock_unity_proxy
):
"""Test that process_ml_model generates proper workunits."""
from datetime import datetime
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.source.unity.proxy_types import (
Catalog,
Metastore,
Model,
ModelVersion,
Schema,
)
config = UnityCatalogSourceConfig.parse_obj(
{
"token": "test_token",
"workspace_url": "https://test.databricks.com",
"warehouse_id": "test_warehouse",
"include_hive_metastore": False,
}
)
ctx = PipelineContext(run_id="test_run")
source = UnityCatalogSource.create(config, ctx)
# Create test schema
metastore = Metastore(
id="metastore",
name="metastore",
comment=None,
global_metastore_id=None,
metastore_id=None,
owner=None,
region=None,
cloud=None,
)
catalog = Catalog(
id="test_catalog",
name="test_catalog",
metastore=metastore,
comment=None,
owner=None,
type=None,
)
schema = Schema(
id="test_catalog.test_schema",
name="test_schema",
catalog=catalog,
comment=None,
owner=None,
)
# Create test model
test_model = Model(
id="test_catalog.test_schema.test_model",
name="test_model",
description="Test description",
schema_name="test_schema",
catalog_name="test_catalog",
created_at=datetime(2023, 1, 1),
updated_at=datetime(2023, 1, 2),
)
# Create test model version
test_model_version = ModelVersion(
id="test_catalog.test_schema.test_model_1",
name="test_model_1",
model=test_model,
version="1",
aliases=["prod"],
description="Version 1",
created_at=datetime(2023, 1, 3),
updated_at=datetime(2023, 1, 4),
created_by="test_user",
)
# Process the model
ml_model_workunits = list(source.process_ml_model(test_model, schema))
# Should generate workunits (MLModelGroup creation and container assignment)
assert len(ml_model_workunits) > 0
assert len(source.report.ml_models.processed_entities) == 1
assert (
source.report.ml_models.processed_entities[0][1]
== "test_catalog.test_schema.test_model"
)
# Process the model version
model_urn = source.gen_ml_model_urn(test_model.id)
ml_model_version_workunits = list(
source.process_ml_model_version(model_urn, test_model_version, schema)
)
# Should generate workunits (MLModel creation and container assignment)
assert len(ml_model_version_workunits) > 0
# Verify the report was updated
assert len(source.report.ml_model_versions.processed_entities) == 1
assert (
source.report.ml_model_versions.processed_entities[0][1]
== "test_catalog.test_schema.test_model_1"
)