fix: explicitly state USE CATALOG for databricks connection (#10940)

This commit is contained in:
Teddy 2023-04-05 18:47:18 +02:00 committed by GitHub
parent 4683bee91a
commit 06b8d8e7ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 0 deletions

View File

@ -68,6 +68,7 @@ class SQATestSuiteInterface(SQAInterfaceMixin, TestSuiteProtocol):
get_connection(self.service_connection_config)
)
self.set_session_tag(self.session)
self.set_catalog(self.session)
self._table = self._convert_table_to_orm_object(sqa_metadata_obj)

View File

@ -20,6 +20,9 @@ from typing import Optional
from sqlalchemy import Column, MetaData, inspect
from sqlalchemy.orm import DeclarativeMeta
from metadata.generated.schema.entity.services.connections.database.databricksConnection import (
DatabricksConnection,
)
from metadata.generated.schema.entity.services.connections.database.snowflakeConnection import (
SnowflakeType,
)
@ -80,6 +83,19 @@ class SQAInterfaceMixin:
)
)
def set_catalog(self, session) -> None:
"""Set catalog for the session. Right now only databricks requires it
Args:
session (Session): sqa session object
"""
if isinstance(self.service_connection_config, DatabricksConnection):
bind = session.get_bind()
bind.execute(
"USE CATALOG %(catalog)s;",
{"catalog": self.service_connection_config.catalog},
).first()
def close(self):
"""close session"""
self.session.close()

View File

@ -100,6 +100,7 @@ class SQAProfilerInterface(ProfilerProtocol, SQAInterfaceMixin):
self.session_factory = self._session_factory(service_connection_config)
self.session = self.session_factory()
self.set_session_tag(self.session)
self.set_catalog(self.session)
self.profile_sample_config = profile_sample_config
self.profile_query = sample_query
@ -392,6 +393,7 @@ class SQAProfilerInterface(ProfilerProtocol, SQAInterfaceMixin):
Session = self.session_factory # pylint: disable=invalid-name
with Session() as session:
self.set_session_tag(session)
self.set_catalog(session)
sampler = self._create_thread_safe_sampler(
session,
table,