MINOR: Improve Databricks Profiler & Test Connection (#22732)

This commit is contained in:
Mayur Singal 2025-08-06 00:41:11 +05:30 committed by GitHub
parent 55c82ec8ca
commit 00b6da5b84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 117 additions and 35 deletions

View File

@ -56,11 +56,17 @@ class DatabricksEngineWrapper:
self.inspector = inspect(engine)
self.schemas = None
self.first_schema = None
self.first_catalog = None
def get_schemas(self):
def get_schemas(self, schema_name: Optional[str] = None):
"""Get schemas and cache them"""
if schema_name is not None:
with self.engine.connect() as connection:
connection.execute(f"USE CATALOG `{self.first_catalog}`")
self.first_schema = schema_name
return [schema_name]
if self.schemas is None:
self.schemas = self.inspector.get_schema_names()
self.schemas = self.inspector.get_schema_names(database=self.first_catalog)
if self.schemas:
# Find the first schema that's not a system schema
for schema in self.schemas:
@ -81,7 +87,11 @@ class DatabricksEngineWrapper:
if self.first_schema is None:
self.get_schemas() # This will set first_schema
if self.first_schema:
return self.inspector.get_table_names(self.first_schema)
with self.engine.connect() as connection:
tables = connection.execute(
f"SHOW TABLES IN `{self.first_catalog}`.`{self.first_schema}`"
)
return tables
return []
def get_views(self):
@ -89,9 +99,27 @@ class DatabricksEngineWrapper:
if self.first_schema is None:
self.get_schemas() # This will set first_schema
if self.first_schema:
return self.inspector.get_view_names(self.first_schema)
with self.engine.connect() as connection:
views = connection.execute(
f"SHOW VIEWS IN `{self.first_catalog}`.`{self.first_schema}`"
)
return views
return []
def get_catalogs(self, catalog_name: Optional[str] = None):
"""Get catalogs"""
catalogs = []
if catalog_name is not None:
self.first_catalog = catalog_name
return [catalog_name]
with self.engine.connect() as connection:
catalogs = connection.execute(DATABRICKS_GET_CATALOGS).fetchall()
for catalog in catalogs:
if catalog[0] != "__databricks_internal":
self.first_catalog = catalog[0]
break
return catalogs
def get_connection_url(connection: DatabricksConnection) -> str:
url = f"{connection.scheme.value}://token:{connection.token.get_secret_value()}@{connection.hostPort}"
@ -144,13 +172,13 @@ def test_connection(
test_fn = {
"CheckAccess": partial(test_connection_engine_step, connection),
"GetSchemas": engine_wrapper.get_schemas,
"GetSchemas": partial(
engine_wrapper.get_schemas, schema_name=service_connection.databaseSchema
),
"GetTables": engine_wrapper.get_tables,
"GetViews": engine_wrapper.get_views,
"GetDatabases": partial(
test_database_query,
engine=connection,
statement=DATABRICKS_GET_CATALOGS,
engine_wrapper.get_catalogs, catalog_name=service_connection.catalog
),
"GetQueries": partial(
test_database_query,

View File

@ -111,16 +111,15 @@ def test_connection(
def get_catalogs(connection: WorkspaceClient, table_obj: DatabricksTable):
for catalog in connection.catalogs.list():
table_obj.catalog_name = catalog.name
break
if catalog.name != "__databricks_internal":
table_obj.catalog_name = catalog.name
return
def get_schemas(connection: WorkspaceClient, table_obj: DatabricksTable):
for catalog in connection.catalogs.list():
for schema in connection.schemas.list(catalog_name=catalog.name):
if schema.name:
table_obj.schema_name = schema.name
table_obj.catalog_name = catalog.name
return
for schema in connection.schemas.list(catalog_name=table_obj.catalog_name):
if schema.name:
table_obj.schema_name = schema.name
return
def get_tables(connection: WorkspaceClient, table_obj: DatabricksTable):
if table_obj.catalog_name and table_obj.schema_name:

View File

@ -92,7 +92,7 @@ logger = ingestion_logger()
UNITY_CATALOG_TAG = "UNITY CATALOG TAG"
UNITY_CATALOG_TAG_CLASSIFICATION = "UNITY CATALOG TAG CLASSIFICATION"
# pylint: disable=protected-access
class UnitycatalogSource(
ExternalTableLineageMixin, DatabaseServiceSource, MultiDBSource
):
@ -306,6 +306,8 @@ class UnitycatalogSource(
if table.table_type:
if table.table_type.value.lower() == TableType.View.value.lower():
table_type: TableType = TableType.View
if table.table_type.value.lower() == "materialized_view":
table_type: TableType = TableType.MaterializedView
elif (
table.table_type.value.lower()
== TableType.External.value.lower()
@ -427,7 +429,6 @@ class UnitycatalogSource(
)
if referred_table_fqn:
for parent_column in column.parent_columns:
# pylint: disable=protected-access
col_fqn = fqn._build(referred_table_fqn, parent_column, quote=False)
if col_fqn:
referred_column_fqns.append(FullyQualifiedEntityName(col_fqn))
@ -563,7 +564,7 @@ class UnitycatalogSource(
yield from get_ometa_tag_and_classification(
tag_fqn=FullyQualifiedEntityName(
fqn._build(*tag_fqn_builder(tag))
), # pylint: disable=protected-access
),
tags=[tag.tag_value],
classification_name=tag.tag_name,
tag_description=UNITY_CATALOG_TAG,
@ -617,7 +618,7 @@ class UnitycatalogSource(
yield from get_ometa_tag_and_classification(
tag_fqn=FullyQualifiedEntityName(
fqn._build(*tag_fqn_builder(tag))
), # pylint: disable=protected-access
),
tags=[tag.tag_value],
classification_name=tag.tag_name,
tag_description=UNITY_CATALOG_TAG,

View File

@ -23,6 +23,7 @@ from metadata.generated.schema.entity.data.table import (
ColumnName,
DataType,
SystemProfile,
TableType,
)
from metadata.generated.schema.entity.services.databaseService import (
DatabaseServiceType,
@ -51,6 +52,11 @@ class DatabricksProfilerInterface(SQAProfilerInterface):
*args,
**kwargs,
) -> List[SystemProfile]:
if self.table_entity.tableType in (TableType.View, TableType.MaterializedView):
logger.debug(
f"Skipping {metrics.name()} metric for view {runner.table_name}"
)
return []
logger.debug(f"Computing {metrics.name()} metric for {runner.table_name}")
self.system_metrics_class = cast(
Type[DatabricksSystemMetricsComputer], self.system_metrics_class

View File

@ -12,9 +12,9 @@
"""
Test databricks using the topology
"""
# pylint: disable=invalid-name,import-outside-toplevel
from unittest import TestCase
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch
from metadata.generated.schema.api.data.createDatabaseSchema import (
CreateDatabaseSchemaRequest,
@ -36,7 +36,6 @@ from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.ometa.utils import model_str
from metadata.ingestion.source.database.databricks.metadata import DatabricksSource
# pylint: disable=line-too-long
mock_databricks_config = {
"source": {
"type": "databricks",
@ -514,12 +513,22 @@ class DatabricksConnectionTest(TestCase):
def test_databricks_engine_wrapper_get_tables_with_cached_schema(self):
"""Test get_tables with cached schema"""
mock_engine = Mock()
mock_connection = Mock()
mock_context_manager = MagicMock()
mock_context_manager.__enter__.return_value = mock_connection
mock_context_manager.__exit__.return_value = None
mock_engine.connect.return_value = mock_context_manager
# Mock the connection execute method to return a result
mock_result = Mock()
mock_result.fetchall.return_value = [("table1",), ("table2",)]
mock_connection.execute.return_value = mock_result
mock_inspector = Mock()
mock_inspector.get_schema_names.return_value = [
"test_schema",
"information_schema",
]
mock_inspector.get_table_names.return_value = ["table1", "table2"]
with patch(
"metadata.ingestion.source.database.databricks.connection.inspect"
@ -527,23 +536,35 @@ class DatabricksConnectionTest(TestCase):
mock_inspect.return_value = mock_inspector
wrapper = self.DatabricksEngineWrapper(mock_engine)
# Set the first_catalog for the wrapper
wrapper.first_catalog = "test_catalog"
# First call to get_schemas to set first_schema
wrapper.get_schemas()
# Then call get_tables
tables = wrapper.get_tables()
self.assertEqual(tables, ["table1", "table2"])
mock_inspector.get_table_names.assert_called_once_with("test_schema")
self.assertEqual(tables, mock_result)
mock_connection.execute.assert_called_once()
def test_databricks_engine_wrapper_get_tables_without_cached_schema(self):
"""Test get_tables without cached schema"""
mock_engine = Mock()
mock_connection = Mock()
mock_context_manager = MagicMock()
mock_context_manager.__enter__.return_value = mock_connection
mock_context_manager.__exit__.return_value = None
mock_engine.connect.return_value = mock_context_manager
# Mock the connection execute method to return a result
mock_result = Mock()
mock_result.fetchall.return_value = [("table1",), ("table2",)]
mock_connection.execute.return_value = mock_result
mock_inspector = Mock()
mock_inspector.get_schema_names.return_value = [
"test_schema",
"information_schema",
]
mock_inspector.get_table_names.return_value = ["table1", "table2"]
with patch(
"metadata.ingestion.source.database.databricks.connection.inspect"
@ -551,13 +572,15 @@ class DatabricksConnectionTest(TestCase):
mock_inspect.return_value = mock_inspector
wrapper = self.DatabricksEngineWrapper(mock_engine)
# Set the first_catalog for the wrapper
wrapper.first_catalog = "test_catalog"
# Call get_tables directly without calling get_schemas first
tables = wrapper.get_tables()
self.assertEqual(tables, ["table1", "table2"])
self.assertEqual(tables, mock_result)
# Should have called get_schemas internally
mock_inspector.get_schema_names.assert_called_once()
mock_inspector.get_table_names.assert_called_once_with("test_schema")
mock_connection.execute.assert_called_once()
def test_databricks_engine_wrapper_get_tables_no_schemas(self):
"""Test get_tables when no schemas are available"""
@ -579,12 +602,22 @@ class DatabricksConnectionTest(TestCase):
def test_databricks_engine_wrapper_get_views_with_cached_schema(self):
"""Test get_views with cached schema"""
mock_engine = Mock()
mock_connection = Mock()
mock_context_manager = MagicMock()
mock_context_manager.__enter__.return_value = mock_connection
mock_context_manager.__exit__.return_value = None
mock_engine.connect.return_value = mock_context_manager
# Mock the connection execute method to return a result
mock_result = Mock()
mock_result.fetchall.return_value = [("view1",), ("view2",)]
mock_connection.execute.return_value = mock_result
mock_inspector = Mock()
mock_inspector.get_schema_names.return_value = [
"test_schema",
"information_schema",
]
mock_inspector.get_view_names.return_value = ["view1", "view2"]
with patch(
"metadata.ingestion.source.database.databricks.connection.inspect"
@ -592,23 +625,35 @@ class DatabricksConnectionTest(TestCase):
mock_inspect.return_value = mock_inspector
wrapper = self.DatabricksEngineWrapper(mock_engine)
# Set the first_catalog for the wrapper
wrapper.first_catalog = "test_catalog"
# First call to get_schemas to set first_schema
wrapper.get_schemas()
# Then call get_views
views = wrapper.get_views()
self.assertEqual(views, ["view1", "view2"])
mock_inspector.get_view_names.assert_called_once_with("test_schema")
self.assertEqual(views, mock_result)
mock_connection.execute.assert_called_once()
def test_databricks_engine_wrapper_get_views_without_cached_schema(self):
"""Test get_views without cached schema"""
mock_engine = Mock()
mock_connection = Mock()
mock_context_manager = MagicMock()
mock_context_manager.__enter__.return_value = mock_connection
mock_context_manager.__exit__.return_value = None
mock_engine.connect.return_value = mock_context_manager
# Mock the connection execute method to return a result
mock_result = Mock()
mock_result.fetchall.return_value = [("view1",), ("view2",)]
mock_connection.execute.return_value = mock_result
mock_inspector = Mock()
mock_inspector.get_schema_names.return_value = [
"test_schema",
"information_schema",
]
mock_inspector.get_view_names.return_value = ["view1", "view2"]
with patch(
"metadata.ingestion.source.database.databricks.connection.inspect"
@ -616,13 +661,15 @@ class DatabricksConnectionTest(TestCase):
mock_inspect.return_value = mock_inspector
wrapper = self.DatabricksEngineWrapper(mock_engine)
# Set the first_catalog for the wrapper
wrapper.first_catalog = "test_catalog"
# Call get_views directly without calling get_schemas first
views = wrapper.get_views()
self.assertEqual(views, ["view1", "view2"])
self.assertEqual(views, mock_result)
# Should have called get_schemas internally
mock_inspector.get_schema_names.assert_called_once()
mock_inspector.get_view_names.assert_called_once_with("test_schema")
mock_connection.execute.assert_called_once()
def test_databricks_engine_wrapper_get_views_no_schemas(self):
"""Test get_views when no schemas are available"""
@ -669,6 +716,7 @@ class DatabricksConnectionTest(TestCase):
# get_schema_names should only be called once due to caching
mock_inspector.get_schema_names.assert_called_once()
# pylint: disable=too-many-locals
@patch(
"metadata.ingestion.source.database.databricks.connection.test_connection_steps"
)