ISSUE #21146 - Properly handle connection on sampler (#21186)

* fix: properly close connection on sampler ingestion

* fix: dangling connection test

* style: ran python linting

* fix: revert to 9
This commit is contained in:
Teddy 2025-05-15 12:21:01 +02:00 committed by GitHub
parent 87463df51d
commit cd6434dd73
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 32 additions and 23 deletions

View File

@ -99,7 +99,7 @@ class SamplerProcessor(Processor):
data=sampler_interface.generate_sample_data(),
store=self.source_config.storeSampleData,
)
sampler_interface.close()
return Either(
right=SamplerResponse(
table=entity,

View File

@ -245,3 +245,6 @@ class SamplerInterface(ABC):
logger.debug(traceback.format_exc())
logger.warning(f"Error fetching sample data: {err}")
raise err
def close(self):
"""Default noop"""

View File

@ -99,16 +99,16 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
"""
# only sample the column if we are computing a column metric to limit the amount of data scaned
selectable = self.set_tablesample(self.raw_dataset.__table__)
client = self.get_client()
entity = selectable if column is None else selectable.c.get(column.key)
with self.get_client() as client:
if label is not None:
query = client.query(entity, label)
else:
query = client.query(entity)
if label is not None:
query = client.query(entity, label)
else:
query = client.query(entity)
if self.partition_details:
query = self.get_partitioned_query(query)
if self.partition_details:
query = self.get_partitioned_query(query)
return query
def get_sampler_table_name(self) -> str:
@ -124,19 +124,18 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
def get_sample_query(self, *, column=None) -> Query:
"""get query for sample data"""
client = self.get_client()
if self.sample_config.profileSampleType == ProfileSampleType.PERCENTAGE:
rnd = self._base_sample_query(
column,
(ModuloFn(RandomNumFn(), 100)).label(RANDOM_LABEL),
).cte(f"{self.get_sampler_table_name()}_rnd")
with self.get_client() as client:
session_query = client.query(rnd)
session_query = client.query(rnd)
return session_query.where(
rnd.c.random <= self.sample_config.profileSample
).cte(f"{self.get_sampler_table_name()}_sample")
with self.get_client() as client:
table_query = client.query(self.raw_dataset)
table_query = client.query(self.raw_dataset)
session_query = self._base_sample_query(
column,
(ModuloFn(RandomNumFn(), table_query.count())).label(RANDOM_LABEL)
@ -197,6 +196,7 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
for col in inspect(ds).c
if col.name != RANDOM_LABEL and col.name in names
]
with self.get_client() as client:
sqa_sample = (
client.query(*sqa_columns)
@ -204,10 +204,11 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
.limit(self.sample_limit)
.all()
)
return TableData(
columns=[column.name for column in sqa_columns],
rows=[list(row) for row in sqa_sample],
)
return TableData(
columns=[column.name for column in sqa_columns],
rows=[list(row) for row in sqa_sample],
)
def _fetch_sample_data_from_user_query(self) -> TableData:
"""Returns a table data object using results from query execution"""
@ -232,10 +233,11 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
stmt = text(f"{self.sample_query}")
stmt = stmt.columns(*list(inspect(self.raw_dataset).c))
with self.get_client() as client:
return client.query(stmt.subquery()).cte(
f"{self.get_sampler_table_name()}_user_sampled"
)
return (
self.get_client()
.query(stmt.subquery())
.cte(f"{self.get_sampler_table_name()}_user_sampled")
)
def _partitioned_table(self) -> Query:
"""Return the Query object for partitioned tables"""
@ -252,9 +254,13 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
)
if query is not None:
return query.filter(partition_filter)
with self.get_client() as client:
return client.query(self.raw_dataset).filter(partition_filter)
return self.get_client().query(self.raw_dataset).filter(partition_filter)
def get_columns(self):
"""get columns from entity"""
return list(inspect(self.raw_dataset).c)
def close(self):
"""Close the connection"""
self.get_client().close()
self.connection.pool.dispose()

View File

@ -54,7 +54,7 @@ def mysql_container(tmp_path_factory):
engine.dispose()
assert_dangling_connections(container, 1)
yield container
# TODO: We are still leaving some connections open. Should be fixed in the future.
# Needs to be handled for Test Cases https://github.com/open-metadata/OpenMetadata/issues/21187
assert_dangling_connections(container, 9)