FIX #21955 - Handle sampler SQA sessions (#21994)

* FIX #21955

* FIX #21955
This commit is contained in:
Pere Miquel Brull 2025-06-27 08:58:25 +02:00 committed by GitHub
parent ea382c50db
commit 5f0f32c366
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 38 additions and 8 deletions

View File

@ -103,7 +103,6 @@ class HeuristicPIIClassifier(ColumnClassifier[PIITag]):
column_name: Optional[str] = None, column_name: Optional[str] = None,
column_data_type: Optional[DataType] = None, column_data_type: Optional[DataType] = None,
) -> Mapping[PIITag, float]: ) -> Mapping[PIITag, float]:
if column_data_type is not None and is_non_pii_datatype(column_data_type): if column_data_type is not None and is_non_pii_datatype(column_data_type):
return {} return {}

View File

@ -16,6 +16,10 @@ from typing import List, Protocol, Sequence
from dateutil.parser import parse from dateutil.parser import parse
from presidio_analyzer import RecognizerResult from presidio_analyzer import RecognizerResult
from metadata.utils.logger import pii_logger
logger = pii_logger()
class PresidioRecognizerResultPatcher(Protocol): class PresidioRecognizerResultPatcher(Protocol):
""" """
@ -76,8 +80,11 @@ def date_time_patcher(
# try to parse using dateutils, if it fails, skip the result # try to parse using dateutils, if it fails, skip the result
try: try:
_ = parse(text[result.start : result.end]) _ = parse(text[result.start : result.end])
except ValueError: except (ValueError, OverflowError):
# if parsing fails, skip the result # if parsing fails, skip the result
continue continue
except Exception as e:
logger.info("Unexpected error while parsing date time: %s", e)
continue
patched_result.append(result) patched_result.append(result)
return patched_result return patched_result

View File

@ -18,9 +18,13 @@ from metadata.sampler.sqlalchemy.sampler import SQASampler
class DatabricksSamplerInterface(SQASampler): class DatabricksSamplerInterface(SQASampler):
def __init__(self, *args, **kwargs):
"""Initialize with a single Databricks connection"""
super().__init__(*args, **kwargs)
self.connection = databricks_get_connection(self.service_connection_config)
def get_client(self): def get_client(self):
"""client is the session for SQA""" """client is the session for SQA"""
self.connection = databricks_get_connection(self.service_connection_config)
client = super().get_client() client = super().get_client()
self.set_catalog(client) self.set_catalog(client)
return client return client

View File

@ -13,6 +13,7 @@ Helper module to handle data sampling
for the profiler for the profiler
""" """
import hashlib import hashlib
from contextlib import contextmanager
from typing import List, Optional, Union, cast from typing import List, Optional, Union, cast
from sqlalchemy import Column, inspect, text from sqlalchemy import Column, inspect, text
@ -71,6 +72,7 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
self._table = self.build_table_orm( self._table = self.build_table_orm(
self.entity, self.service_connection_config, self.ometa_client self.entity, self.service_connection_config, self.ometa_client
) )
self._active_sessions = set()
@property @property
def raw_dataset(self): def raw_dataset(self):
@ -79,7 +81,20 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
def get_client(self): def get_client(self):
"""Build the SQA Client""" """Build the SQA Client"""
session_factory = create_and_bind_thread_safe_session(self.connection) session_factory = create_and_bind_thread_safe_session(self.connection)
return session_factory() session = session_factory()
self._active_sessions.add(session)
return session
@contextmanager
def get_client_context(self):
"""Get client as context manager for proper cleanup"""
session = self.get_client()
try:
yield session
finally:
if session in self._active_sessions:
self._active_sessions.remove(session)
session.close()
def set_tablesample(self, selectable: Table): def set_tablesample(self, selectable: Table):
"""Set the tablesample for the table. To be implemented by the child SQA sampler class """Set the tablesample for the table. To be implemented by the child SQA sampler class
@ -199,7 +214,7 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
if col.name != RANDOM_LABEL and col.name in names if col.name != RANDOM_LABEL and col.name in names
] ]
with self.get_client() as client: with self.get_client_context() as client:
sqa_sample = ( sqa_sample = (
client.query(*sqa_columns) client.query(*sqa_columns)
.select_from(ds) .select_from(ds)
@ -217,7 +232,7 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
if not is_safe_sql_query(self.sample_query): if not is_safe_sql_query(self.sample_query):
raise RuntimeError(f"SQL expression is not safe\n\n{self.sample_query}") raise RuntimeError(f"SQL expression is not safe\n\n{self.sample_query}")
with self.get_client() as client: with self.get_client_context() as client:
rnd = client.execute(f"{self.sample_query}") rnd = client.execute(f"{self.sample_query}")
try: try:
columns = [col.name for col in rnd.cursor.description] columns = [col.name for col in rnd.cursor.description]
@ -264,5 +279,6 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
def close(self): def close(self):
"""Close the connection""" """Close the connection"""
self.get_client().close() for session in self._active_sessions:
session.close()
self.connection.pool.dispose() self.connection.pool.dispose()

View File

@ -20,9 +20,13 @@ from metadata.sampler.sqlalchemy.sampler import SQASampler
class UnityCatalogSamplerInterface(SQASampler): class UnityCatalogSamplerInterface(SQASampler):
def __init__(self, *args, **kwargs):
"""Initialize with a single Databricks connection"""
super().__init__(*args, **kwargs)
self.connection = databricks_get_connection(self.service_connection_config)
def get_client(self): def get_client(self):
"""client is the session for SQA""" """client is the session for SQA"""
self.connection = databricks_get_connection(self.service_connection_config)
client = super().get_client() client = super().get_client()
self.set_catalog(client) self.set_catalog(client)
return client return client