mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2026-01-04 11:27:10 +00:00
MINOR - generic profiler optimization for sampling and BQ (#14507)
* fix: limit sampling to specific column * fix: handle bigquery struct columns * fix: default partition to 1 DAY for BQ * fix: default to __TABLES__ for BQ table metrics * style: ran python linting * style: fix linting * fix: python style * fix: set partition to DAY if not HOUR
This commit is contained in:
parent
3dc642989c
commit
61ef55290e
@ -32,6 +32,9 @@ class BigQueryProfilerInterface(SQAProfilerInterface):
|
||||
for key, value in columns.items():
|
||||
if not isinstance(value, STRUCT):
|
||||
col = Column(f"{parent}.{key}", value)
|
||||
# pylint: disable=protected-access
|
||||
col._set_parent(self.table.__table__)
|
||||
# pylint: enable=protected-access
|
||||
columns_list.append(col)
|
||||
else:
|
||||
col = self._get_struct_columns(
|
||||
|
||||
@ -413,6 +413,8 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
|
||||
partition_details=self.partition_details,
|
||||
profile_sample_query=self.profile_query,
|
||||
)
|
||||
return thread_local.runner
|
||||
thread_local.runner._sample = sample # pylint: disable=protected-access
|
||||
return thread_local.runner
|
||||
|
||||
def compute_metrics_in_thread(
|
||||
@ -431,7 +433,7 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
|
||||
session,
|
||||
metric_func.table,
|
||||
)
|
||||
sample = sampler.random_sample()
|
||||
sample = sampler.random_sample(metric_func.column)
|
||||
runner = self._create_thread_safe_runner(
|
||||
session,
|
||||
metric_func.table,
|
||||
@ -565,7 +567,7 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
|
||||
dictionnary of results
|
||||
"""
|
||||
sampler = self._get_sampler(table=kwargs.get("table"))
|
||||
sample = sampler.random_sample()
|
||||
sample = sampler.random_sample(column)
|
||||
try:
|
||||
return metric(column).fn(sample, column_results, self.session)
|
||||
except Exception as exc:
|
||||
|
||||
@ -173,34 +173,64 @@ def bigquery_table_construct(runner: QueryRunner, **kwargs):
|
||||
Args:
|
||||
runner (QueryRunner): query runner object
|
||||
"""
|
||||
conn_config = kwargs.get("conn_config")
|
||||
conn_config = cast(BigQueryConnection, conn_config)
|
||||
try:
|
||||
schema_name, table_name = _get_table_and_schema_name(runner.table)
|
||||
project_id = conn_config.credentials.gcpConfig.projectId.__root__
|
||||
except AttributeError:
|
||||
raise AttributeError(ERROR_MSG)
|
||||
|
||||
conn_config = kwargs.get("conn_config")
|
||||
conn_config = cast(BigQueryConnection, conn_config)
|
||||
|
||||
table_storage = _build_table(
|
||||
"TABLE_STORAGE", f"region-{conn_config.usageLocation}.INFORMATION_SCHEMA"
|
||||
)
|
||||
col_names, col_count = _get_col_names_and_count(runner.table)
|
||||
columns = [
|
||||
Column("total_rows").label("rowCount"),
|
||||
Column("total_logical_bytes").label("sizeInBytes"),
|
||||
Column("creation_time").label("createDateTime"),
|
||||
col_names,
|
||||
col_count,
|
||||
]
|
||||
|
||||
where_clause = [
|
||||
Column("table_schema") == schema_name,
|
||||
Column("table_name") == table_name,
|
||||
]
|
||||
def table_storage():
|
||||
"""Fall back method if retrieving table metadata from`__TABLES__` fails"""
|
||||
table_storage = _build_table(
|
||||
"TABLE_STORAGE", f"region-{conn_config.usageLocation}.INFORMATION_SCHEMA"
|
||||
)
|
||||
|
||||
query = _build_query(columns, table_storage, where_clause)
|
||||
columns = [
|
||||
Column("total_rows").label("rowCount"),
|
||||
Column("total_logical_bytes").label("sizeInBytes"),
|
||||
Column("creation_time").label("createDateTime"),
|
||||
col_names,
|
||||
col_count,
|
||||
]
|
||||
|
||||
return runner._session.execute(query).first()
|
||||
where_clause = [
|
||||
Column("project_id") == project_id,
|
||||
Column("table_schema") == schema_name,
|
||||
Column("table_name") == table_name,
|
||||
]
|
||||
|
||||
query = _build_query(columns, table_storage, where_clause)
|
||||
|
||||
return runner._session.execute(query).first()
|
||||
|
||||
def tables():
|
||||
"""retrieve table metadata from `__TABLES__`"""
|
||||
table_meta = _build_table("__TABLES__", f"{project_id}.{schema_name}")
|
||||
columns = [
|
||||
Column("row_count").label("rowCount"),
|
||||
Column("size_bytes").label("sizeInBytes"),
|
||||
Column("creation_time").label("createDateTime"),
|
||||
col_names,
|
||||
col_count,
|
||||
]
|
||||
where_clause = [
|
||||
Column("project_id") == project_id,
|
||||
Column("dataset_id") == schema_name,
|
||||
Column("table_id") == table_name,
|
||||
]
|
||||
|
||||
query = _build_query(columns, table_meta, where_clause)
|
||||
return runner._session.execute(query).first()
|
||||
|
||||
try:
|
||||
return tables()
|
||||
except Exception as exc:
|
||||
logger.debug(f"Error retrieving table metadata from `__TABLES__`: {exc}")
|
||||
return table_storage()
|
||||
|
||||
|
||||
def clickhouse_table_construct(runner: QueryRunner, **kwargs):
|
||||
|
||||
@ -112,8 +112,8 @@ class partition_filter_handler:
|
||||
)
|
||||
if self.build_sample:
|
||||
return (
|
||||
_self.client.query(
|
||||
_self.table,
|
||||
_self._base_sample_query(
|
||||
kwargs.get("column"),
|
||||
(ModuloFn(RandomNumFn(), 100)).label(RANDOM_LABEL),
|
||||
)
|
||||
.filter(partition_filter)
|
||||
|
||||
@ -14,6 +14,7 @@ for the profiler
|
||||
"""
|
||||
from typing import Dict, Optional
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy.orm import Query
|
||||
|
||||
from metadata.generated.schema.entity.data.table import ProfileSampleType, TableType
|
||||
@ -50,8 +51,35 @@ class BigQuerySampler(SQASampler):
|
||||
)
|
||||
self.table_type: TableType = table_type
|
||||
|
||||
def _base_sample_query(self, column: Optional[Column], label=None):
|
||||
"""Base query for sampling
|
||||
|
||||
Args:
|
||||
column (Optional[Column]): if computing a column metric only sample for the column
|
||||
label (_type_, optional):
|
||||
|
||||
Returns:
|
||||
"""
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from sqlalchemy_bigquery import STRUCT
|
||||
|
||||
if column is not None:
|
||||
column_parts = column.name.split(".")
|
||||
if len(column_parts) > 1:
|
||||
# for struct columns (e.g. `foo.bar`) we need to create a new column corresponding to
|
||||
# the struct (e.g. `foo`) and then use that in the sample query as the column that
|
||||
# will be query is `foo.bar`.
|
||||
# e.g. WITH sample AS (SELECT `foo` FROM table) SELECT `foo.bar`
|
||||
# FROM sample TABLESAMPLE SYSTEM (n PERCENT)
|
||||
column = Column(column_parts[0], STRUCT)
|
||||
# pylint: disable=protected-access
|
||||
column._set_parent(self.table.__table__)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
return super()._base_sample_query(column, label=label)
|
||||
|
||||
@partition_filter_handler(build_sample=True)
|
||||
def get_sample_query(self) -> Query:
|
||||
def get_sample_query(self, *, column=None) -> Query:
|
||||
"""get query for sample data"""
|
||||
# TABLESAMPLE SYSTEM is not supported for views
|
||||
if (
|
||||
@ -59,11 +87,11 @@ class BigQuerySampler(SQASampler):
|
||||
and self.table_type != TableType.View
|
||||
):
|
||||
return (
|
||||
self._base_sample_query()
|
||||
self._base_sample_query(column)
|
||||
.suffix_with(
|
||||
f"TABLESAMPLE SYSTEM ({self.profile_sample or 100} PERCENT)",
|
||||
)
|
||||
.cte(f"{self.table.__tablename__}_sample")
|
||||
)
|
||||
|
||||
return super().get_sample_query()
|
||||
return super().get_sample_query(column=column)
|
||||
|
||||
@ -68,17 +68,28 @@ class SQASampler(SamplerInterface):
|
||||
run the query in the whole table.
|
||||
"""
|
||||
|
||||
def _base_sample_query(self, label=None):
|
||||
def _base_sample_query(self, column: Optional[Column], label=None):
|
||||
"""Base query for sampling
|
||||
|
||||
Args:
|
||||
column (Optional[Column]): if computing a column metric only sample for the column
|
||||
label (_type_, optional):
|
||||
|
||||
Returns:
|
||||
"""
|
||||
# only sample the column if we are computing a column metric to limit the amount of data scaned
|
||||
entity = self.table if column is None else column
|
||||
if label is not None:
|
||||
return self.client.query(self.table, label)
|
||||
return self.client.query(self.table)
|
||||
return self.client.query(entity, label)
|
||||
return self.client.query(entity)
|
||||
|
||||
@partition_filter_handler(build_sample=True)
|
||||
def get_sample_query(self) -> Query:
|
||||
def get_sample_query(self, *, column=None) -> Query:
|
||||
"""get query for sample data"""
|
||||
if self.profile_sample_type == ProfileSampleType.PERCENTAGE:
|
||||
rnd = (
|
||||
self._base_sample_query(
|
||||
column,
|
||||
(ModuloFn(RandomNumFn(), 100)).label(RANDOM_LABEL),
|
||||
)
|
||||
.suffix_with(
|
||||
@ -94,6 +105,7 @@ class SQASampler(SamplerInterface):
|
||||
|
||||
table_query = self.client.query(self.table)
|
||||
session_query = self._base_sample_query(
|
||||
column,
|
||||
(ModuloFn(RandomNumFn(), table_query.count())).label(RANDOM_LABEL),
|
||||
)
|
||||
return (
|
||||
@ -102,7 +114,7 @@ class SQASampler(SamplerInterface):
|
||||
.cte(f"{self.table.__tablename__}_rnd")
|
||||
)
|
||||
|
||||
def random_sample(self) -> Union[DeclarativeMeta, AliasedClass]:
|
||||
def random_sample(self, ccolumn=None) -> Union[DeclarativeMeta, AliasedClass]:
|
||||
"""
|
||||
Either return a sampled CTE of table, or
|
||||
the full table if no sampling is required.
|
||||
@ -117,7 +129,7 @@ class SQASampler(SamplerInterface):
|
||||
return self.table
|
||||
|
||||
# Add new RandomNumFn column
|
||||
sampled = self.get_sample_query()
|
||||
sampled = self.get_sample_query(column=ccolumn)
|
||||
|
||||
# Assign as an alias
|
||||
return aliased(self.table, sampled)
|
||||
|
||||
@ -33,9 +33,10 @@ class TrinoSampler(SQASampler):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _base_sample_query(self, label=None):
|
||||
def _base_sample_query(self, column, label=None):
|
||||
sqa_columns = [col for col in inspect(self.table).c if col.name != RANDOM_LABEL]
|
||||
return self.client.query(self.table, label).where(
|
||||
entity = self.table if column is None else column
|
||||
return self.client.query(entity, label).where(
|
||||
or_(
|
||||
*[
|
||||
text(f"is_nan({cols.name}) = False")
|
||||
|
||||
@ -15,6 +15,7 @@ from typing import Optional
|
||||
|
||||
from metadata.generated.schema.entity.data.table import (
|
||||
IntervalType,
|
||||
PartitionIntervalUnit,
|
||||
PartitionProfilerConfig,
|
||||
Table,
|
||||
)
|
||||
@ -47,8 +48,10 @@ def get_partition_details(entity: Table) -> Optional[PartitionProfilerConfig]:
|
||||
return PartitionProfilerConfig(
|
||||
enablePartitioning=True,
|
||||
partitionColumnName=entity.tablePartition.columns[0],
|
||||
partitionIntervalUnit=entity.tablePartition.interval,
|
||||
partitionInterval=30,
|
||||
partitionIntervalUnit=PartitionIntervalUnit.DAY
|
||||
if entity.tablePartition.interval != "HOUR"
|
||||
else entity.tablePartition.interval,
|
||||
partitionInterval=1,
|
||||
partitionIntervalType=entity.tablePartition.intervalType.value,
|
||||
partitionValues=None,
|
||||
partitionIntegerRangeStart=None,
|
||||
@ -60,8 +63,10 @@ def get_partition_details(entity: Table) -> Optional[PartitionProfilerConfig]:
|
||||
partitionColumnName="_PARTITIONDATE"
|
||||
if entity.tablePartition.interval == "DAY"
|
||||
else "_PARTITIONTIME",
|
||||
partitionIntervalUnit=entity.tablePartition.interval,
|
||||
partitionInterval=30,
|
||||
partitionIntervalUnit=PartitionIntervalUnit.DAY
|
||||
if entity.tablePartition.interval != "HOUR"
|
||||
else entity.tablePartition.interval,
|
||||
partitionInterval=1,
|
||||
partitionIntervalType=entity.tablePartition.intervalType.value,
|
||||
partitionValues=None,
|
||||
partitionIntegerRangeStart=None,
|
||||
@ -72,7 +77,7 @@ def get_partition_details(entity: Table) -> Optional[PartitionProfilerConfig]:
|
||||
enablePartitioning=True,
|
||||
partitionColumnName=entity.tablePartition.columns[0],
|
||||
partitionIntervalUnit=None,
|
||||
partitionInterval=30,
|
||||
partitionInterval=None,
|
||||
partitionIntervalType=entity.tablePartition.intervalType.value,
|
||||
partitionValues=None,
|
||||
partitionIntegerRangeStart=1,
|
||||
|
||||
@ -152,7 +152,7 @@ class ProfilerPartitionUnitTest(TestCase):
|
||||
|
||||
if resp:
|
||||
assert resp.partitionColumnName == "e"
|
||||
assert resp.partitionInterval == 30
|
||||
assert resp.partitionInterval == 1
|
||||
assert not resp.partitionValues
|
||||
else:
|
||||
assert False
|
||||
@ -187,7 +187,7 @@ class ProfilerPartitionUnitTest(TestCase):
|
||||
|
||||
if resp:
|
||||
assert resp.partitionColumnName == "_PARTITIONDATE"
|
||||
assert resp.partitionInterval == 30
|
||||
assert resp.partitionInterval == 1
|
||||
assert not resp.partitionValues
|
||||
else:
|
||||
assert False
|
||||
@ -221,7 +221,7 @@ class ProfilerPartitionUnitTest(TestCase):
|
||||
|
||||
if resp:
|
||||
assert resp.partitionColumnName == "_PARTITIONTIME"
|
||||
assert resp.partitionInterval == 30
|
||||
assert resp.partitionInterval == 1
|
||||
assert not resp.partitionValues
|
||||
else:
|
||||
assert False
|
||||
|
||||
@ -82,7 +82,7 @@ def test_get_partition_details():
|
||||
assert partition.enablePartitioning == True
|
||||
assert partition.partitionColumnName == "_PARTITIONTIME"
|
||||
assert partition.partitionIntervalType == PartitionIntervalType.INGESTION_TIME
|
||||
assert partition.partitionInterval == 30
|
||||
assert partition.partitionInterval == 1
|
||||
assert partition.partitionIntervalUnit == PartitionIntervalUnit.HOUR
|
||||
|
||||
table_entity = MockTable(
|
||||
@ -97,5 +97,5 @@ def test_get_partition_details():
|
||||
assert partition.enablePartitioning == True
|
||||
assert partition.partitionColumnName == "_PARTITIONDATE"
|
||||
assert partition.partitionIntervalType == PartitionIntervalType.INGESTION_TIME
|
||||
assert partition.partitionInterval == 30
|
||||
assert partition.partitionInterval == 1
|
||||
assert partition.partitionIntervalUnit == PartitionIntervalUnit.DAY
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user