diff --git a/ingestion/src/metadata/sampler/sqlalchemy/databricks/sampler.py b/ingestion/src/metadata/sampler/sqlalchemy/databricks/sampler.py index e2481f3dd35..c05223acb7f 100644 --- a/ingestion/src/metadata/sampler/sqlalchemy/databricks/sampler.py +++ b/ingestion/src/metadata/sampler/sqlalchemy/databricks/sampler.py @@ -11,9 +11,12 @@ """ Helper module to handle data sampling for the profiler """ +from sqlalchemy import Column, text + from metadata.ingestion.source.database.databricks.connection import ( get_connection as databricks_get_connection, ) +from metadata.profiler.orm.types.custom_array import CustomArray from metadata.sampler.sqlalchemy.sampler import SQASampler @@ -28,3 +31,27 @@ class DatabricksSamplerInterface(SQASampler): client = super().get_client() self.set_catalog(client) return client + + def _handle_array_column(self, column: Column) -> bool: + """Check if a column is an array type""" + return isinstance(column.type, CustomArray) + + def _get_slice_expression(self, column: Column): + """Generate SQL expression to slice array elements at query level + + Args: + column_name: Name of the column + max_elements: Maximum number of elements to extract + + Returns: + SQL expression string for array slicing + """ + max_elements = self._get_max_array_elements() + return text( + f""" + CASE + WHEN `{column.name}` IS NULL THEN NULL + ELSE slice(`{column.name}`, 1, {max_elements}) + END AS `{column._label}` + """ + ) diff --git a/ingestion/src/metadata/sampler/sqlalchemy/sampler.py b/ingestion/src/metadata/sampler/sqlalchemy/sampler.py index 20107522f50..b16c02ecfaa 100644 --- a/ingestion/src/metadata/sampler/sqlalchemy/sampler.py +++ b/ingestion/src/metadata/sampler/sqlalchemy/sampler.py @@ -41,6 +41,9 @@ logger = profiler_interface_registry_logger() RANDOM_LABEL = "random" +# Default maximum number of elements to extract from array columns to prevent OOM +DEFAULT_MAX_ARRAY_ELEMENTS = 10 + def _object_value_for_elem(self, elem): """ @@ -103,6 +106,36 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin): """ return selectable + def _get_max_array_elements(self) -> int: + """Get the maximum number of array elements from config or use default""" + if ( + self.sample_config + and hasattr(self.sample_config, "maxArrayElements") + and self.sample_config.maxArrayElements + ): + return self.sample_config.maxArrayElements + return DEFAULT_MAX_ARRAY_ELEMENTS + + def _handle_array_column(self, column: Column) -> bool: + """Check if a column is an array type""" + # Implement this method in the child classes + return False + + def _process_array_value(self, value): + """Process array values to convert numpy arrays to Python lists""" + import numpy as np # pylint: disable=import-outside-toplevel + + if isinstance(value, np.ndarray): + return value.tolist() + return value + + def _get_slice_expression(self, column: Column): + """Generate SQL expression to slice array elements at query level + By default, we return the column as is. + Child classes can override this method to return a different expression. + """ + return column + def _base_sample_query(self, column: Optional[Column], label=None): """Base query for sampling @@ -215,16 +248,49 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin): ] with self.get_client_context() as client: + + # Handle array columns with special query modification + max_elements = self._get_max_array_elements() + select_columns = [] + has_array_columns = False + + for col in sqa_columns: + if self._handle_array_column(col): + slice_expression = self._get_slice_expression(col) + select_columns.append(slice_expression) + logger.debug( + f"Limiting array column {col.name} to {max_elements} elements to prevent OOM" + ) + has_array_columns = True + else: + select_columns.append(col) + + # Create query with modified columns sqa_sample = ( - client.query(*sqa_columns) + client.query(*select_columns) .select_from(ds) .limit(self.sample_limit) .all() ) + # Process array columns manually if we used text() expressions + processed_rows = [] + if has_array_columns: + for row in sqa_sample: + processed_row = [] + for i, col in enumerate(sqa_columns): + value = row[i] + if self._handle_array_column(col): + processed_value = self._process_array_value(value) + processed_row.append(processed_value) + else: + processed_row.append(value) + processed_rows.append(processed_row) + else: + processed_rows = [list(row) for row in sqa_sample] return TableData( columns=[column.name for column in sqa_columns], - rows=[list(row) for row in sqa_sample], + rows=processed_rows, ) def _fetch_sample_data_from_user_query(self) -> TableData: diff --git a/ingestion/src/metadata/sampler/sqlalchemy/unitycatalog/sampler.py b/ingestion/src/metadata/sampler/sqlalchemy/unitycatalog/sampler.py index ba1dd151e11..0ccb44c4ae4 100644 --- a/ingestion/src/metadata/sampler/sqlalchemy/unitycatalog/sampler.py +++ b/ingestion/src/metadata/sampler/sqlalchemy/unitycatalog/sampler.py @@ -13,9 +13,13 @@ Interfaces with database for all database engine supporting sqlalchemy abstraction layer """ + +from sqlalchemy import Column, text + from metadata.ingestion.source.database.databricks.connection import ( get_connection as databricks_get_connection, ) +from metadata.profiler.orm.types.custom_array import CustomArray from metadata.sampler.sqlalchemy.sampler import SQASampler @@ -30,3 +34,27 @@ class UnityCatalogSamplerInterface(SQASampler): client = super().get_client() self.set_catalog(client) return client + + def _handle_array_column(self, column: Column) -> bool: + """Check if a column is an array type""" + return isinstance(column.type, CustomArray) + + def _get_slice_expression(self, column: Column): + """Generate SQL expression to slice array elements at query level + + Args: + column_name: Name of the column + max_elements: Maximum number of elements to extract + + Returns: + SQL expression string for array slicing + """ + max_elements = self._get_max_array_elements() + return text( + f""" + CASE + WHEN `{column.name}` IS NULL THEN NULL + ELSE slice(`{column.name}`, 1, {max_elements}) + END AS `{column._label}` + """ + ) diff --git a/ingestion/tests/unit/sampler/sqlalchemy/test_unitycatalog_sampler.py b/ingestion/tests/unit/sampler/sqlalchemy/test_unitycatalog_sampler.py new file mode 100644 index 00000000000..68cd669ef52 --- /dev/null +++ b/ingestion/tests/unit/sampler/sqlalchemy/test_unitycatalog_sampler.py @@ -0,0 +1,124 @@ +# Copyright 2025 Collate +# Licensed under the Collate Community License, Version 1.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test Unity Catalog sampler functionality +""" +from unittest import TestCase +from unittest.mock import patch +from uuid import uuid4 + +from sqlalchemy import Column, Integer, String +from sqlalchemy.orm import declarative_base + +from metadata.generated.schema.entity.data.table import Column as EntityColumn +from metadata.generated.schema.entity.data.table import ColumnName, DataType, Table +from metadata.generated.schema.entity.services.connections.database.unityCatalogConnection import ( + UnityCatalogConnection, +) +from metadata.profiler.orm.types.custom_array import CustomArray +from metadata.sampler.models import SampleConfig +from metadata.sampler.sqlalchemy.sampler import DEFAULT_MAX_ARRAY_ELEMENTS +from metadata.sampler.sqlalchemy.unitycatalog.sampler import ( + UnityCatalogSamplerInterface, +) + +Base = declarative_base() + + +class _TestTableModel(Base): + __tablename__ = "test_table" + id = Column(Integer, primary_key=True) + name = Column(String(256)) + array_col = Column(CustomArray(String)) + + +class UnityCatalogSamplerTest(TestCase): + """Test Unity Catalog sampler functionality""" + + def setUp(self): + """Set up test fixtures""" + self.table_entity = Table( + id=uuid4(), + name="test_table", + columns=[ + EntityColumn( + name=ColumnName("id"), + dataType=DataType.INT, + ), + EntityColumn( + name=ColumnName("name"), + dataType=DataType.STRING, + ), + EntityColumn( + name=ColumnName("array_col"), + dataType=DataType.ARRAY, + ), + ], + ) + + self.unity_catalog_conn = UnityCatalogConnection( + hostPort="localhost:443", + token="test_token", + httpPath="/sql/1.0/warehouses/test", + catalog="test_catalog", + ) + + @patch( + "metadata.sampler.sqlalchemy.unitycatalog.sampler.SQASampler.build_table_orm" + ) + def test_handle_array_column(self, mock_build_table_orm): + """Test array column detection""" + mock_build_table_orm.return_value = _TestTableModel + + sampler = UnityCatalogSamplerInterface( + service_connection_config=self.unity_catalog_conn, + ometa_client=None, + entity=self.table_entity, + sample_config=SampleConfig(), + ) + + # Test with array column + array_col = _TestTableModel.__table__.c.array_col + self.assertTrue(sampler._handle_array_column(array_col)) + + # Test with non-array column + name_col = _TestTableModel.__table__.c.name + self.assertFalse(sampler._handle_array_column(name_col)) + + def test_get_max_array_elements_default(self): + """Test default max array elements from base class""" + # Create a minimal sampler instance to test the method + sampler = UnityCatalogSamplerInterface.__new__(UnityCatalogSamplerInterface) + sampler.sample_config = SampleConfig() + + # Test that it returns the default value from the base class + self.assertEqual(sampler._get_max_array_elements(), DEFAULT_MAX_ARRAY_ELEMENTS) + + def test_get_slice_expression(self): + """Test slice expression generation for array columns""" + # Create a minimal sampler instance to test the method + sampler = UnityCatalogSamplerInterface.__new__(UnityCatalogSamplerInterface) + sampler.sample_config = SampleConfig() + + # Create a mock column + array_col = _TestTableModel.__table__.c.array_col + + # Test the slice expression generation + expression = sampler._get_slice_expression(array_col) + + # Check that it returns a text() object with the expected SQL + self.assertIsNotNone(expression) + # The expression should contain the expected SQL pattern + sql_str = str(expression.compile(compile_kwargs={"literal_binds": True})) + self.assertIn("CASE", sql_str) + self.assertIn("slice", sql_str) + self.assertIn("array_col", sql_str)