From 9c4ec6c76078dc2b9d6547292e6b0e19643e433c Mon Sep 17 00:00:00 2001 From: Teddy Date: Thu, 18 Jan 2024 20:03:57 +0100 Subject: [PATCH] fix: added fall back to table metric computation (#14771) --- .../sqlalchemy/profiler_interface.py | 7 +- .../orm/functions/table_metric_computer.py | 463 ++++++++++++++++++ .../orm/functions/table_metric_construct.py | 395 --------------- 3 files changed, 466 insertions(+), 399 deletions(-) create mode 100644 ingestion/src/metadata/profiler/orm/functions/table_metric_computer.py delete mode 100644 ingestion/src/metadata/profiler/orm/functions/table_metric_construct.py diff --git a/ingestion/src/metadata/profiler/interface/sqlalchemy/profiler_interface.py b/ingestion/src/metadata/profiler/interface/sqlalchemy/profiler_interface.py index 4117db37205..e9f0763b3cf 100644 --- a/ingestion/src/metadata/profiler/interface/sqlalchemy/profiler_interface.py +++ b/ingestion/src/metadata/profiler/interface/sqlalchemy/profiler_interface.py @@ -37,9 +37,7 @@ from metadata.profiler.metrics.registry import Metrics from metadata.profiler.metrics.static.mean import Mean from metadata.profiler.metrics.static.stddev import StdDev from metadata.profiler.metrics.static.sum import Sum -from metadata.profiler.orm.functions.table_metric_construct import ( - table_metric_construct_factory, -) +from metadata.profiler.orm.functions.table_metric_computer import TableMetricComputer from metadata.profiler.orm.registry import Dialects from metadata.profiler.processor.runner import QueryRunner from metadata.utils.constants import SAMPLE_DATA_DEFAULT_COUNT @@ -186,12 +184,13 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin): # pylint: disable=protected-access try: dialect = runner._session.get_bind().dialect.name - row = table_metric_construct_factory.construct( + table_metric_computer: TableMetricComputer = TableMetricComputer( dialect, runner=runner, metrics=metrics, conn_config=self.service_connection_config, ) + row = table_metric_computer.compute() if row: return dict(row) return None diff --git a/ingestion/src/metadata/profiler/orm/functions/table_metric_computer.py b/ingestion/src/metadata/profiler/orm/functions/table_metric_computer.py new file mode 100644 index 00000000000..75fb14c647f --- /dev/null +++ b/ingestion/src/metadata/profiler/orm/functions/table_metric_computer.py @@ -0,0 +1,463 @@ +# pylint: disable=protected-access,attribute-defined-outside-init +# Copyright 2021 Collate +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run profiler metrics on the table +""" + +import traceback +from abc import ABC, abstractmethod +from typing import Callable, List, Optional, Tuple + +from sqlalchemy import Column, MetaData, Table, func, inspect, literal, select +from sqlalchemy.sql.expression import ColumnOperators, and_, cte +from sqlalchemy.types import String + +from metadata.profiler.metrics.registry import Metrics +from metadata.profiler.orm.registry import Dialects +from metadata.profiler.processor.runner import QueryRunner +from metadata.utils.logger import profiler_interface_registry_logger + +logger = profiler_interface_registry_logger() + +COLUMN_COUNT = "columnCount" +COLUMN_NAMES = "columnNames" +ROW_COUNT = "rowCount" +SIZE_IN_BYTES = "sizeInBytes" +CREATE_DATETIME = "createDateTime" + +ERROR_MSG = ( + "Schema/Table name not found in table args. Falling back to default computation" +) + + +class AbstractTableMetricComputer(ABC): + """Base table computer""" + + def __init__(self, runner: QueryRunner, metrics: List[Metrics], conn_config): + """Instantiate base table computer""" + self._runner = runner + self._metrics = metrics + self._conn_config = conn_config + self._database = self._runner._session.get_bind().url.database + self._table = self._runner.table + + @property + def database(self): + return self._database + + @property + def table(self): + return self._table + + @property + def runner(self): + return self._runner + + @property + def metrics(self): + return self._metrics + + @property + def table_name(self): + return self._table_name + + @property + def schema_name(self): + return self._schema_name + + @property + def conn_config(self): + return self._conn_config + + def _set_table_and_schema_name(self): + """get table and schema name from table args + + Args: + table (DeclarativeMeta): _description_ + """ + try: + self._schema_name = self.table.__table_args__.get("schema") + self._table_name = self.table.__tablename__ + except AttributeError: + raise AttributeError(ERROR_MSG) + + def _build_table(self, table, schema) -> Table: + """build table object from table name and schema name + + Args: + table_name (str): table name + schema_name (str): schema name + + Returns: + Table + """ + if schema: + return Table(table, MetaData(), schema=schema) + return Table(table, MetaData()) + + def _get_col_names_and_count(self) -> Tuple[str, int]: + """get column names and count from table + + Args: + table (DeclarativeMeta): table object + + Returns: + Tuple[str, int] + """ + col_names = literal(",".join(inspect(self.table).c.keys()), type_=String).label( + COLUMN_NAMES + ) + col_count = literal(len(inspect(self.table).c)).label(COLUMN_COUNT) + return col_names, col_count + + def _build_query( + self, + columns: List[Column], + table: Table, + where_clause: Optional[List[ColumnOperators]] = None, + ): + + query = select(*columns).select_from(table) + if where_clause: + query = query.where(*where_clause) + + return query + + @abstractmethod + def compute(self): + """Default compute behavior for table metrics""" + raise NotImplementedError + + +class BaseTableMetricComputer(AbstractTableMetricComputer): + """Base table computer""" + + def _check_and_return(self, res): + """Check if the result is None and return the result or fallback + + Args: + res (object): result + """ + if res.rowCount is None: + return super().compute() + return res + + def compute(self): + """Default compute behavior for table metrics""" + return self.runner.select_first_from_table( + *[metric().fn() for metric in self.metrics] + ) + + +class SnowflakeTableMetricComputer(BaseTableMetricComputer): + """Snowflake Table Metric Computer""" + + def compute(self): + """Compute table metrics for snowflake""" + columns = [ + Column("ROW_COUNT").label("rowCount"), + Column("BYTES").label("sizeInBytes"), + Column("CREATED").label("createDateTime"), + *self._get_col_names_and_count(), + ] + where_clause = [ + func.lower(Column("TABLE_CATALOG")) == self.database.lower(), + func.lower(Column("TABLE_SCHEMA")) == self.schema_name.lower(), + func.lower(Column("TABLE_NAME")) == self.table_name.lower(), + ] + query = self._build_query( + columns, + self._build_table("TABLES", f"{self.database}.INFORMATION_SCHEMA"), + where_clause, + ) + + rest = self._runner._session.execute(query).first() + if rest.rowCount is None: + # if we don't have any row count, fallback to the base logic + return super().compute() + return rest + + +class OracleTableMetricComputer(BaseTableMetricComputer): + """Oracle Table Metric Computer""" + + def compute(self): + """Compute table metrics for oracle""" + create_date = cte( + self._build_query( + [ + Column("owner"), + Column("object_name").label("table_name"), + Column("created"), + ], + self._build_table("dba_objects", None), + [ + func.lower(Column("owner")) == self.schema_name.lower(), + func.lower(Column("object_name")) == self.table_name.lower(), + ], + ) + ) + + row_count = cte( + self._build_query( + [ + Column("owner"), + Column("table_name"), + Column("NUM_ROWS"), + ], + self._build_table("all_tables", None), + [ + func.lower(Column("owner")) == self.schema_name.lower(), + func.lower(Column("table_name")) == self.table_name.lower(), + ], + ) + ) + + columns = [ + Column("NUM_ROWS").label("rowCount"), + Column("created").label("createDateTime"), + *self._get_col_names_and_count(), + ] + query = self._build_query(columns, row_count).join( + create_date, + and_( + row_count.c.table_name == create_date.c.table_name, + row_count.c.owner == create_date.c.owner, + ), + ) + + res = self.runner._session.execute(query).first() + if res.rowCount is None: + # if we don't have any row count, fallback to the base logic + return super().compute() + return res + + +class ClickHouseTableMetricComputer(BaseTableMetricComputer): + """ClickHouse Table Metric Computer""" + + def compute(self): + """compute table metrics for clickhouse""" + columns = [ + Column("total_rows").label("rowCount"), + Column("total_bytes").label("sizeInBytes"), + *self._get_col_names_and_count(), + ] + + where_clause = [ + Column("database") == self.schema_name, + Column("name") == self.table_name, + ] + + query = self._build_query( + columns, self._build_table("tables", "system"), where_clause + ) + + res = self.runner._session.execute(query).first() + if res.rowCount is None: + # if we don't have any row count, fallback to the base logic + return super().compute() + return res + + +class BigQueryTableMetricComputer(BaseTableMetricComputer): + """BigQuery Table Metric Computer""" + + def compute(self): + """compute table metrics for bigquery""" + try: + return self.tables() + except Exception as exc: + # if an error occurs fetching data from `__TABLES__`, fallback to `TABLE_STORAGE` + logger.debug(f"Error retrieving table metadata from `__TABLES__`: {exc}") + return self.table_storage() + + def table_storage(self): + """Fall back method if retrieving table metadata from`__TABLES__` fails""" + columns = [ + Column("total_rows").label("rowCount"), + Column("total_logical_bytes").label("sizeInBytes"), + Column("creation_time").label("createDateTime"), + *self._get_col_names_and_count(), + ] + + where_clause = [ + Column("project_id") + == self.conn_config.credentials.gcpConfig.projectId.__root__, + Column("table_schema") == self.schema_name, + Column("table_name") == self.table_name, + ] + + query = self._build_query( + columns, + self._build_table( + "TABLE_STORAGE", + f"region-{self.conn_config.usageLocation}.INFORMATION_SCHEMA", + ), + where_clause, + ) + + res = self.runner._session.execute(query).first() + if res.rowCount is None: + # if we don't have any row count, fallback to the base logic + return super().compute() + return res + + def tables(self): + """retrieve table metadata from `__TABLES__`""" + columns = [ + Column("row_count").label("rowCount"), + Column("size_bytes").label("sizeInBytes"), + Column("creation_time").label("createDateTime"), + *self._get_col_names_and_count(), + ] + where_clause = [ + Column("project_id") + == self.conn_config.credentials.gcpConfig.projectId.__root__, + Column("dataset_id") == self.schema_name, + Column("table_id") == self.table_name, + ] + + query = self._build_query( + columns, + self._build_table( + "__TABLES__", + f"{self.conn_config.credentials.gcpConfig.projectId.__root__}.{self.schema_name}", + ), + where_clause, + ) + res = self.runner._session.execute(query).first() + if res.rowCount is None: + # if we don't have any row count, fallback to the base logic + return super().compute() + return res + + +class MySQLTableMetricComputer(BaseTableMetricComputer): + """MySQL Table Metric Computer""" + + def compute(self): + """compute table metrics for mysql""" + + columns = [ + Column("TABLE_ROWS").label(ROW_COUNT), + (Column("data_length") + Column("index_length")).label(SIZE_IN_BYTES), + Column("CREATE_TIME").label(CREATE_DATETIME), + *self._get_col_names_and_count(), + ] + where_clause = [ + Column("TABLE_SCHEMA") == self.schema_name, + Column("TABLE_NAME") == self.table_name, + ] + query = self._build_query( + columns, self._build_table("tables", "information_schema"), where_clause + ) + + res = self.runner._session.execute(query).first() + if res.rowCount is None: + # if we don't have any row count, fallback to the base logic + return super().compute() + return res + + +class RedshiftTableMetricComputer(BaseTableMetricComputer): + """Redshift Table Metric Computer""" + + def compute(self): + """compute table metrics for redshift""" + columns = [ + Column("estimated_visible_rows").label(ROW_COUNT), + Column("size").label(SIZE_IN_BYTES), + Column("create_time").label(CREATE_DATETIME), + *self._get_col_names_and_count(), + ] + + where_clause = [ + Column("schema") == self.schema_name, + Column("table") == self.table_name, + ] + + query = self._build_query( + columns, self._build_table("svv_table_info", "pg_catalog"), where_clause + ) + res = self.runner._session.execute(query).first() + if res.rowCount is None: + # if we don't have any row count, fallback to the base logic + return super().compute() + return res + + +class TableMetricComputer: + """Table Metric Construct""" + + def __init__( + self, dialect: str, runner: QueryRunner, metrics: List[Metrics], conn_config + ): + """Instantiate table metric computer with a dialect computer""" + self._dialect = dialect + self._runner = runner + self._metrics = metrics + self._conn_config = conn_config + self.table_metric_computer: AbstractTableMetricComputer = ( + table_metric_computer_factory.construct( + self._dialect, + runner=self._runner, + metrics=self._metrics, + conn_config=self._conn_config, + ) + ) + + def compute(self): + """Compute table metrics""" + return self.table_metric_computer.compute() + + +class TableMetricComputerFactory: + """Factory returning the correct construct for the table metrics based on dialect""" + + def __init__(self): + self._constructs = {} + + def register(self, dialect: str, construct: Callable): + """Register a construct for a dialect""" + self._constructs[dialect] = construct + + def construct(self, dialect, **kwargs): + """Construct the query""" + + # check if we have registered a construct for the dialect + construct = self._constructs.get(dialect) + if not construct: + construct = self._constructs["base"] + return construct(**kwargs) + + try: + construct_instance: AbstractTableMetricComputer = construct(**kwargs) + construct_instance._set_table_and_schema_name() + return construct_instance + except Exception: + # if an error occurs, fallback to the base construct + logger.debug(traceback.format_exc()) + return self._constructs["base"](**kwargs) + + +table_metric_computer_factory = TableMetricComputerFactory() +table_metric_computer_factory.register("base", BaseTableMetricComputer) +table_metric_computer_factory.register(Dialects.Redshift, RedshiftTableMetricComputer) +table_metric_computer_factory.register(Dialects.MySQL, MySQLTableMetricComputer) +table_metric_computer_factory.register(Dialects.BigQuery, BigQueryTableMetricComputer) +table_metric_computer_factory.register( + Dialects.ClickHouse, ClickHouseTableMetricComputer +) +table_metric_computer_factory.register(Dialects.Oracle, OracleTableMetricComputer) +table_metric_computer_factory.register(Dialects.Snowflake, SnowflakeTableMetricComputer) diff --git a/ingestion/src/metadata/profiler/orm/functions/table_metric_construct.py b/ingestion/src/metadata/profiler/orm/functions/table_metric_construct.py deleted file mode 100644 index fd202da8c10..00000000000 --- a/ingestion/src/metadata/profiler/orm/functions/table_metric_construct.py +++ /dev/null @@ -1,395 +0,0 @@ -# pylint: disable=unused-argument,protected-access -# Copyright 2021 Collate -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Run profiler metrics on the table -""" - -import traceback -from typing import Callable, List, Optional, Tuple, cast - -from sqlalchemy import Column, MetaData, Table, func, inspect, literal, select -from sqlalchemy.orm import DeclarativeMeta -from sqlalchemy.sql.expression import ColumnOperators, and_, cte -from sqlalchemy.types import String - -from metadata.generated.schema.entity.services.connections.database.bigQueryConnection import ( - BigQueryConnection, -) -from metadata.profiler.metrics.registry import Metrics -from metadata.profiler.orm.registry import Dialects -from metadata.profiler.processor.runner import QueryRunner -from metadata.utils.logger import profiler_interface_registry_logger - -logger = profiler_interface_registry_logger() - -COLUMN_COUNT = "columnCount" -COLUMN_NAMES = "columnNames" -ROW_COUNT = "rowCount" -SIZE_IN_BYTES = "sizeInBytes" -CREATE_DATETIME = "createDateTime" - -ERROR_MSG = ( - "Schema/Table name not found in table args. Falling back to default computation" -) - - -def _get_table_and_schema_name(table: DeclarativeMeta) -> Tuple[str, str]: - """get table and schema name from table args - - Args: - table (DeclarativeMeta): _description_ - """ - schema_name = table.__table_args__.get("schema") - table_name = table.__tablename__ - return schema_name, table_name - - -def _build_table(table_name: str, schema_name: Optional[str] = None) -> Table: - """build table object from table name and schema name - - Args: - table_name (str): table name - schema_name (str): schema name - - Returns: - Table - """ - if schema_name: - return Table(table_name, MetaData(), schema=schema_name) - return Table(table_name, MetaData()) - - -def _get_col_names_and_count(table: DeclarativeMeta) -> Tuple[str, int]: - """get column names and count from table - - Args: - table (DeclarativeMeta): table object - - Returns: - Tuple[str, int] - """ - col_names = literal(",".join(inspect(table).c.keys()), type_=String).label( - COLUMN_NAMES - ) - col_count = literal(len(inspect(table).c)).label(COLUMN_COUNT) - return col_names, col_count - - -def _build_query( - columns: List[Column], - table: Table, - where_clause: Optional[List[ColumnOperators]] = None, -): - - query = select(*columns).select_from(table) - if where_clause: - query = query.where(*where_clause) - - return query - - -def base_table_construct(runner: QueryRunner, metrics: List[Metrics], **kwargs): - """base table construct for table metrics - - Args: - runner (QueryRunner): runner object to execute query - metrics (List[Metrics]): list of metrics - """ - return runner.select_first_from_table(*[metric().fn() for metric in metrics]) - - -def redshift_table_construct(runner: QueryRunner, **kwargs): - """redshift table construct for table metrics - - Args: - runner (QueryRunner): runner object to execute query - """ - try: - schema_name, table_name = _get_table_and_schema_name(runner.table) - except AttributeError: - raise AttributeError(ERROR_MSG) - - columns = [ - Column("estimated_visible_rows").label(ROW_COUNT), - Column("size").label(SIZE_IN_BYTES), - Column("create_time").label(CREATE_DATETIME), - *_get_col_names_and_count(runner.table), - ] - - where_clause = [ - Column("schema") == schema_name, - Column("table") == table_name, - ] - - query = _build_query( - columns, _build_table("svv_table_info", "pg_catalog"), where_clause - ) - return runner._session.execute(query).first() - - -def mysql_table_construct(runner: QueryRunner, **kwargs): - """MySQL table construct for table metrics - - Args: - runner (QueryRunner): query runner object - """ - try: - schema_name, table_name = _get_table_and_schema_name(runner.table) - except AttributeError: - raise AttributeError(ERROR_MSG) - - tables = _build_table("tables", "information_schema") - col_names, col_count = _get_col_names_and_count(runner.table) - - columns = [ - Column("TABLE_ROWS").label(ROW_COUNT), - (Column("data_length") + Column("index_length")).label(SIZE_IN_BYTES), - Column("CREATE_TIME").label(CREATE_DATETIME), - col_names, - col_count, - ] - where_clause = [ - Column("TABLE_SCHEMA") == schema_name, - Column("TABLE_NAME") == table_name, - ] - query = _build_query(columns, tables, where_clause) - - return runner._session.execute(query).first() - - -def bigquery_table_construct(runner: QueryRunner, **kwargs): - """bigquery table construct for table metrics - - Args: - runner (QueryRunner): query runner object - """ - conn_config = kwargs.get("conn_config") - conn_config = cast(BigQueryConnection, conn_config) - try: - schema_name, table_name = _get_table_and_schema_name(runner.table) - project_id = conn_config.credentials.gcpConfig.projectId.__root__ - except AttributeError: - raise AttributeError(ERROR_MSG) - - col_names, col_count = _get_col_names_and_count(runner.table) - - def table_storage(): - """Fall back method if retrieving table metadata from`__TABLES__` fails""" - table_storage = _build_table( - "TABLE_STORAGE", f"region-{conn_config.usageLocation}.INFORMATION_SCHEMA" - ) - - columns = [ - Column("total_rows").label("rowCount"), - Column("total_logical_bytes").label("sizeInBytes"), - Column("creation_time").label("createDateTime"), - col_names, - col_count, - ] - - where_clause = [ - Column("project_id") == project_id, - Column("table_schema") == schema_name, - Column("table_name") == table_name, - ] - - query = _build_query(columns, table_storage, where_clause) - - return runner._session.execute(query).first() - - def tables(): - """retrieve table metadata from `__TABLES__`""" - table_meta = _build_table("__TABLES__", f"{project_id}.{schema_name}") - columns = [ - Column("row_count").label("rowCount"), - Column("size_bytes").label("sizeInBytes"), - Column("creation_time").label("createDateTime"), - col_names, - col_count, - ] - where_clause = [ - Column("project_id") == project_id, - Column("dataset_id") == schema_name, - Column("table_id") == table_name, - ] - - query = _build_query(columns, table_meta, where_clause) - return runner._session.execute(query).first() - - try: - return tables() - except Exception as exc: - logger.debug(f"Error retrieving table metadata from `__TABLES__`: {exc}") - return table_storage() - - -def clickhouse_table_construct(runner: QueryRunner, **kwargs): - """clickhouse table construct for table metrics - - Args: - runner (QueryRunner): query runner object - """ - try: - schema_name, table_name = _get_table_and_schema_name(runner.table) - except AttributeError: - raise AttributeError(ERROR_MSG) - - tables = _build_table("tables", "system") - col_names, col_count = _get_col_names_and_count(runner.table) - - columns = [ - Column("total_rows").label("rowCount"), - Column("total_bytes").label("sizeInBytes"), - col_names, - col_count, - ] - - where_clause = [ - Column("database") == schema_name, - Column("name") == table_name, - ] - - query = _build_query(columns, tables, where_clause) - - return runner._session.execute(query).first() - - -def oracle_table_construct(runner: QueryRunner, **kwargs): - """oracle table construct for table metrics - - Args: - runner (QueryRunner): query runner object - """ - try: - schema_name, table_name = _get_table_and_schema_name(runner.table) - except AttributeError: - raise AttributeError(ERROR_MSG) - - dba_objects = _build_table("dba_objects", None) - all_tables = _build_table("all_tables", None) - col_names, col_count = _get_col_names_and_count(runner.table) - - create_date = cte( - _build_query( - [ - Column("owner"), - Column("object_name").label("table_name"), - Column("created"), - ], - dba_objects, - [ - func.lower(Column("owner")) == schema_name.lower(), - func.lower(Column("object_name")) == table_name.lower(), - ], - ) - ) - - row_count = cte( - _build_query( - [ - Column("owner"), - Column("table_name"), - Column("NUM_ROWS"), - ], - all_tables, - [ - func.lower(Column("owner")) == schema_name.lower(), - func.lower(Column("table_name")) == table_name.lower(), - ], - ) - ) - - columns = [ - Column("NUM_ROWS").label("rowCount"), - Column("created").label("createDateTime"), - col_names, - col_count, - ] - query = _build_query(columns, row_count).join( - create_date, - and_( - row_count.c.table_name == create_date.c.table_name, - row_count.c.owner == create_date.c.owner, - ), - ) - - return runner._session.execute(query).first() - - -def snowflake_table_construct(runner: QueryRunner, **kwargs): - """Snowflake table construct for table metrics - - Args: - runner (QueryRunner): query runner object - """ - try: - schema_name, table_name = _get_table_and_schema_name(runner.table) - except AttributeError: - raise AttributeError(ERROR_MSG) - - database = runner._session.get_bind().url.database - - table_storage = _build_table("TABLES", f"{database}.INFORMATION_SCHEMA") - col_names, col_count = _get_col_names_and_count(runner.table) - - columns = [ - Column("ROW_COUNT").label("rowCount"), - Column("BYTES").label("sizeInBytes"), - Column("CREATED").label("createDateTime"), - col_names, - col_count, - ] - where_clause = [ - func.lower(Column("TABLE_CATALOG")) == database.lower(), - func.lower(Column("TABLE_SCHEMA")) == schema_name.lower(), - func.lower(Column("TABLE_NAME")) == table_name.lower(), - ] - query = _build_query(columns, table_storage, where_clause) - - return runner._session.execute(query).first() - - -class TableMetricConstructFactory: - """Factory returning the correct construct for the table metrics based on dialect""" - - def __init__(self): - self._constructs = {} - - def register(self, dialect: str, construct: Callable): - """Register a construct for a dialect""" - self._constructs[dialect] = construct - - def construct(self, dialect, **kwargs): - """Construct the query""" - - # check if we have registered a construct for the dialect - construct = self._constructs.get(dialect) - if not construct: - construct = self._constructs["base"] - - try: - return construct(**kwargs) - except Exception: - # if an error occurs, fallback to the base construct - logger.debug(traceback.format_exc()) - return self._constructs["base"](**kwargs) - - -table_metric_construct_factory = TableMetricConstructFactory() -table_metric_construct_factory.register("base", base_table_construct) -table_metric_construct_factory.register(Dialects.Redshift, redshift_table_construct) -table_metric_construct_factory.register(Dialects.MySQL, mysql_table_construct) -table_metric_construct_factory.register(Dialects.BigQuery, bigquery_table_construct) -table_metric_construct_factory.register(Dialects.ClickHouse, clickhouse_table_construct) -table_metric_construct_factory.register(Dialects.Oracle, oracle_table_construct) -table_metric_construct_factory.register(Dialects.Snowflake, snowflake_table_construct)