MINOR - Fix sqa table reference (#18839)

* fix: sqa table reference

* style: ran python linting

* fix: added raw dataset to query runner

* fix: get table and schema name from orm object

* fix: get table level config for table tests
This commit is contained in:
Teddy 2024-11-28 18:49:11 +01:00 committed by GitHub
parent da176767a8
commit ac2f6d7132
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 66 additions and 57 deletions

View File

@ -104,6 +104,7 @@ class SQATestSuiteInterface(SQAInterfaceMixin, TestSuiteInterface):
QueryRunner(
session=self.session,
dataset=self.dataset,
raw_dataset=self.sampler.raw_dataset,
partition_details=self.table_partition_config,
profile_sample_query=self.table_sample_query,
)

View File

@ -40,7 +40,7 @@ class ColumnValueLengthsToBeBetweenValidator(
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)
def _run_results(self, metric: Metrics, column: Column) -> Optional[int]:

View File

@ -38,7 +38,7 @@ class ColumnValueMaxToBeBetweenValidator(
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)
def _run_results(self, metric: Metrics, column: Column) -> Optional[int]:

View File

@ -39,7 +39,7 @@ class ColumnValueMeanToBeBetweenValidator(
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)
def _run_results(self, metric: Metrics, column: Column) -> Optional[int]:

View File

@ -39,7 +39,7 @@ class ColumnValueMedianToBeBetweenValidator(
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)
def _run_results(self, metric: Metrics, column: Column) -> Optional[int]:

View File

@ -39,7 +39,7 @@ class ColumnValueMinToBeBetweenValidator(
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)
def _run_results(self, metric: Metrics, column: Column) -> Optional[int]:

View File

@ -39,7 +39,7 @@ class ColumnValueStdDevToBeBetweenValidator(
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)
def _run_results(self, metric: Metrics, column: Column) -> Optional[int]:

View File

@ -42,7 +42,7 @@ class ColumnValuesMissingCountValidator(
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)
def _run_results(self, metric: Metrics, column: Column, **kwargs) -> Optional[int]:

View File

@ -39,7 +39,7 @@ class ColumnValuesSumToBeBetweenValidator(
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)
def _run_results(self, metric: Metrics, column: Column) -> Optional[int]:

View File

@ -37,7 +37,7 @@ class ColumnValuesToBeAtExpectedLocationValidator(
def _fetch_data(self, columns: List[str]) -> Iterator:
"""Fetch data from the runner object"""
self.runner = cast(QueryRunner, self.runner)
inspection = inspect(self.runner.table)
inspection = inspect(self.runner.dataset)
table_columns: List[Column] = inspection.c if inspection is not None else []
cols = [col for col in table_columns if col.name in columns]
for col in cols:

View File

@ -39,7 +39,7 @@ class ColumnValuesToBeBetweenValidator(
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)
def _run_results(self, metric: Metrics, column: Column) -> Optional[int]:

View File

@ -39,7 +39,7 @@ class ColumnValuesToBeInSetValidator(
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)
def _run_results(self, metric: Metrics, column: Column, **kwargs) -> Optional[int]:

View File

@ -39,7 +39,7 @@ class ColumnValuesToBeNotInSetValidator(
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)
def _run_results(self, metric: Metrics, column: Column, **kwargs) -> Optional[int]:

View File

@ -42,7 +42,7 @@ class ColumnValuesToBeNotNullValidator(
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)
def _run_results(self, metric: Metrics, column: Column) -> Optional[int]:

View File

@ -17,7 +17,6 @@ from typing import Optional
from sqlalchemy import Column, inspect
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm.util import AliasedClass
from metadata.data_quality.validations.column.base.columnValuesToBeUnique import (
BaseColumnValuesToBeUniqueValidator,
@ -41,7 +40,7 @@ class ColumnValuesToBeUniqueValidator(
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)
def _run_results(self, metric: Metrics, column: Column) -> Optional[int]:
@ -53,12 +52,7 @@ class ColumnValuesToBeUniqueValidator(
"""
count = Metrics.COUNT.value(column).fn()
unique_count = Metrics.UNIQUE_COUNT.value(column).query(
sample=self.runner._sample # pylint: disable=protected-access
if isinstance(
self.runner._sample, # pylint: disable=protected-access
AliasedClass,
)
else self.runner.table,
sample=self.runner.dataset,
session=self.runner._session, # pylint: disable=protected-access
) # type: ignore

View File

@ -43,7 +43,7 @@ class ColumnValuesToMatchRegexValidator(
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)
def _run_results(

View File

@ -43,7 +43,7 @@ class ColumnValuesToNotMatchRegexValidator(
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)
def _run_results(self, metric: Metrics, column: Column, **kwargs) -> Optional[int]:

View File

@ -13,7 +13,7 @@
Validator for table row inserted count to be between test case
"""
from sqlalchemy import Column, text
from sqlalchemy import Column, inspect, text
from metadata.data_quality.validations.mixins.sqa_validator_mixin import (
SQAValidatorMixin,
@ -52,7 +52,7 @@ class TableRowInsertedCountToBeBetweenValidator(
date_or_datetime_fn = dispatch_to_date_or_datetime(
range_interval,
text(range_type),
get_partition_col_type(column_name.name, self.runner.table.c), # type: ignore
get_partition_col_type(column_name.name, inspect(self.runner.dataset).c), # type: ignore
)
return dict(

View File

@ -22,7 +22,7 @@ class BigQueryProfiler(BigQueryProfilerInterface):
**kwargs,
) -> List[SystemProfile]:
return self.system_metrics_computer.get_system_metrics(
table=runner.table,
table=runner.dataset,
usage_location=self.service_connection_config.usageLocation,
)

View File

@ -32,7 +32,7 @@ class DB2ProfilerInterface(SQAProfilerInterface):
# pylint: disable=protected-access
if exc.orig and "overflow" in exc.orig._message:
logger.info(
f"Computing metrics without sum for {runner.table.name}.{column.name}"
f"Computing metrics without sum for {runner.table_name}.{column.name}"
)
return self._compute_static_metrics_wo_sum(metrics, runner, session, column)
return None

View File

@ -77,11 +77,11 @@ class MariaDBProfilerInterface(SQAProfilerInterface):
return dict(row)
except ProgrammingError:
logger.info(
f"Skipping window metrics for {runner.table.name}.{column.name} due to overflow"
f"Skipping window metrics for {runner.table_name}.{column.name} due to overflow"
)
return None
except Exception as exc:
msg = f"Error trying to compute profile for {runner.table.name}.{column.name}: {exc}"
msg = f"Error trying to compute profile for {runner.table_name}.{column.name}: {exc}"
handle_query_exception(msg, exc, session)
return None

View File

@ -156,7 +156,7 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
)
return dict(row)
except Exception as exc:
msg = f"Error trying to compute profile for {runner.table.name}.{column.name}: {exc}"
msg = f"Error trying to compute profile for {runner.table_name}.{column.name}: {exc}"
handle_query_exception(msg, exc, session)
return None
@ -194,7 +194,7 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Error trying to compute profile for {runner.table.name}: {exc}" # type: ignore
f"Error trying to compute profile for {runner.table_name}: {exc}" # type: ignore
)
session.rollback()
raise RuntimeError(exc)
@ -231,7 +231,7 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
runner, column, exc, session, metrics
)
except Exception as exc:
msg = f"Error trying to compute profile for {runner.table.name}.{column.name}: {exc}"
msg = f"Error trying to compute profile for {runner.table_name}.{column.name}: {exc}"
handle_query_exception(msg, exc, session)
return None
@ -274,10 +274,10 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
runner._session.get_bind().dialect.name
!= Dialects.Druid
):
msg = f"Error trying to compute profile for {runner.table.name}.{column.name}: {exc}"
msg = f"Error trying to compute profile for {runner.table_name}.{column.name}: {exc}"
handle_query_exception(msg, exc, session)
except Exception as exc:
msg = f"Error trying to compute profile for {runner.table.name}.{column.name}: {exc}"
msg = f"Error trying to compute profile for {runner.table_name}.{column.name}: {exc}"
handle_query_exception(msg, exc, session)
return None
@ -310,10 +310,10 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
return dict(row)
except ProgrammingError as exc:
logger.info(
f"Skipping metrics for {runner.table.name}.{column.name} due to {exc}"
f"Skipping metrics for {runner.table_name}.{column.name} due to {exc}"
)
except Exception as exc:
msg = f"Error trying to compute profile for {runner.table.name}.{column.name}: {exc}"
msg = f"Error trying to compute profile for {runner.table_name}.{column.name}: {exc}"
handle_query_exception(msg, exc, session)
return None
@ -347,7 +347,7 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
)
except Exception as exc:
msg = f"Error trying to compute profile for {runner.table.name}.{metric.columnName}: {exc}"
msg = f"Error trying to compute profile for {runner.table_name}.{metric.columnName}: {exc}"
logger.debug(traceback.format_exc())
logger.warning(msg)
if custom_metrics:
@ -371,8 +371,8 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
Returns:
dictionnary of results
"""
logger.debug(f"Computing system metrics for {runner.table.name}")
return self.system_metrics_computer.get_system_metrics(table=runner.table)
logger.debug(f"Computing system metrics for {runner.table_name}")
return self.system_metrics_computer.get_system_metrics(table=runner.dataset)
def _create_thread_safe_runner(self, session, column=None):
"""Create thread safe runner"""
@ -380,6 +380,7 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
thread_local.runner = QueryRunner(
session=session,
dataset=self.sampler.get_dataset(column=column),
raw_dataset=self.sampler.raw_dataset,
partition_details=self.sampler.partition_details,
profile_sample_query=self.sampler.sample_query,
)

View File

@ -76,11 +76,11 @@ class SingleStoreProfilerInterface(SQAProfilerInterface):
return dict(row)
except ProgrammingError:
logger.info(
f"Skipping window metrics for {runner.table.name}.{column.name} due to overflow"
f"Skipping window metrics for {runner.table_name}.{column.name} due to overflow"
)
return None
except Exception as exc:
msg = f"Error trying to compute profile for {runner.table.name}.{column.name}: {exc}"
msg = f"Error trying to compute profile for {runner.table_name}.{column.name}: {exc}"
handle_query_exception(msg, exc, session)
return None

View File

@ -41,7 +41,7 @@ class SnowflakeProfilerInterface(SQAProfilerInterface):
session.bind.dialect.name
):
logger.info(
f"Computing metrics without sum for {runner.table.name}.{column.name}"
f"Computing metrics without sum for {runner.table_name}.{column.name}"
)
return self._compute_static_metrics_wo_sum(metrics, runner, session, column)
return None

View File

@ -79,8 +79,8 @@ class ProfilerWithStatistics(SQAProfilerInterface, StoredStatisticsSource):
list,
partition(self.is_statistic_metric, metrics),
)
schema = runner.table.schema
table_name = runner.table.name
schema = runner.schema_name
table_name = runner.table_name
logger.debug(
"Getting statistics for column: %s.%s.%s",
schema,
@ -118,8 +118,8 @@ class ProfilerWithStatistics(SQAProfilerInterface, StoredStatisticsSource):
list,
partition(self.is_statistic_metric, metrics),
)
schema = runner.table.schema
table_name = runner.table.name
schema = runner.schema_name
table_name = runner.table_name
logger.debug("Geting statistics for table: %s.%s", schema, table_name)
result.update(
super().get_table_statistics(stat_metrics, schema, table_name)

View File

@ -76,11 +76,11 @@ class TrinoProfilerInterface(ProfilerWithStatistics, TrinoStoredStatisticsSource
return dict(row)
except ProgrammingError as err:
logger.info(
f"Skipping window metrics for {runner.table.name}.{column.name} due to {err}"
f"Skipping window metrics for {runner.table_name}.{column.name} due to {err}"
)
return None
except Exception as exc:
msg = f"Error trying to compute profile for {runner.table.name}.{column.name}: {exc}"
msg = f"Error trying to compute profile for {runner.table_name}.{column.name}: {exc}"
handle_query_exception(msg, exc, session)
return None

View File

@ -53,7 +53,7 @@ class AbstractTableMetricComputer(ABC):
self._metrics = metrics
self._conn_config = conn_config
self._database = self._runner._session.get_bind().url.database
self._table = self._runner.table
self._table = self._runner.dataset
self._entity = entity
@property

View File

@ -44,6 +44,7 @@ class QueryRunner:
self,
session: Session,
dataset: Union[DeclarativeMeta, AliasedClass],
raw_dataset: Table,
partition_details: Optional[Dict] = None,
profile_sample_query: Optional[str] = None,
):
@ -51,11 +52,12 @@ class QueryRunner:
self._dataset = dataset
self.partition_details = partition_details
self.profile_sample_query = profile_sample_query
self.raw_dataset = raw_dataset
@property
def table(self) -> Table:
"""Backward compatibility table attribute access"""
return self._dataset.__table__
return self.raw_dataset
@property
def _sample(self):
@ -71,6 +73,16 @@ class QueryRunner:
def dataset(self, dataset):
self._dataset = dataset
@property
def table_name(self):
"""Table name attribute access"""
return self.raw_dataset.__table__.name
@property
def schema_name(self):
"""Table name attribute access"""
return self.raw_dataset.__table__.schema
def _build_query(self, *entities, **kwargs) -> Query:
"""Build query object

View File

@ -16,7 +16,7 @@ import traceback
from typing import List, Optional, Union, cast
from sqlalchemy import Column, inspect, text
from sqlalchemy.orm import DeclarativeMeta, Query, aliased
from sqlalchemy.orm import DeclarativeMeta, Query
from sqlalchemy.orm.util import AliasedClass
from sqlalchemy.schema import Table
from sqlalchemy.sql.sqltypes import Enum
@ -145,13 +145,12 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
and self.sample_config.profile_sample_type == ProfileSampleType.PERCENTAGE
):
if self.partition_details:
return self._partitioned_table()
partitioned = self._partitioned_table()
return partitioned.cte(f"{self.raw_dataset.__tablename__}_partitioned")
return self.raw_dataset
sampled = self.get_sample_query(column=column)
return aliased(self.raw_dataset, sampled)
return self.get_sample_query(column=column)
def fetch_sample_data(self, columns: Optional[List[Column]] = None) -> TableData:
"""
@ -230,7 +229,7 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
def _partitioned_table(self) -> Query:
"""Return the Query object for partitioned tables"""
return aliased(self.raw_dataset, self.get_partitioned_query().subquery())
return self.get_partitioned_query()
def get_partitioned_query(self, query=None) -> Query:
"""Return the partitioned query"""

View File

@ -94,7 +94,9 @@ class RunnerTest(TestCase):
)
cls.dataset = sampler.get_dataset()
cls.raw_runner = QueryRunner(session=cls.session, dataset=cls.dataset)
cls.raw_runner = QueryRunner(
session=cls.session, dataset=cls.dataset, raw_dataset=sampler.raw_dataset
)
cls.timeout_runner: Timer = cls_timeout(1)(Timer())
# Insert 30 rows