mirror of
				https://github.com/datahub-project/datahub.git
				synced 2025-10-31 10:49:00 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			332 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			332 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from unittest.mock import patch
 | |
| 
 | |
| import pytest
 | |
| 
 | |
| from datahub.ingestion.api.common import PipelineContext
 | |
| from datahub.ingestion.api.report import EntityFilterReport
 | |
| from datahub.ingestion.source.unity.config import UnityCatalogSourceConfig
 | |
| from datahub.ingestion.source.unity.source import UnityCatalogSource
 | |
| 
 | |
| 
 | |
| class TestUnityCatalogSource:
 | |
|     @pytest.fixture
 | |
|     def minimal_config(self):
 | |
|         """Create a minimal config for testing."""
 | |
|         return UnityCatalogSourceConfig.parse_obj(
 | |
|             {
 | |
|                 "token": "test_token",
 | |
|                 "workspace_url": "https://test.databricks.com",
 | |
|                 "warehouse_id": "test_warehouse",
 | |
|                 "include_hive_metastore": False,
 | |
|             }
 | |
|         )
 | |
| 
 | |
|     @pytest.fixture
 | |
|     def config_with_page_size(self):
 | |
|         """Create a config with custom page size."""
 | |
|         return UnityCatalogSourceConfig.parse_obj(
 | |
|             {
 | |
|                 "token": "test_token",
 | |
|                 "workspace_url": "https://test.databricks.com",
 | |
|                 "warehouse_id": "test_warehouse",
 | |
|                 "include_hive_metastore": False,
 | |
|                 "databricks_api_page_size": 75,
 | |
|             }
 | |
|         )
 | |
| 
 | |
|     @pytest.fixture
 | |
|     def config_with_ml_model_settings(self):
 | |
|         """Create a config with ML model settings."""
 | |
|         return UnityCatalogSourceConfig.parse_obj(
 | |
|             {
 | |
|                 "token": "test_token",
 | |
|                 "workspace_url": "https://test.databricks.com",
 | |
|                 "warehouse_id": "test_warehouse",
 | |
|                 "include_hive_metastore": False,
 | |
|                 "include_ml_model_aliases": True,
 | |
|                 "ml_model_max_results": 500,
 | |
|                 "databricks_api_page_size": 100,
 | |
|             }
 | |
|         )
 | |
| 
 | |
|     @patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy")
 | |
|     @patch("datahub.ingestion.source.unity.source.HiveMetastoreProxy")
 | |
|     def test_source_constructor_passes_default_page_size_to_proxy(
 | |
|         self, mock_hive_proxy, mock_unity_proxy, minimal_config
 | |
|     ):
 | |
|         """Test that UnityCatalogSource passes default databricks_api_page_size to proxy."""
 | |
|         # Create a mock context
 | |
|         ctx = PipelineContext(run_id="test_run")
 | |
|         source = UnityCatalogSource.create(minimal_config, ctx)
 | |
| 
 | |
|         # Verify proxy was created with correct parameters including page size
 | |
|         mock_unity_proxy.assert_called_once_with(
 | |
|             minimal_config.workspace_url,
 | |
|             minimal_config.token,
 | |
|             minimal_config.warehouse_id,
 | |
|             report=source.report,
 | |
|             hive_metastore_proxy=source.hive_metastore_proxy,
 | |
|             lineage_data_source=minimal_config.lineage_data_source,
 | |
|             databricks_api_page_size=0,  # Default value
 | |
|         )
 | |
| 
 | |
|     @patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy")
 | |
|     @patch("datahub.ingestion.source.unity.source.HiveMetastoreProxy")
 | |
|     def test_source_constructor_passes_custom_page_size_to_proxy(
 | |
|         self, mock_hive_proxy, mock_unity_proxy, config_with_page_size
 | |
|     ):
 | |
|         """Test that UnityCatalogSource passes custom databricks_api_page_size to proxy."""
 | |
|         ctx = PipelineContext(run_id="test_run")
 | |
|         source = UnityCatalogSource.create(config_with_page_size, ctx)
 | |
| 
 | |
|         # Verify proxy was created with correct parameters including custom page size
 | |
|         mock_unity_proxy.assert_called_once_with(
 | |
|             config_with_page_size.workspace_url,
 | |
|             config_with_page_size.token,
 | |
|             config_with_page_size.warehouse_id,
 | |
|             report=source.report,
 | |
|             hive_metastore_proxy=source.hive_metastore_proxy,
 | |
|             lineage_data_source=config_with_page_size.lineage_data_source,
 | |
|             databricks_api_page_size=75,  # Custom value
 | |
|         )
 | |
| 
 | |
|     @patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy")
 | |
|     @patch("datahub.ingestion.source.unity.source.HiveMetastoreProxy")
 | |
|     def test_source_config_page_size_available_to_source(
 | |
|         self, mock_hive_proxy, mock_unity_proxy, config_with_page_size
 | |
|     ):
 | |
|         """Test that UnityCatalogSource has access to databricks_api_page_size config."""
 | |
|         ctx = PipelineContext(run_id="test_run")
 | |
|         source = UnityCatalogSource.create(config_with_page_size, ctx)
 | |
| 
 | |
|         # Verify the source has access to the configuration value
 | |
|         assert source.config.databricks_api_page_size == 75
 | |
| 
 | |
|     @patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy")
 | |
|     @patch("datahub.ingestion.source.unity.source.HiveMetastoreProxy")
 | |
|     def test_source_with_hive_metastore_disabled(
 | |
|         self, mock_hive_proxy, mock_unity_proxy
 | |
|     ):
 | |
|         """Test that UnityCatalogSource works with hive metastore disabled."""
 | |
|         config = UnityCatalogSourceConfig.parse_obj(
 | |
|             {
 | |
|                 "token": "test_token",
 | |
|                 "workspace_url": "https://test.databricks.com",
 | |
|                 "warehouse_id": "test_warehouse",
 | |
|                 "include_hive_metastore": False,
 | |
|                 "databricks_api_page_size": 200,
 | |
|             }
 | |
|         )
 | |
| 
 | |
|         ctx = PipelineContext(run_id="test_run")
 | |
|         source = UnityCatalogSource.create(config, ctx)
 | |
| 
 | |
|         # Verify proxy was created with correct page size even when hive metastore is disabled
 | |
|         mock_unity_proxy.assert_called_once_with(
 | |
|             config.workspace_url,
 | |
|             config.token,
 | |
|             config.warehouse_id,
 | |
|             report=source.report,
 | |
|             hive_metastore_proxy=None,  # Should be None when disabled
 | |
|             lineage_data_source=config.lineage_data_source,
 | |
|             databricks_api_page_size=200,
 | |
|         )
 | |
| 
 | |
|     def test_test_connection_with_page_size_config(self):
 | |
|         """Test that test_connection properly handles databricks_api_page_size."""
 | |
|         config_dict = {
 | |
|             "token": "test_token",
 | |
|             "workspace_url": "https://test.databricks.com",
 | |
|             "warehouse_id": "test_warehouse",
 | |
|             "databricks_api_page_size": 300,
 | |
|         }
 | |
| 
 | |
|         with patch(
 | |
|             "datahub.ingestion.source.unity.source.UnityCatalogConnectionTest"
 | |
|         ) as mock_connection_test:
 | |
|             mock_connection_test.return_value.get_connection_test.return_value = (
 | |
|                 "test_report"
 | |
|             )
 | |
| 
 | |
|             result = UnityCatalogSource.test_connection(config_dict)
 | |
| 
 | |
|             # Verify connection test was created with correct config
 | |
|             assert result == "test_report"
 | |
|             mock_connection_test.assert_called_once()
 | |
| 
 | |
|             # Get the config that was passed to UnityCatalogConnectionTest
 | |
|             connection_test_config = mock_connection_test.call_args[0][0]
 | |
|             assert connection_test_config.databricks_api_page_size == 300
 | |
| 
 | |
|     @patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy")
 | |
|     @patch("datahub.ingestion.source.unity.source.HiveMetastoreProxy")
 | |
|     def test_source_report_includes_ml_model_stats(
 | |
|         self, mock_hive_proxy, mock_unity_proxy
 | |
|     ):
 | |
|         """Test that source report properly tracks ML model statistics."""
 | |
|         from datahub.ingestion.api.common import PipelineContext
 | |
| 
 | |
|         # Setup mocks
 | |
|         mock_unity_instance = mock_unity_proxy.return_value
 | |
|         mock_unity_instance.catalogs.return_value = []
 | |
|         mock_unity_instance.check_basic_connectivity.return_value = True
 | |
| 
 | |
|         config = UnityCatalogSourceConfig.parse_obj(
 | |
|             {
 | |
|                 "token": "test_token",
 | |
|                 "workspace_url": "https://test.databricks.com",
 | |
|                 "warehouse_id": "test_warehouse",
 | |
|                 "include_hive_metastore": False,
 | |
|                 "databricks_api_page_size": 200,
 | |
|             }
 | |
|         )
 | |
| 
 | |
|         ctx = PipelineContext(run_id="test_run")
 | |
|         source = UnityCatalogSource.create(config, ctx)
 | |
| 
 | |
|         # Verify report has proper ML model tracking attributes
 | |
|         assert hasattr(source.report, "ml_models")
 | |
|         assert hasattr(source.report, "ml_model_versions")
 | |
| 
 | |
|         # Verify they are EntityFilterReport objects
 | |
|         assert isinstance(source.report.ml_models, EntityFilterReport)
 | |
|         assert isinstance(source.report.ml_model_versions, EntityFilterReport)
 | |
| 
 | |
|     def test_test_connection_with_ml_model_configs(self):
 | |
|         """Test that test_connection properly handles ML model configs."""
 | |
|         config_dict = {
 | |
|             "token": "test_token",
 | |
|             "workspace_url": "https://test.databricks.com",
 | |
|             "warehouse_id": "test_warehouse",
 | |
|             "include_ml_model_aliases": True,
 | |
|             "ml_model_max_results": 750,
 | |
|             "databricks_api_page_size": 200,
 | |
|         }
 | |
| 
 | |
|         with patch(
 | |
|             "datahub.ingestion.source.unity.source.UnityCatalogConnectionTest"
 | |
|         ) as mock_connection_test:
 | |
|             mock_connection_test.return_value.get_connection_test.return_value = (
 | |
|                 "test_report"
 | |
|             )
 | |
| 
 | |
|             result = UnityCatalogSource.test_connection(config_dict)
 | |
| 
 | |
|             # Verify connection test was created with correct config
 | |
|             assert result == "test_report"
 | |
|             mock_connection_test.assert_called_once()
 | |
| 
 | |
|             # Get the config that was passed to UnityCatalogConnectionTest
 | |
|             connection_test_config = mock_connection_test.call_args[0][0]
 | |
|             assert connection_test_config.include_ml_model_aliases is True
 | |
|             assert connection_test_config.ml_model_max_results == 750
 | |
|             assert connection_test_config.databricks_api_page_size == 200
 | |
| 
 | |
|     @patch("datahub.ingestion.source.unity.source.UnityCatalogApiProxy")
 | |
|     @patch("datahub.ingestion.source.unity.source.HiveMetastoreProxy")
 | |
|     def test_process_ml_model_generates_workunits(
 | |
|         self, mock_hive_proxy, mock_unity_proxy
 | |
|     ):
 | |
|         """Test that process_ml_model generates proper workunits."""
 | |
|         from datetime import datetime
 | |
| 
 | |
|         from datahub.ingestion.api.common import PipelineContext
 | |
|         from datahub.ingestion.source.unity.proxy_types import (
 | |
|             Catalog,
 | |
|             Metastore,
 | |
|             Model,
 | |
|             ModelVersion,
 | |
|             Schema,
 | |
|         )
 | |
| 
 | |
|         config = UnityCatalogSourceConfig.parse_obj(
 | |
|             {
 | |
|                 "token": "test_token",
 | |
|                 "workspace_url": "https://test.databricks.com",
 | |
|                 "warehouse_id": "test_warehouse",
 | |
|                 "include_hive_metastore": False,
 | |
|             }
 | |
|         )
 | |
| 
 | |
|         ctx = PipelineContext(run_id="test_run")
 | |
|         source = UnityCatalogSource.create(config, ctx)
 | |
| 
 | |
|         # Create test schema
 | |
|         metastore = Metastore(
 | |
|             id="metastore",
 | |
|             name="metastore",
 | |
|             comment=None,
 | |
|             global_metastore_id=None,
 | |
|             metastore_id=None,
 | |
|             owner=None,
 | |
|             region=None,
 | |
|             cloud=None,
 | |
|         )
 | |
|         catalog = Catalog(
 | |
|             id="test_catalog",
 | |
|             name="test_catalog",
 | |
|             metastore=metastore,
 | |
|             comment=None,
 | |
|             owner=None,
 | |
|             type=None,
 | |
|         )
 | |
|         schema = Schema(
 | |
|             id="test_catalog.test_schema",
 | |
|             name="test_schema",
 | |
|             catalog=catalog,
 | |
|             comment=None,
 | |
|             owner=None,
 | |
|         )
 | |
| 
 | |
|         # Create test model
 | |
|         test_model = Model(
 | |
|             id="test_catalog.test_schema.test_model",
 | |
|             name="test_model",
 | |
|             description="Test description",
 | |
|             schema_name="test_schema",
 | |
|             catalog_name="test_catalog",
 | |
|             created_at=datetime(2023, 1, 1),
 | |
|             updated_at=datetime(2023, 1, 2),
 | |
|         )
 | |
| 
 | |
|         # Create test model version
 | |
|         test_model_version = ModelVersion(
 | |
|             id="test_catalog.test_schema.test_model_1",
 | |
|             name="test_model_1",
 | |
|             model=test_model,
 | |
|             version="1",
 | |
|             aliases=["prod"],
 | |
|             description="Version 1",
 | |
|             created_at=datetime(2023, 1, 3),
 | |
|             updated_at=datetime(2023, 1, 4),
 | |
|             created_by="test_user",
 | |
|         )
 | |
| 
 | |
|         # Process the model
 | |
|         ml_model_workunits = list(source.process_ml_model(test_model, schema))
 | |
| 
 | |
|         # Should generate workunits (MLModelGroup creation and container assignment)
 | |
|         assert len(ml_model_workunits) > 0
 | |
| 
 | |
|         assert len(source.report.ml_models.processed_entities) == 1
 | |
|         assert (
 | |
|             source.report.ml_models.processed_entities[0][1]
 | |
|             == "test_catalog.test_schema.test_model"
 | |
|         )
 | |
| 
 | |
|         # Process the model version
 | |
|         model_urn = source.gen_ml_model_urn(test_model.id)
 | |
|         ml_model_version_workunits = list(
 | |
|             source.process_ml_model_version(model_urn, test_model_version, schema)
 | |
|         )
 | |
| 
 | |
|         # Should generate workunits (MLModel creation and container assignment)
 | |
|         assert len(ml_model_version_workunits) > 0
 | |
| 
 | |
|         # Verify the report was updated
 | |
|         assert len(source.report.ml_model_versions.processed_entities) == 1
 | |
|         assert (
 | |
|             source.report.ml_model_versions.processed_entities[0][1]
 | |
|             == "test_catalog.test_schema.test_model_1"
 | |
|         )
 | 
