diff --git a/ingestion/setup.py b/ingestion/setup.py index 597ad688cf3..a7b7df1d18c 100644 --- a/ingestion/setup.py +++ b/ingestion/setup.py @@ -188,6 +188,8 @@ plugins: Dict[str, Set[str]] = { "ndg-httpsclient~=0.5.1", "pyOpenSSL~=24.1.0", "pyasn1~=0.6.0", + # databricks has a dependency on pyhive for metadata as well as profiler + VERSIONS["pyhive"], }, "datalake-azure": { VERSIONS["azure-storage-blob"], diff --git a/ingestion/src/metadata/profiler/interface/sqlalchemy/databricks/profiler_interface.py b/ingestion/src/metadata/profiler/interface/sqlalchemy/databricks/profiler_interface.py index b27589ca904..90f2c98879a 100644 --- a/ingestion/src/metadata/profiler/interface/sqlalchemy/databricks/profiler_interface.py +++ b/ingestion/src/metadata/profiler/interface/sqlalchemy/databricks/profiler_interface.py @@ -17,6 +17,7 @@ from typing import List from pyhive.sqlalchemy_hive import HiveCompiler from sqlalchemy import Column, inspect +from sqlalchemy.sql import column from metadata.generated.schema.entity.data.table import Column as OMColumn from metadata.generated.schema.entity.data.table import ColumnName, DataType, TableData @@ -36,15 +37,17 @@ class DatabricksProfilerInterface(SQAProfilerInterface): result = super( # pylint: disable=bad-super-call HiveCompiler, self ).visit_column(*args, **kwargs) - dot_count = result.count(".") # Here the databricks uses HiveCompiler. # the `result` here would be `db.schema.table` or `db.schema.table.column` # for struct it will be `db.schema.table.column.nestedchild.nestedchild` etc # the logic is to add the backticks to nested children. - if dot_count > 2: - splitted_result = result.split(".", 2)[-1].split(".") - result = ".".join(result.split(".", 2)[:-1]) - result += "." + "`.`".join(splitted_result) + dot_count = result.count(".") + if dot_count > 1 and "." in result.split("`.`")[-1]: + splitted_result = result.split("`.")[-1].split(".") + result = "`.".join(result.split("`.")[:-1]) + if result: + result += "`." + result += "`.`".join(splitted_result) return result def __init__(self, service_connection_config, **kwargs): @@ -56,14 +59,17 @@ class DatabricksProfilerInterface(SQAProfilerInterface): """Get struct columns""" columns_list = [] - for idx, col in enumerate(columns): + for col in columns: if col.dataType != DataType.STRUCT: col.name = ColumnName(f"{parent}.{col.name.root}") - col = build_orm_col(idx, col, DatabaseServiceType.Databricks) + col = build_orm_col( + idx=1, col=col, table_service_type=DatabaseServiceType.Databricks + ) col._set_parent( # pylint: disable=protected-access self.table.__table__ ) - columns_list.append(col) + + columns_list.append(column(col.label(col.name.replace(".", "_")))) else: col = self._get_struct_columns( col.children, f"{parent}.{col.name.root}" @@ -74,10 +80,10 @@ class DatabricksProfilerInterface(SQAProfilerInterface): def get_columns(self) -> Column: """Get columns from table""" columns = [] - for idx, column in enumerate(self.table_entity.columns): - if column.dataType == DataType.STRUCT: + for idx, column_obj in enumerate(self.table_entity.columns): + if column_obj.dataType == DataType.STRUCT: columns.extend( - self._get_struct_columns(column.children, column.name.root) + self._get_struct_columns(column_obj.children, column_obj.name.root) ) else: col = build_orm_col(idx, column, DatabaseServiceType.Databricks) diff --git a/ingestion/tests/unit/profiler/sqlalchemy/databricks/test_visit_column.py b/ingestion/tests/unit/profiler/sqlalchemy/databricks/test_visit_column.py new file mode 100644 index 00000000000..776427a0bd4 --- /dev/null +++ b/ingestion/tests/unit/profiler/sqlalchemy/databricks/test_visit_column.py @@ -0,0 +1,74 @@ +import unittest +from unittest.mock import MagicMock, patch + +from pyhive.sqlalchemy_hive import HiveCompiler + +from metadata.profiler.interface.sqlalchemy.databricks.profiler_interface import ( + DatabricksProfilerInterface, +) + + +class FakeHiveCompiler( + DatabricksProfilerInterface, + HiveCompiler, +): + def __init__(self, service_connection_config): + self.service_connection_config = service_connection_config + + +class TestDatabricksProfilerInterface(unittest.TestCase): + @patch( + "metadata.profiler.interface.sqlalchemy.databricks.profiler_interface.DatabricksProfilerInterface.set_catalog", + return_value=None, + ) + @patch( + "metadata.profiler.interface.sqlalchemy.databricks.profiler_interface.DatabricksProfilerInterface.__init__", + return_value=None, + ) + @patch("pyhive.sqlalchemy_hive.HiveCompiler.visit_column") + def setUp( + self, + mock_visit_column, + mock_init, + mock_set_catalog, + ) -> None: + self.profiler = FakeHiveCompiler(service_connection_config={}) + + @patch("sqlalchemy.sql.compiler.SQLCompiler.visit_column") + def test_visit_column_no_nesting(self, mock_visit_column_super): + # Mock the response of the super class method + mock_visit_column_super.return_value = "`db`.`schema`.`table`" + assert self.profiler.visit_column(MagicMock()) == "`db`.`schema`.`table`" + + mock_visit_column_super.return_value = "`db`" + assert self.profiler.visit_column(MagicMock()) == "`db`" + + mock_visit_column_super.return_value = "`schema`" + assert self.profiler.visit_column(MagicMock()) == "`schema`" + + mock_visit_column_super.return_value = "`table`" + assert self.profiler.visit_column(MagicMock()) == "`table`" + + mock_visit_column_super.return_value = "table" + assert self.profiler.visit_column(MagicMock()) == "table" + + @patch("sqlalchemy.sql.compiler.SQLCompiler.visit_column") + def test_visit_column_nesting(self, mock_visit_column_super): + # Mock the response of the super class method + mock_visit_column_super.return_value = "`db`.`schema`.`table`.`col.u.m.n`" + assert ( + self.profiler.visit_column(MagicMock()) + == "`db`.`schema`.`table`.`col`.`u`.`m`.`n`" + ) + + mock_visit_column_super.return_value = "`db`.`schema`.`table`.`col.1`" + assert ( + self.profiler.visit_column(MagicMock()) == "`db`.`schema`.`table`.`col`.`1`" + ) + + mock_visit_column_super.return_value = "`table`.`1.2`" + assert self.profiler.visit_column(MagicMock()) == "`table`.`1`.`2`" + + # single dot in column name should not be split + mock_visit_column_super.return_value = "`col.1`" + assert self.profiler.visit_column(MagicMock()) == "`col.1`"