Fix Snowflake Test Connection when no database passed (#10831)

Co-authored-by: Sachin Chaurasiya <sachinchaurasiyachotey87@gmail.com>
This commit is contained in:
Mayur Singal 2023-03-29 23:49:22 +05:30 committed by GitHub
parent 54b635dd60
commit ec0ca7010e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -12,13 +12,15 @@
"""
Source connection handler
"""
from typing import Optional
from functools import partial
from typing import Any, Optional
from urllib.parse import quote_plus
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from pydantic import SecretStr
from pydantic import BaseModel, SecretStr
from sqlalchemy.engine import Engine
from sqlalchemy.inspection import inspect
from metadata.generated.schema.entity.automations.workflow import (
Workflow as AutomationWorkflow,
@ -32,7 +34,11 @@ from metadata.ingestion.connections.builders import (
get_connection_options_dict,
init_empty_connection_arguments,
)
from metadata.ingestion.connections.test_connections import test_connection_db_common
from metadata.ingestion.connections.test_connections import (
test_connection_engine_step,
test_connection_steps,
test_query,
)
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.database.snowflake.queries import (
SNOWFLAKE_GET_DATABASES,
@ -44,6 +50,12 @@ from metadata.utils.logger import ingestion_logger
logger = ingestion_logger()
class SnowflakeEngineWrapper(BaseModel):
service_connection: SnowflakeConnection
engine: Any
is_use_executed: bool = False
def get_connection_url(connection: SnowflakeConnection) -> str:
"""
Set the connection URL
@ -132,16 +144,50 @@ def test_connection(
Test connection. This can be executed either as part
of a metadata workflow or during an Automation Workflow
"""
queries = {
"GetQueries": SNOWFLAKE_TEST_GET_QUERIES,
"GetDatabases": SNOWFLAKE_GET_DATABASES,
"GetTags": SNOWFLAKE_TEST_FETCH_TAG,
}
test_connection_db_common(
metadata=metadata,
engine=engine,
service_connection=service_connection,
automation_workflow=automation_workflow,
queries=queries,
engine_wrapper = SnowflakeEngineWrapper(
service_connection=service_connection, engine=engine, is_use_executed=False
)
test_fn = {
"CheckAccess": partial(test_connection_engine_step, engine),
"GetDatabases": partial(
test_query, statement=SNOWFLAKE_GET_DATABASES, engine=engine
),
"GetSchemas": partial(
execute_inspector_func, engine_wrapper, "get_schema_names"
),
"GetTables": partial(execute_inspector_func, engine_wrapper, "get_table_names"),
"GetViews": partial(execute_inspector_func, engine_wrapper, "get_view_names"),
"GetQueries": partial(
test_query, statement=SNOWFLAKE_TEST_GET_QUERIES, engine=engine
),
"GetTags": partial(
test_query, statement=SNOWFLAKE_TEST_FETCH_TAG, engine=engine
),
}
test_connection_steps(
metadata=metadata,
test_fn=test_fn,
service_fqn=service_connection.type.value,
automation_workflow=automation_workflow,
)
def execute_inspector_func(engine_wrapper: SnowflakeEngineWrapper, func_name: str):
"""
Method to test connection via inspector functions,
this function creates the inspector object and fetches
the function with name `func_name` and executes it
"""
if (
not engine_wrapper.service_connection.database
and not engine_wrapper.is_use_executed
):
databases = engine_wrapper.engine.execute(SNOWFLAKE_GET_DATABASES)
for database in databases:
engine_wrapper.engine.execute(f"USE DATABASE {database.name}")
engine_wrapper.is_use_executed = True
break
inspector = inspect(engine_wrapper.engine)
inspector_fn = getattr(inspector, func_name)
inspector_fn()