Fix Snowflake Test Connection when no database passed (#10831)

Co-authored-by: Sachin Chaurasiya <sachinchaurasiyachotey87@gmail.com>
This commit is contained in:
Mayur Singal 2023-03-29 23:49:22 +05:30 committed by GitHub
parent 54b635dd60
commit ec0ca7010e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -12,13 +12,15 @@
""" """
Source connection handler Source connection handler
""" """
from typing import Optional from functools import partial
from typing import Any, Optional
from urllib.parse import quote_plus from urllib.parse import quote_plus
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives import serialization
from pydantic import SecretStr from pydantic import BaseModel, SecretStr
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
from sqlalchemy.inspection import inspect
from metadata.generated.schema.entity.automations.workflow import ( from metadata.generated.schema.entity.automations.workflow import (
Workflow as AutomationWorkflow, Workflow as AutomationWorkflow,
@ -32,7 +34,11 @@ from metadata.ingestion.connections.builders import (
get_connection_options_dict, get_connection_options_dict,
init_empty_connection_arguments, init_empty_connection_arguments,
) )
from metadata.ingestion.connections.test_connections import test_connection_db_common from metadata.ingestion.connections.test_connections import (
test_connection_engine_step,
test_connection_steps,
test_query,
)
from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.database.snowflake.queries import ( from metadata.ingestion.source.database.snowflake.queries import (
SNOWFLAKE_GET_DATABASES, SNOWFLAKE_GET_DATABASES,
@ -44,6 +50,12 @@ from metadata.utils.logger import ingestion_logger
logger = ingestion_logger() logger = ingestion_logger()
class SnowflakeEngineWrapper(BaseModel):
service_connection: SnowflakeConnection
engine: Any
is_use_executed: bool = False
def get_connection_url(connection: SnowflakeConnection) -> str: def get_connection_url(connection: SnowflakeConnection) -> str:
""" """
Set the connection URL Set the connection URL
@ -132,16 +144,50 @@ def test_connection(
Test connection. This can be executed either as part Test connection. This can be executed either as part
of a metadata workflow or during an Automation Workflow of a metadata workflow or during an Automation Workflow
""" """
engine_wrapper = SnowflakeEngineWrapper(
queries = { service_connection=service_connection, engine=engine, is_use_executed=False
"GetQueries": SNOWFLAKE_TEST_GET_QUERIES,
"GetDatabases": SNOWFLAKE_GET_DATABASES,
"GetTags": SNOWFLAKE_TEST_FETCH_TAG,
}
test_connection_db_common(
metadata=metadata,
engine=engine,
service_connection=service_connection,
automation_workflow=automation_workflow,
queries=queries,
) )
test_fn = {
"CheckAccess": partial(test_connection_engine_step, engine),
"GetDatabases": partial(
test_query, statement=SNOWFLAKE_GET_DATABASES, engine=engine
),
"GetSchemas": partial(
execute_inspector_func, engine_wrapper, "get_schema_names"
),
"GetTables": partial(execute_inspector_func, engine_wrapper, "get_table_names"),
"GetViews": partial(execute_inspector_func, engine_wrapper, "get_view_names"),
"GetQueries": partial(
test_query, statement=SNOWFLAKE_TEST_GET_QUERIES, engine=engine
),
"GetTags": partial(
test_query, statement=SNOWFLAKE_TEST_FETCH_TAG, engine=engine
),
}
test_connection_steps(
metadata=metadata,
test_fn=test_fn,
service_fqn=service_connection.type.value,
automation_workflow=automation_workflow,
)
def execute_inspector_func(engine_wrapper: SnowflakeEngineWrapper, func_name: str):
"""
Method to test connection via inspector functions,
this function creates the inspector object and fetches
the function with name `func_name` and executes it
"""
if (
not engine_wrapper.service_connection.database
and not engine_wrapper.is_use_executed
):
databases = engine_wrapper.engine.execute(SNOWFLAKE_GET_DATABASES)
for database in databases:
engine_wrapper.engine.execute(f"USE DATABASE {database.name}")
engine_wrapper.is_use_executed = True
break
inspector = inspect(engine_wrapper.engine)
inspector_fn = getattr(inspector, func_name)
inspector_fn()