From ec0ca7010e86787b4b86a7c0e82a76c18c0ee79e Mon Sep 17 00:00:00 2001 From: Mayur Singal <39544459+ulixius9@users.noreply.github.com> Date: Wed, 29 Mar 2023 23:49:22 +0530 Subject: [PATCH] Fix Snowflake Test Connection when no database passed (#10831) Co-authored-by: Sachin Chaurasiya --- .../source/database/snowflake/connection.py | 76 +++++++++++++++---- 1 file changed, 61 insertions(+), 15 deletions(-) diff --git a/ingestion/src/metadata/ingestion/source/database/snowflake/connection.py b/ingestion/src/metadata/ingestion/source/database/snowflake/connection.py index f6cca091b6e..1d54ebfb0e3 100644 --- a/ingestion/src/metadata/ingestion/source/database/snowflake/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/snowflake/connection.py @@ -12,13 +12,15 @@ """ Source connection handler """ -from typing import Optional +from functools import partial +from typing import Any, Optional from urllib.parse import quote_plus from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization -from pydantic import SecretStr +from pydantic import BaseModel, SecretStr from sqlalchemy.engine import Engine +from sqlalchemy.inspection import inspect from metadata.generated.schema.entity.automations.workflow import ( Workflow as AutomationWorkflow, @@ -32,7 +34,11 @@ from metadata.ingestion.connections.builders import ( get_connection_options_dict, init_empty_connection_arguments, ) -from metadata.ingestion.connections.test_connections import test_connection_db_common +from metadata.ingestion.connections.test_connections import ( + test_connection_engine_step, + test_connection_steps, + test_query, +) from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.source.database.snowflake.queries import ( SNOWFLAKE_GET_DATABASES, @@ -44,6 +50,12 @@ from metadata.utils.logger import ingestion_logger logger = ingestion_logger() +class SnowflakeEngineWrapper(BaseModel): + service_connection: SnowflakeConnection + engine: Any + is_use_executed: bool = False + + def get_connection_url(connection: SnowflakeConnection) -> str: """ Set the connection URL @@ -132,16 +144,50 @@ def test_connection( Test connection. This can be executed either as part of a metadata workflow or during an Automation Workflow """ - - queries = { - "GetQueries": SNOWFLAKE_TEST_GET_QUERIES, - "GetDatabases": SNOWFLAKE_GET_DATABASES, - "GetTags": SNOWFLAKE_TEST_FETCH_TAG, - } - test_connection_db_common( - metadata=metadata, - engine=engine, - service_connection=service_connection, - automation_workflow=automation_workflow, - queries=queries, + engine_wrapper = SnowflakeEngineWrapper( + service_connection=service_connection, engine=engine, is_use_executed=False ) + test_fn = { + "CheckAccess": partial(test_connection_engine_step, engine), + "GetDatabases": partial( + test_query, statement=SNOWFLAKE_GET_DATABASES, engine=engine + ), + "GetSchemas": partial( + execute_inspector_func, engine_wrapper, "get_schema_names" + ), + "GetTables": partial(execute_inspector_func, engine_wrapper, "get_table_names"), + "GetViews": partial(execute_inspector_func, engine_wrapper, "get_view_names"), + "GetQueries": partial( + test_query, statement=SNOWFLAKE_TEST_GET_QUERIES, engine=engine + ), + "GetTags": partial( + test_query, statement=SNOWFLAKE_TEST_FETCH_TAG, engine=engine + ), + } + + test_connection_steps( + metadata=metadata, + test_fn=test_fn, + service_fqn=service_connection.type.value, + automation_workflow=automation_workflow, + ) + + +def execute_inspector_func(engine_wrapper: SnowflakeEngineWrapper, func_name: str): + """ + Method to test connection via inspector functions, + this function creates the inspector object and fetches + the function with name `func_name` and executes it + """ + if ( + not engine_wrapper.service_connection.database + and not engine_wrapper.is_use_executed + ): + databases = engine_wrapper.engine.execute(SNOWFLAKE_GET_DATABASES) + for database in databases: + engine_wrapper.engine.execute(f"USE DATABASE {database.name}") + engine_wrapper.is_use_executed = True + break + inspector = inspect(engine_wrapper.engine) + inspector_fn = getattr(inspector, func_name) + inspector_fn()