mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-11-03 12:08:31 +00:00
MINOR: Improve Array Sampler for UC & DBX (#22155)
* MINOR: Improve Array Sampler for UC & DBX * make log debug * address comments
This commit is contained in:
parent
2cfabf6017
commit
573e3bfc21
@ -11,9 +11,12 @@
|
||||
"""
|
||||
Helper module to handle data sampling for the profiler
|
||||
"""
|
||||
from sqlalchemy import Column, text
|
||||
|
||||
from metadata.ingestion.source.database.databricks.connection import (
|
||||
get_connection as databricks_get_connection,
|
||||
)
|
||||
from metadata.profiler.orm.types.custom_array import CustomArray
|
||||
from metadata.sampler.sqlalchemy.sampler import SQASampler
|
||||
|
||||
|
||||
@ -28,3 +31,27 @@ class DatabricksSamplerInterface(SQASampler):
|
||||
client = super().get_client()
|
||||
self.set_catalog(client)
|
||||
return client
|
||||
|
||||
def _handle_array_column(self, column: Column) -> bool:
|
||||
"""Check if a column is an array type"""
|
||||
return isinstance(column.type, CustomArray)
|
||||
|
||||
def _get_slice_expression(self, column: Column):
|
||||
"""Generate SQL expression to slice array elements at query level
|
||||
|
||||
Args:
|
||||
column_name: Name of the column
|
||||
max_elements: Maximum number of elements to extract
|
||||
|
||||
Returns:
|
||||
SQL expression string for array slicing
|
||||
"""
|
||||
max_elements = self._get_max_array_elements()
|
||||
return text(
|
||||
f"""
|
||||
CASE
|
||||
WHEN `{column.name}` IS NULL THEN NULL
|
||||
ELSE slice(`{column.name}`, 1, {max_elements})
|
||||
END AS `{column._label}`
|
||||
"""
|
||||
)
|
||||
|
||||
@ -41,6 +41,9 @@ logger = profiler_interface_registry_logger()
|
||||
|
||||
RANDOM_LABEL = "random"
|
||||
|
||||
# Default maximum number of elements to extract from array columns to prevent OOM
|
||||
DEFAULT_MAX_ARRAY_ELEMENTS = 10
|
||||
|
||||
|
||||
def _object_value_for_elem(self, elem):
|
||||
"""
|
||||
@ -103,6 +106,36 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
|
||||
"""
|
||||
return selectable
|
||||
|
||||
def _get_max_array_elements(self) -> int:
|
||||
"""Get the maximum number of array elements from config or use default"""
|
||||
if (
|
||||
self.sample_config
|
||||
and hasattr(self.sample_config, "maxArrayElements")
|
||||
and self.sample_config.maxArrayElements
|
||||
):
|
||||
return self.sample_config.maxArrayElements
|
||||
return DEFAULT_MAX_ARRAY_ELEMENTS
|
||||
|
||||
def _handle_array_column(self, column: Column) -> bool:
|
||||
"""Check if a column is an array type"""
|
||||
# Implement this method in the child classes
|
||||
return False
|
||||
|
||||
def _process_array_value(self, value):
|
||||
"""Process array values to convert numpy arrays to Python lists"""
|
||||
import numpy as np # pylint: disable=import-outside-toplevel
|
||||
|
||||
if isinstance(value, np.ndarray):
|
||||
return value.tolist()
|
||||
return value
|
||||
|
||||
def _get_slice_expression(self, column: Column):
|
||||
"""Generate SQL expression to slice array elements at query level
|
||||
By default, we return the column as is.
|
||||
Child classes can override this method to return a different expression.
|
||||
"""
|
||||
return column
|
||||
|
||||
def _base_sample_query(self, column: Optional[Column], label=None):
|
||||
"""Base query for sampling
|
||||
|
||||
@ -215,16 +248,49 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
|
||||
]
|
||||
|
||||
with self.get_client_context() as client:
|
||||
|
||||
# Handle array columns with special query modification
|
||||
max_elements = self._get_max_array_elements()
|
||||
select_columns = []
|
||||
has_array_columns = False
|
||||
|
||||
for col in sqa_columns:
|
||||
if self._handle_array_column(col):
|
||||
slice_expression = self._get_slice_expression(col)
|
||||
select_columns.append(slice_expression)
|
||||
logger.debug(
|
||||
f"Limiting array column {col.name} to {max_elements} elements to prevent OOM"
|
||||
)
|
||||
has_array_columns = True
|
||||
else:
|
||||
select_columns.append(col)
|
||||
|
||||
# Create query with modified columns
|
||||
sqa_sample = (
|
||||
client.query(*sqa_columns)
|
||||
client.query(*select_columns)
|
||||
.select_from(ds)
|
||||
.limit(self.sample_limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Process array columns manually if we used text() expressions
|
||||
processed_rows = []
|
||||
if has_array_columns:
|
||||
for row in sqa_sample:
|
||||
processed_row = []
|
||||
for i, col in enumerate(sqa_columns):
|
||||
value = row[i]
|
||||
if self._handle_array_column(col):
|
||||
processed_value = self._process_array_value(value)
|
||||
processed_row.append(processed_value)
|
||||
else:
|
||||
processed_row.append(value)
|
||||
processed_rows.append(processed_row)
|
||||
else:
|
||||
processed_rows = [list(row) for row in sqa_sample]
|
||||
return TableData(
|
||||
columns=[column.name for column in sqa_columns],
|
||||
rows=[list(row) for row in sqa_sample],
|
||||
rows=processed_rows,
|
||||
)
|
||||
|
||||
def _fetch_sample_data_from_user_query(self) -> TableData:
|
||||
|
||||
@ -13,9 +13,13 @@
|
||||
Interfaces with database for all database engine
|
||||
supporting sqlalchemy abstraction layer
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, text
|
||||
|
||||
from metadata.ingestion.source.database.databricks.connection import (
|
||||
get_connection as databricks_get_connection,
|
||||
)
|
||||
from metadata.profiler.orm.types.custom_array import CustomArray
|
||||
from metadata.sampler.sqlalchemy.sampler import SQASampler
|
||||
|
||||
|
||||
@ -30,3 +34,27 @@ class UnityCatalogSamplerInterface(SQASampler):
|
||||
client = super().get_client()
|
||||
self.set_catalog(client)
|
||||
return client
|
||||
|
||||
def _handle_array_column(self, column: Column) -> bool:
|
||||
"""Check if a column is an array type"""
|
||||
return isinstance(column.type, CustomArray)
|
||||
|
||||
def _get_slice_expression(self, column: Column):
|
||||
"""Generate SQL expression to slice array elements at query level
|
||||
|
||||
Args:
|
||||
column_name: Name of the column
|
||||
max_elements: Maximum number of elements to extract
|
||||
|
||||
Returns:
|
||||
SQL expression string for array slicing
|
||||
"""
|
||||
max_elements = self._get_max_array_elements()
|
||||
return text(
|
||||
f"""
|
||||
CASE
|
||||
WHEN `{column.name}` IS NULL THEN NULL
|
||||
ELSE slice(`{column.name}`, 1, {max_elements})
|
||||
END AS `{column._label}`
|
||||
"""
|
||||
)
|
||||
|
||||
@ -0,0 +1,124 @@
|
||||
# Copyright 2025 Collate
|
||||
# Licensed under the Collate Community License, Version 1.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Test Unity Catalog sampler functionality
|
||||
"""
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import Column, Integer, String
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
from metadata.generated.schema.entity.data.table import Column as EntityColumn
|
||||
from metadata.generated.schema.entity.data.table import ColumnName, DataType, Table
|
||||
from metadata.generated.schema.entity.services.connections.database.unityCatalogConnection import (
|
||||
UnityCatalogConnection,
|
||||
)
|
||||
from metadata.profiler.orm.types.custom_array import CustomArray
|
||||
from metadata.sampler.models import SampleConfig
|
||||
from metadata.sampler.sqlalchemy.sampler import DEFAULT_MAX_ARRAY_ELEMENTS
|
||||
from metadata.sampler.sqlalchemy.unitycatalog.sampler import (
|
||||
UnityCatalogSamplerInterface,
|
||||
)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class _TestTableModel(Base):
|
||||
__tablename__ = "test_table"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(256))
|
||||
array_col = Column(CustomArray(String))
|
||||
|
||||
|
||||
class UnityCatalogSamplerTest(TestCase):
|
||||
"""Test Unity Catalog sampler functionality"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
self.table_entity = Table(
|
||||
id=uuid4(),
|
||||
name="test_table",
|
||||
columns=[
|
||||
EntityColumn(
|
||||
name=ColumnName("id"),
|
||||
dataType=DataType.INT,
|
||||
),
|
||||
EntityColumn(
|
||||
name=ColumnName("name"),
|
||||
dataType=DataType.STRING,
|
||||
),
|
||||
EntityColumn(
|
||||
name=ColumnName("array_col"),
|
||||
dataType=DataType.ARRAY,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
self.unity_catalog_conn = UnityCatalogConnection(
|
||||
hostPort="localhost:443",
|
||||
token="test_token",
|
||||
httpPath="/sql/1.0/warehouses/test",
|
||||
catalog="test_catalog",
|
||||
)
|
||||
|
||||
@patch(
|
||||
"metadata.sampler.sqlalchemy.unitycatalog.sampler.SQASampler.build_table_orm"
|
||||
)
|
||||
def test_handle_array_column(self, mock_build_table_orm):
|
||||
"""Test array column detection"""
|
||||
mock_build_table_orm.return_value = _TestTableModel
|
||||
|
||||
sampler = UnityCatalogSamplerInterface(
|
||||
service_connection_config=self.unity_catalog_conn,
|
||||
ometa_client=None,
|
||||
entity=self.table_entity,
|
||||
sample_config=SampleConfig(),
|
||||
)
|
||||
|
||||
# Test with array column
|
||||
array_col = _TestTableModel.__table__.c.array_col
|
||||
self.assertTrue(sampler._handle_array_column(array_col))
|
||||
|
||||
# Test with non-array column
|
||||
name_col = _TestTableModel.__table__.c.name
|
||||
self.assertFalse(sampler._handle_array_column(name_col))
|
||||
|
||||
def test_get_max_array_elements_default(self):
|
||||
"""Test default max array elements from base class"""
|
||||
# Create a minimal sampler instance to test the method
|
||||
sampler = UnityCatalogSamplerInterface.__new__(UnityCatalogSamplerInterface)
|
||||
sampler.sample_config = SampleConfig()
|
||||
|
||||
# Test that it returns the default value from the base class
|
||||
self.assertEqual(sampler._get_max_array_elements(), DEFAULT_MAX_ARRAY_ELEMENTS)
|
||||
|
||||
def test_get_slice_expression(self):
|
||||
"""Test slice expression generation for array columns"""
|
||||
# Create a minimal sampler instance to test the method
|
||||
sampler = UnityCatalogSamplerInterface.__new__(UnityCatalogSamplerInterface)
|
||||
sampler.sample_config = SampleConfig()
|
||||
|
||||
# Create a mock column
|
||||
array_col = _TestTableModel.__table__.c.array_col
|
||||
|
||||
# Test the slice expression generation
|
||||
expression = sampler._get_slice_expression(array_col)
|
||||
|
||||
# Check that it returns a text() object with the expected SQL
|
||||
self.assertIsNotNone(expression)
|
||||
# The expression should contain the expected SQL pattern
|
||||
sql_str = str(expression.compile(compile_kwargs={"literal_binds": True}))
|
||||
self.assertIn("CASE", sql_str)
|
||||
self.assertIn("slice", sql_str)
|
||||
self.assertIn("array_col", sql_str)
|
||||
Loading…
x
Reference in New Issue
Block a user