From 00b6da5b8405b17807cada30a285884abbf40f05 Mon Sep 17 00:00:00 2001 From: Mayur Singal <39544459+ulixius9@users.noreply.github.com> Date: Wed, 6 Aug 2025 00:41:11 +0530 Subject: [PATCH] MINOR: Improve Databricks Profiler & Test Connection (#22732) --- .../source/database/databricks/connection.py | 44 +++++++++-- .../database/unitycatalog/connection.py | 15 ++-- .../source/database/unitycatalog/metadata.py | 9 ++- .../databricks/profiler_interface.py | 6 ++ .../unit/topology/database/test_databricks.py | 78 +++++++++++++++---- 5 files changed, 117 insertions(+), 35 deletions(-) diff --git a/ingestion/src/metadata/ingestion/source/database/databricks/connection.py b/ingestion/src/metadata/ingestion/source/database/databricks/connection.py index a6c3fa837ae..171c4a74ab2 100644 --- a/ingestion/src/metadata/ingestion/source/database/databricks/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/databricks/connection.py @@ -56,11 +56,17 @@ class DatabricksEngineWrapper: self.inspector = inspect(engine) self.schemas = None self.first_schema = None + self.first_catalog = None - def get_schemas(self): + def get_schemas(self, schema_name: Optional[str] = None): """Get schemas and cache them""" + if schema_name is not None: + with self.engine.connect() as connection: + connection.execute(f"USE CATALOG `{self.first_catalog}`") + self.first_schema = schema_name + return [schema_name] if self.schemas is None: - self.schemas = self.inspector.get_schema_names() + self.schemas = self.inspector.get_schema_names(database=self.first_catalog) if self.schemas: # Find the first schema that's not a system schema for schema in self.schemas: @@ -81,7 +87,11 @@ class DatabricksEngineWrapper: if self.first_schema is None: self.get_schemas() # This will set first_schema if self.first_schema: - return self.inspector.get_table_names(self.first_schema) + with self.engine.connect() as connection: + tables = connection.execute( + f"SHOW TABLES IN `{self.first_catalog}`.`{self.first_schema}`" + ) + return tables return [] def get_views(self): @@ -89,9 +99,27 @@ class DatabricksEngineWrapper: if self.first_schema is None: self.get_schemas() # This will set first_schema if self.first_schema: - return self.inspector.get_view_names(self.first_schema) + with self.engine.connect() as connection: + views = connection.execute( + f"SHOW VIEWS IN `{self.first_catalog}`.`{self.first_schema}`" + ) + return views return [] + def get_catalogs(self, catalog_name: Optional[str] = None): + """Get catalogs""" + catalogs = [] + if catalog_name is not None: + self.first_catalog = catalog_name + return [catalog_name] + with self.engine.connect() as connection: + catalogs = connection.execute(DATABRICKS_GET_CATALOGS).fetchall() + for catalog in catalogs: + if catalog[0] != "__databricks_internal": + self.first_catalog = catalog[0] + break + return catalogs + def get_connection_url(connection: DatabricksConnection) -> str: url = f"{connection.scheme.value}://token:{connection.token.get_secret_value()}@{connection.hostPort}" @@ -144,13 +172,13 @@ def test_connection( test_fn = { "CheckAccess": partial(test_connection_engine_step, connection), - "GetSchemas": engine_wrapper.get_schemas, + "GetSchemas": partial( + engine_wrapper.get_schemas, schema_name=service_connection.databaseSchema + ), "GetTables": engine_wrapper.get_tables, "GetViews": engine_wrapper.get_views, "GetDatabases": partial( - test_database_query, - engine=connection, - statement=DATABRICKS_GET_CATALOGS, + engine_wrapper.get_catalogs, catalog_name=service_connection.catalog ), "GetQueries": partial( test_database_query, diff --git a/ingestion/src/metadata/ingestion/source/database/unitycatalog/connection.py b/ingestion/src/metadata/ingestion/source/database/unitycatalog/connection.py index 594bebb5fe0..a5e82bc9c45 100644 --- a/ingestion/src/metadata/ingestion/source/database/unitycatalog/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/unitycatalog/connection.py @@ -111,16 +111,15 @@ def test_connection( def get_catalogs(connection: WorkspaceClient, table_obj: DatabricksTable): for catalog in connection.catalogs.list(): - table_obj.catalog_name = catalog.name - break + if catalog.name != "__databricks_internal": + table_obj.catalog_name = catalog.name + return def get_schemas(connection: WorkspaceClient, table_obj: DatabricksTable): - for catalog in connection.catalogs.list(): - for schema in connection.schemas.list(catalog_name=catalog.name): - if schema.name: - table_obj.schema_name = schema.name - table_obj.catalog_name = catalog.name - return + for schema in connection.schemas.list(catalog_name=table_obj.catalog_name): + if schema.name: + table_obj.schema_name = schema.name + return def get_tables(connection: WorkspaceClient, table_obj: DatabricksTable): if table_obj.catalog_name and table_obj.schema_name: diff --git a/ingestion/src/metadata/ingestion/source/database/unitycatalog/metadata.py b/ingestion/src/metadata/ingestion/source/database/unitycatalog/metadata.py index 3166fb42ad0..3e8fb12b819 100644 --- a/ingestion/src/metadata/ingestion/source/database/unitycatalog/metadata.py +++ b/ingestion/src/metadata/ingestion/source/database/unitycatalog/metadata.py @@ -92,7 +92,7 @@ logger = ingestion_logger() UNITY_CATALOG_TAG = "UNITY CATALOG TAG" UNITY_CATALOG_TAG_CLASSIFICATION = "UNITY CATALOG TAG CLASSIFICATION" - +# pylint: disable=protected-access class UnitycatalogSource( ExternalTableLineageMixin, DatabaseServiceSource, MultiDBSource ): @@ -306,6 +306,8 @@ class UnitycatalogSource( if table.table_type: if table.table_type.value.lower() == TableType.View.value.lower(): table_type: TableType = TableType.View + if table.table_type.value.lower() == "materialized_view": + table_type: TableType = TableType.MaterializedView elif ( table.table_type.value.lower() == TableType.External.value.lower() @@ -427,7 +429,6 @@ class UnitycatalogSource( ) if referred_table_fqn: for parent_column in column.parent_columns: - # pylint: disable=protected-access col_fqn = fqn._build(referred_table_fqn, parent_column, quote=False) if col_fqn: referred_column_fqns.append(FullyQualifiedEntityName(col_fqn)) @@ -563,7 +564,7 @@ class UnitycatalogSource( yield from get_ometa_tag_and_classification( tag_fqn=FullyQualifiedEntityName( fqn._build(*tag_fqn_builder(tag)) - ), # pylint: disable=protected-access + ), tags=[tag.tag_value], classification_name=tag.tag_name, tag_description=UNITY_CATALOG_TAG, @@ -617,7 +618,7 @@ class UnitycatalogSource( yield from get_ometa_tag_and_classification( tag_fqn=FullyQualifiedEntityName( fqn._build(*tag_fqn_builder(tag)) - ), # pylint: disable=protected-access + ), tags=[tag.tag_value], classification_name=tag.tag_name, tag_description=UNITY_CATALOG_TAG, diff --git a/ingestion/src/metadata/profiler/interface/sqlalchemy/databricks/profiler_interface.py b/ingestion/src/metadata/profiler/interface/sqlalchemy/databricks/profiler_interface.py index 403a1e3265a..eed50f51b1e 100644 --- a/ingestion/src/metadata/profiler/interface/sqlalchemy/databricks/profiler_interface.py +++ b/ingestion/src/metadata/profiler/interface/sqlalchemy/databricks/profiler_interface.py @@ -23,6 +23,7 @@ from metadata.generated.schema.entity.data.table import ( ColumnName, DataType, SystemProfile, + TableType, ) from metadata.generated.schema.entity.services.databaseService import ( DatabaseServiceType, @@ -51,6 +52,11 @@ class DatabricksProfilerInterface(SQAProfilerInterface): *args, **kwargs, ) -> List[SystemProfile]: + if self.table_entity.tableType in (TableType.View, TableType.MaterializedView): + logger.debug( + f"Skipping {metrics.name()} metric for view {runner.table_name}" + ) + return [] logger.debug(f"Computing {metrics.name()} metric for {runner.table_name}") self.system_metrics_class = cast( Type[DatabricksSystemMetricsComputer], self.system_metrics_class diff --git a/ingestion/tests/unit/topology/database/test_databricks.py b/ingestion/tests/unit/topology/database/test_databricks.py index 419309c3dc3..37b74eee599 100644 --- a/ingestion/tests/unit/topology/database/test_databricks.py +++ b/ingestion/tests/unit/topology/database/test_databricks.py @@ -12,9 +12,9 @@ """ Test databricks using the topology """ - +# pylint: disable=invalid-name,import-outside-toplevel from unittest import TestCase -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch from metadata.generated.schema.api.data.createDatabaseSchema import ( CreateDatabaseSchemaRequest, @@ -36,7 +36,6 @@ from metadata.generated.schema.type.entityReference import EntityReference from metadata.ingestion.ometa.utils import model_str from metadata.ingestion.source.database.databricks.metadata import DatabricksSource -# pylint: disable=line-too-long mock_databricks_config = { "source": { "type": "databricks", @@ -514,12 +513,22 @@ class DatabricksConnectionTest(TestCase): def test_databricks_engine_wrapper_get_tables_with_cached_schema(self): """Test get_tables with cached schema""" mock_engine = Mock() + mock_connection = Mock() + mock_context_manager = MagicMock() + mock_context_manager.__enter__.return_value = mock_connection + mock_context_manager.__exit__.return_value = None + mock_engine.connect.return_value = mock_context_manager + + # Mock the connection execute method to return a result + mock_result = Mock() + mock_result.fetchall.return_value = [("table1",), ("table2",)] + mock_connection.execute.return_value = mock_result + mock_inspector = Mock() mock_inspector.get_schema_names.return_value = [ "test_schema", "information_schema", ] - mock_inspector.get_table_names.return_value = ["table1", "table2"] with patch( "metadata.ingestion.source.database.databricks.connection.inspect" @@ -527,23 +536,35 @@ class DatabricksConnectionTest(TestCase): mock_inspect.return_value = mock_inspector wrapper = self.DatabricksEngineWrapper(mock_engine) + # Set the first_catalog for the wrapper + wrapper.first_catalog = "test_catalog" # First call to get_schemas to set first_schema wrapper.get_schemas() # Then call get_tables tables = wrapper.get_tables() - self.assertEqual(tables, ["table1", "table2"]) - mock_inspector.get_table_names.assert_called_once_with("test_schema") + self.assertEqual(tables, mock_result) + mock_connection.execute.assert_called_once() def test_databricks_engine_wrapper_get_tables_without_cached_schema(self): """Test get_tables without cached schema""" mock_engine = Mock() + mock_connection = Mock() + mock_context_manager = MagicMock() + mock_context_manager.__enter__.return_value = mock_connection + mock_context_manager.__exit__.return_value = None + mock_engine.connect.return_value = mock_context_manager + + # Mock the connection execute method to return a result + mock_result = Mock() + mock_result.fetchall.return_value = [("table1",), ("table2",)] + mock_connection.execute.return_value = mock_result + mock_inspector = Mock() mock_inspector.get_schema_names.return_value = [ "test_schema", "information_schema", ] - mock_inspector.get_table_names.return_value = ["table1", "table2"] with patch( "metadata.ingestion.source.database.databricks.connection.inspect" @@ -551,13 +572,15 @@ class DatabricksConnectionTest(TestCase): mock_inspect.return_value = mock_inspector wrapper = self.DatabricksEngineWrapper(mock_engine) + # Set the first_catalog for the wrapper + wrapper.first_catalog = "test_catalog" # Call get_tables directly without calling get_schemas first tables = wrapper.get_tables() - self.assertEqual(tables, ["table1", "table2"]) + self.assertEqual(tables, mock_result) # Should have called get_schemas internally mock_inspector.get_schema_names.assert_called_once() - mock_inspector.get_table_names.assert_called_once_with("test_schema") + mock_connection.execute.assert_called_once() def test_databricks_engine_wrapper_get_tables_no_schemas(self): """Test get_tables when no schemas are available""" @@ -579,12 +602,22 @@ class DatabricksConnectionTest(TestCase): def test_databricks_engine_wrapper_get_views_with_cached_schema(self): """Test get_views with cached schema""" mock_engine = Mock() + mock_connection = Mock() + mock_context_manager = MagicMock() + mock_context_manager.__enter__.return_value = mock_connection + mock_context_manager.__exit__.return_value = None + mock_engine.connect.return_value = mock_context_manager + + # Mock the connection execute method to return a result + mock_result = Mock() + mock_result.fetchall.return_value = [("view1",), ("view2",)] + mock_connection.execute.return_value = mock_result + mock_inspector = Mock() mock_inspector.get_schema_names.return_value = [ "test_schema", "information_schema", ] - mock_inspector.get_view_names.return_value = ["view1", "view2"] with patch( "metadata.ingestion.source.database.databricks.connection.inspect" @@ -592,23 +625,35 @@ class DatabricksConnectionTest(TestCase): mock_inspect.return_value = mock_inspector wrapper = self.DatabricksEngineWrapper(mock_engine) + # Set the first_catalog for the wrapper + wrapper.first_catalog = "test_catalog" # First call to get_schemas to set first_schema wrapper.get_schemas() # Then call get_views views = wrapper.get_views() - self.assertEqual(views, ["view1", "view2"]) - mock_inspector.get_view_names.assert_called_once_with("test_schema") + self.assertEqual(views, mock_result) + mock_connection.execute.assert_called_once() def test_databricks_engine_wrapper_get_views_without_cached_schema(self): """Test get_views without cached schema""" mock_engine = Mock() + mock_connection = Mock() + mock_context_manager = MagicMock() + mock_context_manager.__enter__.return_value = mock_connection + mock_context_manager.__exit__.return_value = None + mock_engine.connect.return_value = mock_context_manager + + # Mock the connection execute method to return a result + mock_result = Mock() + mock_result.fetchall.return_value = [("view1",), ("view2",)] + mock_connection.execute.return_value = mock_result + mock_inspector = Mock() mock_inspector.get_schema_names.return_value = [ "test_schema", "information_schema", ] - mock_inspector.get_view_names.return_value = ["view1", "view2"] with patch( "metadata.ingestion.source.database.databricks.connection.inspect" @@ -616,13 +661,15 @@ class DatabricksConnectionTest(TestCase): mock_inspect.return_value = mock_inspector wrapper = self.DatabricksEngineWrapper(mock_engine) + # Set the first_catalog for the wrapper + wrapper.first_catalog = "test_catalog" # Call get_views directly without calling get_schemas first views = wrapper.get_views() - self.assertEqual(views, ["view1", "view2"]) + self.assertEqual(views, mock_result) # Should have called get_schemas internally mock_inspector.get_schema_names.assert_called_once() - mock_inspector.get_view_names.assert_called_once_with("test_schema") + mock_connection.execute.assert_called_once() def test_databricks_engine_wrapper_get_views_no_schemas(self): """Test get_views when no schemas are available""" @@ -669,6 +716,7 @@ class DatabricksConnectionTest(TestCase): # get_schema_names should only be called once due to caching mock_inspector.get_schema_names.assert_called_once() + # pylint: disable=too-many-locals @patch( "metadata.ingestion.source.database.databricks.connection.test_connection_steps" )