mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-09-29 02:45:25 +00:00
MINOR: Improve Databricks Profiler & Test Connection (#22732)
This commit is contained in:
parent
55c82ec8ca
commit
00b6da5b84
@ -56,11 +56,17 @@ class DatabricksEngineWrapper:
|
||||
self.inspector = inspect(engine)
|
||||
self.schemas = None
|
||||
self.first_schema = None
|
||||
self.first_catalog = None
|
||||
|
||||
def get_schemas(self):
|
||||
def get_schemas(self, schema_name: Optional[str] = None):
|
||||
"""Get schemas and cache them"""
|
||||
if schema_name is not None:
|
||||
with self.engine.connect() as connection:
|
||||
connection.execute(f"USE CATALOG `{self.first_catalog}`")
|
||||
self.first_schema = schema_name
|
||||
return [schema_name]
|
||||
if self.schemas is None:
|
||||
self.schemas = self.inspector.get_schema_names()
|
||||
self.schemas = self.inspector.get_schema_names(database=self.first_catalog)
|
||||
if self.schemas:
|
||||
# Find the first schema that's not a system schema
|
||||
for schema in self.schemas:
|
||||
@ -81,7 +87,11 @@ class DatabricksEngineWrapper:
|
||||
if self.first_schema is None:
|
||||
self.get_schemas() # This will set first_schema
|
||||
if self.first_schema:
|
||||
return self.inspector.get_table_names(self.first_schema)
|
||||
with self.engine.connect() as connection:
|
||||
tables = connection.execute(
|
||||
f"SHOW TABLES IN `{self.first_catalog}`.`{self.first_schema}`"
|
||||
)
|
||||
return tables
|
||||
return []
|
||||
|
||||
def get_views(self):
|
||||
@ -89,9 +99,27 @@ class DatabricksEngineWrapper:
|
||||
if self.first_schema is None:
|
||||
self.get_schemas() # This will set first_schema
|
||||
if self.first_schema:
|
||||
return self.inspector.get_view_names(self.first_schema)
|
||||
with self.engine.connect() as connection:
|
||||
views = connection.execute(
|
||||
f"SHOW VIEWS IN `{self.first_catalog}`.`{self.first_schema}`"
|
||||
)
|
||||
return views
|
||||
return []
|
||||
|
||||
def get_catalogs(self, catalog_name: Optional[str] = None):
|
||||
"""Get catalogs"""
|
||||
catalogs = []
|
||||
if catalog_name is not None:
|
||||
self.first_catalog = catalog_name
|
||||
return [catalog_name]
|
||||
with self.engine.connect() as connection:
|
||||
catalogs = connection.execute(DATABRICKS_GET_CATALOGS).fetchall()
|
||||
for catalog in catalogs:
|
||||
if catalog[0] != "__databricks_internal":
|
||||
self.first_catalog = catalog[0]
|
||||
break
|
||||
return catalogs
|
||||
|
||||
|
||||
def get_connection_url(connection: DatabricksConnection) -> str:
|
||||
url = f"{connection.scheme.value}://token:{connection.token.get_secret_value()}@{connection.hostPort}"
|
||||
@ -144,13 +172,13 @@ def test_connection(
|
||||
|
||||
test_fn = {
|
||||
"CheckAccess": partial(test_connection_engine_step, connection),
|
||||
"GetSchemas": engine_wrapper.get_schemas,
|
||||
"GetSchemas": partial(
|
||||
engine_wrapper.get_schemas, schema_name=service_connection.databaseSchema
|
||||
),
|
||||
"GetTables": engine_wrapper.get_tables,
|
||||
"GetViews": engine_wrapper.get_views,
|
||||
"GetDatabases": partial(
|
||||
test_database_query,
|
||||
engine=connection,
|
||||
statement=DATABRICKS_GET_CATALOGS,
|
||||
engine_wrapper.get_catalogs, catalog_name=service_connection.catalog
|
||||
),
|
||||
"GetQueries": partial(
|
||||
test_database_query,
|
||||
|
@ -111,16 +111,15 @@ def test_connection(
|
||||
|
||||
def get_catalogs(connection: WorkspaceClient, table_obj: DatabricksTable):
|
||||
for catalog in connection.catalogs.list():
|
||||
table_obj.catalog_name = catalog.name
|
||||
break
|
||||
if catalog.name != "__databricks_internal":
|
||||
table_obj.catalog_name = catalog.name
|
||||
return
|
||||
|
||||
def get_schemas(connection: WorkspaceClient, table_obj: DatabricksTable):
|
||||
for catalog in connection.catalogs.list():
|
||||
for schema in connection.schemas.list(catalog_name=catalog.name):
|
||||
if schema.name:
|
||||
table_obj.schema_name = schema.name
|
||||
table_obj.catalog_name = catalog.name
|
||||
return
|
||||
for schema in connection.schemas.list(catalog_name=table_obj.catalog_name):
|
||||
if schema.name:
|
||||
table_obj.schema_name = schema.name
|
||||
return
|
||||
|
||||
def get_tables(connection: WorkspaceClient, table_obj: DatabricksTable):
|
||||
if table_obj.catalog_name and table_obj.schema_name:
|
||||
|
@ -92,7 +92,7 @@ logger = ingestion_logger()
|
||||
UNITY_CATALOG_TAG = "UNITY CATALOG TAG"
|
||||
UNITY_CATALOG_TAG_CLASSIFICATION = "UNITY CATALOG TAG CLASSIFICATION"
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
class UnitycatalogSource(
|
||||
ExternalTableLineageMixin, DatabaseServiceSource, MultiDBSource
|
||||
):
|
||||
@ -306,6 +306,8 @@ class UnitycatalogSource(
|
||||
if table.table_type:
|
||||
if table.table_type.value.lower() == TableType.View.value.lower():
|
||||
table_type: TableType = TableType.View
|
||||
if table.table_type.value.lower() == "materialized_view":
|
||||
table_type: TableType = TableType.MaterializedView
|
||||
elif (
|
||||
table.table_type.value.lower()
|
||||
== TableType.External.value.lower()
|
||||
@ -427,7 +429,6 @@ class UnitycatalogSource(
|
||||
)
|
||||
if referred_table_fqn:
|
||||
for parent_column in column.parent_columns:
|
||||
# pylint: disable=protected-access
|
||||
col_fqn = fqn._build(referred_table_fqn, parent_column, quote=False)
|
||||
if col_fqn:
|
||||
referred_column_fqns.append(FullyQualifiedEntityName(col_fqn))
|
||||
@ -563,7 +564,7 @@ class UnitycatalogSource(
|
||||
yield from get_ometa_tag_and_classification(
|
||||
tag_fqn=FullyQualifiedEntityName(
|
||||
fqn._build(*tag_fqn_builder(tag))
|
||||
), # pylint: disable=protected-access
|
||||
),
|
||||
tags=[tag.tag_value],
|
||||
classification_name=tag.tag_name,
|
||||
tag_description=UNITY_CATALOG_TAG,
|
||||
@ -617,7 +618,7 @@ class UnitycatalogSource(
|
||||
yield from get_ometa_tag_and_classification(
|
||||
tag_fqn=FullyQualifiedEntityName(
|
||||
fqn._build(*tag_fqn_builder(tag))
|
||||
), # pylint: disable=protected-access
|
||||
),
|
||||
tags=[tag.tag_value],
|
||||
classification_name=tag.tag_name,
|
||||
tag_description=UNITY_CATALOG_TAG,
|
||||
|
@ -23,6 +23,7 @@ from metadata.generated.schema.entity.data.table import (
|
||||
ColumnName,
|
||||
DataType,
|
||||
SystemProfile,
|
||||
TableType,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.databaseService import (
|
||||
DatabaseServiceType,
|
||||
@ -51,6 +52,11 @@ class DatabricksProfilerInterface(SQAProfilerInterface):
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> List[SystemProfile]:
|
||||
if self.table_entity.tableType in (TableType.View, TableType.MaterializedView):
|
||||
logger.debug(
|
||||
f"Skipping {metrics.name()} metric for view {runner.table_name}"
|
||||
)
|
||||
return []
|
||||
logger.debug(f"Computing {metrics.name()} metric for {runner.table_name}")
|
||||
self.system_metrics_class = cast(
|
||||
Type[DatabricksSystemMetricsComputer], self.system_metrics_class
|
||||
|
@ -12,9 +12,9 @@
|
||||
"""
|
||||
Test databricks using the topology
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name,import-outside-toplevel
|
||||
from unittest import TestCase
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
from metadata.generated.schema.api.data.createDatabaseSchema import (
|
||||
CreateDatabaseSchemaRequest,
|
||||
@ -36,7 +36,6 @@ from metadata.generated.schema.type.entityReference import EntityReference
|
||||
from metadata.ingestion.ometa.utils import model_str
|
||||
from metadata.ingestion.source.database.databricks.metadata import DatabricksSource
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
mock_databricks_config = {
|
||||
"source": {
|
||||
"type": "databricks",
|
||||
@ -514,12 +513,22 @@ class DatabricksConnectionTest(TestCase):
|
||||
def test_databricks_engine_wrapper_get_tables_with_cached_schema(self):
|
||||
"""Test get_tables with cached schema"""
|
||||
mock_engine = Mock()
|
||||
mock_connection = Mock()
|
||||
mock_context_manager = MagicMock()
|
||||
mock_context_manager.__enter__.return_value = mock_connection
|
||||
mock_context_manager.__exit__.return_value = None
|
||||
mock_engine.connect.return_value = mock_context_manager
|
||||
|
||||
# Mock the connection execute method to return a result
|
||||
mock_result = Mock()
|
||||
mock_result.fetchall.return_value = [("table1",), ("table2",)]
|
||||
mock_connection.execute.return_value = mock_result
|
||||
|
||||
mock_inspector = Mock()
|
||||
mock_inspector.get_schema_names.return_value = [
|
||||
"test_schema",
|
||||
"information_schema",
|
||||
]
|
||||
mock_inspector.get_table_names.return_value = ["table1", "table2"]
|
||||
|
||||
with patch(
|
||||
"metadata.ingestion.source.database.databricks.connection.inspect"
|
||||
@ -527,23 +536,35 @@ class DatabricksConnectionTest(TestCase):
|
||||
mock_inspect.return_value = mock_inspector
|
||||
|
||||
wrapper = self.DatabricksEngineWrapper(mock_engine)
|
||||
# Set the first_catalog for the wrapper
|
||||
wrapper.first_catalog = "test_catalog"
|
||||
# First call to get_schemas to set first_schema
|
||||
wrapper.get_schemas()
|
||||
# Then call get_tables
|
||||
tables = wrapper.get_tables()
|
||||
|
||||
self.assertEqual(tables, ["table1", "table2"])
|
||||
mock_inspector.get_table_names.assert_called_once_with("test_schema")
|
||||
self.assertEqual(tables, mock_result)
|
||||
mock_connection.execute.assert_called_once()
|
||||
|
||||
def test_databricks_engine_wrapper_get_tables_without_cached_schema(self):
|
||||
"""Test get_tables without cached schema"""
|
||||
mock_engine = Mock()
|
||||
mock_connection = Mock()
|
||||
mock_context_manager = MagicMock()
|
||||
mock_context_manager.__enter__.return_value = mock_connection
|
||||
mock_context_manager.__exit__.return_value = None
|
||||
mock_engine.connect.return_value = mock_context_manager
|
||||
|
||||
# Mock the connection execute method to return a result
|
||||
mock_result = Mock()
|
||||
mock_result.fetchall.return_value = [("table1",), ("table2",)]
|
||||
mock_connection.execute.return_value = mock_result
|
||||
|
||||
mock_inspector = Mock()
|
||||
mock_inspector.get_schema_names.return_value = [
|
||||
"test_schema",
|
||||
"information_schema",
|
||||
]
|
||||
mock_inspector.get_table_names.return_value = ["table1", "table2"]
|
||||
|
||||
with patch(
|
||||
"metadata.ingestion.source.database.databricks.connection.inspect"
|
||||
@ -551,13 +572,15 @@ class DatabricksConnectionTest(TestCase):
|
||||
mock_inspect.return_value = mock_inspector
|
||||
|
||||
wrapper = self.DatabricksEngineWrapper(mock_engine)
|
||||
# Set the first_catalog for the wrapper
|
||||
wrapper.first_catalog = "test_catalog"
|
||||
# Call get_tables directly without calling get_schemas first
|
||||
tables = wrapper.get_tables()
|
||||
|
||||
self.assertEqual(tables, ["table1", "table2"])
|
||||
self.assertEqual(tables, mock_result)
|
||||
# Should have called get_schemas internally
|
||||
mock_inspector.get_schema_names.assert_called_once()
|
||||
mock_inspector.get_table_names.assert_called_once_with("test_schema")
|
||||
mock_connection.execute.assert_called_once()
|
||||
|
||||
def test_databricks_engine_wrapper_get_tables_no_schemas(self):
|
||||
"""Test get_tables when no schemas are available"""
|
||||
@ -579,12 +602,22 @@ class DatabricksConnectionTest(TestCase):
|
||||
def test_databricks_engine_wrapper_get_views_with_cached_schema(self):
|
||||
"""Test get_views with cached schema"""
|
||||
mock_engine = Mock()
|
||||
mock_connection = Mock()
|
||||
mock_context_manager = MagicMock()
|
||||
mock_context_manager.__enter__.return_value = mock_connection
|
||||
mock_context_manager.__exit__.return_value = None
|
||||
mock_engine.connect.return_value = mock_context_manager
|
||||
|
||||
# Mock the connection execute method to return a result
|
||||
mock_result = Mock()
|
||||
mock_result.fetchall.return_value = [("view1",), ("view2",)]
|
||||
mock_connection.execute.return_value = mock_result
|
||||
|
||||
mock_inspector = Mock()
|
||||
mock_inspector.get_schema_names.return_value = [
|
||||
"test_schema",
|
||||
"information_schema",
|
||||
]
|
||||
mock_inspector.get_view_names.return_value = ["view1", "view2"]
|
||||
|
||||
with patch(
|
||||
"metadata.ingestion.source.database.databricks.connection.inspect"
|
||||
@ -592,23 +625,35 @@ class DatabricksConnectionTest(TestCase):
|
||||
mock_inspect.return_value = mock_inspector
|
||||
|
||||
wrapper = self.DatabricksEngineWrapper(mock_engine)
|
||||
# Set the first_catalog for the wrapper
|
||||
wrapper.first_catalog = "test_catalog"
|
||||
# First call to get_schemas to set first_schema
|
||||
wrapper.get_schemas()
|
||||
# Then call get_views
|
||||
views = wrapper.get_views()
|
||||
|
||||
self.assertEqual(views, ["view1", "view2"])
|
||||
mock_inspector.get_view_names.assert_called_once_with("test_schema")
|
||||
self.assertEqual(views, mock_result)
|
||||
mock_connection.execute.assert_called_once()
|
||||
|
||||
def test_databricks_engine_wrapper_get_views_without_cached_schema(self):
|
||||
"""Test get_views without cached schema"""
|
||||
mock_engine = Mock()
|
||||
mock_connection = Mock()
|
||||
mock_context_manager = MagicMock()
|
||||
mock_context_manager.__enter__.return_value = mock_connection
|
||||
mock_context_manager.__exit__.return_value = None
|
||||
mock_engine.connect.return_value = mock_context_manager
|
||||
|
||||
# Mock the connection execute method to return a result
|
||||
mock_result = Mock()
|
||||
mock_result.fetchall.return_value = [("view1",), ("view2",)]
|
||||
mock_connection.execute.return_value = mock_result
|
||||
|
||||
mock_inspector = Mock()
|
||||
mock_inspector.get_schema_names.return_value = [
|
||||
"test_schema",
|
||||
"information_schema",
|
||||
]
|
||||
mock_inspector.get_view_names.return_value = ["view1", "view2"]
|
||||
|
||||
with patch(
|
||||
"metadata.ingestion.source.database.databricks.connection.inspect"
|
||||
@ -616,13 +661,15 @@ class DatabricksConnectionTest(TestCase):
|
||||
mock_inspect.return_value = mock_inspector
|
||||
|
||||
wrapper = self.DatabricksEngineWrapper(mock_engine)
|
||||
# Set the first_catalog for the wrapper
|
||||
wrapper.first_catalog = "test_catalog"
|
||||
# Call get_views directly without calling get_schemas first
|
||||
views = wrapper.get_views()
|
||||
|
||||
self.assertEqual(views, ["view1", "view2"])
|
||||
self.assertEqual(views, mock_result)
|
||||
# Should have called get_schemas internally
|
||||
mock_inspector.get_schema_names.assert_called_once()
|
||||
mock_inspector.get_view_names.assert_called_once_with("test_schema")
|
||||
mock_connection.execute.assert_called_once()
|
||||
|
||||
def test_databricks_engine_wrapper_get_views_no_schemas(self):
|
||||
"""Test get_views when no schemas are available"""
|
||||
@ -669,6 +716,7 @@ class DatabricksConnectionTest(TestCase):
|
||||
# get_schema_names should only be called once due to caching
|
||||
mock_inspector.get_schema_names.assert_called_once()
|
||||
|
||||
# pylint: disable=too-many-locals
|
||||
@patch(
|
||||
"metadata.ingestion.source.database.databricks.connection.test_connection_steps"
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user