From 75b3d824a3abde9b32a6b87b9e0490a653f0043b Mon Sep 17 00:00:00 2001 From: IceS2 Date: Thu, 15 May 2025 12:19:58 +0200 Subject: [PATCH] Fixes #21095: Handle Conn Retry and implement is_disconnect for MSSQL (#21185) * Handle Conn Retry and implement is_disconnect for MSSQL * Change log to debug (cherry picked from commit 87463df51df5691b3fd345659986e07067fb3ea5) --- .../sqlalchemy/profiler_interface.py | 98 ++++++++++++------- .../source/database/mssql/profiler_source.py | 86 ++++++++++++++++ .../source/fetcher/profiler_source_factory.py | 13 +++ 3 files changed, 163 insertions(+), 34 deletions(-) create mode 100644 ingestion/src/metadata/profiler/source/database/mssql/profiler_source.py diff --git a/ingestion/src/metadata/profiler/interface/sqlalchemy/profiler_interface.py b/ingestion/src/metadata/profiler/interface/sqlalchemy/profiler_interface.py index 931e3294a49..0acfc27c742 100644 --- a/ingestion/src/metadata/profiler/interface/sqlalchemy/profiler_interface.py +++ b/ingestion/src/metadata/profiler/interface/sqlalchemy/profiler_interface.py @@ -18,6 +18,7 @@ supporting sqlalchemy abstraction layer import concurrent.futures import math import threading +import time import traceback from collections import defaultdict from datetime import datetime @@ -397,43 +398,72 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin): f"Running profiler for {metric_func.table.__tablename__} on thread {threading.current_thread()}" ) Session = self.session_factory # pylint: disable=invalid-name - with Session() as session: - self.set_session_tag(session) - self.set_catalog(session) - runner = self._create_thread_safe_runner(session, metric_func.column) - row = None - try: - row = self._get_metric_fn[metric_func.metric_type.value]( - metric_func.metrics, - runner=runner, - session=session, - column=metric_func.column, - sample=runner.dataset, - ) - if isinstance(row, dict): - row = self._validate_nulls(row) - if isinstance(row, list): - row = [ - self._validate_nulls(r) if isinstance(r, dict) else r - for r in row - ] + max_retries = 3 + retry_count = 0 + initial_backoff = 5 + max_backoff = 30 + row = None - except Exception as exc: - error = ( - f"{metric_func.column if metric_func.column is not None else metric_func.table.__tablename__} " - f"metric_type.value: {exc}" - ) - logger.error(error) - self.status.failed_profiler(error, traceback.format_exc()) + while retry_count < max_retries: + with Session() as session: + self.set_session_tag(session) + self.set_catalog(session) + runner = self._create_thread_safe_runner(session, metric_func.column) + try: + row = self._get_metric_fn[metric_func.metric_type.value]( + metric_func.metrics, + runner=runner, + session=session, + column=metric_func.column, + sample=runner.dataset, + ) + if isinstance(row, dict): + row = self._validate_nulls(row) + if isinstance(row, list): + row = [ + self._validate_nulls(r) if isinstance(r, dict) else r + for r in row + ] - if metric_func.column is not None: - column = metric_func.column.name - self.status.scanned(f"{metric_func.table.__tablename__}.{column}") - else: - self.status.scanned(metric_func.table.__tablename__) - column = None + # On success, log the scan and break out of the retry loop + if metric_func.column is not None: + column = metric_func.column.name + self.status.scanned( + f"{metric_func.table.__tablename__}.{column}" + ) + else: + self.status.scanned(metric_func.table.__tablename__) + column = None - return row, column, metric_func.metric_type.value + return row, column, metric_func.metric_type.value + + except Exception as exc: + dialect = session.get_bind().dialect + if dialect.is_disconnect(exc, session.get_bind(), None): + retry_count += 1 + if retry_count < max_retries: + backoff = min( + initial_backoff * (2 ** (retry_count - 1)), max_backoff + ) + logger.debug( + f"Connection error detected, retrying ({retry_count}/{max_retries}) " + f"after {backoff:.2f} seconds..." + ) + session.rollback() + time.sleep(backoff) + continue + logger.error( + f"Max retries ({max_retries}) exceeded for disconnection" + ) + error = ( + f"{metric_func.column if metric_func.column is not None else metric_func.table.__tablename__} " + f"metric_type.value: {exc}" + ) + logger.error(error) + self.status.failed_profiler(error, traceback.format_exc()) + + # If we've exhausted all retries without success, return a tuple of None values + return None, None, None @staticmethod def _validate_nulls(row: Dict[str, Any]) -> Dict[str, Any]: diff --git a/ingestion/src/metadata/profiler/source/database/mssql/profiler_source.py b/ingestion/src/metadata/profiler/source/database/mssql/profiler_source.py new file mode 100644 index 00000000000..85da6ef2edc --- /dev/null +++ b/ingestion/src/metadata/profiler/source/database/mssql/profiler_source.py @@ -0,0 +1,86 @@ +"""Extend the ProfilerSource class to add support for MSSQL is_disconnect SQA method""" +from metadata.generated.schema.configuration.profilerConfiguration import ( + ProfilerConfiguration, +) +from metadata.generated.schema.entity.data.database import Database +from metadata.generated.schema.entity.services.connections.database.mssqlConnection import ( + MssqlScheme, +) +from metadata.generated.schema.metadataIngestion.workflow import ( + OpenMetadataWorkflowConfig, +) +from metadata.ingestion.ometa.ometa_api import OpenMetadata +from metadata.profiler.source.database.base.profiler_source import ProfilerSource + + +def is_disconnect(is_disconnect_original): + """Wrapper to add custom is_disconnect method for the MSSQL dialects""" + + def inner_is_disconnect(self, e, connection, cursor): + """is_disconnect method for the MSSQL dialects""" + error_str = str(e) + + mssql_disconnect_codes = [ + "08S01", # Communication link failure + "08001", # Cannot connect + "HY000", # General error often used for connection issues + ] + + mssql_disconnect_messages = [ + "Server closed connection", + "ClosedConnectionError", + "Connection is closed", + "Connection reset by peer", + "Timeout expired", + "Socket closed", + ] + + if any(code in error_str for code in mssql_disconnect_codes) or any( + message in error_str for message in mssql_disconnect_messages + ): + return True + + # If none of our custom checks match, fall back to SQLAlchemy's built-in detection + return is_disconnect_original(self, e, connection, cursor) + + return inner_is_disconnect + + +class MssqlProfilerSource(ProfilerSource): + """MSSQL Profiler source""" + + def __init__( + self, + config: OpenMetadataWorkflowConfig, + database: Database, + ometa_client: OpenMetadata, + global_profiler_config: ProfilerConfiguration, + ): + super().__init__(config, database, ometa_client, global_profiler_config) + self.set_is_disconnect(config) + + def set_is_disconnect(self, config: OpenMetadataWorkflowConfig): + """Set the is_disconnect method based on the configured connection scheme""" + # pylint: disable=import-outside-toplevel + + # Get the configured scheme from the source connection + scheme = config.source.serviceConnection.root.config.scheme + + # Set the appropriate is_disconnect method based on the scheme + if scheme == MssqlScheme.mssql_pytds: + from sqlalchemy_pytds.dialect import MSDialect_pytds + + original_is_disconnect = MSDialect_pytds.is_disconnect + MSDialect_pytds.is_disconnect = is_disconnect(original_is_disconnect) + elif scheme == MssqlScheme.mssql_pyodbc: + from sqlalchemy.dialects.mssql.pyodbc import MSDialect_pyodbc + + original_is_disconnect = MSDialect_pyodbc.is_disconnect + MSDialect_pyodbc.is_disconnect = is_disconnect(original_is_disconnect) + elif scheme == MssqlScheme.mssql_pymssql: + from sqlalchemy.dialects.mssql.pymssql import MSDialect_pymssql + + original_is_disconnect = MSDialect_pymssql.is_disconnect + MSDialect_pymssql.is_disconnect = is_disconnect(original_is_disconnect) + else: + raise ValueError(f"Unsupported MSSQL scheme: {scheme}") diff --git a/ingestion/src/metadata/profiler/source/fetcher/profiler_source_factory.py b/ingestion/src/metadata/profiler/source/fetcher/profiler_source_factory.py index e317af31f98..cc3714a04d8 100644 --- a/ingestion/src/metadata/profiler/source/fetcher/profiler_source_factory.py +++ b/ingestion/src/metadata/profiler/source/fetcher/profiler_source_factory.py @@ -21,6 +21,9 @@ from metadata.generated.schema.entity.services.connections.database.bigQueryConn from metadata.generated.schema.entity.services.connections.database.databricksConnection import ( DatabricksType, ) +from metadata.generated.schema.entity.services.connections.database.mssqlConnection import ( + MssqlType, +) from metadata.profiler.source.profiler_source_interface import ProfilerSourceInterface @@ -79,10 +82,20 @@ class ProfilerSourceFactory: return DataBricksProfilerSource + @staticmethod + def mssql() -> Type[ProfilerSourceInterface]: + """Lazy loading of the MSSQL source""" + from metadata.profiler.source.database.mssql.profiler_source import ( + MssqlProfilerSource, + ) + + return MssqlProfilerSource + source = { BigqueryType.BigQuery.value.lower(): ProfilerSourceFactory.bigquery, DatabricksType.Databricks.value.lower(): ProfilerSourceFactory.databricks, + MssqlType.Mssql.value.lower(): ProfilerSourceFactory.mssql, } profiler_source_factory = ProfilerSourceFactory()