diff --git a/ingestion/src/metadata/sampler/processor.py b/ingestion/src/metadata/sampler/processor.py index fe982c788a8..16f0acce563 100644 --- a/ingestion/src/metadata/sampler/processor.py +++ b/ingestion/src/metadata/sampler/processor.py @@ -99,7 +99,7 @@ class SamplerProcessor(Processor): data=sampler_interface.generate_sample_data(), store=self.source_config.storeSampleData, ) - + sampler_interface.close() return Either( right=SamplerResponse( table=entity, diff --git a/ingestion/src/metadata/sampler/sampler_interface.py b/ingestion/src/metadata/sampler/sampler_interface.py index d2b15969b13..6bee6ef5de4 100644 --- a/ingestion/src/metadata/sampler/sampler_interface.py +++ b/ingestion/src/metadata/sampler/sampler_interface.py @@ -245,3 +245,6 @@ class SamplerInterface(ABC): logger.debug(traceback.format_exc()) logger.warning(f"Error fetching sample data: {err}") raise err + + def close(self): + """Default noop""" diff --git a/ingestion/src/metadata/sampler/sqlalchemy/sampler.py b/ingestion/src/metadata/sampler/sqlalchemy/sampler.py index e030e9c51f6..27d51398598 100644 --- a/ingestion/src/metadata/sampler/sqlalchemy/sampler.py +++ b/ingestion/src/metadata/sampler/sqlalchemy/sampler.py @@ -99,16 +99,16 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin): """ # only sample the column if we are computing a column metric to limit the amount of data scaned selectable = self.set_tablesample(self.raw_dataset.__table__) + client = self.get_client() entity = selectable if column is None else selectable.c.get(column.key) - with self.get_client() as client: - if label is not None: - query = client.query(entity, label) - else: - query = client.query(entity) + if label is not None: + query = client.query(entity, label) + else: + query = client.query(entity) - if self.partition_details: - query = self.get_partitioned_query(query) + if self.partition_details: + query = self.get_partitioned_query(query) return query def get_sampler_table_name(self) -> str: @@ -124,19 +124,18 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin): def get_sample_query(self, *, column=None) -> Query: """get query for sample data""" + client = self.get_client() if self.sample_config.profileSampleType == ProfileSampleType.PERCENTAGE: rnd = self._base_sample_query( column, (ModuloFn(RandomNumFn(), 100)).label(RANDOM_LABEL), ).cte(f"{self.get_sampler_table_name()}_rnd") - with self.get_client() as client: - session_query = client.query(rnd) + session_query = client.query(rnd) return session_query.where( rnd.c.random <= self.sample_config.profileSample ).cte(f"{self.get_sampler_table_name()}_sample") - with self.get_client() as client: - table_query = client.query(self.raw_dataset) + table_query = client.query(self.raw_dataset) session_query = self._base_sample_query( column, (ModuloFn(RandomNumFn(), table_query.count())).label(RANDOM_LABEL) @@ -197,6 +196,7 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin): for col in inspect(ds).c if col.name != RANDOM_LABEL and col.name in names ] + with self.get_client() as client: sqa_sample = ( client.query(*sqa_columns) @@ -204,10 +204,11 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin): .limit(self.sample_limit) .all() ) - return TableData( - columns=[column.name for column in sqa_columns], - rows=[list(row) for row in sqa_sample], - ) + + return TableData( + columns=[column.name for column in sqa_columns], + rows=[list(row) for row in sqa_sample], + ) def _fetch_sample_data_from_user_query(self) -> TableData: """Returns a table data object using results from query execution""" @@ -232,10 +233,11 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin): stmt = text(f"{self.sample_query}") stmt = stmt.columns(*list(inspect(self.raw_dataset).c)) - with self.get_client() as client: - return client.query(stmt.subquery()).cte( - f"{self.get_sampler_table_name()}_user_sampled" - ) + return ( + self.get_client() + .query(stmt.subquery()) + .cte(f"{self.get_sampler_table_name()}_user_sampled") + ) def _partitioned_table(self) -> Query: """Return the Query object for partitioned tables""" @@ -252,9 +254,13 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin): ) if query is not None: return query.filter(partition_filter) - with self.get_client() as client: - return client.query(self.raw_dataset).filter(partition_filter) + return self.get_client().query(self.raw_dataset).filter(partition_filter) def get_columns(self): """get columns from entity""" return list(inspect(self.raw_dataset).c) + + def close(self): + """Close the connection""" + self.get_client().close() + self.connection.pool.dispose() diff --git a/ingestion/tests/integration/mysql/conftest.py b/ingestion/tests/integration/mysql/conftest.py index 5c1b7d2bf97..60349c73c31 100644 --- a/ingestion/tests/integration/mysql/conftest.py +++ b/ingestion/tests/integration/mysql/conftest.py @@ -54,7 +54,7 @@ def mysql_container(tmp_path_factory): engine.dispose() assert_dangling_connections(container, 1) yield container - # TODO: We are still leaving some connections open. Should be fixed in the future. + # Needs to be handled for Test Cases https://github.com/open-metadata/OpenMetadata/issues/21187 assert_dangling_connections(container, 9)