From 68b0eb34b76a0316c44dcf72b91ffceaefe846b6 Mon Sep 17 00:00:00 2001 From: Teddy Date: Mon, 2 Jun 2025 09:02:17 +0200 Subject: [PATCH] MINOR: row sampling error (#21488) * fix: row sampling error * fix: return sample query (cherry picked from commit 859f24aba7c2b0bcbaf85b149b2a02d60fa4e201) --- .../src/metadata/sampler/sqlalchemy/azuresql/sampler.py | 9 ++++----- .../src/metadata/sampler/sqlalchemy/mssql/sampler.py | 5 ++--- ingestion/src/metadata/sampler/sqlalchemy/sampler.py | 2 ++ .../src/metadata/sampler/sqlalchemy/snowflake/sampler.py | 5 ++--- 4 files changed, 10 insertions(+), 11 deletions(-) diff --git a/ingestion/src/metadata/sampler/sqlalchemy/azuresql/sampler.py b/ingestion/src/metadata/sampler/sqlalchemy/azuresql/sampler.py index 70e3ee14486..83a3c4062e9 100644 --- a/ingestion/src/metadata/sampler/sqlalchemy/azuresql/sampler.py +++ b/ingestion/src/metadata/sampler/sqlalchemy/azuresql/sampler.py @@ -15,7 +15,7 @@ for the profiler from typing import List, Optional from sqlalchemy import Column, Table, text -from sqlalchemy.orm import Query +from sqlalchemy.sql.selectable import CTE from metadata.generated.schema.entity.data.table import TableData, TableType from metadata.sampler.sqlalchemy.sampler import ProfileSampleType, SQASampler @@ -49,13 +49,12 @@ class AzureSQLSampler(SQASampler): return selectable - def get_sample_query(self, *, column=None) -> Query: - """get query for sample data""" + def get_sample_query(self, *, column=None) -> CTE: + """Override the base method as ROWS or PERCENT sampling handled through the tablesample clause""" rnd = self._base_sample_query(column).cte( f"{self.get_sampler_table_name()}_rnd" ) - with self.get_client() as client: - query = client.query(rnd) + query = self.get_client().query(rnd) return query.cte(f"{self.get_sampler_table_name()}_sample") def fetch_sample_data(self, columns: Optional[List[Column]] = None) -> TableData: diff --git a/ingestion/src/metadata/sampler/sqlalchemy/mssql/sampler.py b/ingestion/src/metadata/sampler/sqlalchemy/mssql/sampler.py index a1e2fb93609..349ce36fee8 100644 --- a/ingestion/src/metadata/sampler/sqlalchemy/mssql/sampler.py +++ b/ingestion/src/metadata/sampler/sqlalchemy/mssql/sampler.py @@ -44,10 +44,9 @@ class MssqlSampler(SQASampler): return selectable def get_sample_query(self, *, column=None) -> CTE: - """get query for sample data""" + """Override the base method as ROWS or PERCENT sampling handled through the tablesample clause""" rnd = self._base_sample_query(column).cte( f"{self.get_sampler_table_name()}_rnd" ) - with self.get_client() as client: - query = client.query(rnd) + query = self.get_client().query(rnd) return query.cte(f"{self.get_sampler_table_name()}_sample") diff --git a/ingestion/src/metadata/sampler/sqlalchemy/sampler.py b/ingestion/src/metadata/sampler/sqlalchemy/sampler.py index 27d51398598..69dbd2a75b1 100644 --- a/ingestion/src/metadata/sampler/sqlalchemy/sampler.py +++ b/ingestion/src/metadata/sampler/sqlalchemy/sampler.py @@ -136,6 +136,8 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin): ).cte(f"{self.get_sampler_table_name()}_sample") table_query = client.query(self.raw_dataset) + if self.partition_details: + table_query = self.get_partitioned_query(table_query) session_query = self._base_sample_query( column, (ModuloFn(RandomNumFn(), table_query.count())).label(RANDOM_LABEL) diff --git a/ingestion/src/metadata/sampler/sqlalchemy/snowflake/sampler.py b/ingestion/src/metadata/sampler/sqlalchemy/snowflake/sampler.py index 274d0c0ce4f..ea64bb9e3f8 100644 --- a/ingestion/src/metadata/sampler/sqlalchemy/snowflake/sampler.py +++ b/ingestion/src/metadata/sampler/sqlalchemy/snowflake/sampler.py @@ -87,10 +87,9 @@ class SnowflakeSampler(SQASampler): ) def get_sample_query(self, *, column=None) -> CTE: - """get query for sample data""" + """Override the base method as ROWS or PERCENT sampling handled through the tablesample clause""" rnd = self._base_sample_query(column).cte( f"{self.get_sampler_table_name()}_rnd" ) - with self.get_client() as client: - query = client.query(rnd) + query = self.get_client().query(rnd) return query.cte(f"{self.get_sampler_table_name()}_sample")