diff --git a/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py b/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py index 53ca1b28f70..3b59c59da7a 100644 --- a/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py +++ b/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # pylint: disable=missing-module-docstring +import logging import traceback from itertools import islice from typing import Dict, Iterable, List, Optional, Tuple @@ -59,6 +60,21 @@ class UnsupportedDialectError(Exception): super().__init__(f"Unsupported dialect in param {param}: {dialect}") +def masked(s: str, mask: bool = True) -> str: + """Mask a string if masked is True otherwise return the string. + Only for development purposes, do not use in production. + Change it False if you want to see the data in the logs. + + Args: + s: string to mask + mask: boolean to mask the string + + Returns: + masked string if mask is True otherwise return the string + """ + return "***" if mask else s + + class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): """ Compare two tables and fail if the number of differences exceeds a threshold @@ -110,12 +126,14 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): stats = table_diff_iter.get_stats_dict() if stats["total"] > 0: logger.debug("Sample of failed rows:") - for s in islice(self.get_table_diff(), 10): - # since the data can contiant sensitive information, we don't want to log it - # we can uncomment this line if we must see the data in the logs - # logger.debug(s) - # by default we will log the data masked - logger.debug([s[0], ["*" for _ in s[1]]]) + # depending on the data, this require scanning a lot of data + # so we only log the sample in debug mode. data can be sensitive + # so it is masked by default + for s in islice( + self.safe_table_diff_iterator(), + 10 if logger.level <= logging.DEBUG else 0, + ): + logger.debug("%s", str([s[0]] + [masked(st) for st in s[1]])) test_case_result = self.get_row_diff_test_case_result( threshold, stats["total"], @@ -222,12 +240,10 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): "where": self.get_where(), } logger.debug( - "Calling table diff with parameters:" # pylint: disable=consider-using-f-string - " table1={}, table2={}, kwargs={}".format( - table1.table_path, - table2.table_path, - ",".join(f"{k}={v}" for k, v in data_diff_kwargs.items()), - ) + "Calling table diff with parameters: table1=%s, table2=%s, kwargs=%s", + table1.table_path, + table2.table_path, + ",".join(f"{k}={v}" for k, v in data_diff_kwargs.items()), ) return data_diff.diff_tables(table1, table2, **data_diff_kwargs) # type: ignore @@ -400,3 +416,19 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): if len(key_set) > limit: len(key_set) return len(key_set) + + def safe_table_diff_iterator(self) -> DiffResultWrapper: + """A safe iterator object which properly closes the diff object when the generator is exhausted. + Otherwise the data_diff library will continue to hold the connection open and eventually + raise a KeyError. + """ + gen = self.get_table_diff() + try: + yield from gen + finally: + try: + gen.diff.close() + except KeyError as ex: + if str(ex) == "2": + # This is a known issue in data_diff where the diff object is closed + pass