From 06b8d8e7ce0cd9aa54bc0f4bc62a19c251db533a Mon Sep 17 00:00:00 2001 From: Teddy Date: Wed, 5 Apr 2023 18:47:18 +0200 Subject: [PATCH] fix: explicitly state `USE CATALOG` for databricks connection (#10940) --- .../sqlalchemy/sqa_test_suite_interface.py | 1 + .../src/metadata/mixins/sqalchemy/sqa_mixin.py | 16 ++++++++++++++++ .../sqlalchemy/sqa_profiler_interface.py | 2 ++ 3 files changed, 19 insertions(+) diff --git a/ingestion/src/metadata/data_quality/interface/sqlalchemy/sqa_test_suite_interface.py b/ingestion/src/metadata/data_quality/interface/sqlalchemy/sqa_test_suite_interface.py index 5e665b4576b..bfe94034646 100644 --- a/ingestion/src/metadata/data_quality/interface/sqlalchemy/sqa_test_suite_interface.py +++ b/ingestion/src/metadata/data_quality/interface/sqlalchemy/sqa_test_suite_interface.py @@ -68,6 +68,7 @@ class SQATestSuiteInterface(SQAInterfaceMixin, TestSuiteProtocol): get_connection(self.service_connection_config) ) self.set_session_tag(self.session) + self.set_catalog(self.session) self._table = self._convert_table_to_orm_object(sqa_metadata_obj) diff --git a/ingestion/src/metadata/mixins/sqalchemy/sqa_mixin.py b/ingestion/src/metadata/mixins/sqalchemy/sqa_mixin.py index b0fdc38a7cf..48c233566ef 100644 --- a/ingestion/src/metadata/mixins/sqalchemy/sqa_mixin.py +++ b/ingestion/src/metadata/mixins/sqalchemy/sqa_mixin.py @@ -20,6 +20,9 @@ from typing import Optional from sqlalchemy import Column, MetaData, inspect from sqlalchemy.orm import DeclarativeMeta +from metadata.generated.schema.entity.services.connections.database.databricksConnection import ( + DatabricksConnection, +) from metadata.generated.schema.entity.services.connections.database.snowflakeConnection import ( SnowflakeType, ) @@ -80,6 +83,19 @@ class SQAInterfaceMixin: ) ) + def set_catalog(self, session) -> None: + """Set catalog for the session. Right now only databricks requires it + + Args: + session (Session): sqa session object + """ + if isinstance(self.service_connection_config, DatabricksConnection): + bind = session.get_bind() + bind.execute( + "USE CATALOG %(catalog)s;", + {"catalog": self.service_connection_config.catalog}, + ).first() + def close(self): """close session""" self.session.close() diff --git a/ingestion/src/metadata/profiler/interface/sqlalchemy/sqa_profiler_interface.py b/ingestion/src/metadata/profiler/interface/sqlalchemy/sqa_profiler_interface.py index 78d22f8406a..0c96a3dc61f 100644 --- a/ingestion/src/metadata/profiler/interface/sqlalchemy/sqa_profiler_interface.py +++ b/ingestion/src/metadata/profiler/interface/sqlalchemy/sqa_profiler_interface.py @@ -100,6 +100,7 @@ class SQAProfilerInterface(ProfilerProtocol, SQAInterfaceMixin): self.session_factory = self._session_factory(service_connection_config) self.session = self.session_factory() self.set_session_tag(self.session) + self.set_catalog(self.session) self.profile_sample_config = profile_sample_config self.profile_query = sample_query @@ -392,6 +393,7 @@ class SQAProfilerInterface(ProfilerProtocol, SQAInterfaceMixin): Session = self.session_factory # pylint: disable=invalid-name with Session() as session: self.set_session_tag(session) + self.set_catalog(session) sampler = self._create_thread_safe_sampler( session, table,