diff --git a/ingestion/src/metadata/sampler/sqlalchemy/azuresql/sampler.py b/ingestion/src/metadata/sampler/sqlalchemy/azuresql/sampler.py index 70e3ee14486..83a3c4062e9 100644 --- a/ingestion/src/metadata/sampler/sqlalchemy/azuresql/sampler.py +++ b/ingestion/src/metadata/sampler/sqlalchemy/azuresql/sampler.py @@ -15,7 +15,7 @@ for the profiler from typing import List, Optional from sqlalchemy import Column, Table, text -from sqlalchemy.orm import Query +from sqlalchemy.sql.selectable import CTE from metadata.generated.schema.entity.data.table import TableData, TableType from metadata.sampler.sqlalchemy.sampler import ProfileSampleType, SQASampler @@ -49,13 +49,12 @@ class AzureSQLSampler(SQASampler): return selectable - def get_sample_query(self, *, column=None) -> Query: - """get query for sample data""" + def get_sample_query(self, *, column=None) -> CTE: + """Override the base method as ROWS or PERCENT sampling handled through the tablesample clause""" rnd = self._base_sample_query(column).cte( f"{self.get_sampler_table_name()}_rnd" ) - with self.get_client() as client: - query = client.query(rnd) + query = self.get_client().query(rnd) return query.cte(f"{self.get_sampler_table_name()}_sample") def fetch_sample_data(self, columns: Optional[List[Column]] = None) -> TableData: diff --git a/ingestion/src/metadata/sampler/sqlalchemy/mssql/sampler.py b/ingestion/src/metadata/sampler/sqlalchemy/mssql/sampler.py index a1e2fb93609..349ce36fee8 100644 --- a/ingestion/src/metadata/sampler/sqlalchemy/mssql/sampler.py +++ b/ingestion/src/metadata/sampler/sqlalchemy/mssql/sampler.py @@ -44,10 +44,9 @@ class MssqlSampler(SQASampler): return selectable def get_sample_query(self, *, column=None) -> CTE: - """get query for sample data""" + """Override the base method as ROWS or PERCENT sampling handled through the tablesample clause""" rnd = self._base_sample_query(column).cte( f"{self.get_sampler_table_name()}_rnd" ) - with self.get_client() as client: - query = client.query(rnd) + query = self.get_client().query(rnd) return query.cte(f"{self.get_sampler_table_name()}_sample") diff --git a/ingestion/src/metadata/sampler/sqlalchemy/sampler.py b/ingestion/src/metadata/sampler/sqlalchemy/sampler.py index 27d51398598..69dbd2a75b1 100644 --- a/ingestion/src/metadata/sampler/sqlalchemy/sampler.py +++ b/ingestion/src/metadata/sampler/sqlalchemy/sampler.py @@ -136,6 +136,8 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin): ).cte(f"{self.get_sampler_table_name()}_sample") table_query = client.query(self.raw_dataset) + if self.partition_details: + table_query = self.get_partitioned_query(table_query) session_query = self._base_sample_query( column, (ModuloFn(RandomNumFn(), table_query.count())).label(RANDOM_LABEL) diff --git a/ingestion/src/metadata/sampler/sqlalchemy/snowflake/sampler.py b/ingestion/src/metadata/sampler/sqlalchemy/snowflake/sampler.py index 274d0c0ce4f..ea64bb9e3f8 100644 --- a/ingestion/src/metadata/sampler/sqlalchemy/snowflake/sampler.py +++ b/ingestion/src/metadata/sampler/sqlalchemy/snowflake/sampler.py @@ -87,10 +87,9 @@ class SnowflakeSampler(SQASampler): ) def get_sample_query(self, *, column=None) -> CTE: - """get query for sample data""" + """Override the base method as ROWS or PERCENT sampling handled through the tablesample clause""" rnd = self._base_sample_query(column).cte( f"{self.get_sampler_table_name()}_rnd" ) - with self.get_client() as client: - query = client.query(rnd) + query = self.get_client().query(rnd) return query.cte(f"{self.get_sampler_table_name()}_sample")