Fix snowflake test connection (#13354)

* Fix: Flaky Test Connection for Snowflake

* optimize code

* pyformat
This commit is contained in:
Mayur Singal 2023-09-27 23:26:20 +05:30 committed by GitHub
parent 74e29a9f16
commit 1dee004dee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 8 deletions

View File

@ -54,6 +54,7 @@ logger = ingestion_logger()
class SnowflakeEngineWrapper(BaseModel): class SnowflakeEngineWrapper(BaseModel):
service_connection: SnowflakeConnection service_connection: SnowflakeConnection
engine: Any engine: Any
database_name: Optional[str]
def get_connection_url(connection: SnowflakeConnection) -> str: def get_connection_url(connection: SnowflakeConnection) -> str:
@ -160,7 +161,7 @@ def test_connection(
`get_table_names` function with our custom queries. `get_table_names` function with our custom queries.
""" """
engine_wrapper = SnowflakeEngineWrapper( engine_wrapper = SnowflakeEngineWrapper(
service_connection=service_connection, engine=engine service_connection=service_connection, engine=engine, database_name=None
) )
test_fn = { test_fn = {
"CheckAccess": partial(test_connection_engine_step, engine), "CheckAccess": partial(test_connection_engine_step, engine),
@ -171,7 +172,9 @@ def test_connection(
execute_inspector_func, engine_wrapper, "get_schema_names" execute_inspector_func, engine_wrapper, "get_schema_names"
), ),
"GetTables": partial( "GetTables": partial(
test_query, statement=SNOWFLAKE_TEST_GET_TABLES, engine=engine test_table_query,
statement=SNOWFLAKE_TEST_GET_TABLES,
engine_wrapper=engine_wrapper,
), ),
"GetViews": partial(execute_inspector_func, engine_wrapper, "get_view_names"), "GetViews": partial(execute_inspector_func, engine_wrapper, "get_view_names"),
"GetQueries": partial( "GetQueries": partial(
@ -190,17 +193,37 @@ def test_connection(
) )
def _init_database(engine_wrapper: SnowflakeEngineWrapper):
"""
Initialize database
"""
if not engine_wrapper.service_connection.database:
if not engine_wrapper.database_name:
databases = engine_wrapper.engine.execute(SNOWFLAKE_GET_DATABASES)
for database in databases:
engine_wrapper.database_name = database.name
break
def execute_inspector_func(engine_wrapper: SnowflakeEngineWrapper, func_name: str): def execute_inspector_func(engine_wrapper: SnowflakeEngineWrapper, func_name: str):
""" """
Method to test connection via inspector functions, Method to test connection via inspector functions,
this function creates the inspector object and fetches this function creates the inspector object and fetches
the function with name `func_name` and executes it the function with name `func_name` and executes it
""" """
if not engine_wrapper.service_connection.database: _init_database(engine_wrapper)
databases = engine_wrapper.engine.execute(SNOWFLAKE_GET_DATABASES) engine_wrapper.engine.execute(f"USE DATABASE {engine_wrapper.database_name}")
for database in databases:
engine_wrapper.engine.execute(f"USE DATABASE {database.name}")
break
inspector = inspect(engine_wrapper.engine) inspector = inspect(engine_wrapper.engine)
inspector_fn = getattr(inspector, func_name) inspector_fn = getattr(inspector, func_name)
inspector_fn() inspector_fn()
def test_table_query(engine_wrapper: SnowflakeEngineWrapper, statement: str):
"""
Test Table queries
"""
_init_database(engine_wrapper)
test_query(
engine=engine_wrapper.engine,
statement=statement.format(database_name=engine_wrapper.database_name),
)

View File

@ -118,7 +118,7 @@ SELECT query_text from snowflake.account_usage.query_history limit 1
""" """
SNOWFLAKE_TEST_GET_TABLES = """ SNOWFLAKE_TEST_GET_TABLES = """
SELECT TABLE_NAME FROM information_schema.tables LIMIT 1 SELECT TABLE_NAME FROM "{database_name}".information_schema.tables LIMIT 1
""" """
SNOWFLAKE_GET_DATABASES = "SHOW DATABASES" SNOWFLAKE_GET_DATABASES = "SHOW DATABASES"