mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-09-01 13:13:10 +00:00
* fix: properly close connection on sampler ingestion * fix: dangling connection test * style: ran python linting * fix: revert to 9
This commit is contained in:
parent
87463df51d
commit
cd6434dd73
@ -99,7 +99,7 @@ class SamplerProcessor(Processor):
|
|||||||
data=sampler_interface.generate_sample_data(),
|
data=sampler_interface.generate_sample_data(),
|
||||||
store=self.source_config.storeSampleData,
|
store=self.source_config.storeSampleData,
|
||||||
)
|
)
|
||||||
|
sampler_interface.close()
|
||||||
return Either(
|
return Either(
|
||||||
right=SamplerResponse(
|
right=SamplerResponse(
|
||||||
table=entity,
|
table=entity,
|
||||||
|
@ -245,3 +245,6 @@ class SamplerInterface(ABC):
|
|||||||
logger.debug(traceback.format_exc())
|
logger.debug(traceback.format_exc())
|
||||||
logger.warning(f"Error fetching sample data: {err}")
|
logger.warning(f"Error fetching sample data: {err}")
|
||||||
raise err
|
raise err
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Default noop"""
|
||||||
|
@ -99,16 +99,16 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
|
|||||||
"""
|
"""
|
||||||
# only sample the column if we are computing a column metric to limit the amount of data scaned
|
# only sample the column if we are computing a column metric to limit the amount of data scaned
|
||||||
selectable = self.set_tablesample(self.raw_dataset.__table__)
|
selectable = self.set_tablesample(self.raw_dataset.__table__)
|
||||||
|
client = self.get_client()
|
||||||
|
|
||||||
entity = selectable if column is None else selectable.c.get(column.key)
|
entity = selectable if column is None else selectable.c.get(column.key)
|
||||||
with self.get_client() as client:
|
if label is not None:
|
||||||
if label is not None:
|
query = client.query(entity, label)
|
||||||
query = client.query(entity, label)
|
else:
|
||||||
else:
|
query = client.query(entity)
|
||||||
query = client.query(entity)
|
|
||||||
|
|
||||||
if self.partition_details:
|
if self.partition_details:
|
||||||
query = self.get_partitioned_query(query)
|
query = self.get_partitioned_query(query)
|
||||||
return query
|
return query
|
||||||
|
|
||||||
def get_sampler_table_name(self) -> str:
|
def get_sampler_table_name(self) -> str:
|
||||||
@ -124,19 +124,18 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
|
|||||||
|
|
||||||
def get_sample_query(self, *, column=None) -> Query:
|
def get_sample_query(self, *, column=None) -> Query:
|
||||||
"""get query for sample data"""
|
"""get query for sample data"""
|
||||||
|
client = self.get_client()
|
||||||
if self.sample_config.profileSampleType == ProfileSampleType.PERCENTAGE:
|
if self.sample_config.profileSampleType == ProfileSampleType.PERCENTAGE:
|
||||||
rnd = self._base_sample_query(
|
rnd = self._base_sample_query(
|
||||||
column,
|
column,
|
||||||
(ModuloFn(RandomNumFn(), 100)).label(RANDOM_LABEL),
|
(ModuloFn(RandomNumFn(), 100)).label(RANDOM_LABEL),
|
||||||
).cte(f"{self.get_sampler_table_name()}_rnd")
|
).cte(f"{self.get_sampler_table_name()}_rnd")
|
||||||
with self.get_client() as client:
|
session_query = client.query(rnd)
|
||||||
session_query = client.query(rnd)
|
|
||||||
return session_query.where(
|
return session_query.where(
|
||||||
rnd.c.random <= self.sample_config.profileSample
|
rnd.c.random <= self.sample_config.profileSample
|
||||||
).cte(f"{self.get_sampler_table_name()}_sample")
|
).cte(f"{self.get_sampler_table_name()}_sample")
|
||||||
|
|
||||||
with self.get_client() as client:
|
table_query = client.query(self.raw_dataset)
|
||||||
table_query = client.query(self.raw_dataset)
|
|
||||||
session_query = self._base_sample_query(
|
session_query = self._base_sample_query(
|
||||||
column,
|
column,
|
||||||
(ModuloFn(RandomNumFn(), table_query.count())).label(RANDOM_LABEL)
|
(ModuloFn(RandomNumFn(), table_query.count())).label(RANDOM_LABEL)
|
||||||
@ -197,6 +196,7 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
|
|||||||
for col in inspect(ds).c
|
for col in inspect(ds).c
|
||||||
if col.name != RANDOM_LABEL and col.name in names
|
if col.name != RANDOM_LABEL and col.name in names
|
||||||
]
|
]
|
||||||
|
|
||||||
with self.get_client() as client:
|
with self.get_client() as client:
|
||||||
sqa_sample = (
|
sqa_sample = (
|
||||||
client.query(*sqa_columns)
|
client.query(*sqa_columns)
|
||||||
@ -204,10 +204,11 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
|
|||||||
.limit(self.sample_limit)
|
.limit(self.sample_limit)
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
return TableData(
|
|
||||||
columns=[column.name for column in sqa_columns],
|
return TableData(
|
||||||
rows=[list(row) for row in sqa_sample],
|
columns=[column.name for column in sqa_columns],
|
||||||
)
|
rows=[list(row) for row in sqa_sample],
|
||||||
|
)
|
||||||
|
|
||||||
def _fetch_sample_data_from_user_query(self) -> TableData:
|
def _fetch_sample_data_from_user_query(self) -> TableData:
|
||||||
"""Returns a table data object using results from query execution"""
|
"""Returns a table data object using results from query execution"""
|
||||||
@ -232,10 +233,11 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
|
|||||||
|
|
||||||
stmt = text(f"{self.sample_query}")
|
stmt = text(f"{self.sample_query}")
|
||||||
stmt = stmt.columns(*list(inspect(self.raw_dataset).c))
|
stmt = stmt.columns(*list(inspect(self.raw_dataset).c))
|
||||||
with self.get_client() as client:
|
return (
|
||||||
return client.query(stmt.subquery()).cte(
|
self.get_client()
|
||||||
f"{self.get_sampler_table_name()}_user_sampled"
|
.query(stmt.subquery())
|
||||||
)
|
.cte(f"{self.get_sampler_table_name()}_user_sampled")
|
||||||
|
)
|
||||||
|
|
||||||
def _partitioned_table(self) -> Query:
|
def _partitioned_table(self) -> Query:
|
||||||
"""Return the Query object for partitioned tables"""
|
"""Return the Query object for partitioned tables"""
|
||||||
@ -252,9 +254,13 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
|
|||||||
)
|
)
|
||||||
if query is not None:
|
if query is not None:
|
||||||
return query.filter(partition_filter)
|
return query.filter(partition_filter)
|
||||||
with self.get_client() as client:
|
return self.get_client().query(self.raw_dataset).filter(partition_filter)
|
||||||
return client.query(self.raw_dataset).filter(partition_filter)
|
|
||||||
|
|
||||||
def get_columns(self):
|
def get_columns(self):
|
||||||
"""get columns from entity"""
|
"""get columns from entity"""
|
||||||
return list(inspect(self.raw_dataset).c)
|
return list(inspect(self.raw_dataset).c)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Close the connection"""
|
||||||
|
self.get_client().close()
|
||||||
|
self.connection.pool.dispose()
|
||||||
|
@ -54,7 +54,7 @@ def mysql_container(tmp_path_factory):
|
|||||||
engine.dispose()
|
engine.dispose()
|
||||||
assert_dangling_connections(container, 1)
|
assert_dangling_connections(container, 1)
|
||||||
yield container
|
yield container
|
||||||
# TODO: We are still leaving some connections open. Should be fixed in the future.
|
# Needs to be handled for Test Cases https://github.com/open-metadata/OpenMetadata/issues/21187
|
||||||
assert_dangling_connections(container, 9)
|
assert_dangling_connections(container, 9)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user