diff --git a/ingestion/src/metadata/ingestion/source/database/snowflake/connection.py b/ingestion/src/metadata/ingestion/source/database/snowflake/connection.py index 48e08300fd2..1313baf3435 100644 --- a/ingestion/src/metadata/ingestion/source/database/snowflake/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/snowflake/connection.py @@ -54,6 +54,7 @@ logger = ingestion_logger() class SnowflakeEngineWrapper(BaseModel): service_connection: SnowflakeConnection engine: Any + database_name: Optional[str] def get_connection_url(connection: SnowflakeConnection) -> str: @@ -160,7 +161,7 @@ def test_connection( `get_table_names` function with our custom queries. """ engine_wrapper = SnowflakeEngineWrapper( - service_connection=service_connection, engine=engine + service_connection=service_connection, engine=engine, database_name=None ) test_fn = { "CheckAccess": partial(test_connection_engine_step, engine), @@ -171,7 +172,9 @@ def test_connection( execute_inspector_func, engine_wrapper, "get_schema_names" ), "GetTables": partial( - test_query, statement=SNOWFLAKE_TEST_GET_TABLES, engine=engine + test_table_query, + statement=SNOWFLAKE_TEST_GET_TABLES, + engine_wrapper=engine_wrapper, ), "GetViews": partial(execute_inspector_func, engine_wrapper, "get_view_names"), "GetQueries": partial( @@ -190,17 +193,37 @@ def test_connection( ) +def _init_database(engine_wrapper: SnowflakeEngineWrapper): + """ + Initialize database + """ + if not engine_wrapper.service_connection.database: + if not engine_wrapper.database_name: + databases = engine_wrapper.engine.execute(SNOWFLAKE_GET_DATABASES) + for database in databases: + engine_wrapper.database_name = database.name + break + + def execute_inspector_func(engine_wrapper: SnowflakeEngineWrapper, func_name: str): """ Method to test connection via inspector functions, this function creates the inspector object and fetches the function with name `func_name` and executes it """ - if not engine_wrapper.service_connection.database: - databases = engine_wrapper.engine.execute(SNOWFLAKE_GET_DATABASES) - for database in databases: - engine_wrapper.engine.execute(f"USE DATABASE {database.name}") - break + _init_database(engine_wrapper) + engine_wrapper.engine.execute(f"USE DATABASE {engine_wrapper.database_name}") inspector = inspect(engine_wrapper.engine) inspector_fn = getattr(inspector, func_name) inspector_fn() + + +def test_table_query(engine_wrapper: SnowflakeEngineWrapper, statement: str): + """ + Test Table queries + """ + _init_database(engine_wrapper) + test_query( + engine=engine_wrapper.engine, + statement=statement.format(database_name=engine_wrapper.database_name), + ) diff --git a/ingestion/src/metadata/ingestion/source/database/snowflake/queries.py b/ingestion/src/metadata/ingestion/source/database/snowflake/queries.py index d0c90efc0ce..f2c58ccbb24 100644 --- a/ingestion/src/metadata/ingestion/source/database/snowflake/queries.py +++ b/ingestion/src/metadata/ingestion/source/database/snowflake/queries.py @@ -118,7 +118,7 @@ SELECT query_text from snowflake.account_usage.query_history limit 1 """ SNOWFLAKE_TEST_GET_TABLES = """ -SELECT TABLE_NAME FROM information_schema.tables LIMIT 1 +SELECT TABLE_NAME FROM "{database_name}".information_schema.tables LIMIT 1 """ SNOWFLAKE_GET_DATABASES = "SHOW DATABASES"