mirror of
				https://github.com/datahub-project/datahub.git
				synced 2025-11-03 20:27:50 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			332 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			332 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from unittest.mock import patch
 | 
						|
 | 
						|
import pytest
 | 
						|
 | 
						|
from datahub.ingestion.api.common import PipelineContext
 | 
						|
from datahub.ingestion.api.report import EntityFilterReport
 | 
						|
from datahub.ingestion.source.unity.config import UnityCatalogSourceConfig
 | 
						|
from datahub.ingestion.source.unity.source import UnityCatalogSource
 | 
						|
 | 
						|
 | 
						|
class TestUnityCatalogSource:
 | 
						|
    @pytest.fixture
 | 
						|
    def minimal_config(self):
 | 
						|
        """Create a minimal config for testing."""
 | 
						|
        return UnityCatalogSourceConfig.parse_obj(
 | 
						|
            {
 | 
						|
                "token": "test_token",
 | 
						|
                "workspace_url": "https://test.databricks.com",
 | 
						|
                "warehouse_id": "test_warehouse",
 | 
						|
                "include_hive_metastore": False,
 | 
						|
            }
 | 
						|
        )
 | 
						|
 | 
						|
    @pytest.fixture
 | 
						|
    def config_with_page_size(self):
 | 
						|
        """Create a config with custom page size."""
 | 
						|
        return UnityCatalogSourceConfig.parse_obj(
 | 
						|
            {
 | 
						|
                "token": "test_token",
 | 
						|
                "workspace_url": "https://test.databricks.com",
 | 
						|
                "warehouse_id": "test_warehouse",
 | 
						|
                "include_hive_metastore": False,
 | 
						|
                "databricks_api_page_size": 75,
 | 
						|
            }
 | 
						|
        )
 | 
						|
 | 
						|
    @pytest.fixture
 | 
						|
    def config_with_ml_model_settings(self):
 | 
						|
        """Create a config with ML model settings."""
 | 
						|
        return UnityCatalogSourceConfig.parse_obj(
 | 
						|
            {
 | 
						|
                "token": "test_token",
 | 
						|
                "workspace_url": "https://test.databricks.com",
 | 
						|
                "warehouse_id": "test_warehouse",
 | 
						|
                "include_hive_metastore": False,
 | 
						|
                "include_ml_model_aliases": True,
 | 
						|
                "ml_model_max_results": 500,
 | 
						|
                "databricks_api_page_size": 100,
 | 
						|
            }
 | 
						|
        )
 | 
						|
 | 
						|
    @patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy")
 | 
						|
    @patch("datahub.ingestion.source.unity.source.HiveMetastoreProxy")
 | 
						|
    def test_source_constructor_passes_default_page_size_to_proxy(
 | 
						|
        self, mock_hive_proxy, mock_unity_proxy, minimal_config
 | 
						|
    ):
 | 
						|
        """Test that UnityCatalogSource passes default databricks_api_page_size to proxy."""
 | 
						|
        # Create a mock context
 | 
						|
        ctx = PipelineContext(run_id="test_run")
 | 
						|
        source = UnityCatalogSource.create(minimal_config, ctx)
 | 
						|
 | 
						|
        # Verify proxy was created with correct parameters including page size
 | 
						|
        mock_unity_proxy.assert_called_once_with(
 | 
						|
            minimal_config.workspace_url,
 | 
						|
            minimal_config.token,
 | 
						|
            minimal_config.warehouse_id,
 | 
						|
            report=source.report,
 | 
						|
            hive_metastore_proxy=source.hive_metastore_proxy,
 | 
						|
            lineage_data_source=minimal_config.lineage_data_source,
 | 
						|
            databricks_api_page_size=0,  # Default value
 | 
						|
        )
 | 
						|
 | 
						|
    @patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy")
 | 
						|
    @patch("datahub.ingestion.source.unity.source.HiveMetastoreProxy")
 | 
						|
    def test_source_constructor_passes_custom_page_size_to_proxy(
 | 
						|
        self, mock_hive_proxy, mock_unity_proxy, config_with_page_size
 | 
						|
    ):
 | 
						|
        """Test that UnityCatalogSource passes custom databricks_api_page_size to proxy."""
 | 
						|
        ctx = PipelineContext(run_id="test_run")
 | 
						|
        source = UnityCatalogSource.create(config_with_page_size, ctx)
 | 
						|
 | 
						|
        # Verify proxy was created with correct parameters including custom page size
 | 
						|
        mock_unity_proxy.assert_called_once_with(
 | 
						|
            config_with_page_size.workspace_url,
 | 
						|
            config_with_page_size.token,
 | 
						|
            config_with_page_size.warehouse_id,
 | 
						|
            report=source.report,
 | 
						|
            hive_metastore_proxy=source.hive_metastore_proxy,
 | 
						|
            lineage_data_source=config_with_page_size.lineage_data_source,
 | 
						|
            databricks_api_page_size=75,  # Custom value
 | 
						|
        )
 | 
						|
 | 
						|
    @patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy")
 | 
						|
    @patch("datahub.ingestion.source.unity.source.HiveMetastoreProxy")
 | 
						|
    def test_source_config_page_size_available_to_source(
 | 
						|
        self, mock_hive_proxy, mock_unity_proxy, config_with_page_size
 | 
						|
    ):
 | 
						|
        """Test that UnityCatalogSource has access to databricks_api_page_size config."""
 | 
						|
        ctx = PipelineContext(run_id="test_run")
 | 
						|
        source = UnityCatalogSource.create(config_with_page_size, ctx)
 | 
						|
 | 
						|
        # Verify the source has access to the configuration value
 | 
						|
        assert source.config.databricks_api_page_size == 75
 | 
						|
 | 
						|
    @patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy")
 | 
						|
    @patch("datahub.ingestion.source.unity.source.HiveMetastoreProxy")
 | 
						|
    def test_source_with_hive_metastore_disabled(
 | 
						|
        self, mock_hive_proxy, mock_unity_proxy
 | 
						|
    ):
 | 
						|
        """Test that UnityCatalogSource works with hive metastore disabled."""
 | 
						|
        config = UnityCatalogSourceConfig.parse_obj(
 | 
						|
            {
 | 
						|
                "token": "test_token",
 | 
						|
                "workspace_url": "https://test.databricks.com",
 | 
						|
                "warehouse_id": "test_warehouse",
 | 
						|
                "include_hive_metastore": False,
 | 
						|
                "databricks_api_page_size": 200,
 | 
						|
            }
 | 
						|
        )
 | 
						|
 | 
						|
        ctx = PipelineContext(run_id="test_run")
 | 
						|
        source = UnityCatalogSource.create(config, ctx)
 | 
						|
 | 
						|
        # Verify proxy was created with correct page size even when hive metastore is disabled
 | 
						|
        mock_unity_proxy.assert_called_once_with(
 | 
						|
            config.workspace_url,
 | 
						|
            config.token,
 | 
						|
            config.warehouse_id,
 | 
						|
            report=source.report,
 | 
						|
            hive_metastore_proxy=None,  # Should be None when disabled
 | 
						|
            lineage_data_source=config.lineage_data_source,
 | 
						|
            databricks_api_page_size=200,
 | 
						|
        )
 | 
						|
 | 
						|
    def test_test_connection_with_page_size_config(self):
 | 
						|
        """Test that test_connection properly handles databricks_api_page_size."""
 | 
						|
        config_dict = {
 | 
						|
            "token": "test_token",
 | 
						|
            "workspace_url": "https://test.databricks.com",
 | 
						|
            "warehouse_id": "test_warehouse",
 | 
						|
            "databricks_api_page_size": 300,
 | 
						|
        }
 | 
						|
 | 
						|
        with patch(
 | 
						|
            "datahub.ingestion.source.unity.source.UnityCatalogConnectionTest"
 | 
						|
        ) as mock_connection_test:
 | 
						|
            mock_connection_test.return_value.get_connection_test.return_value = (
 | 
						|
                "test_report"
 | 
						|
            )
 | 
						|
 | 
						|
            result = UnityCatalogSource.test_connection(config_dict)
 | 
						|
 | 
						|
            # Verify connection test was created with correct config
 | 
						|
            assert result == "test_report"
 | 
						|
            mock_connection_test.assert_called_once()
 | 
						|
 | 
						|
            # Get the config that was passed to UnityCatalogConnectionTest
 | 
						|
            connection_test_config = mock_connection_test.call_args[0][0]
 | 
						|
            assert connection_test_config.databricks_api_page_size == 300
 | 
						|
 | 
						|
    @patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy")
 | 
						|
    @patch("datahub.ingestion.source.unity.source.HiveMetastoreProxy")
 | 
						|
    def test_source_report_includes_ml_model_stats(
 | 
						|
        self, mock_hive_proxy, mock_unity_proxy
 | 
						|
    ):
 | 
						|
        """Test that source report properly tracks ML model statistics."""
 | 
						|
        from datahub.ingestion.api.common import PipelineContext
 | 
						|
 | 
						|
        # Setup mocks
 | 
						|
        mock_unity_instance = mock_unity_proxy.return_value
 | 
						|
        mock_unity_instance.catalogs.return_value = []
 | 
						|
        mock_unity_instance.check_basic_connectivity.return_value = True
 | 
						|
 | 
						|
        config = UnityCatalogSourceConfig.parse_obj(
 | 
						|
            {
 | 
						|
                "token": "test_token",
 | 
						|
                "workspace_url": "https://test.databricks.com",
 | 
						|
                "warehouse_id": "test_warehouse",
 | 
						|
                "include_hive_metastore": False,
 | 
						|
                "databricks_api_page_size": 200,
 | 
						|
            }
 | 
						|
        )
 | 
						|
 | 
						|
        ctx = PipelineContext(run_id="test_run")
 | 
						|
        source = UnityCatalogSource.create(config, ctx)
 | 
						|
 | 
						|
        # Verify report has proper ML model tracking attributes
 | 
						|
        assert hasattr(source.report, "ml_models")
 | 
						|
        assert hasattr(source.report, "ml_model_versions")
 | 
						|
 | 
						|
        # Verify they are EntityFilterReport objects
 | 
						|
        assert isinstance(source.report.ml_models, EntityFilterReport)
 | 
						|
        assert isinstance(source.report.ml_model_versions, EntityFilterReport)
 | 
						|
 | 
						|
    def test_test_connection_with_ml_model_configs(self):
 | 
						|
        """Test that test_connection properly handles ML model configs."""
 | 
						|
        config_dict = {
 | 
						|
            "token": "test_token",
 | 
						|
            "workspace_url": "https://test.databricks.com",
 | 
						|
            "warehouse_id": "test_warehouse",
 | 
						|
            "include_ml_model_aliases": True,
 | 
						|
            "ml_model_max_results": 750,
 | 
						|
            "databricks_api_page_size": 200,
 | 
						|
        }
 | 
						|
 | 
						|
        with patch(
 | 
						|
            "datahub.ingestion.source.unity.source.UnityCatalogConnectionTest"
 | 
						|
        ) as mock_connection_test:
 | 
						|
            mock_connection_test.return_value.get_connection_test.return_value = (
 | 
						|
                "test_report"
 | 
						|
            )
 | 
						|
 | 
						|
            result = UnityCatalogSource.test_connection(config_dict)
 | 
						|
 | 
						|
            # Verify connection test was created with correct config
 | 
						|
            assert result == "test_report"
 | 
						|
            mock_connection_test.assert_called_once()
 | 
						|
 | 
						|
            # Get the config that was passed to UnityCatalogConnectionTest
 | 
						|
            connection_test_config = mock_connection_test.call_args[0][0]
 | 
						|
            assert connection_test_config.include_ml_model_aliases is True
 | 
						|
            assert connection_test_config.ml_model_max_results == 750
 | 
						|
            assert connection_test_config.databricks_api_page_size == 200
 | 
						|
 | 
						|
    @patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy")
 | 
						|
    @patch("datahub.ingestion.source.unity.source.HiveMetastoreProxy")
 | 
						|
    def test_process_ml_model_generates_workunits(
 | 
						|
        self, mock_hive_proxy, mock_unity_proxy
 | 
						|
    ):
 | 
						|
        """Test that process_ml_model generates proper workunits."""
 | 
						|
        from datetime import datetime
 | 
						|
 | 
						|
        from datahub.ingestion.api.common import PipelineContext
 | 
						|
        from datahub.ingestion.source.unity.proxy_types import (
 | 
						|
            Catalog,
 | 
						|
            Metastore,
 | 
						|
            Model,
 | 
						|
            ModelVersion,
 | 
						|
            Schema,
 | 
						|
        )
 | 
						|
 | 
						|
        config = UnityCatalogSourceConfig.parse_obj(
 | 
						|
            {
 | 
						|
                "token": "test_token",
 | 
						|
                "workspace_url": "https://test.databricks.com",
 | 
						|
                "warehouse_id": "test_warehouse",
 | 
						|
                "include_hive_metastore": False,
 | 
						|
            }
 | 
						|
        )
 | 
						|
 | 
						|
        ctx = PipelineContext(run_id="test_run")
 | 
						|
        source = UnityCatalogSource.create(config, ctx)
 | 
						|
 | 
						|
        # Create test schema
 | 
						|
        metastore = Metastore(
 | 
						|
            id="metastore",
 | 
						|
            name="metastore",
 | 
						|
            comment=None,
 | 
						|
            global_metastore_id=None,
 | 
						|
            metastore_id=None,
 | 
						|
            owner=None,
 | 
						|
            region=None,
 | 
						|
            cloud=None,
 | 
						|
        )
 | 
						|
        catalog = Catalog(
 | 
						|
            id="test_catalog",
 | 
						|
            name="test_catalog",
 | 
						|
            metastore=metastore,
 | 
						|
            comment=None,
 | 
						|
            owner=None,
 | 
						|
            type=None,
 | 
						|
        )
 | 
						|
        schema = Schema(
 | 
						|
            id="test_catalog.test_schema",
 | 
						|
            name="test_schema",
 | 
						|
            catalog=catalog,
 | 
						|
            comment=None,
 | 
						|
            owner=None,
 | 
						|
        )
 | 
						|
 | 
						|
        # Create test model
 | 
						|
        test_model = Model(
 | 
						|
            id="test_catalog.test_schema.test_model",
 | 
						|
            name="test_model",
 | 
						|
            description="Test description",
 | 
						|
            schema_name="test_schema",
 | 
						|
            catalog_name="test_catalog",
 | 
						|
            created_at=datetime(2023, 1, 1),
 | 
						|
            updated_at=datetime(2023, 1, 2),
 | 
						|
        )
 | 
						|
 | 
						|
        # Create test model version
 | 
						|
        test_model_version = ModelVersion(
 | 
						|
            id="test_catalog.test_schema.test_model_1",
 | 
						|
            name="test_model_1",
 | 
						|
            model=test_model,
 | 
						|
            version="1",
 | 
						|
            aliases=["prod"],
 | 
						|
            description="Version 1",
 | 
						|
            created_at=datetime(2023, 1, 3),
 | 
						|
            updated_at=datetime(2023, 1, 4),
 | 
						|
            created_by="test_user",
 | 
						|
        )
 | 
						|
 | 
						|
        # Process the model
 | 
						|
        ml_model_workunits = list(source.process_ml_model(test_model, schema))
 | 
						|
 | 
						|
        # Should generate workunits (MLModelGroup creation and container assignment)
 | 
						|
        assert len(ml_model_workunits) > 0
 | 
						|
 | 
						|
        assert len(source.report.ml_models.processed_entities) == 1
 | 
						|
        assert (
 | 
						|
            source.report.ml_models.processed_entities[0][1]
 | 
						|
            == "test_catalog.test_schema.test_model"
 | 
						|
        )
 | 
						|
 | 
						|
        # Process the model version
 | 
						|
        model_urn = source.gen_ml_model_urn(test_model.id)
 | 
						|
        ml_model_version_workunits = list(
 | 
						|
            source.process_ml_model_version(model_urn, test_model_version, schema)
 | 
						|
        )
 | 
						|
 | 
						|
        # Should generate workunits (MLModel creation and container assignment)
 | 
						|
        assert len(ml_model_version_workunits) > 0
 | 
						|
 | 
						|
        # Verify the report was updated
 | 
						|
        assert len(source.report.ml_model_versions.processed_entities) == 1
 | 
						|
        assert (
 | 
						|
            source.report.ml_model_versions.processed_entities[0][1]
 | 
						|
            == "test_catalog.test_schema.test_model_1"
 | 
						|
        )
 |