MINOR: Improve Array Sampler for UC & DBX (#22155)

* MINOR: Improve Array Sampler for UC & DBX

* make log debug

* address comments
This commit is contained in:
Mayur Singal 2025-07-08 17:55:20 +05:30 committed by GitHub
parent 2cfabf6017
commit 573e3bfc21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 247 additions and 2 deletions

View File

@ -11,9 +11,12 @@
"""
Helper module to handle data sampling for the profiler
"""
from sqlalchemy import Column, text
from metadata.ingestion.source.database.databricks.connection import (
get_connection as databricks_get_connection,
)
from metadata.profiler.orm.types.custom_array import CustomArray
from metadata.sampler.sqlalchemy.sampler import SQASampler
@ -28,3 +31,27 @@ class DatabricksSamplerInterface(SQASampler):
client = super().get_client()
self.set_catalog(client)
return client
def _handle_array_column(self, column: Column) -> bool:
"""Check if a column is an array type"""
return isinstance(column.type, CustomArray)
def _get_slice_expression(self, column: Column):
"""Generate SQL expression to slice array elements at query level
Args:
column_name: Name of the column
max_elements: Maximum number of elements to extract
Returns:
SQL expression string for array slicing
"""
max_elements = self._get_max_array_elements()
return text(
f"""
CASE
WHEN `{column.name}` IS NULL THEN NULL
ELSE slice(`{column.name}`, 1, {max_elements})
END AS `{column._label}`
"""
)

View File

@ -41,6 +41,9 @@ logger = profiler_interface_registry_logger()
RANDOM_LABEL = "random"
# Default maximum number of elements to extract from array columns to prevent OOM
DEFAULT_MAX_ARRAY_ELEMENTS = 10
def _object_value_for_elem(self, elem):
"""
@ -103,6 +106,36 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
"""
return selectable
def _get_max_array_elements(self) -> int:
"""Get the maximum number of array elements from config or use default"""
if (
self.sample_config
and hasattr(self.sample_config, "maxArrayElements")
and self.sample_config.maxArrayElements
):
return self.sample_config.maxArrayElements
return DEFAULT_MAX_ARRAY_ELEMENTS
def _handle_array_column(self, column: Column) -> bool:
"""Check if a column is an array type"""
# Implement this method in the child classes
return False
def _process_array_value(self, value):
"""Process array values to convert numpy arrays to Python lists"""
import numpy as np # pylint: disable=import-outside-toplevel
if isinstance(value, np.ndarray):
return value.tolist()
return value
def _get_slice_expression(self, column: Column):
"""Generate SQL expression to slice array elements at query level
By default, we return the column as is.
Child classes can override this method to return a different expression.
"""
return column
def _base_sample_query(self, column: Optional[Column], label=None):
"""Base query for sampling
@ -215,16 +248,49 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
]
with self.get_client_context() as client:
# Handle array columns with special query modification
max_elements = self._get_max_array_elements()
select_columns = []
has_array_columns = False
for col in sqa_columns:
if self._handle_array_column(col):
slice_expression = self._get_slice_expression(col)
select_columns.append(slice_expression)
logger.debug(
f"Limiting array column {col.name} to {max_elements} elements to prevent OOM"
)
has_array_columns = True
else:
select_columns.append(col)
# Create query with modified columns
sqa_sample = (
client.query(*sqa_columns)
client.query(*select_columns)
.select_from(ds)
.limit(self.sample_limit)
.all()
)
# Process array columns manually if we used text() expressions
processed_rows = []
if has_array_columns:
for row in sqa_sample:
processed_row = []
for i, col in enumerate(sqa_columns):
value = row[i]
if self._handle_array_column(col):
processed_value = self._process_array_value(value)
processed_row.append(processed_value)
else:
processed_row.append(value)
processed_rows.append(processed_row)
else:
processed_rows = [list(row) for row in sqa_sample]
return TableData(
columns=[column.name for column in sqa_columns],
rows=[list(row) for row in sqa_sample],
rows=processed_rows,
)
def _fetch_sample_data_from_user_query(self) -> TableData:

View File

@ -13,9 +13,13 @@
Interfaces with database for all database engine
supporting sqlalchemy abstraction layer
"""
from sqlalchemy import Column, text
from metadata.ingestion.source.database.databricks.connection import (
get_connection as databricks_get_connection,
)
from metadata.profiler.orm.types.custom_array import CustomArray
from metadata.sampler.sqlalchemy.sampler import SQASampler
@ -30,3 +34,27 @@ class UnityCatalogSamplerInterface(SQASampler):
client = super().get_client()
self.set_catalog(client)
return client
def _handle_array_column(self, column: Column) -> bool:
"""Check if a column is an array type"""
return isinstance(column.type, CustomArray)
def _get_slice_expression(self, column: Column):
"""Generate SQL expression to slice array elements at query level
Args:
column_name: Name of the column
max_elements: Maximum number of elements to extract
Returns:
SQL expression string for array slicing
"""
max_elements = self._get_max_array_elements()
return text(
f"""
CASE
WHEN `{column.name}` IS NULL THEN NULL
ELSE slice(`{column.name}`, 1, {max_elements})
END AS `{column._label}`
"""
)

View File

@ -0,0 +1,124 @@
# Copyright 2025 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Test Unity Catalog sampler functionality
"""
from unittest import TestCase
from unittest.mock import patch
from uuid import uuid4
from sqlalchemy import Column, Integer, String
from sqlalchemy.orm import declarative_base
from metadata.generated.schema.entity.data.table import Column as EntityColumn
from metadata.generated.schema.entity.data.table import ColumnName, DataType, Table
from metadata.generated.schema.entity.services.connections.database.unityCatalogConnection import (
UnityCatalogConnection,
)
from metadata.profiler.orm.types.custom_array import CustomArray
from metadata.sampler.models import SampleConfig
from metadata.sampler.sqlalchemy.sampler import DEFAULT_MAX_ARRAY_ELEMENTS
from metadata.sampler.sqlalchemy.unitycatalog.sampler import (
UnityCatalogSamplerInterface,
)
Base = declarative_base()
class _TestTableModel(Base):
__tablename__ = "test_table"
id = Column(Integer, primary_key=True)
name = Column(String(256))
array_col = Column(CustomArray(String))
class UnityCatalogSamplerTest(TestCase):
"""Test Unity Catalog sampler functionality"""
def setUp(self):
"""Set up test fixtures"""
self.table_entity = Table(
id=uuid4(),
name="test_table",
columns=[
EntityColumn(
name=ColumnName("id"),
dataType=DataType.INT,
),
EntityColumn(
name=ColumnName("name"),
dataType=DataType.STRING,
),
EntityColumn(
name=ColumnName("array_col"),
dataType=DataType.ARRAY,
),
],
)
self.unity_catalog_conn = UnityCatalogConnection(
hostPort="localhost:443",
token="test_token",
httpPath="/sql/1.0/warehouses/test",
catalog="test_catalog",
)
@patch(
"metadata.sampler.sqlalchemy.unitycatalog.sampler.SQASampler.build_table_orm"
)
def test_handle_array_column(self, mock_build_table_orm):
"""Test array column detection"""
mock_build_table_orm.return_value = _TestTableModel
sampler = UnityCatalogSamplerInterface(
service_connection_config=self.unity_catalog_conn,
ometa_client=None,
entity=self.table_entity,
sample_config=SampleConfig(),
)
# Test with array column
array_col = _TestTableModel.__table__.c.array_col
self.assertTrue(sampler._handle_array_column(array_col))
# Test with non-array column
name_col = _TestTableModel.__table__.c.name
self.assertFalse(sampler._handle_array_column(name_col))
def test_get_max_array_elements_default(self):
"""Test default max array elements from base class"""
# Create a minimal sampler instance to test the method
sampler = UnityCatalogSamplerInterface.__new__(UnityCatalogSamplerInterface)
sampler.sample_config = SampleConfig()
# Test that it returns the default value from the base class
self.assertEqual(sampler._get_max_array_elements(), DEFAULT_MAX_ARRAY_ELEMENTS)
def test_get_slice_expression(self):
"""Test slice expression generation for array columns"""
# Create a minimal sampler instance to test the method
sampler = UnityCatalogSamplerInterface.__new__(UnityCatalogSamplerInterface)
sampler.sample_config = SampleConfig()
# Create a mock column
array_col = _TestTableModel.__table__.c.array_col
# Test the slice expression generation
expression = sampler._get_slice_expression(array_col)
# Check that it returns a text() object with the expected SQL
self.assertIsNotNone(expression)
# The expression should contain the expected SQL pattern
sql_str = str(expression.compile(compile_kwargs={"literal_binds": True}))
self.assertIn("CASE", sql_str)
self.assertIn("slice", sql_str)
self.assertIn("array_col", sql_str)