From 1260a0600a16e3262ff4d2a73c7ad443cda81a9b Mon Sep 17 00:00:00 2001 From: IceS2 Date: Mon, 7 Jul 2025 16:38:22 +0200 Subject: [PATCH] MINOR: Update Snowflake Connection (#22167) * Update Snowflake Connection * Extracting needed methods --- .../source/database/snowflake/connection.py | 356 ++++++++++-------- .../source/database/snowflake/service_spec.py | 2 + .../tests/unit/test_source_connection.py | 44 +-- 3 files changed, 211 insertions(+), 191 deletions(-) diff --git a/ingestion/src/metadata/ingestion/source/database/snowflake/connection.py b/ingestion/src/metadata/ingestion/source/database/snowflake/connection.py index a4ee6cf0eed..57716d58b68 100644 --- a/ingestion/src/metadata/ingestion/source/database/snowflake/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/snowflake/connection.py @@ -26,7 +26,7 @@ from metadata.generated.schema.entity.automations.workflow import ( Workflow as AutomationWorkflow, ) from metadata.generated.schema.entity.services.connections.database.snowflakeConnection import ( - SnowflakeConnection, + SnowflakeConnection as SnowflakeConnectionConfig, ) from metadata.generated.schema.entity.services.connections.testConnectionResult import ( TestConnectionResult, @@ -37,6 +37,7 @@ from metadata.ingestion.connections.builders import ( get_connection_options_dict, init_empty_connection_arguments, ) +from metadata.ingestion.connections.connection import BaseConnection from metadata.ingestion.connections.test_connections import ( test_connection_engine_step, test_connection_steps, @@ -58,171 +59,11 @@ logger = ingestion_logger() class SnowflakeEngineWrapper(BaseModel): - service_connection: SnowflakeConnection + service_connection: SnowflakeConnectionConfig engine: Any database_name: Optional[str] = None -def get_connection_url(connection: SnowflakeConnection) -> str: - """ - Set the connection URL - """ - url = f"{connection.scheme.value}://" - - if connection.username: - url += f"{quote_plus(connection.username)}" - if not connection.password: - connection.password = SecretStr("") - url += ( - f":{quote_plus(connection.password.get_secret_value())}" - if connection - else "" - ) - url += "@" - - url += connection.account - url += f"/{connection.database}" if connection.database else "" - - options = get_connection_options_dict(connection) - if options: - if not connection.database: - url += "/" - params = "&".join( - f"{key}={quote_plus(value)}" for (key, value) in options.items() if value - ) - url = f"{url}?{params}" - options = { - "account": connection.account, - "warehouse": connection.warehouse, - "role": connection.role, - } - params = "&".join(f"{key}={value}" for (key, value) in options.items() if value) - if params: - url = f"{url}?{params}" - return url - - -def get_connection(connection: SnowflakeConnection) -> Engine: - """ - Create connection - """ - if not connection.connectionArguments: - connection.connectionArguments = init_empty_connection_arguments() - - if connection.privateKey: - snowflake_private_key_passphrase = ( - connection.snowflakePrivatekeyPassphrase.get_secret_value() - if connection.snowflakePrivatekeyPassphrase - else "" - ) - - if not snowflake_private_key_passphrase: - logger.warning( - "Snowflake Private Key Passphrase not found, replacing it with empty string" - ) - p_key = serialization.load_pem_private_key( - bytes(connection.privateKey.get_secret_value(), "utf-8"), - password=snowflake_private_key_passphrase.encode() or None, - backend=default_backend(), - ) - pkb = p_key.private_bytes( - encoding=serialization.Encoding.DER, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - ) - - connection.connectionArguments.root["private_key"] = pkb - - if connection.clientSessionKeepAlive: - connection.connectionArguments.root[ - "client_session_keep_alive" - ] = connection.clientSessionKeepAlive - - engine = create_generic_db_connection( - connection=connection, - get_connection_url_fn=get_connection_url, - get_connection_args_fn=get_connection_args_common, - ) - if connection.connectionArguments.root and connection.connectionArguments.root.get( - "private_key" - ): - del connection.connectionArguments.root["private_key"] - return engine - - -def test_connection( - metadata: OpenMetadata, - engine: Engine, - service_connection: SnowflakeConnection, - automation_workflow: Optional[AutomationWorkflow] = None, - timeout_seconds: Optional[int] = THREE_MIN, -) -> TestConnectionResult: - """ - Test connection. This can be executed either as part - of a metadata workflow or during an Automation Workflow. - - Note how we run a custom GetTables query: - - The default inspector `get_table_names` runs a SHOW which - has a limit on 10000 rows in the result set: - https://github.com/open-metadata/OpenMetadata/issues/12798 - - This can cause errors if we are running tests against schemas - with more tables than that. There is no issues during the metadata - ingestion since in metadata.py we are overriding the default - `get_table_names` function with our custom queries. - """ - engine_wrapper = SnowflakeEngineWrapper( - service_connection=service_connection, engine=engine, database_name=None - ) - test_fn = { - "CheckAccess": partial(test_connection_engine_step, engine), - "GetDatabases": partial( - test_query, statement=SNOWFLAKE_GET_DATABASES, engine=engine - ), - "GetSchemas": partial( - execute_inspector_func, engine_wrapper, "get_schema_names" - ), - "GetTables": partial( - test_table_query, - statement=SNOWFLAKE_TEST_GET_TABLES, - engine_wrapper=engine_wrapper, - ), - "GetViews": partial( - test_table_query, - statement=SNOWFLAKE_TEST_GET_VIEWS, - engine_wrapper=engine_wrapper, - ), - "GetStreams": partial( - test_table_query, - statement=SNOWFLAKE_TEST_GET_STREAMS, - engine_wrapper=engine_wrapper, - ), - "GetQueries": partial( - test_query, - statement=SNOWFLAKE_TEST_GET_QUERIES.format( - account_usage=service_connection.accountUsageSchema - ), - engine=engine, - ), - "GetTags": partial( - test_query, - statement=SNOWFLAKE_TEST_FETCH_TAG.format( - account_usage=service_connection.accountUsageSchema - ), - engine=engine, - ), - } - - return test_connection_steps( - metadata=metadata, - test_fn=test_fn, - service_type=service_connection.type.value, - automation_workflow=automation_workflow, - timeout_seconds=timeout_seconds, - ) - - def _init_database(engine_wrapper: SnowflakeEngineWrapper): """ Initialize database @@ -259,3 +100,194 @@ def test_table_query(engine_wrapper: SnowflakeEngineWrapper, statement: str): engine=engine_wrapper.engine, statement=statement.format(database_name=engine_wrapper.database_name), ) + + +class SnowflakeConnection(BaseConnection[SnowflakeConnectionConfig, Engine]): + def _get_client(self) -> Engine: + """ + Return the SQLAlchemy Engine for Snowflake. + """ + return self.get_connection() + + @staticmethod + def get_connection_url(connection: SnowflakeConnectionConfig) -> str: + """ + Set the connection URL + """ + + url = f"{connection.scheme.value}://" + + if connection.username: + url += f"{quote_plus(connection.username)}" + if not connection.password: + connection.password = SecretStr("") + url += ( + f":{quote_plus(connection.password.get_secret_value())}" + if connection + else "" + ) + url += "@" + + url += connection.account + url += f"/{connection.database}" if connection.database else "" + + options = get_connection_options_dict(connection) + if options: + if not connection.database: + url += "/" + params = "&".join( + f"{key}={quote_plus(value)}" + for (key, value) in options.items() + if value + ) + url = f"{url}?{params}" + options = { + "account": connection.account, + "warehouse": connection.warehouse, + "role": connection.role, + } + params = "&".join(f"{key}={value}" for (key, value) in options.items() if value) + if params: + url = f"{url}?{params}" + return url + + def _get_private_key(self) -> Optional[bytes]: + connection = self.service_connection + if connection.privateKey: + snowflake_private_key_passphrase = ( + connection.snowflakePrivatekeyPassphrase.get_secret_value() + if connection.snowflakePrivatekeyPassphrase + else "" + ) + + if not snowflake_private_key_passphrase: + logger.warning( + "Snowflake Private Key Passphrase not found, replacing it with empty string" + ) + p_key = serialization.load_pem_private_key( + bytes(connection.privateKey.get_secret_value(), "utf-8"), + password=snowflake_private_key_passphrase.encode() or None, + backend=default_backend(), + ) + pkb = p_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + return pkb + return None + + def _get_client_session_keep_alive(self) -> Optional[bool]: + connection = self.service_connection + if connection.clientSessionKeepAlive: + return connection.clientSessionKeepAlive + return None + + def get_connection(self) -> Engine: + """ + Create connection + """ + connection = self.service_connection + if not connection.connectionArguments: + connection.connectionArguments = init_empty_connection_arguments() + + if private_key := self._get_private_key(): + connection.connectionArguments.root["private_key"] = private_key + + if keep_alive := self._get_client_session_keep_alive(): + connection.connectionArguments.root[ + "client_session_keep_alive" + ] = keep_alive + + engine = create_generic_db_connection( + connection=connection, + get_connection_url_fn=self.get_connection_url, + get_connection_args_fn=get_connection_args_common, + ) + if ( + connection.connectionArguments.root + and connection.connectionArguments.root.get("private_key") + ): + del connection.connectionArguments.root["private_key"] + return engine + + def test_connection( + self, + metadata: OpenMetadata, + automation_workflow: Optional[AutomationWorkflow] = None, + timeout_seconds: Optional[int] = THREE_MIN, + ) -> TestConnectionResult: + """ + Test connection. This can be executed either as part + of a metadata workflow or during an Automation Workflow. + + Note how we run a custom GetTables query: + + The default inspector `get_table_names` runs a SHOW which + has a limit on 10000 rows in the result set: + https://github.com/open-metadata/OpenMetadata/issues/12798 + + This can cause errors if we are running tests against schemas + with more tables than that. There is no issues during the metadata + ingestion since in metadata.py we are overriding the default + `get_table_names` function with our custom queries. + """ + engine_wrapper = SnowflakeEngineWrapper( + service_connection=self.service_connection, + engine=self.client, + database_name=None, + ) + test_fn = { + "CheckAccess": partial(test_connection_engine_step, self.client), + "GetDatabases": partial( + test_query, statement=SNOWFLAKE_GET_DATABASES, engine=self.client + ), + "GetSchemas": partial( + execute_inspector_func, engine_wrapper, "get_schema_names" + ), + "GetTables": partial( + test_table_query, + statement=SNOWFLAKE_TEST_GET_TABLES, + engine_wrapper=engine_wrapper, + ), + "GetViews": partial( + test_table_query, + statement=SNOWFLAKE_TEST_GET_VIEWS, + engine_wrapper=engine_wrapper, + ), + "GetStreams": partial( + test_table_query, + statement=SNOWFLAKE_TEST_GET_STREAMS, + engine_wrapper=engine_wrapper, + ), + "GetQueries": partial( + test_query, + statement=SNOWFLAKE_TEST_GET_QUERIES.format( + account_usage=self.service_connection.accountUsageSchema + ), + engine=self.client, + ), + "GetTags": partial( + test_query, + statement=SNOWFLAKE_TEST_FETCH_TAG.format( + account_usage=self.service_connection.accountUsageSchema + ), + engine=self.client, + ), + } + + return test_connection_steps( + metadata=metadata, + test_fn=test_fn, + service_type=self.service_connection.type.value, + automation_workflow=automation_workflow, + timeout_seconds=timeout_seconds, + ) + + def get_connection_dict(self) -> dict: + """ + Return the connection dictionary for this service. + """ + raise NotImplementedError( + "get_connection_dict is not implemented for Snowflake" + ) diff --git a/ingestion/src/metadata/ingestion/source/database/snowflake/service_spec.py b/ingestion/src/metadata/ingestion/source/database/snowflake/service_spec.py index c08ba11212a..88e4467e683 100644 --- a/ingestion/src/metadata/ingestion/source/database/snowflake/service_spec.py +++ b/ingestion/src/metadata/ingestion/source/database/snowflake/service_spec.py @@ -1,6 +1,7 @@ from metadata.data_quality.interface.sqlalchemy.snowflake.test_suite_interface import ( SnowflakeTestSuiteInterface, ) +from metadata.ingestion.source.database.snowflake.connection import SnowflakeConnection from metadata.ingestion.source.database.snowflake.data_diff.data_diff import ( SnowflakeTableParameter, ) @@ -21,4 +22,5 @@ ServiceSpec = DefaultDatabaseSpec( test_suite_class=SnowflakeTestSuiteInterface, sampler_class=SnowflakeSampler, data_diff=SnowflakeTableParameter, + connection_class=SnowflakeConnection, ) diff --git a/ingestion/tests/unit/test_source_connection.py b/ingestion/tests/unit/test_source_connection.py index a58652772a8..d703f999202 100644 --- a/ingestion/tests/unit/test_source_connection.py +++ b/ingestion/tests/unit/test_source_connection.py @@ -96,7 +96,9 @@ from metadata.generated.schema.entity.services.connections.database.singleStoreC SingleStoreScheme, ) from metadata.generated.schema.entity.services.connections.database.snowflakeConnection import ( - SnowflakeConnection, + SnowflakeConnection as SnowflakeConnectionConfig, +) +from metadata.generated.schema.entity.services.connections.database.snowflakeConnection import ( SnowflakeScheme, ) from metadata.generated.schema.entity.services.connections.database.trinoConnection import ( @@ -114,6 +116,7 @@ from metadata.ingestion.connections.builders import ( get_connection_args_common, get_connection_url_common, ) +from metadata.ingestion.source.database.snowflake.connection import SnowflakeConnection from metadata.ingestion.source.database.trino.connection import TrinoConnection @@ -782,31 +785,9 @@ class SourceConnectionTest(TestCase): assert expected_url == get_connection_url_common(db2_conn_obj) def test_snowflake_url(self): - # connection arguments without db - - from metadata.ingestion.source.database.snowflake.connection import ( - get_connection_url, - ) - - expected_url = "snowflake://coding:Abhi@ue18849.us-east-2.aws?account=ue18849.us-east-2.aws&warehouse=COMPUTE_WH" - snowflake_conn_obj = SnowflakeConnection( - scheme=SnowflakeScheme.snowflake, - username="coding", - password="Abhi", - warehouse="COMPUTE_WH", - account="ue18849.us-east-2.aws", - ) - - assert expected_url == get_connection_url(snowflake_conn_obj) - - def test_snowflake_url(self): - from metadata.ingestion.source.database.snowflake.connection import ( - get_connection_url, - ) - # Passing @ in username and password expected_url = "snowflake://coding%40444:Abhi%40123@ue18849.us-east-2.aws?account=ue18849.us-east-2.aws&warehouse=COMPUTE_WH" - snowflake_conn_obj = SnowflakeConnection( + snowflake_conn_obj = SnowflakeConnectionConfig( scheme=SnowflakeScheme.snowflake, username="coding@444", password="Abhi@123", @@ -814,11 +795,13 @@ class SourceConnectionTest(TestCase): account="ue18849.us-east-2.aws", ) - assert expected_url == get_connection_url(snowflake_conn_obj) + assert expected_url == SnowflakeConnection.get_connection_url( + snowflake_conn_obj + ) # connection arguments with db expected_url = "snowflake://coding:Abhi@ue18849.us-east-2.aws/testdb?account=ue18849.us-east-2.aws&warehouse=COMPUTE_WH" - snowflake_conn_obj = SnowflakeConnection( + snowflake_conn_obj = SnowflakeConnectionConfig( scheme=SnowflakeScheme.snowflake, username="coding", password="Abhi", @@ -826,7 +809,10 @@ class SourceConnectionTest(TestCase): warehouse="COMPUTE_WH", account="ue18849.us-east-2.aws", ) - assert expected_url == get_connection_url(snowflake_conn_obj) + + assert expected_url == SnowflakeConnection.get_connection_url( + snowflake_conn_obj + ) def test_mysql_conn_arguments(self): # connection arguments without connectionArguments @@ -1008,7 +994,7 @@ class SourceConnectionTest(TestCase): def test_snowflake_conn_arguments(self): # connection arguments without connectionArguments expected_args = {} - snowflake_conn_obj = SnowflakeConnection( + snowflake_conn_obj = SnowflakeConnectionConfig( username="user", password="test-pwd", database="tiny", @@ -1020,7 +1006,7 @@ class SourceConnectionTest(TestCase): # connection arguments with connectionArguments expected_args = {"user": "user-to-be-impersonated"} - snowflake_conn_obj = SnowflakeConnection( + snowflake_conn_obj = SnowflakeConnectionConfig( username="user", password="test-pwd", database="tiny",