mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-08-31 20:51:26 +00:00
* fix: properly close connection on sampler ingestion * fix: dangling connection test * style: ran python linting * fix: revert to 9
This commit is contained in:
parent
87463df51d
commit
cd6434dd73
@ -99,7 +99,7 @@ class SamplerProcessor(Processor):
|
||||
data=sampler_interface.generate_sample_data(),
|
||||
store=self.source_config.storeSampleData,
|
||||
)
|
||||
|
||||
sampler_interface.close()
|
||||
return Either(
|
||||
right=SamplerResponse(
|
||||
table=entity,
|
||||
|
@ -245,3 +245,6 @@ class SamplerInterface(ABC):
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(f"Error fetching sample data: {err}")
|
||||
raise err
|
||||
|
||||
def close(self):
|
||||
"""Default noop"""
|
||||
|
@ -99,16 +99,16 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
|
||||
"""
|
||||
# only sample the column if we are computing a column metric to limit the amount of data scaned
|
||||
selectable = self.set_tablesample(self.raw_dataset.__table__)
|
||||
client = self.get_client()
|
||||
|
||||
entity = selectable if column is None else selectable.c.get(column.key)
|
||||
with self.get_client() as client:
|
||||
if label is not None:
|
||||
query = client.query(entity, label)
|
||||
else:
|
||||
query = client.query(entity)
|
||||
if label is not None:
|
||||
query = client.query(entity, label)
|
||||
else:
|
||||
query = client.query(entity)
|
||||
|
||||
if self.partition_details:
|
||||
query = self.get_partitioned_query(query)
|
||||
if self.partition_details:
|
||||
query = self.get_partitioned_query(query)
|
||||
return query
|
||||
|
||||
def get_sampler_table_name(self) -> str:
|
||||
@ -124,19 +124,18 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
|
||||
|
||||
def get_sample_query(self, *, column=None) -> Query:
|
||||
"""get query for sample data"""
|
||||
client = self.get_client()
|
||||
if self.sample_config.profileSampleType == ProfileSampleType.PERCENTAGE:
|
||||
rnd = self._base_sample_query(
|
||||
column,
|
||||
(ModuloFn(RandomNumFn(), 100)).label(RANDOM_LABEL),
|
||||
).cte(f"{self.get_sampler_table_name()}_rnd")
|
||||
with self.get_client() as client:
|
||||
session_query = client.query(rnd)
|
||||
session_query = client.query(rnd)
|
||||
return session_query.where(
|
||||
rnd.c.random <= self.sample_config.profileSample
|
||||
).cte(f"{self.get_sampler_table_name()}_sample")
|
||||
|
||||
with self.get_client() as client:
|
||||
table_query = client.query(self.raw_dataset)
|
||||
table_query = client.query(self.raw_dataset)
|
||||
session_query = self._base_sample_query(
|
||||
column,
|
||||
(ModuloFn(RandomNumFn(), table_query.count())).label(RANDOM_LABEL)
|
||||
@ -197,6 +196,7 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
|
||||
for col in inspect(ds).c
|
||||
if col.name != RANDOM_LABEL and col.name in names
|
||||
]
|
||||
|
||||
with self.get_client() as client:
|
||||
sqa_sample = (
|
||||
client.query(*sqa_columns)
|
||||
@ -204,10 +204,11 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
|
||||
.limit(self.sample_limit)
|
||||
.all()
|
||||
)
|
||||
return TableData(
|
||||
columns=[column.name for column in sqa_columns],
|
||||
rows=[list(row) for row in sqa_sample],
|
||||
)
|
||||
|
||||
return TableData(
|
||||
columns=[column.name for column in sqa_columns],
|
||||
rows=[list(row) for row in sqa_sample],
|
||||
)
|
||||
|
||||
def _fetch_sample_data_from_user_query(self) -> TableData:
|
||||
"""Returns a table data object using results from query execution"""
|
||||
@ -232,10 +233,11 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
|
||||
|
||||
stmt = text(f"{self.sample_query}")
|
||||
stmt = stmt.columns(*list(inspect(self.raw_dataset).c))
|
||||
with self.get_client() as client:
|
||||
return client.query(stmt.subquery()).cte(
|
||||
f"{self.get_sampler_table_name()}_user_sampled"
|
||||
)
|
||||
return (
|
||||
self.get_client()
|
||||
.query(stmt.subquery())
|
||||
.cte(f"{self.get_sampler_table_name()}_user_sampled")
|
||||
)
|
||||
|
||||
def _partitioned_table(self) -> Query:
|
||||
"""Return the Query object for partitioned tables"""
|
||||
@ -252,9 +254,13 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
|
||||
)
|
||||
if query is not None:
|
||||
return query.filter(partition_filter)
|
||||
with self.get_client() as client:
|
||||
return client.query(self.raw_dataset).filter(partition_filter)
|
||||
return self.get_client().query(self.raw_dataset).filter(partition_filter)
|
||||
|
||||
def get_columns(self):
|
||||
"""get columns from entity"""
|
||||
return list(inspect(self.raw_dataset).c)
|
||||
|
||||
def close(self):
|
||||
"""Close the connection"""
|
||||
self.get_client().close()
|
||||
self.connection.pool.dispose()
|
||||
|
@ -54,7 +54,7 @@ def mysql_container(tmp_path_factory):
|
||||
engine.dispose()
|
||||
assert_dangling_connections(container, 1)
|
||||
yield container
|
||||
# TODO: We are still leaving some connections open. Should be fixed in the future.
|
||||
# Needs to be handled for Test Cases https://github.com/open-metadata/OpenMetadata/issues/21187
|
||||
assert_dangling_connections(container, 9)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user