fix: added fall back to table metric computation (#14771)

This commit is contained in:
Teddy 2024-01-18 20:03:57 +01:00 committed by GitHub
parent 5ff74b5108
commit 9c4ec6c760
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 466 additions and 399 deletions

View File

@ -37,9 +37,7 @@ from metadata.profiler.metrics.registry import Metrics
from metadata.profiler.metrics.static.mean import Mean
from metadata.profiler.metrics.static.stddev import StdDev
from metadata.profiler.metrics.static.sum import Sum
from metadata.profiler.orm.functions.table_metric_construct import (
table_metric_construct_factory,
)
from metadata.profiler.orm.functions.table_metric_computer import TableMetricComputer
from metadata.profiler.orm.registry import Dialects
from metadata.profiler.processor.runner import QueryRunner
from metadata.utils.constants import SAMPLE_DATA_DEFAULT_COUNT
@ -186,12 +184,13 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
# pylint: disable=protected-access
try:
dialect = runner._session.get_bind().dialect.name
row = table_metric_construct_factory.construct(
table_metric_computer: TableMetricComputer = TableMetricComputer(
dialect,
runner=runner,
metrics=metrics,
conn_config=self.service_connection_config,
)
row = table_metric_computer.compute()
if row:
return dict(row)
return None

View File

@ -0,0 +1,463 @@
# pylint: disable=protected-access,attribute-defined-outside-init
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Run profiler metrics on the table
"""
import traceback
from abc import ABC, abstractmethod
from typing import Callable, List, Optional, Tuple
from sqlalchemy import Column, MetaData, Table, func, inspect, literal, select
from sqlalchemy.sql.expression import ColumnOperators, and_, cte
from sqlalchemy.types import String
from metadata.profiler.metrics.registry import Metrics
from metadata.profiler.orm.registry import Dialects
from metadata.profiler.processor.runner import QueryRunner
from metadata.utils.logger import profiler_interface_registry_logger
logger = profiler_interface_registry_logger()
COLUMN_COUNT = "columnCount"
COLUMN_NAMES = "columnNames"
ROW_COUNT = "rowCount"
SIZE_IN_BYTES = "sizeInBytes"
CREATE_DATETIME = "createDateTime"
ERROR_MSG = (
"Schema/Table name not found in table args. Falling back to default computation"
)
class AbstractTableMetricComputer(ABC):
"""Base table computer"""
def __init__(self, runner: QueryRunner, metrics: List[Metrics], conn_config):
"""Instantiate base table computer"""
self._runner = runner
self._metrics = metrics
self._conn_config = conn_config
self._database = self._runner._session.get_bind().url.database
self._table = self._runner.table
@property
def database(self):
return self._database
@property
def table(self):
return self._table
@property
def runner(self):
return self._runner
@property
def metrics(self):
return self._metrics
@property
def table_name(self):
return self._table_name
@property
def schema_name(self):
return self._schema_name
@property
def conn_config(self):
return self._conn_config
def _set_table_and_schema_name(self):
"""get table and schema name from table args
Args:
table (DeclarativeMeta): _description_
"""
try:
self._schema_name = self.table.__table_args__.get("schema")
self._table_name = self.table.__tablename__
except AttributeError:
raise AttributeError(ERROR_MSG)
def _build_table(self, table, schema) -> Table:
"""build table object from table name and schema name
Args:
table_name (str): table name
schema_name (str): schema name
Returns:
Table
"""
if schema:
return Table(table, MetaData(), schema=schema)
return Table(table, MetaData())
def _get_col_names_and_count(self) -> Tuple[str, int]:
"""get column names and count from table
Args:
table (DeclarativeMeta): table object
Returns:
Tuple[str, int]
"""
col_names = literal(",".join(inspect(self.table).c.keys()), type_=String).label(
COLUMN_NAMES
)
col_count = literal(len(inspect(self.table).c)).label(COLUMN_COUNT)
return col_names, col_count
def _build_query(
self,
columns: List[Column],
table: Table,
where_clause: Optional[List[ColumnOperators]] = None,
):
query = select(*columns).select_from(table)
if where_clause:
query = query.where(*where_clause)
return query
@abstractmethod
def compute(self):
"""Default compute behavior for table metrics"""
raise NotImplementedError
class BaseTableMetricComputer(AbstractTableMetricComputer):
"""Base table computer"""
def _check_and_return(self, res):
"""Check if the result is None and return the result or fallback
Args:
res (object): result
"""
if res.rowCount is None:
return super().compute()
return res
def compute(self):
"""Default compute behavior for table metrics"""
return self.runner.select_first_from_table(
*[metric().fn() for metric in self.metrics]
)
class SnowflakeTableMetricComputer(BaseTableMetricComputer):
"""Snowflake Table Metric Computer"""
def compute(self):
"""Compute table metrics for snowflake"""
columns = [
Column("ROW_COUNT").label("rowCount"),
Column("BYTES").label("sizeInBytes"),
Column("CREATED").label("createDateTime"),
*self._get_col_names_and_count(),
]
where_clause = [
func.lower(Column("TABLE_CATALOG")) == self.database.lower(),
func.lower(Column("TABLE_SCHEMA")) == self.schema_name.lower(),
func.lower(Column("TABLE_NAME")) == self.table_name.lower(),
]
query = self._build_query(
columns,
self._build_table("TABLES", f"{self.database}.INFORMATION_SCHEMA"),
where_clause,
)
rest = self._runner._session.execute(query).first()
if rest.rowCount is None:
# if we don't have any row count, fallback to the base logic
return super().compute()
return rest
class OracleTableMetricComputer(BaseTableMetricComputer):
"""Oracle Table Metric Computer"""
def compute(self):
"""Compute table metrics for oracle"""
create_date = cte(
self._build_query(
[
Column("owner"),
Column("object_name").label("table_name"),
Column("created"),
],
self._build_table("dba_objects", None),
[
func.lower(Column("owner")) == self.schema_name.lower(),
func.lower(Column("object_name")) == self.table_name.lower(),
],
)
)
row_count = cte(
self._build_query(
[
Column("owner"),
Column("table_name"),
Column("NUM_ROWS"),
],
self._build_table("all_tables", None),
[
func.lower(Column("owner")) == self.schema_name.lower(),
func.lower(Column("table_name")) == self.table_name.lower(),
],
)
)
columns = [
Column("NUM_ROWS").label("rowCount"),
Column("created").label("createDateTime"),
*self._get_col_names_and_count(),
]
query = self._build_query(columns, row_count).join(
create_date,
and_(
row_count.c.table_name == create_date.c.table_name,
row_count.c.owner == create_date.c.owner,
),
)
res = self.runner._session.execute(query).first()
if res.rowCount is None:
# if we don't have any row count, fallback to the base logic
return super().compute()
return res
class ClickHouseTableMetricComputer(BaseTableMetricComputer):
"""ClickHouse Table Metric Computer"""
def compute(self):
"""compute table metrics for clickhouse"""
columns = [
Column("total_rows").label("rowCount"),
Column("total_bytes").label("sizeInBytes"),
*self._get_col_names_and_count(),
]
where_clause = [
Column("database") == self.schema_name,
Column("name") == self.table_name,
]
query = self._build_query(
columns, self._build_table("tables", "system"), where_clause
)
res = self.runner._session.execute(query).first()
if res.rowCount is None:
# if we don't have any row count, fallback to the base logic
return super().compute()
return res
class BigQueryTableMetricComputer(BaseTableMetricComputer):
"""BigQuery Table Metric Computer"""
def compute(self):
"""compute table metrics for bigquery"""
try:
return self.tables()
except Exception as exc:
# if an error occurs fetching data from `__TABLES__`, fallback to `TABLE_STORAGE`
logger.debug(f"Error retrieving table metadata from `__TABLES__`: {exc}")
return self.table_storage()
def table_storage(self):
"""Fall back method if retrieving table metadata from`__TABLES__` fails"""
columns = [
Column("total_rows").label("rowCount"),
Column("total_logical_bytes").label("sizeInBytes"),
Column("creation_time").label("createDateTime"),
*self._get_col_names_and_count(),
]
where_clause = [
Column("project_id")
== self.conn_config.credentials.gcpConfig.projectId.__root__,
Column("table_schema") == self.schema_name,
Column("table_name") == self.table_name,
]
query = self._build_query(
columns,
self._build_table(
"TABLE_STORAGE",
f"region-{self.conn_config.usageLocation}.INFORMATION_SCHEMA",
),
where_clause,
)
res = self.runner._session.execute(query).first()
if res.rowCount is None:
# if we don't have any row count, fallback to the base logic
return super().compute()
return res
def tables(self):
"""retrieve table metadata from `__TABLES__`"""
columns = [
Column("row_count").label("rowCount"),
Column("size_bytes").label("sizeInBytes"),
Column("creation_time").label("createDateTime"),
*self._get_col_names_and_count(),
]
where_clause = [
Column("project_id")
== self.conn_config.credentials.gcpConfig.projectId.__root__,
Column("dataset_id") == self.schema_name,
Column("table_id") == self.table_name,
]
query = self._build_query(
columns,
self._build_table(
"__TABLES__",
f"{self.conn_config.credentials.gcpConfig.projectId.__root__}.{self.schema_name}",
),
where_clause,
)
res = self.runner._session.execute(query).first()
if res.rowCount is None:
# if we don't have any row count, fallback to the base logic
return super().compute()
return res
class MySQLTableMetricComputer(BaseTableMetricComputer):
"""MySQL Table Metric Computer"""
def compute(self):
"""compute table metrics for mysql"""
columns = [
Column("TABLE_ROWS").label(ROW_COUNT),
(Column("data_length") + Column("index_length")).label(SIZE_IN_BYTES),
Column("CREATE_TIME").label(CREATE_DATETIME),
*self._get_col_names_and_count(),
]
where_clause = [
Column("TABLE_SCHEMA") == self.schema_name,
Column("TABLE_NAME") == self.table_name,
]
query = self._build_query(
columns, self._build_table("tables", "information_schema"), where_clause
)
res = self.runner._session.execute(query).first()
if res.rowCount is None:
# if we don't have any row count, fallback to the base logic
return super().compute()
return res
class RedshiftTableMetricComputer(BaseTableMetricComputer):
"""Redshift Table Metric Computer"""
def compute(self):
"""compute table metrics for redshift"""
columns = [
Column("estimated_visible_rows").label(ROW_COUNT),
Column("size").label(SIZE_IN_BYTES),
Column("create_time").label(CREATE_DATETIME),
*self._get_col_names_and_count(),
]
where_clause = [
Column("schema") == self.schema_name,
Column("table") == self.table_name,
]
query = self._build_query(
columns, self._build_table("svv_table_info", "pg_catalog"), where_clause
)
res = self.runner._session.execute(query).first()
if res.rowCount is None:
# if we don't have any row count, fallback to the base logic
return super().compute()
return res
class TableMetricComputer:
"""Table Metric Construct"""
def __init__(
self, dialect: str, runner: QueryRunner, metrics: List[Metrics], conn_config
):
"""Instantiate table metric computer with a dialect computer"""
self._dialect = dialect
self._runner = runner
self._metrics = metrics
self._conn_config = conn_config
self.table_metric_computer: AbstractTableMetricComputer = (
table_metric_computer_factory.construct(
self._dialect,
runner=self._runner,
metrics=self._metrics,
conn_config=self._conn_config,
)
)
def compute(self):
"""Compute table metrics"""
return self.table_metric_computer.compute()
class TableMetricComputerFactory:
"""Factory returning the correct construct for the table metrics based on dialect"""
def __init__(self):
self._constructs = {}
def register(self, dialect: str, construct: Callable):
"""Register a construct for a dialect"""
self._constructs[dialect] = construct
def construct(self, dialect, **kwargs):
"""Construct the query"""
# check if we have registered a construct for the dialect
construct = self._constructs.get(dialect)
if not construct:
construct = self._constructs["base"]
return construct(**kwargs)
try:
construct_instance: AbstractTableMetricComputer = construct(**kwargs)
construct_instance._set_table_and_schema_name()
return construct_instance
except Exception:
# if an error occurs, fallback to the base construct
logger.debug(traceback.format_exc())
return self._constructs["base"](**kwargs)
table_metric_computer_factory = TableMetricComputerFactory()
table_metric_computer_factory.register("base", BaseTableMetricComputer)
table_metric_computer_factory.register(Dialects.Redshift, RedshiftTableMetricComputer)
table_metric_computer_factory.register(Dialects.MySQL, MySQLTableMetricComputer)
table_metric_computer_factory.register(Dialects.BigQuery, BigQueryTableMetricComputer)
table_metric_computer_factory.register(
Dialects.ClickHouse, ClickHouseTableMetricComputer
)
table_metric_computer_factory.register(Dialects.Oracle, OracleTableMetricComputer)
table_metric_computer_factory.register(Dialects.Snowflake, SnowflakeTableMetricComputer)

View File

@ -1,395 +0,0 @@
# pylint: disable=unused-argument,protected-access
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Run profiler metrics on the table
"""
import traceback
from typing import Callable, List, Optional, Tuple, cast
from sqlalchemy import Column, MetaData, Table, func, inspect, literal, select
from sqlalchemy.orm import DeclarativeMeta
from sqlalchemy.sql.expression import ColumnOperators, and_, cte
from sqlalchemy.types import String
from metadata.generated.schema.entity.services.connections.database.bigQueryConnection import (
BigQueryConnection,
)
from metadata.profiler.metrics.registry import Metrics
from metadata.profiler.orm.registry import Dialects
from metadata.profiler.processor.runner import QueryRunner
from metadata.utils.logger import profiler_interface_registry_logger
logger = profiler_interface_registry_logger()
COLUMN_COUNT = "columnCount"
COLUMN_NAMES = "columnNames"
ROW_COUNT = "rowCount"
SIZE_IN_BYTES = "sizeInBytes"
CREATE_DATETIME = "createDateTime"
ERROR_MSG = (
"Schema/Table name not found in table args. Falling back to default computation"
)
def _get_table_and_schema_name(table: DeclarativeMeta) -> Tuple[str, str]:
"""get table and schema name from table args
Args:
table (DeclarativeMeta): _description_
"""
schema_name = table.__table_args__.get("schema")
table_name = table.__tablename__
return schema_name, table_name
def _build_table(table_name: str, schema_name: Optional[str] = None) -> Table:
"""build table object from table name and schema name
Args:
table_name (str): table name
schema_name (str): schema name
Returns:
Table
"""
if schema_name:
return Table(table_name, MetaData(), schema=schema_name)
return Table(table_name, MetaData())
def _get_col_names_and_count(table: DeclarativeMeta) -> Tuple[str, int]:
"""get column names and count from table
Args:
table (DeclarativeMeta): table object
Returns:
Tuple[str, int]
"""
col_names = literal(",".join(inspect(table).c.keys()), type_=String).label(
COLUMN_NAMES
)
col_count = literal(len(inspect(table).c)).label(COLUMN_COUNT)
return col_names, col_count
def _build_query(
columns: List[Column],
table: Table,
where_clause: Optional[List[ColumnOperators]] = None,
):
query = select(*columns).select_from(table)
if where_clause:
query = query.where(*where_clause)
return query
def base_table_construct(runner: QueryRunner, metrics: List[Metrics], **kwargs):
"""base table construct for table metrics
Args:
runner (QueryRunner): runner object to execute query
metrics (List[Metrics]): list of metrics
"""
return runner.select_first_from_table(*[metric().fn() for metric in metrics])
def redshift_table_construct(runner: QueryRunner, **kwargs):
"""redshift table construct for table metrics
Args:
runner (QueryRunner): runner object to execute query
"""
try:
schema_name, table_name = _get_table_and_schema_name(runner.table)
except AttributeError:
raise AttributeError(ERROR_MSG)
columns = [
Column("estimated_visible_rows").label(ROW_COUNT),
Column("size").label(SIZE_IN_BYTES),
Column("create_time").label(CREATE_DATETIME),
*_get_col_names_and_count(runner.table),
]
where_clause = [
Column("schema") == schema_name,
Column("table") == table_name,
]
query = _build_query(
columns, _build_table("svv_table_info", "pg_catalog"), where_clause
)
return runner._session.execute(query).first()
def mysql_table_construct(runner: QueryRunner, **kwargs):
"""MySQL table construct for table metrics
Args:
runner (QueryRunner): query runner object
"""
try:
schema_name, table_name = _get_table_and_schema_name(runner.table)
except AttributeError:
raise AttributeError(ERROR_MSG)
tables = _build_table("tables", "information_schema")
col_names, col_count = _get_col_names_and_count(runner.table)
columns = [
Column("TABLE_ROWS").label(ROW_COUNT),
(Column("data_length") + Column("index_length")).label(SIZE_IN_BYTES),
Column("CREATE_TIME").label(CREATE_DATETIME),
col_names,
col_count,
]
where_clause = [
Column("TABLE_SCHEMA") == schema_name,
Column("TABLE_NAME") == table_name,
]
query = _build_query(columns, tables, where_clause)
return runner._session.execute(query).first()
def bigquery_table_construct(runner: QueryRunner, **kwargs):
"""bigquery table construct for table metrics
Args:
runner (QueryRunner): query runner object
"""
conn_config = kwargs.get("conn_config")
conn_config = cast(BigQueryConnection, conn_config)
try:
schema_name, table_name = _get_table_and_schema_name(runner.table)
project_id = conn_config.credentials.gcpConfig.projectId.__root__
except AttributeError:
raise AttributeError(ERROR_MSG)
col_names, col_count = _get_col_names_and_count(runner.table)
def table_storage():
"""Fall back method if retrieving table metadata from`__TABLES__` fails"""
table_storage = _build_table(
"TABLE_STORAGE", f"region-{conn_config.usageLocation}.INFORMATION_SCHEMA"
)
columns = [
Column("total_rows").label("rowCount"),
Column("total_logical_bytes").label("sizeInBytes"),
Column("creation_time").label("createDateTime"),
col_names,
col_count,
]
where_clause = [
Column("project_id") == project_id,
Column("table_schema") == schema_name,
Column("table_name") == table_name,
]
query = _build_query(columns, table_storage, where_clause)
return runner._session.execute(query).first()
def tables():
"""retrieve table metadata from `__TABLES__`"""
table_meta = _build_table("__TABLES__", f"{project_id}.{schema_name}")
columns = [
Column("row_count").label("rowCount"),
Column("size_bytes").label("sizeInBytes"),
Column("creation_time").label("createDateTime"),
col_names,
col_count,
]
where_clause = [
Column("project_id") == project_id,
Column("dataset_id") == schema_name,
Column("table_id") == table_name,
]
query = _build_query(columns, table_meta, where_clause)
return runner._session.execute(query).first()
try:
return tables()
except Exception as exc:
logger.debug(f"Error retrieving table metadata from `__TABLES__`: {exc}")
return table_storage()
def clickhouse_table_construct(runner: QueryRunner, **kwargs):
"""clickhouse table construct for table metrics
Args:
runner (QueryRunner): query runner object
"""
try:
schema_name, table_name = _get_table_and_schema_name(runner.table)
except AttributeError:
raise AttributeError(ERROR_MSG)
tables = _build_table("tables", "system")
col_names, col_count = _get_col_names_and_count(runner.table)
columns = [
Column("total_rows").label("rowCount"),
Column("total_bytes").label("sizeInBytes"),
col_names,
col_count,
]
where_clause = [
Column("database") == schema_name,
Column("name") == table_name,
]
query = _build_query(columns, tables, where_clause)
return runner._session.execute(query).first()
def oracle_table_construct(runner: QueryRunner, **kwargs):
"""oracle table construct for table metrics
Args:
runner (QueryRunner): query runner object
"""
try:
schema_name, table_name = _get_table_and_schema_name(runner.table)
except AttributeError:
raise AttributeError(ERROR_MSG)
dba_objects = _build_table("dba_objects", None)
all_tables = _build_table("all_tables", None)
col_names, col_count = _get_col_names_and_count(runner.table)
create_date = cte(
_build_query(
[
Column("owner"),
Column("object_name").label("table_name"),
Column("created"),
],
dba_objects,
[
func.lower(Column("owner")) == schema_name.lower(),
func.lower(Column("object_name")) == table_name.lower(),
],
)
)
row_count = cte(
_build_query(
[
Column("owner"),
Column("table_name"),
Column("NUM_ROWS"),
],
all_tables,
[
func.lower(Column("owner")) == schema_name.lower(),
func.lower(Column("table_name")) == table_name.lower(),
],
)
)
columns = [
Column("NUM_ROWS").label("rowCount"),
Column("created").label("createDateTime"),
col_names,
col_count,
]
query = _build_query(columns, row_count).join(
create_date,
and_(
row_count.c.table_name == create_date.c.table_name,
row_count.c.owner == create_date.c.owner,
),
)
return runner._session.execute(query).first()
def snowflake_table_construct(runner: QueryRunner, **kwargs):
"""Snowflake table construct for table metrics
Args:
runner (QueryRunner): query runner object
"""
try:
schema_name, table_name = _get_table_and_schema_name(runner.table)
except AttributeError:
raise AttributeError(ERROR_MSG)
database = runner._session.get_bind().url.database
table_storage = _build_table("TABLES", f"{database}.INFORMATION_SCHEMA")
col_names, col_count = _get_col_names_and_count(runner.table)
columns = [
Column("ROW_COUNT").label("rowCount"),
Column("BYTES").label("sizeInBytes"),
Column("CREATED").label("createDateTime"),
col_names,
col_count,
]
where_clause = [
func.lower(Column("TABLE_CATALOG")) == database.lower(),
func.lower(Column("TABLE_SCHEMA")) == schema_name.lower(),
func.lower(Column("TABLE_NAME")) == table_name.lower(),
]
query = _build_query(columns, table_storage, where_clause)
return runner._session.execute(query).first()
class TableMetricConstructFactory:
"""Factory returning the correct construct for the table metrics based on dialect"""
def __init__(self):
self._constructs = {}
def register(self, dialect: str, construct: Callable):
"""Register a construct for a dialect"""
self._constructs[dialect] = construct
def construct(self, dialect, **kwargs):
"""Construct the query"""
# check if we have registered a construct for the dialect
construct = self._constructs.get(dialect)
if not construct:
construct = self._constructs["base"]
try:
return construct(**kwargs)
except Exception:
# if an error occurs, fallback to the base construct
logger.debug(traceback.format_exc())
return self._constructs["base"](**kwargs)
table_metric_construct_factory = TableMetricConstructFactory()
table_metric_construct_factory.register("base", base_table_construct)
table_metric_construct_factory.register(Dialects.Redshift, redshift_table_construct)
table_metric_construct_factory.register(Dialects.MySQL, mysql_table_construct)
table_metric_construct_factory.register(Dialects.BigQuery, bigquery_table_construct)
table_metric_construct_factory.register(Dialects.ClickHouse, clickhouse_table_construct)
table_metric_construct_factory.register(Dialects.Oracle, oracle_table_construct)
table_metric_construct_factory.register(Dialects.Snowflake, snowflake_table_construct)