MINOR: row sampling error (#21488)

* fix: row sampling error

* fix: return sample query

(cherry picked from commit 859f24aba7c2b0bcbaf85b149b2a02d60fa4e201)
This commit is contained in:
Teddy 2025-06-02 09:02:17 +02:00 committed by Teddy Crepineau
parent 38724bf2fe
commit 68b0eb34b7
4 changed files with 10 additions and 11 deletions

View File

@ -15,7 +15,7 @@ for the profiler
from typing import List, Optional
from sqlalchemy import Column, Table, text
from sqlalchemy.orm import Query
from sqlalchemy.sql.selectable import CTE
from metadata.generated.schema.entity.data.table import TableData, TableType
from metadata.sampler.sqlalchemy.sampler import ProfileSampleType, SQASampler
@ -49,13 +49,12 @@ class AzureSQLSampler(SQASampler):
return selectable
def get_sample_query(self, *, column=None) -> Query:
"""get query for sample data"""
def get_sample_query(self, *, column=None) -> CTE:
"""Override the base method as ROWS or PERCENT sampling handled through the tablesample clause"""
rnd = self._base_sample_query(column).cte(
f"{self.get_sampler_table_name()}_rnd"
)
with self.get_client() as client:
query = client.query(rnd)
query = self.get_client().query(rnd)
return query.cte(f"{self.get_sampler_table_name()}_sample")
def fetch_sample_data(self, columns: Optional[List[Column]] = None) -> TableData:

View File

@ -44,10 +44,9 @@ class MssqlSampler(SQASampler):
return selectable
def get_sample_query(self, *, column=None) -> CTE:
"""get query for sample data"""
"""Override the base method as ROWS or PERCENT sampling handled through the tablesample clause"""
rnd = self._base_sample_query(column).cte(
f"{self.get_sampler_table_name()}_rnd"
)
with self.get_client() as client:
query = client.query(rnd)
query = self.get_client().query(rnd)
return query.cte(f"{self.get_sampler_table_name()}_sample")

View File

@ -136,6 +136,8 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
).cte(f"{self.get_sampler_table_name()}_sample")
table_query = client.query(self.raw_dataset)
if self.partition_details:
table_query = self.get_partitioned_query(table_query)
session_query = self._base_sample_query(
column,
(ModuloFn(RandomNumFn(), table_query.count())).label(RANDOM_LABEL)

View File

@ -87,10 +87,9 @@ class SnowflakeSampler(SQASampler):
)
def get_sample_query(self, *, column=None) -> CTE:
"""get query for sample data"""
"""Override the base method as ROWS or PERCENT sampling handled through the tablesample clause"""
rnd = self._base_sample_query(column).cte(
f"{self.get_sampler_table_name()}_rnd"
)
with self.get_client() as client:
query = client.query(rnd)
query = self.get_client().query(rnd)
return query.cte(f"{self.get_sampler_table_name()}_sample")