datahub/metadata-ingestion/tests/unit/test_unity_catalog_source.py

332 lines
13 KiB
Python
Raw Normal View History

from unittest.mock import patch
import pytest
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.report import EntityFilterReport
from datahub.ingestion.source.unity.config import UnityCatalogSourceConfig
from datahub.ingestion.source.unity.source import UnityCatalogSource
class TestUnityCatalogSource:
@pytest.fixture
def minimal_config(self):
"""Create a minimal config for testing."""
return UnityCatalogSourceConfig.parse_obj(
{
"token": "test_token",
"workspace_url": "https://test.databricks.com",
"warehouse_id": "test_warehouse",
"include_hive_metastore": False,
}
)
@pytest.fixture
def config_with_page_size(self):
"""Create a config with custom page size."""
return UnityCatalogSourceConfig.parse_obj(
{
"token": "test_token",
"workspace_url": "https://test.databricks.com",
"warehouse_id": "test_warehouse",
"include_hive_metastore": False,
"databricks_api_page_size": 75,
}
)
@pytest.fixture
def config_with_ml_model_settings(self):
"""Create a config with ML model settings."""
return UnityCatalogSourceConfig.parse_obj(
{
"token": "test_token",
"workspace_url": "https://test.databricks.com",
"warehouse_id": "test_warehouse",
"include_hive_metastore": False,
"include_ml_model_aliases": True,
"ml_model_max_results": 500,
"databricks_api_page_size": 100,
}
)
@patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy")
@patch("datahub.ingestion.source.unity.source.HiveMetastoreProxy")
def test_source_constructor_passes_default_page_size_to_proxy(
self, mock_hive_proxy, mock_unity_proxy, minimal_config
):
"""Test that UnityCatalogSource passes default databricks_api_page_size to proxy."""
# Create a mock context
ctx = PipelineContext(run_id="test_run")
source = UnityCatalogSource.create(minimal_config, ctx)
# Verify proxy was created with correct parameters including page size
mock_unity_proxy.assert_called_once_with(
minimal_config.workspace_url,
minimal_config.token,
minimal_config.warehouse_id,
report=source.report,
hive_metastore_proxy=source.hive_metastore_proxy,
lineage_data_source=minimal_config.lineage_data_source,
databricks_api_page_size=0, # Default value
)
@patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy")
@patch("datahub.ingestion.source.unity.source.HiveMetastoreProxy")
def test_source_constructor_passes_custom_page_size_to_proxy(
self, mock_hive_proxy, mock_unity_proxy, config_with_page_size
):
"""Test that UnityCatalogSource passes custom databricks_api_page_size to proxy."""
ctx = PipelineContext(run_id="test_run")
source = UnityCatalogSource.create(config_with_page_size, ctx)
# Verify proxy was created with correct parameters including custom page size
mock_unity_proxy.assert_called_once_with(
config_with_page_size.workspace_url,
config_with_page_size.token,
config_with_page_size.warehouse_id,
report=source.report,
hive_metastore_proxy=source.hive_metastore_proxy,
lineage_data_source=config_with_page_size.lineage_data_source,
databricks_api_page_size=75, # Custom value
)
@patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy")
@patch("datahub.ingestion.source.unity.source.HiveMetastoreProxy")
def test_source_config_page_size_available_to_source(
self, mock_hive_proxy, mock_unity_proxy, config_with_page_size
):
"""Test that UnityCatalogSource has access to databricks_api_page_size config."""
ctx = PipelineContext(run_id="test_run")
source = UnityCatalogSource.create(config_with_page_size, ctx)
# Verify the source has access to the configuration value
assert source.config.databricks_api_page_size == 75
@patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy")
@patch("datahub.ingestion.source.unity.source.HiveMetastoreProxy")
def test_source_with_hive_metastore_disabled(
self, mock_hive_proxy, mock_unity_proxy
):
"""Test that UnityCatalogSource works with hive metastore disabled."""
config = UnityCatalogSourceConfig.parse_obj(
{
"token": "test_token",
"workspace_url": "https://test.databricks.com",
"warehouse_id": "test_warehouse",
"include_hive_metastore": False,
"databricks_api_page_size": 200,
}
)
ctx = PipelineContext(run_id="test_run")
source = UnityCatalogSource.create(config, ctx)
# Verify proxy was created with correct page size even when hive metastore is disabled
mock_unity_proxy.assert_called_once_with(
config.workspace_url,
config.token,
config.warehouse_id,
report=source.report,
hive_metastore_proxy=None, # Should be None when disabled
lineage_data_source=config.lineage_data_source,
databricks_api_page_size=200,
)
def test_test_connection_with_page_size_config(self):
"""Test that test_connection properly handles databricks_api_page_size."""
config_dict = {
"token": "test_token",
"workspace_url": "https://test.databricks.com",
"warehouse_id": "test_warehouse",
"databricks_api_page_size": 300,
}
with patch(
"datahub.ingestion.source.unity.source.UnityCatalogConnectionTest"
) as mock_connection_test:
mock_connection_test.return_value.get_connection_test.return_value = (
"test_report"
)
result = UnityCatalogSource.test_connection(config_dict)
# Verify connection test was created with correct config
assert result == "test_report"
mock_connection_test.assert_called_once()
# Get the config that was passed to UnityCatalogConnectionTest
connection_test_config = mock_connection_test.call_args[0][0]
assert connection_test_config.databricks_api_page_size == 300
@patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy")
@patch("datahub.ingestion.source.unity.source.HiveMetastoreProxy")
def test_source_report_includes_ml_model_stats(
self, mock_hive_proxy, mock_unity_proxy
):
"""Test that source report properly tracks ML model statistics."""
from datahub.ingestion.api.common import PipelineContext
# Setup mocks
mock_unity_instance = mock_unity_proxy.return_value
mock_unity_instance.catalogs.return_value = []
mock_unity_instance.check_basic_connectivity.return_value = True
config = UnityCatalogSourceConfig.parse_obj(
{
"token": "test_token",
"workspace_url": "https://test.databricks.com",
"warehouse_id": "test_warehouse",
"include_hive_metastore": False,
"databricks_api_page_size": 200,
}
)
ctx = PipelineContext(run_id="test_run")
source = UnityCatalogSource.create(config, ctx)
# Verify report has proper ML model tracking attributes
assert hasattr(source.report, "ml_models")
assert hasattr(source.report, "ml_model_versions")
# Verify they are EntityFilterReport objects
assert isinstance(source.report.ml_models, EntityFilterReport)
assert isinstance(source.report.ml_model_versions, EntityFilterReport)
def test_test_connection_with_ml_model_configs(self):
"""Test that test_connection properly handles ML model configs."""
config_dict = {
"token": "test_token",
"workspace_url": "https://test.databricks.com",
"warehouse_id": "test_warehouse",
"include_ml_model_aliases": True,
"ml_model_max_results": 750,
"databricks_api_page_size": 200,
}
with patch(
"datahub.ingestion.source.unity.source.UnityCatalogConnectionTest"
) as mock_connection_test:
mock_connection_test.return_value.get_connection_test.return_value = (
"test_report"
)
result = UnityCatalogSource.test_connection(config_dict)
# Verify connection test was created with correct config
assert result == "test_report"
mock_connection_test.assert_called_once()
# Get the config that was passed to UnityCatalogConnectionTest
connection_test_config = mock_connection_test.call_args[0][0]
assert connection_test_config.include_ml_model_aliases is True
assert connection_test_config.ml_model_max_results == 750
assert connection_test_config.databricks_api_page_size == 200
@patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy")
@patch("datahub.ingestion.source.unity.source.HiveMetastoreProxy")
def test_process_ml_model_generates_workunits(
self, mock_hive_proxy, mock_unity_proxy
):
"""Test that process_ml_model generates proper workunits."""
from datetime import datetime
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.source.unity.proxy_types import (
Catalog,
Metastore,
Model,
ModelVersion,
Schema,
)
config = UnityCatalogSourceConfig.parse_obj(
{
"token": "test_token",
"workspace_url": "https://test.databricks.com",
"warehouse_id": "test_warehouse",
"include_hive_metastore": False,
}
)
ctx = PipelineContext(run_id="test_run")
source = UnityCatalogSource.create(config, ctx)
# Create test schema
metastore = Metastore(
id="metastore",
name="metastore",
comment=None,
global_metastore_id=None,
metastore_id=None,
owner=None,
region=None,
cloud=None,
)
catalog = Catalog(
id="test_catalog",
name="test_catalog",
metastore=metastore,
comment=None,
owner=None,
type=None,
)
schema = Schema(
id="test_catalog.test_schema",
name="test_schema",
catalog=catalog,
comment=None,
owner=None,
)
# Create test model
test_model = Model(
id="test_catalog.test_schema.test_model",
name="test_model",
description="Test description",
schema_name="test_schema",
catalog_name="test_catalog",
created_at=datetime(2023, 1, 1),
updated_at=datetime(2023, 1, 2),
)
# Create test model version
test_model_version = ModelVersion(
id="test_catalog.test_schema.test_model_1",
name="test_model_1",
model=test_model,
version="1",
aliases=["prod"],
description="Version 1",
created_at=datetime(2023, 1, 3),
updated_at=datetime(2023, 1, 4),
created_by="test_user",
)
# Process the model
ml_model_workunits = list(source.process_ml_model(test_model, schema))
# Should generate workunits (MLModelGroup creation and container assignment)
assert len(ml_model_workunits) > 0
assert len(source.report.ml_models.processed_entities) == 1
assert (
source.report.ml_models.processed_entities[0][1]
== "test_catalog.test_schema.test_model"
)
# Process the model version
model_urn = source.gen_ml_model_urn(test_model.id)
ml_model_version_workunits = list(
source.process_ml_model_version(model_urn, test_model_version, schema)
)
# Should generate workunits (MLModel creation and container assignment)
assert len(ml_model_version_workunits) > 0
# Verify the report was updated
assert len(source.report.ml_model_versions.processed_entities) == 1
assert (
source.report.ml_model_versions.processed_entities[0][1]
== "test_catalog.test_schema.test_model_1"
)