mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2026-01-06 04:26:57 +00:00
Fixes 16562: Modify HiveCompiler to compile column names properly (#16954)
* Modify HiveCompiler to compile column names properly
This commit is contained in:
parent
49876b9cd6
commit
421d191bae
@ -188,6 +188,8 @@ plugins: Dict[str, Set[str]] = {
|
||||
"ndg-httpsclient~=0.5.1",
|
||||
"pyOpenSSL~=24.1.0",
|
||||
"pyasn1~=0.6.0",
|
||||
# databricks has a dependency on pyhive for metadata as well as profiler
|
||||
VERSIONS["pyhive"],
|
||||
},
|
||||
"datalake-azure": {
|
||||
VERSIONS["azure-storage-blob"],
|
||||
|
||||
@ -17,6 +17,7 @@ from typing import List
|
||||
|
||||
from pyhive.sqlalchemy_hive import HiveCompiler
|
||||
from sqlalchemy import Column, inspect
|
||||
from sqlalchemy.sql import column
|
||||
|
||||
from metadata.generated.schema.entity.data.table import Column as OMColumn
|
||||
from metadata.generated.schema.entity.data.table import ColumnName, DataType, TableData
|
||||
@ -36,15 +37,17 @@ class DatabricksProfilerInterface(SQAProfilerInterface):
|
||||
result = super( # pylint: disable=bad-super-call
|
||||
HiveCompiler, self
|
||||
).visit_column(*args, **kwargs)
|
||||
dot_count = result.count(".")
|
||||
# Here the databricks uses HiveCompiler.
|
||||
# the `result` here would be `db.schema.table` or `db.schema.table.column`
|
||||
# for struct it will be `db.schema.table.column.nestedchild.nestedchild` etc
|
||||
# the logic is to add the backticks to nested children.
|
||||
if dot_count > 2:
|
||||
splitted_result = result.split(".", 2)[-1].split(".")
|
||||
result = ".".join(result.split(".", 2)[:-1])
|
||||
result += "." + "`.`".join(splitted_result)
|
||||
dot_count = result.count(".")
|
||||
if dot_count > 1 and "." in result.split("`.`")[-1]:
|
||||
splitted_result = result.split("`.")[-1].split(".")
|
||||
result = "`.".join(result.split("`.")[:-1])
|
||||
if result:
|
||||
result += "`."
|
||||
result += "`.`".join(splitted_result)
|
||||
return result
|
||||
|
||||
def __init__(self, service_connection_config, **kwargs):
|
||||
@ -56,14 +59,17 @@ class DatabricksProfilerInterface(SQAProfilerInterface):
|
||||
"""Get struct columns"""
|
||||
|
||||
columns_list = []
|
||||
for idx, col in enumerate(columns):
|
||||
for col in columns:
|
||||
if col.dataType != DataType.STRUCT:
|
||||
col.name = ColumnName(f"{parent}.{col.name.root}")
|
||||
col = build_orm_col(idx, col, DatabaseServiceType.Databricks)
|
||||
col = build_orm_col(
|
||||
idx=1, col=col, table_service_type=DatabaseServiceType.Databricks
|
||||
)
|
||||
col._set_parent( # pylint: disable=protected-access
|
||||
self.table.__table__
|
||||
)
|
||||
columns_list.append(col)
|
||||
|
||||
columns_list.append(column(col.label(col.name.replace(".", "_"))))
|
||||
else:
|
||||
col = self._get_struct_columns(
|
||||
col.children, f"{parent}.{col.name.root}"
|
||||
@ -74,10 +80,10 @@ class DatabricksProfilerInterface(SQAProfilerInterface):
|
||||
def get_columns(self) -> Column:
|
||||
"""Get columns from table"""
|
||||
columns = []
|
||||
for idx, column in enumerate(self.table_entity.columns):
|
||||
if column.dataType == DataType.STRUCT:
|
||||
for idx, column_obj in enumerate(self.table_entity.columns):
|
||||
if column_obj.dataType == DataType.STRUCT:
|
||||
columns.extend(
|
||||
self._get_struct_columns(column.children, column.name.root)
|
||||
self._get_struct_columns(column_obj.children, column_obj.name.root)
|
||||
)
|
||||
else:
|
||||
col = build_orm_col(idx, column, DatabaseServiceType.Databricks)
|
||||
|
||||
@ -0,0 +1,74 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from pyhive.sqlalchemy_hive import HiveCompiler
|
||||
|
||||
from metadata.profiler.interface.sqlalchemy.databricks.profiler_interface import (
|
||||
DatabricksProfilerInterface,
|
||||
)
|
||||
|
||||
|
||||
class FakeHiveCompiler(
|
||||
DatabricksProfilerInterface,
|
||||
HiveCompiler,
|
||||
):
|
||||
def __init__(self, service_connection_config):
|
||||
self.service_connection_config = service_connection_config
|
||||
|
||||
|
||||
class TestDatabricksProfilerInterface(unittest.TestCase):
|
||||
@patch(
|
||||
"metadata.profiler.interface.sqlalchemy.databricks.profiler_interface.DatabricksProfilerInterface.set_catalog",
|
||||
return_value=None,
|
||||
)
|
||||
@patch(
|
||||
"metadata.profiler.interface.sqlalchemy.databricks.profiler_interface.DatabricksProfilerInterface.__init__",
|
||||
return_value=None,
|
||||
)
|
||||
@patch("pyhive.sqlalchemy_hive.HiveCompiler.visit_column")
|
||||
def setUp(
|
||||
self,
|
||||
mock_visit_column,
|
||||
mock_init,
|
||||
mock_set_catalog,
|
||||
) -> None:
|
||||
self.profiler = FakeHiveCompiler(service_connection_config={})
|
||||
|
||||
@patch("sqlalchemy.sql.compiler.SQLCompiler.visit_column")
|
||||
def test_visit_column_no_nesting(self, mock_visit_column_super):
|
||||
# Mock the response of the super class method
|
||||
mock_visit_column_super.return_value = "`db`.`schema`.`table`"
|
||||
assert self.profiler.visit_column(MagicMock()) == "`db`.`schema`.`table`"
|
||||
|
||||
mock_visit_column_super.return_value = "`db`"
|
||||
assert self.profiler.visit_column(MagicMock()) == "`db`"
|
||||
|
||||
mock_visit_column_super.return_value = "`schema`"
|
||||
assert self.profiler.visit_column(MagicMock()) == "`schema`"
|
||||
|
||||
mock_visit_column_super.return_value = "`table`"
|
||||
assert self.profiler.visit_column(MagicMock()) == "`table`"
|
||||
|
||||
mock_visit_column_super.return_value = "table"
|
||||
assert self.profiler.visit_column(MagicMock()) == "table"
|
||||
|
||||
@patch("sqlalchemy.sql.compiler.SQLCompiler.visit_column")
|
||||
def test_visit_column_nesting(self, mock_visit_column_super):
|
||||
# Mock the response of the super class method
|
||||
mock_visit_column_super.return_value = "`db`.`schema`.`table`.`col.u.m.n`"
|
||||
assert (
|
||||
self.profiler.visit_column(MagicMock())
|
||||
== "`db`.`schema`.`table`.`col`.`u`.`m`.`n`"
|
||||
)
|
||||
|
||||
mock_visit_column_super.return_value = "`db`.`schema`.`table`.`col.1`"
|
||||
assert (
|
||||
self.profiler.visit_column(MagicMock()) == "`db`.`schema`.`table`.`col`.`1`"
|
||||
)
|
||||
|
||||
mock_visit_column_super.return_value = "`table`.`1.2`"
|
||||
assert self.profiler.visit_column(MagicMock()) == "`table`.`1`.`2`"
|
||||
|
||||
# single dot in column name should not be split
|
||||
mock_visit_column_super.return_value = "`col.1`"
|
||||
assert self.profiler.visit_column(MagicMock()) == "`col.1`"
|
||||
Loading…
x
Reference in New Issue
Block a user