FIX #21955 - Handle sampler SQA sessions (#21994)

* FIX #21955

* FIX #21955
This commit is contained in:
Pere Miquel Brull 2025-06-27 08:58:25 +02:00 committed by GitHub
parent ea382c50db
commit 5f0f32c366
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 38 additions and 8 deletions

View File

@ -103,7 +103,6 @@ class HeuristicPIIClassifier(ColumnClassifier[PIITag]):
column_name: Optional[str] = None,
column_data_type: Optional[DataType] = None,
) -> Mapping[PIITag, float]:
if column_data_type is not None and is_non_pii_datatype(column_data_type):
return {}

View File

@ -16,6 +16,10 @@ from typing import List, Protocol, Sequence
from dateutil.parser import parse
from presidio_analyzer import RecognizerResult
from metadata.utils.logger import pii_logger
logger = pii_logger()
class PresidioRecognizerResultPatcher(Protocol):
"""
@ -76,8 +80,11 @@ def date_time_patcher(
# try to parse using dateutils, if it fails, skip the result
try:
_ = parse(text[result.start : result.end])
except ValueError:
except (ValueError, OverflowError):
# if parsing fails, skip the result
continue
except Exception as e:
logger.info("Unexpected error while parsing date time: %s", e)
continue
patched_result.append(result)
return patched_result

View File

@ -18,9 +18,13 @@ from metadata.sampler.sqlalchemy.sampler import SQASampler
class DatabricksSamplerInterface(SQASampler):
def __init__(self, *args, **kwargs):
"""Initialize with a single Databricks connection"""
super().__init__(*args, **kwargs)
self.connection = databricks_get_connection(self.service_connection_config)
def get_client(self):
"""client is the session for SQA"""
self.connection = databricks_get_connection(self.service_connection_config)
client = super().get_client()
self.set_catalog(client)
return client

View File

@ -13,6 +13,7 @@ Helper module to handle data sampling
for the profiler
"""
import hashlib
from contextlib import contextmanager
from typing import List, Optional, Union, cast
from sqlalchemy import Column, inspect, text
@ -71,6 +72,7 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
self._table = self.build_table_orm(
self.entity, self.service_connection_config, self.ometa_client
)
self._active_sessions = set()
@property
def raw_dataset(self):
@ -79,7 +81,20 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
def get_client(self):
"""Build the SQA Client"""
session_factory = create_and_bind_thread_safe_session(self.connection)
return session_factory()
session = session_factory()
self._active_sessions.add(session)
return session
@contextmanager
def get_client_context(self):
"""Get client as context manager for proper cleanup"""
session = self.get_client()
try:
yield session
finally:
if session in self._active_sessions:
self._active_sessions.remove(session)
session.close()
def set_tablesample(self, selectable: Table):
"""Set the tablesample for the table. To be implemented by the child SQA sampler class
@ -199,7 +214,7 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
if col.name != RANDOM_LABEL and col.name in names
]
with self.get_client() as client:
with self.get_client_context() as client:
sqa_sample = (
client.query(*sqa_columns)
.select_from(ds)
@ -217,7 +232,7 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
if not is_safe_sql_query(self.sample_query):
raise RuntimeError(f"SQL expression is not safe\n\n{self.sample_query}")
with self.get_client() as client:
with self.get_client_context() as client:
rnd = client.execute(f"{self.sample_query}")
try:
columns = [col.name for col in rnd.cursor.description]
@ -264,5 +279,6 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
def close(self):
"""Close the connection"""
self.get_client().close()
for session in self._active_sessions:
session.close()
self.connection.pool.dispose()

View File

@ -20,9 +20,13 @@ from metadata.sampler.sqlalchemy.sampler import SQASampler
class UnityCatalogSamplerInterface(SQASampler):
def __init__(self, *args, **kwargs):
"""Initialize with a single Databricks connection"""
super().__init__(*args, **kwargs)
self.connection = databricks_get_connection(self.service_connection_config)
def get_client(self):
"""client is the session for SQA"""
self.connection = databricks_get_connection(self.service_connection_config)
client = super().get_client()
self.set_catalog(client)
return client