Fixes 16562: Modify HiveCompiler to compile column names properly (#16954)

* Modify HiveCompiler to compile column names properly
This commit is contained in:
Ayush Shah 2024-07-09 12:59:23 +05:30 committed by GitHub
parent 49876b9cd6
commit 421d191bae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 93 additions and 11 deletions

View File

@ -188,6 +188,8 @@ plugins: Dict[str, Set[str]] = {
"ndg-httpsclient~=0.5.1",
"pyOpenSSL~=24.1.0",
"pyasn1~=0.6.0",
# databricks has a dependency on pyhive for metadata as well as profiler
VERSIONS["pyhive"],
},
"datalake-azure": {
VERSIONS["azure-storage-blob"],

View File

@ -17,6 +17,7 @@ from typing import List
from pyhive.sqlalchemy_hive import HiveCompiler
from sqlalchemy import Column, inspect
from sqlalchemy.sql import column
from metadata.generated.schema.entity.data.table import Column as OMColumn
from metadata.generated.schema.entity.data.table import ColumnName, DataType, TableData
@ -36,15 +37,17 @@ class DatabricksProfilerInterface(SQAProfilerInterface):
result = super( # pylint: disable=bad-super-call
HiveCompiler, self
).visit_column(*args, **kwargs)
dot_count = result.count(".")
# Here the databricks uses HiveCompiler.
# the `result` here would be `db.schema.table` or `db.schema.table.column`
# for struct it will be `db.schema.table.column.nestedchild.nestedchild` etc
# the logic is to add the backticks to nested children.
if dot_count > 2:
splitted_result = result.split(".", 2)[-1].split(".")
result = ".".join(result.split(".", 2)[:-1])
result += "." + "`.`".join(splitted_result)
dot_count = result.count(".")
if dot_count > 1 and "." in result.split("`.`")[-1]:
splitted_result = result.split("`.")[-1].split(".")
result = "`.".join(result.split("`.")[:-1])
if result:
result += "`."
result += "`.`".join(splitted_result)
return result
def __init__(self, service_connection_config, **kwargs):
@ -56,14 +59,17 @@ class DatabricksProfilerInterface(SQAProfilerInterface):
"""Get struct columns"""
columns_list = []
for idx, col in enumerate(columns):
for col in columns:
if col.dataType != DataType.STRUCT:
col.name = ColumnName(f"{parent}.{col.name.root}")
col = build_orm_col(idx, col, DatabaseServiceType.Databricks)
col = build_orm_col(
idx=1, col=col, table_service_type=DatabaseServiceType.Databricks
)
col._set_parent( # pylint: disable=protected-access
self.table.__table__
)
columns_list.append(col)
columns_list.append(column(col.label(col.name.replace(".", "_"))))
else:
col = self._get_struct_columns(
col.children, f"{parent}.{col.name.root}"
@ -74,10 +80,10 @@ class DatabricksProfilerInterface(SQAProfilerInterface):
def get_columns(self) -> Column:
"""Get columns from table"""
columns = []
for idx, column in enumerate(self.table_entity.columns):
if column.dataType == DataType.STRUCT:
for idx, column_obj in enumerate(self.table_entity.columns):
if column_obj.dataType == DataType.STRUCT:
columns.extend(
self._get_struct_columns(column.children, column.name.root)
self._get_struct_columns(column_obj.children, column_obj.name.root)
)
else:
col = build_orm_col(idx, column, DatabaseServiceType.Databricks)

View File

@ -0,0 +1,74 @@
import unittest
from unittest.mock import MagicMock, patch
from pyhive.sqlalchemy_hive import HiveCompiler
from metadata.profiler.interface.sqlalchemy.databricks.profiler_interface import (
DatabricksProfilerInterface,
)
class FakeHiveCompiler(
DatabricksProfilerInterface,
HiveCompiler,
):
def __init__(self, service_connection_config):
self.service_connection_config = service_connection_config
class TestDatabricksProfilerInterface(unittest.TestCase):
@patch(
"metadata.profiler.interface.sqlalchemy.databricks.profiler_interface.DatabricksProfilerInterface.set_catalog",
return_value=None,
)
@patch(
"metadata.profiler.interface.sqlalchemy.databricks.profiler_interface.DatabricksProfilerInterface.__init__",
return_value=None,
)
@patch("pyhive.sqlalchemy_hive.HiveCompiler.visit_column")
def setUp(
self,
mock_visit_column,
mock_init,
mock_set_catalog,
) -> None:
self.profiler = FakeHiveCompiler(service_connection_config={})
@patch("sqlalchemy.sql.compiler.SQLCompiler.visit_column")
def test_visit_column_no_nesting(self, mock_visit_column_super):
# Mock the response of the super class method
mock_visit_column_super.return_value = "`db`.`schema`.`table`"
assert self.profiler.visit_column(MagicMock()) == "`db`.`schema`.`table`"
mock_visit_column_super.return_value = "`db`"
assert self.profiler.visit_column(MagicMock()) == "`db`"
mock_visit_column_super.return_value = "`schema`"
assert self.profiler.visit_column(MagicMock()) == "`schema`"
mock_visit_column_super.return_value = "`table`"
assert self.profiler.visit_column(MagicMock()) == "`table`"
mock_visit_column_super.return_value = "table"
assert self.profiler.visit_column(MagicMock()) == "table"
@patch("sqlalchemy.sql.compiler.SQLCompiler.visit_column")
def test_visit_column_nesting(self, mock_visit_column_super):
# Mock the response of the super class method
mock_visit_column_super.return_value = "`db`.`schema`.`table`.`col.u.m.n`"
assert (
self.profiler.visit_column(MagicMock())
== "`db`.`schema`.`table`.`col`.`u`.`m`.`n`"
)
mock_visit_column_super.return_value = "`db`.`schema`.`table`.`col.1`"
assert (
self.profiler.visit_column(MagicMock()) == "`db`.`schema`.`table`.`col`.`1`"
)
mock_visit_column_super.return_value = "`table`.`1.2`"
assert self.profiler.visit_column(MagicMock()) == "`table`.`1`.`2`"
# single dot in column name should not be split
mock_visit_column_super.return_value = "`col.1`"
assert self.profiler.visit_column(MagicMock()) == "`col.1`"