diff --git a/ingestion/src/metadata/profiler/interface/sqlalchemy/databricks/profiler_interface.py b/ingestion/src/metadata/profiler/interface/sqlalchemy/databricks/profiler_interface.py index eed50f51b1e..46563788240 100644 --- a/ingestion/src/metadata/profiler/interface/sqlalchemy/databricks/profiler_interface.py +++ b/ingestion/src/metadata/profiler/interface/sqlalchemy/databricks/profiler_interface.py @@ -85,10 +85,28 @@ class DatabricksProfilerInterface(SQAProfilerInterface): result += "`.`".join(splitted_result) return result + def visit_table(self, *args, **kwargs): + result = super( # pylint: disable=bad-super-call + HiveCompiler, self + ).visit_table(*args, **kwargs) + # Handle table references with hyphens in database/schema names + # Format: `database`.`schema`.`table` for Unity Catalog/Databricks + if "." in result and not result.startswith("`"): + parts = result.split(".") + quoted_parts = [] + for part in parts: + if "-" in part and not (part.startswith("`") and part.endswith("`")): + quoted_parts.append(f"`{part}`") + else: + quoted_parts.append(part) + result = ".".join(quoted_parts) + return result + def __init__(self, service_connection_config, **kwargs): super().__init__(service_connection_config=service_connection_config, **kwargs) self.set_catalog(self.session) HiveCompiler.visit_column = DatabricksProfilerInterface.visit_column + HiveCompiler.visit_table = DatabricksProfilerInterface.visit_table def _get_struct_columns(self, columns: List[OMColumn], parent: str): """Get struct columns""" diff --git a/ingestion/src/metadata/profiler/interface/sqlalchemy/unity_catalog/profiler_interface.py b/ingestion/src/metadata/profiler/interface/sqlalchemy/unity_catalog/profiler_interface.py index d0b1c51951b..c43cb1cf622 100644 --- a/ingestion/src/metadata/profiler/interface/sqlalchemy/unity_catalog/profiler_interface.py +++ b/ingestion/src/metadata/profiler/interface/sqlalchemy/unity_catalog/profiler_interface.py @@ -14,6 +14,8 @@ Interfaces with database for all database engine supporting sqlalchemy abstraction layer """ +from sqlalchemy import event +from sqlalchemy.orm import scoped_session, sessionmaker from metadata.ingestion.source.database.databricks.connection import ( get_connection as databricks_get_connection, @@ -26,5 +28,17 @@ from metadata.profiler.interface.sqlalchemy.databricks.profiler_interface import class UnityCatalogProfilerInterface(DatabricksProfilerInterface): def create_session(self): self.connection = databricks_get_connection(self.service_connection_config) - super().create_session() - self.set_catalog(self.session) + + # Create custom session factory with after_begin event to set catalog + session_maker = sessionmaker(bind=self.connection) + + @event.listens_for(session_maker, "after_begin") + def set_catalog(session, transaction, connection): + # Safely quote the catalog name to prevent SQL injection + quoted_catalog = connection.dialect.identifier_preparer.quote( + self.service_connection_config.catalog + ) + connection.execute(f"USE CATALOG {quoted_catalog};") + + self.session_factory = scoped_session(session_maker) + self.session = self.session_factory() diff --git a/ingestion/src/metadata/profiler/metrics/system/databricks/system.py b/ingestion/src/metadata/profiler/metrics/system/databricks/system.py index 50149ca0317..d3c8250bf14 100644 --- a/ingestion/src/metadata/profiler/metrics/system/databricks/system.py +++ b/ingestion/src/metadata/profiler/metrics/system/databricks/system.py @@ -24,7 +24,7 @@ SYSTEM_QUERY = textwrap.dedent( '{database}' AS database, '{schema}' AS schema, '{table}' AS table - FROM (DESCRIBE HISTORY {database}.{schema}.{table}) + FROM (DESCRIBE HISTORY `{database}`.`{schema}`.`{table}`) WHERE operation IN ({operations}) AND timestamp > DATEADD(day, -1, CURRENT_TIMESTAMP()) """ ) diff --git a/ingestion/src/metadata/sampler/sqlalchemy/databricks/sampler.py b/ingestion/src/metadata/sampler/sqlalchemy/databricks/sampler.py index c05223acb7f..362eba2ccba 100644 --- a/ingestion/src/metadata/sampler/sqlalchemy/databricks/sampler.py +++ b/ingestion/src/metadata/sampler/sqlalchemy/databricks/sampler.py @@ -11,7 +11,8 @@ """ Helper module to handle data sampling for the profiler """ -from sqlalchemy import Column, text +from sqlalchemy import Column, event, text +from sqlalchemy.orm import scoped_session, sessionmaker from metadata.ingestion.source.database.databricks.connection import ( get_connection as databricks_get_connection, @@ -25,6 +26,17 @@ class DatabricksSamplerInterface(SQASampler): """Initialize with a single Databricks connection""" super().__init__(*args, **kwargs) self.connection = databricks_get_connection(self.service_connection_config) + session_maker = sessionmaker(bind=self.connection) + + @event.listens_for(session_maker, "after_begin") + def set_catalog(session, transaction, connection): + # Safely quote the catalog name to prevent SQL injection + quoted_catalog = connection.dialect.identifier_preparer.quote( + self.service_connection_config.catalog + ) + connection.execute(f"USE CATALOG {quoted_catalog};") + + self.session_factory = scoped_session(session_maker) def get_client(self): """client is the session for SQA""" diff --git a/ingestion/src/metadata/sampler/sqlalchemy/unitycatalog/sampler.py b/ingestion/src/metadata/sampler/sqlalchemy/unitycatalog/sampler.py index 0ccb44c4ae4..eca11842280 100644 --- a/ingestion/src/metadata/sampler/sqlalchemy/unitycatalog/sampler.py +++ b/ingestion/src/metadata/sampler/sqlalchemy/unitycatalog/sampler.py @@ -14,47 +14,14 @@ Interfaces with database for all database engine supporting sqlalchemy abstraction layer """ -from sqlalchemy import Column, text -from metadata.ingestion.source.database.databricks.connection import ( - get_connection as databricks_get_connection, -) -from metadata.profiler.orm.types.custom_array import CustomArray -from metadata.sampler.sqlalchemy.sampler import SQASampler +from metadata.sampler.sqlalchemy.databricks.sampler import DatabricksSamplerInterface -class UnityCatalogSamplerInterface(SQASampler): +class UnityCatalogSamplerInterface(DatabricksSamplerInterface): + """ + Unity Catalog Sampler Interface + """ + def __init__(self, *args, **kwargs): - """Initialize with a single Databricks connection""" super().__init__(*args, **kwargs) - self.connection = databricks_get_connection(self.service_connection_config) - - def get_client(self): - """client is the session for SQA""" - client = super().get_client() - self.set_catalog(client) - return client - - def _handle_array_column(self, column: Column) -> bool: - """Check if a column is an array type""" - return isinstance(column.type, CustomArray) - - def _get_slice_expression(self, column: Column): - """Generate SQL expression to slice array elements at query level - - Args: - column_name: Name of the column - max_elements: Maximum number of elements to extract - - Returns: - SQL expression string for array slicing - """ - max_elements = self._get_max_array_elements() - return text( - f""" - CASE - WHEN `{column.name}` IS NULL THEN NULL - ELSE slice(`{column.name}`, 1, {max_elements}) - END AS `{column._label}` - """ - ) diff --git a/ingestion/src/metadata/utils/logger.py b/ingestion/src/metadata/utils/logger.py index 8871700a82a..ad038060a88 100644 --- a/ingestion/src/metadata/utils/logger.py +++ b/ingestion/src/metadata/utils/logger.py @@ -212,6 +212,8 @@ def get_log_name(record: Entity) -> Optional[str]: try: if hasattr(record, "name"): return f"{type(record).__name__} [{getattr(record, 'name').root}]" + if hasattr(record, "table") and hasattr(record.table, "name"): + return f"{type(record).__name__} [{record.table.name.root}]" return f"{type(record).__name__} [{record.entity.name.root}]" except Exception: return str(record) diff --git a/ingestion/tests/unit/sampler/sqlalchemy/test_unitycatalog_sampler.py b/ingestion/tests/unit/sampler/sqlalchemy/test_unitycatalog_sampler.py index 68cd669ef52..209f9c56f28 100644 --- a/ingestion/tests/unit/sampler/sqlalchemy/test_unitycatalog_sampler.py +++ b/ingestion/tests/unit/sampler/sqlalchemy/test_unitycatalog_sampler.py @@ -73,7 +73,7 @@ class UnityCatalogSamplerTest(TestCase): ) @patch( - "metadata.sampler.sqlalchemy.unitycatalog.sampler.SQASampler.build_table_orm" + "metadata.sampler.sqlalchemy.unitycatalog.sampler.UnityCatalogSamplerInterface.build_table_orm" ) def test_handle_array_column(self, mock_build_table_orm): """Test array column detection"""