mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-11-02 03:29:03 +00:00
Fix snowflake test connection (#13354)
* Fix: Flaky Test Connection for Snowflake * optimize code * pyformat
This commit is contained in:
parent
74e29a9f16
commit
1dee004dee
@ -54,6 +54,7 @@ logger = ingestion_logger()
|
||||
class SnowflakeEngineWrapper(BaseModel):
|
||||
service_connection: SnowflakeConnection
|
||||
engine: Any
|
||||
database_name: Optional[str]
|
||||
|
||||
|
||||
def get_connection_url(connection: SnowflakeConnection) -> str:
|
||||
@ -160,7 +161,7 @@ def test_connection(
|
||||
`get_table_names` function with our custom queries.
|
||||
"""
|
||||
engine_wrapper = SnowflakeEngineWrapper(
|
||||
service_connection=service_connection, engine=engine
|
||||
service_connection=service_connection, engine=engine, database_name=None
|
||||
)
|
||||
test_fn = {
|
||||
"CheckAccess": partial(test_connection_engine_step, engine),
|
||||
@ -171,7 +172,9 @@ def test_connection(
|
||||
execute_inspector_func, engine_wrapper, "get_schema_names"
|
||||
),
|
||||
"GetTables": partial(
|
||||
test_query, statement=SNOWFLAKE_TEST_GET_TABLES, engine=engine
|
||||
test_table_query,
|
||||
statement=SNOWFLAKE_TEST_GET_TABLES,
|
||||
engine_wrapper=engine_wrapper,
|
||||
),
|
||||
"GetViews": partial(execute_inspector_func, engine_wrapper, "get_view_names"),
|
||||
"GetQueries": partial(
|
||||
@ -190,17 +193,37 @@ def test_connection(
|
||||
)
|
||||
|
||||
|
||||
def _init_database(engine_wrapper: SnowflakeEngineWrapper):
|
||||
"""
|
||||
Initialize database
|
||||
"""
|
||||
if not engine_wrapper.service_connection.database:
|
||||
if not engine_wrapper.database_name:
|
||||
databases = engine_wrapper.engine.execute(SNOWFLAKE_GET_DATABASES)
|
||||
for database in databases:
|
||||
engine_wrapper.database_name = database.name
|
||||
break
|
||||
|
||||
|
||||
def execute_inspector_func(engine_wrapper: SnowflakeEngineWrapper, func_name: str):
|
||||
"""
|
||||
Method to test connection via inspector functions,
|
||||
this function creates the inspector object and fetches
|
||||
the function with name `func_name` and executes it
|
||||
"""
|
||||
if not engine_wrapper.service_connection.database:
|
||||
databases = engine_wrapper.engine.execute(SNOWFLAKE_GET_DATABASES)
|
||||
for database in databases:
|
||||
engine_wrapper.engine.execute(f"USE DATABASE {database.name}")
|
||||
break
|
||||
_init_database(engine_wrapper)
|
||||
engine_wrapper.engine.execute(f"USE DATABASE {engine_wrapper.database_name}")
|
||||
inspector = inspect(engine_wrapper.engine)
|
||||
inspector_fn = getattr(inspector, func_name)
|
||||
inspector_fn()
|
||||
|
||||
|
||||
def test_table_query(engine_wrapper: SnowflakeEngineWrapper, statement: str):
|
||||
"""
|
||||
Test Table queries
|
||||
"""
|
||||
_init_database(engine_wrapper)
|
||||
test_query(
|
||||
engine=engine_wrapper.engine,
|
||||
statement=statement.format(database_name=engine_wrapper.database_name),
|
||||
)
|
||||
|
||||
@ -118,7 +118,7 @@ SELECT query_text from snowflake.account_usage.query_history limit 1
|
||||
"""
|
||||
|
||||
SNOWFLAKE_TEST_GET_TABLES = """
|
||||
SELECT TABLE_NAME FROM information_schema.tables LIMIT 1
|
||||
SELECT TABLE_NAME FROM "{database_name}".information_schema.tables LIMIT 1
|
||||
"""
|
||||
|
||||
SNOWFLAKE_GET_DATABASES = "SHOW DATABASES"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user