diff --git a/ingestion/src/metadata/interfaces/sqa_interface.py b/ingestion/src/metadata/interfaces/sqa_interface.py index acfa9707f52..4e593824c5d 100644 --- a/ingestion/src/metadata/interfaces/sqa_interface.py +++ b/ingestion/src/metadata/interfaces/sqa_interface.py @@ -26,6 +26,9 @@ from sqlalchemy.engine.row import Row from sqlalchemy.orm import DeclarativeMeta, Session from metadata.generated.schema.entity.data.table import Table +from metadata.generated.schema.entity.services.connections.database.snowflakeConnection import ( + SnowflakeType, +) from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import ( OpenMetadataConnection, ) @@ -55,6 +58,7 @@ from metadata.utils.constants import TEN_MIN from metadata.utils.dispatch import enum_register from metadata.utils.helpers import get_start_and_end from metadata.utils.logger import sqa_interface_registry_logger +from metadata.utils.sql_queries import SNOWFLAKE_SESSION_TAG_QUERY from metadata.utils.timeout import cls_timeout logger = sqa_interface_registry_logger() @@ -86,9 +90,11 @@ class SQAInterface(InterfaceProtocol): # Allows SQA Interface to be used without OM server config self.table = table or self._convert_table_to_orm_object() + self.service_connection_config = service_connection_config self.session_factory = self._session_factory(service_connection_config) self.session: Session = self.session_factory() + self.set_session_tag(self.session) self.profile_sample = profile_sample self.profile_query = profile_query @@ -129,6 +135,23 @@ class SQAInterface(InterfaceProtocol): engine = get_connection(service_connection_config) return create_and_bind_thread_safe_session(engine) + def set_session_tag(self, session): + """ + Set session query tag + Args: + service_connection_config: connection details for the specific service + """ + if ( + self.service_connection_config.type.value == SnowflakeType.Snowflake.value + and hasattr(self.service_connection_config, "queryTag") + and self.service_connection_config.queryTag + ): + session.execute( + SNOWFLAKE_SESSION_TAG_QUERY.format( + query_tag=self.service_connection_config.queryTag + ) + ) + def _get_engine(self, service_connection_config): """Get engine for database @@ -229,6 +252,7 @@ class SQAInterface(InterfaceProtocol): ) Session = self.session_factory session = Session() + self.set_session_tag(session) sampler = self._create_thread_safe_sampler( session, table,