From 1db72c2c1e526b2d8af6ce7362e0ef63cb89a768 Mon Sep 17 00:00:00 2001 From: Imri Paran Date: Tue, 25 Mar 2025 13:48:18 +0100 Subject: [PATCH] MINOR: fix: close client after query (#19711) * fix: close client after query use context clients in SQL sampler to close the connection once the query is complete * use self.context_client in all sql sampler implementations * use sqlalchemy's built-in session management * format * format * use get_client directly --- .../unity_catalog/sampler_interface.py | 7 +- .../src/metadata/sampler/nosql/sampler.py | 4 + .../src/metadata/sampler/pandas/sampler.py | 1 + .../src/metadata/sampler/sampler_interface.py | 1 - .../sampler/sqlalchemy/azuresql/sampler.py | 3 +- .../sampler/sqlalchemy/mssql/sampler.py | 3 +- .../metadata/sampler/sqlalchemy/sampler.py | 80 ++++++++++--------- .../sampler/sqlalchemy/snowflake/sampler.py | 3 +- .../sampler/sqlalchemy/trino/sampler.py | 17 ++-- 9 files changed, 65 insertions(+), 54 deletions(-) diff --git a/ingestion/src/metadata/profiler/interface/sqlalchemy/unity_catalog/sampler_interface.py b/ingestion/src/metadata/profiler/interface/sqlalchemy/unity_catalog/sampler_interface.py index 12a4ae3eaac..c078fc27451 100644 --- a/ingestion/src/metadata/profiler/interface/sqlalchemy/unity_catalog/sampler_interface.py +++ b/ingestion/src/metadata/profiler/interface/sqlalchemy/unity_catalog/sampler_interface.py @@ -23,7 +23,6 @@ class UnityCatalogSamplerInterface(SQASampler): def get_client(self): """client is the session for SQA""" self.connection = databricks_get_connection(self.service_connection_config) - self.client = super().get_client() - self.set_catalog(self.client) - - return self.client + client = super().get_client() + self.set_catalog(client) + return client diff --git a/ingestion/src/metadata/sampler/nosql/sampler.py b/ingestion/src/metadata/sampler/nosql/sampler.py index 16ed1229f62..3b1ea77d551 100644 --- a/ingestion/src/metadata/sampler/nosql/sampler.py +++ b/ingestion/src/metadata/sampler/nosql/sampler.py @@ -24,6 +24,10 @@ class NoSQLSampler(SamplerInterface): client: NoSQLAdaptor + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.client = self.get_client() + @property def raw_dataset(self): return self.entity diff --git a/ingestion/src/metadata/sampler/pandas/sampler.py b/ingestion/src/metadata/sampler/pandas/sampler.py index b221837daed..5c869a21e9e 100644 --- a/ingestion/src/metadata/sampler/pandas/sampler.py +++ b/ingestion/src/metadata/sampler/pandas/sampler.py @@ -46,6 +46,7 @@ class DatalakeSampler(SamplerInterface, PandasInterfaceMixin): super().__init__(*args, **kwargs) self.partition_details = cast(PartitionProfilerConfig, self.partition_details) self._table = None + self.client = self.get_client() @property def raw_dataset(self): diff --git a/ingestion/src/metadata/sampler/sampler_interface.py b/ingestion/src/metadata/sampler/sampler_interface.py index fe363816d01..cf7837a530e 100644 --- a/ingestion/src/metadata/sampler/sampler_interface.py +++ b/ingestion/src/metadata/sampler/sampler_interface.py @@ -86,7 +86,6 @@ class SamplerInterface(ABC): self.service_connection_config = service_connection_config self.connection = get_ssl_connection(self.service_connection_config) - self.client = self.get_client() # pylint: disable=too-many-arguments, too-many-locals @classmethod diff --git a/ingestion/src/metadata/sampler/sqlalchemy/azuresql/sampler.py b/ingestion/src/metadata/sampler/sqlalchemy/azuresql/sampler.py index 09974ff8f96..2332ad98612 100644 --- a/ingestion/src/metadata/sampler/sqlalchemy/azuresql/sampler.py +++ b/ingestion/src/metadata/sampler/sqlalchemy/azuresql/sampler.py @@ -54,7 +54,8 @@ class AzureSQLSampler(SQASampler): rnd = self._base_sample_query(column).cte( f"{self.get_sampler_table_name()}_rnd" ) - query = self.client.query(rnd) + with self.get_client() as client: + query = client.query(rnd) return query.cte(f"{self.get_sampler_table_name()}_sample") def fetch_sample_data(self, columns: Optional[List[Column]] = None) -> TableData: diff --git a/ingestion/src/metadata/sampler/sqlalchemy/mssql/sampler.py b/ingestion/src/metadata/sampler/sqlalchemy/mssql/sampler.py index 7864dacf374..d59c913ed76 100644 --- a/ingestion/src/metadata/sampler/sqlalchemy/mssql/sampler.py +++ b/ingestion/src/metadata/sampler/sqlalchemy/mssql/sampler.py @@ -48,5 +48,6 @@ class MssqlSampler(SQASampler): rnd = self._base_sample_query(column).cte( f"{self.get_sampler_table_name()}_rnd" ) - query = self.client.query(rnd) + with self.get_client() as client: + query = client.query(rnd) return query.cte(f"{self.get_sampler_table_name()}_sample") diff --git a/ingestion/src/metadata/sampler/sqlalchemy/sampler.py b/ingestion/src/metadata/sampler/sqlalchemy/sampler.py index 5129ecc0158..294ee0bcd96 100644 --- a/ingestion/src/metadata/sampler/sqlalchemy/sampler.py +++ b/ingestion/src/metadata/sampler/sqlalchemy/sampler.py @@ -102,13 +102,14 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin): selectable = self.set_tablesample(self.raw_dataset.__table__) entity = selectable if column is None else selectable.c.get(column.key) - if label is not None: - query = self.client.query(entity, label) - else: - query = self.client.query(entity) + with self.get_client() as client: + if label is not None: + query = client.query(entity, label) + else: + query = client.query(entity) - if self.partition_details: - query = self.get_partitioned_query(query) + if self.partition_details: + query = self.get_partitioned_query(query) return query def get_sampler_table_name(self) -> str: @@ -129,12 +130,14 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin): column, (ModuloFn(RandomNumFn(), 100)).label(RANDOM_LABEL), ).cte(f"{self.get_sampler_table_name()}_rnd") - session_query = self.client.query(rnd) + with self.get_client() as client: + session_query = client.query(rnd) return session_query.where( rnd.c.random <= self.sample_config.profileSample ).cte(f"{self.get_sampler_table_name()}_sample") - table_query = self.client.query(self.raw_dataset) + with self.get_client() as client: + table_query = client.query(self.raw_dataset) session_query = self._base_sample_query( column, (ModuloFn(RandomNumFn(), table_query.count())).label(RANDOM_LABEL) @@ -195,38 +198,38 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin): for col in inspect(ds).c if col.name != RANDOM_LABEL and col.name in names ] - - try: - sqa_sample = ( - self.client.query(*sqa_columns) - .select_from(ds) - .limit(self.sample_limit) - .all() + with self.get_client() as client: + try: + sqa_sample = ( + client.query(*sqa_columns) + .select_from(ds) + .limit(self.sample_limit) + .all() + ) + except Exception: + logger.debug( + "Cannot fetch sample data with random sampling. Falling back to 100 rows." + ) + logger.debug(traceback.format_exc()) + sqa_columns = list(inspect(self.raw_dataset).c) + sqa_sample = ( + client.query(*sqa_columns) + .select_from(self.raw_dataset) + .limit(100) + .all() + ) + return TableData( + columns=[column.name for column in sqa_columns], + rows=[list(row) for row in sqa_sample], ) - except Exception: - logger.debug( - "Cannot fetch sample data with random sampling. Falling back to 100 rows." - ) - logger.debug(traceback.format_exc()) - sqa_columns = list(inspect(self.raw_dataset).c) - sqa_sample = ( - self.client.query(*sqa_columns) - .select_from(self.raw_dataset) - .limit(100) - .all() - ) - - return TableData( - columns=[column.name for column in sqa_columns], - rows=[list(row) for row in sqa_sample], - ) def _fetch_sample_data_from_user_query(self) -> TableData: """Returns a table data object using results from query execution""" if not is_safe_sql_query(self.sample_query): raise RuntimeError(f"SQL expression is not safe\n\n{self.sample_query}") - rnd = self.client.execute(f"{self.sample_query}") + with self.get_client() as client: + rnd = client.execute(f"{self.sample_query}") try: columns = [col.name for col in rnd.cursor.description] except AttributeError: @@ -243,10 +246,10 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin): stmt = text(f"{self.sample_query}") stmt = stmt.columns(*list(inspect(self.raw_dataset).c)) - - return self.client.query(stmt.subquery()).cte( - f"{self.get_sampler_table_name()}_user_sampled" - ) + with self.get_client() as client: + return client.query(stmt.subquery()).cte( + f"{self.get_sampler_table_name()}_user_sampled" + ) def _partitioned_table(self) -> Query: """Return the Query object for partitioned tables""" @@ -263,7 +266,8 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin): ) if query is not None: return query.filter(partition_filter) - return self.client.query(self.raw_dataset).filter(partition_filter) + with self.get_client() as client: + return client.query(self.raw_dataset).filter(partition_filter) def get_columns(self): """get columns from entity""" diff --git a/ingestion/src/metadata/sampler/sqlalchemy/snowflake/sampler.py b/ingestion/src/metadata/sampler/sqlalchemy/snowflake/sampler.py index f27e81e132e..958afd288d6 100644 --- a/ingestion/src/metadata/sampler/sqlalchemy/snowflake/sampler.py +++ b/ingestion/src/metadata/sampler/sqlalchemy/snowflake/sampler.py @@ -91,5 +91,6 @@ class SnowflakeSampler(SQASampler): rnd = self._base_sample_query(column).cte( f"{self.get_sampler_table_name()}_rnd" ) - query = self.client.query(rnd) + with self.get_client() as client: + query = client.query(rnd) return query.cte(f"{self.get_sampler_table_name()}_sample") diff --git a/ingestion/src/metadata/sampler/sqlalchemy/trino/sampler.py b/ingestion/src/metadata/sampler/sqlalchemy/trino/sampler.py index 4eb6cc40c51..51d8ce9ddf9 100644 --- a/ingestion/src/metadata/sampler/sqlalchemy/trino/sampler.py +++ b/ingestion/src/metadata/sampler/sqlalchemy/trino/sampler.py @@ -38,12 +38,13 @@ class TrinoSampler(SQASampler): col for col in inspect(self.raw_dataset).c if col.name != RANDOM_LABEL ] entity = self.raw_dataset if column is None else column - return self.client.query(entity, label).where( - or_( - *[ - text(f'is_nan("{cols.name}") = False') - for cols in sqa_columns - if type(cols.type) in FLOAT_SET - ] + with self.get_client() as client: + return client.query(entity, label).where( + or_( + *[ + text(f'is_nan("{cols.name}") = False') + for cols in sqa_columns + if type(cols.type) in FLOAT_SET + ] + ) ) - )