Add Query Tag in Snowflake Profiler (#7106)

* Add Query Tag in Snowflake Profiler

* Revert conflict

* Move to session_factory

* Replace engine with session

* remove set_session_tag from init

* set_session_tag with session as args
This commit is contained in:
Mayur Singal 2022-09-02 18:28:57 +05:30 committed by GitHub
parent 52b100fbc6
commit de2904990d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -26,6 +26,9 @@ from sqlalchemy.engine.row import Row
from sqlalchemy.orm import DeclarativeMeta, Session
from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.entity.services.connections.database.snowflakeConnection import (
SnowflakeType,
)
from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import (
OpenMetadataConnection,
)
@ -55,6 +58,7 @@ from metadata.utils.constants import TEN_MIN
from metadata.utils.dispatch import enum_register
from metadata.utils.helpers import get_start_and_end
from metadata.utils.logger import sqa_interface_registry_logger
from metadata.utils.sql_queries import SNOWFLAKE_SESSION_TAG_QUERY
from metadata.utils.timeout import cls_timeout
logger = sqa_interface_registry_logger()
@ -86,9 +90,11 @@ class SQAInterface(InterfaceProtocol):
# Allows SQA Interface to be used without OM server config
self.table = table or self._convert_table_to_orm_object()
self.service_connection_config = service_connection_config
self.session_factory = self._session_factory(service_connection_config)
self.session: Session = self.session_factory()
self.set_session_tag(self.session)
self.profile_sample = profile_sample
self.profile_query = profile_query
@ -129,6 +135,23 @@ class SQAInterface(InterfaceProtocol):
engine = get_connection(service_connection_config)
return create_and_bind_thread_safe_session(engine)
def set_session_tag(self, session):
"""
Set session query tag
Args:
service_connection_config: connection details for the specific service
"""
if (
self.service_connection_config.type.value == SnowflakeType.Snowflake.value
and hasattr(self.service_connection_config, "queryTag")
and self.service_connection_config.queryTag
):
session.execute(
SNOWFLAKE_SESSION_TAG_QUERY.format(
query_tag=self.service_connection_config.queryTag
)
)
def _get_engine(self, service_connection_config):
"""Get engine for database
@ -229,6 +252,7 @@ class SQAInterface(InterfaceProtocol):
)
Session = self.session_factory
session = Session()
self.set_session_tag(session)
sampler = self._create_thread_safe_sampler(
session,
table,