MINOR - generic profiler optimization for sampling and BQ (#14507)

* fix: limit sampling to specific column

* fix: handle bigquery struct columns

* fix: default partition to 1 DAY for BQ

* fix: default to __TABLES__ for BQ table metrics

* style: ran python linting

* style: fix linting

* fix: python style

* fix: set partition to DAY if not HOUR
This commit is contained in:
Teddy 2023-12-27 19:13:44 +01:00 committed by GitHub
parent 3dc642989c
commit 61ef55290e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 125 additions and 44 deletions

View File

@ -32,6 +32,9 @@ class BigQueryProfilerInterface(SQAProfilerInterface):
for key, value in columns.items():
if not isinstance(value, STRUCT):
col = Column(f"{parent}.{key}", value)
# pylint: disable=protected-access
col._set_parent(self.table.__table__)
# pylint: enable=protected-access
columns_list.append(col)
else:
col = self._get_struct_columns(

View File

@ -413,6 +413,8 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
partition_details=self.partition_details,
profile_sample_query=self.profile_query,
)
return thread_local.runner
thread_local.runner._sample = sample # pylint: disable=protected-access
return thread_local.runner
def compute_metrics_in_thread(
@ -431,7 +433,7 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
session,
metric_func.table,
)
sample = sampler.random_sample()
sample = sampler.random_sample(metric_func.column)
runner = self._create_thread_safe_runner(
session,
metric_func.table,
@ -565,7 +567,7 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
dictionnary of results
"""
sampler = self._get_sampler(table=kwargs.get("table"))
sample = sampler.random_sample()
sample = sampler.random_sample(column)
try:
return metric(column).fn(sample, column_results, self.session)
except Exception as exc:

View File

@ -173,34 +173,64 @@ def bigquery_table_construct(runner: QueryRunner, **kwargs):
Args:
runner (QueryRunner): query runner object
"""
conn_config = kwargs.get("conn_config")
conn_config = cast(BigQueryConnection, conn_config)
try:
schema_name, table_name = _get_table_and_schema_name(runner.table)
project_id = conn_config.credentials.gcpConfig.projectId.__root__
except AttributeError:
raise AttributeError(ERROR_MSG)
conn_config = kwargs.get("conn_config")
conn_config = cast(BigQueryConnection, conn_config)
table_storage = _build_table(
"TABLE_STORAGE", f"region-{conn_config.usageLocation}.INFORMATION_SCHEMA"
)
col_names, col_count = _get_col_names_and_count(runner.table)
columns = [
Column("total_rows").label("rowCount"),
Column("total_logical_bytes").label("sizeInBytes"),
Column("creation_time").label("createDateTime"),
col_names,
col_count,
]
where_clause = [
Column("table_schema") == schema_name,
Column("table_name") == table_name,
]
def table_storage():
"""Fall back method if retrieving table metadata from`__TABLES__` fails"""
table_storage = _build_table(
"TABLE_STORAGE", f"region-{conn_config.usageLocation}.INFORMATION_SCHEMA"
)
query = _build_query(columns, table_storage, where_clause)
columns = [
Column("total_rows").label("rowCount"),
Column("total_logical_bytes").label("sizeInBytes"),
Column("creation_time").label("createDateTime"),
col_names,
col_count,
]
return runner._session.execute(query).first()
where_clause = [
Column("project_id") == project_id,
Column("table_schema") == schema_name,
Column("table_name") == table_name,
]
query = _build_query(columns, table_storage, where_clause)
return runner._session.execute(query).first()
def tables():
"""retrieve table metadata from `__TABLES__`"""
table_meta = _build_table("__TABLES__", f"{project_id}.{schema_name}")
columns = [
Column("row_count").label("rowCount"),
Column("size_bytes").label("sizeInBytes"),
Column("creation_time").label("createDateTime"),
col_names,
col_count,
]
where_clause = [
Column("project_id") == project_id,
Column("dataset_id") == schema_name,
Column("table_id") == table_name,
]
query = _build_query(columns, table_meta, where_clause)
return runner._session.execute(query).first()
try:
return tables()
except Exception as exc:
logger.debug(f"Error retrieving table metadata from `__TABLES__`: {exc}")
return table_storage()
def clickhouse_table_construct(runner: QueryRunner, **kwargs):

View File

@ -112,8 +112,8 @@ class partition_filter_handler:
)
if self.build_sample:
return (
_self.client.query(
_self.table,
_self._base_sample_query(
kwargs.get("column"),
(ModuloFn(RandomNumFn(), 100)).label(RANDOM_LABEL),
)
.filter(partition_filter)

View File

@ -14,6 +14,7 @@ for the profiler
"""
from typing import Dict, Optional
from sqlalchemy import Column
from sqlalchemy.orm import Query
from metadata.generated.schema.entity.data.table import ProfileSampleType, TableType
@ -50,8 +51,35 @@ class BigQuerySampler(SQASampler):
)
self.table_type: TableType = table_type
def _base_sample_query(self, column: Optional[Column], label=None):
"""Base query for sampling
Args:
column (Optional[Column]): if computing a column metric only sample for the column
label (_type_, optional):
Returns:
"""
# pylint: disable=import-outside-toplevel
from sqlalchemy_bigquery import STRUCT
if column is not None:
column_parts = column.name.split(".")
if len(column_parts) > 1:
# for struct columns (e.g. `foo.bar`) we need to create a new column corresponding to
# the struct (e.g. `foo`) and then use that in the sample query as the column that
# will be query is `foo.bar`.
# e.g. WITH sample AS (SELECT `foo` FROM table) SELECT `foo.bar`
# FROM sample TABLESAMPLE SYSTEM (n PERCENT)
column = Column(column_parts[0], STRUCT)
# pylint: disable=protected-access
column._set_parent(self.table.__table__)
# pylint: enable=protected-access
return super()._base_sample_query(column, label=label)
@partition_filter_handler(build_sample=True)
def get_sample_query(self) -> Query:
def get_sample_query(self, *, column=None) -> Query:
"""get query for sample data"""
# TABLESAMPLE SYSTEM is not supported for views
if (
@ -59,11 +87,11 @@ class BigQuerySampler(SQASampler):
and self.table_type != TableType.View
):
return (
self._base_sample_query()
self._base_sample_query(column)
.suffix_with(
f"TABLESAMPLE SYSTEM ({self.profile_sample or 100} PERCENT)",
)
.cte(f"{self.table.__tablename__}_sample")
)
return super().get_sample_query()
return super().get_sample_query(column=column)

View File

@ -68,17 +68,28 @@ class SQASampler(SamplerInterface):
run the query in the whole table.
"""
def _base_sample_query(self, label=None):
def _base_sample_query(self, column: Optional[Column], label=None):
"""Base query for sampling
Args:
column (Optional[Column]): if computing a column metric only sample for the column
label (_type_, optional):
Returns:
"""
# only sample the column if we are computing a column metric to limit the amount of data scaned
entity = self.table if column is None else column
if label is not None:
return self.client.query(self.table, label)
return self.client.query(self.table)
return self.client.query(entity, label)
return self.client.query(entity)
@partition_filter_handler(build_sample=True)
def get_sample_query(self) -> Query:
def get_sample_query(self, *, column=None) -> Query:
"""get query for sample data"""
if self.profile_sample_type == ProfileSampleType.PERCENTAGE:
rnd = (
self._base_sample_query(
column,
(ModuloFn(RandomNumFn(), 100)).label(RANDOM_LABEL),
)
.suffix_with(
@ -94,6 +105,7 @@ class SQASampler(SamplerInterface):
table_query = self.client.query(self.table)
session_query = self._base_sample_query(
column,
(ModuloFn(RandomNumFn(), table_query.count())).label(RANDOM_LABEL),
)
return (
@ -102,7 +114,7 @@ class SQASampler(SamplerInterface):
.cte(f"{self.table.__tablename__}_rnd")
)
def random_sample(self) -> Union[DeclarativeMeta, AliasedClass]:
def random_sample(self, ccolumn=None) -> Union[DeclarativeMeta, AliasedClass]:
"""
Either return a sampled CTE of table, or
the full table if no sampling is required.
@ -117,7 +129,7 @@ class SQASampler(SamplerInterface):
return self.table
# Add new RandomNumFn column
sampled = self.get_sample_query()
sampled = self.get_sample_query(column=ccolumn)
# Assign as an alias
return aliased(self.table, sampled)

View File

@ -33,9 +33,10 @@ class TrinoSampler(SQASampler):
super().__init__(*args, **kwargs)
def _base_sample_query(self, label=None):
def _base_sample_query(self, column, label=None):
sqa_columns = [col for col in inspect(self.table).c if col.name != RANDOM_LABEL]
return self.client.query(self.table, label).where(
entity = self.table if column is None else column
return self.client.query(entity, label).where(
or_(
*[
text(f"is_nan({cols.name}) = False")

View File

@ -15,6 +15,7 @@ from typing import Optional
from metadata.generated.schema.entity.data.table import (
IntervalType,
PartitionIntervalUnit,
PartitionProfilerConfig,
Table,
)
@ -47,8 +48,10 @@ def get_partition_details(entity: Table) -> Optional[PartitionProfilerConfig]:
return PartitionProfilerConfig(
enablePartitioning=True,
partitionColumnName=entity.tablePartition.columns[0],
partitionIntervalUnit=entity.tablePartition.interval,
partitionInterval=30,
partitionIntervalUnit=PartitionIntervalUnit.DAY
if entity.tablePartition.interval != "HOUR"
else entity.tablePartition.interval,
partitionInterval=1,
partitionIntervalType=entity.tablePartition.intervalType.value,
partitionValues=None,
partitionIntegerRangeStart=None,
@ -60,8 +63,10 @@ def get_partition_details(entity: Table) -> Optional[PartitionProfilerConfig]:
partitionColumnName="_PARTITIONDATE"
if entity.tablePartition.interval == "DAY"
else "_PARTITIONTIME",
partitionIntervalUnit=entity.tablePartition.interval,
partitionInterval=30,
partitionIntervalUnit=PartitionIntervalUnit.DAY
if entity.tablePartition.interval != "HOUR"
else entity.tablePartition.interval,
partitionInterval=1,
partitionIntervalType=entity.tablePartition.intervalType.value,
partitionValues=None,
partitionIntegerRangeStart=None,
@ -72,7 +77,7 @@ def get_partition_details(entity: Table) -> Optional[PartitionProfilerConfig]:
enablePartitioning=True,
partitionColumnName=entity.tablePartition.columns[0],
partitionIntervalUnit=None,
partitionInterval=30,
partitionInterval=None,
partitionIntervalType=entity.tablePartition.intervalType.value,
partitionValues=None,
partitionIntegerRangeStart=1,

View File

@ -152,7 +152,7 @@ class ProfilerPartitionUnitTest(TestCase):
if resp:
assert resp.partitionColumnName == "e"
assert resp.partitionInterval == 30
assert resp.partitionInterval == 1
assert not resp.partitionValues
else:
assert False
@ -187,7 +187,7 @@ class ProfilerPartitionUnitTest(TestCase):
if resp:
assert resp.partitionColumnName == "_PARTITIONDATE"
assert resp.partitionInterval == 30
assert resp.partitionInterval == 1
assert not resp.partitionValues
else:
assert False
@ -221,7 +221,7 @@ class ProfilerPartitionUnitTest(TestCase):
if resp:
assert resp.partitionColumnName == "_PARTITIONTIME"
assert resp.partitionInterval == 30
assert resp.partitionInterval == 1
assert not resp.partitionValues
else:
assert False

View File

@ -82,7 +82,7 @@ def test_get_partition_details():
assert partition.enablePartitioning == True
assert partition.partitionColumnName == "_PARTITIONTIME"
assert partition.partitionIntervalType == PartitionIntervalType.INGESTION_TIME
assert partition.partitionInterval == 30
assert partition.partitionInterval == 1
assert partition.partitionIntervalUnit == PartitionIntervalUnit.HOUR
table_entity = MockTable(
@ -97,5 +97,5 @@ def test_get_partition_details():
assert partition.enablePartitioning == True
assert partition.partitionColumnName == "_PARTITIONDATE"
assert partition.partitionIntervalType == PartitionIntervalType.INGESTION_TIME
assert partition.partitionInterval == 30
assert partition.partitionInterval == 1
assert partition.partitionIntervalUnit == PartitionIntervalUnit.DAY