diff --git a/ingestion/src/metadata/ingestion/source/database/athena/connection.py b/ingestion/src/metadata/ingestion/source/database/athena/connection.py index dcce9d912fc..c151be6fd69 100644 --- a/ingestion/src/metadata/ingestion/source/database/athena/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/athena/connection.py @@ -12,10 +12,13 @@ """ Source connection handler """ +from functools import partial from typing import Optional from urllib.parse import quote_plus from sqlalchemy.engine import Engine +from sqlalchemy.engine.reflection import Inspector +from sqlalchemy.inspection import inspect from metadata.clients.aws_client import AWSClient from metadata.generated.schema.entity.automations.workflow import ( @@ -28,7 +31,11 @@ from metadata.ingestion.connections.builders import ( create_generic_db_connection, get_connection_args_common, ) -from metadata.ingestion.connections.test_connections import test_connection_db_common +from metadata.ingestion.connections.test_connections import ( + execute_inspector_func, + test_connection_engine_step, + test_connection_steps, +) from metadata.ingestion.ometa.ometa_api import OpenMetadata @@ -92,9 +99,31 @@ def test_connection( Test connection. This can be executed either as part of a metadata workflow or during an Automation Workflow """ - test_connection_db_common( + + def get_test_schema(inspector: Inspector): + all_schemas = inspector.get_schema_names() + return all_schemas[0] if all_schemas else None + + def custom_executor_for_table(): + inspector = inspect(engine) + test_schema = get_test_schema(inspector) + return inspector.get_table_names(test_schema) if test_schema else [] + + def custom_executor_for_view(): + inspector = inspect(engine) + test_schema = get_test_schema(inspector) + return inspector.get_view_names(test_schema) if test_schema else [] + + test_fn = { + "CheckAccess": partial(test_connection_engine_step, engine), + "GetSchemas": partial(execute_inspector_func, engine, "get_schema_names"), + "GetTables": custom_executor_for_table, + "GetViews": custom_executor_for_view, + } + + test_connection_steps( metadata=metadata, - engine=engine, - service_connection=service_connection, + test_fn=test_fn, + service_type=service_connection.type.value, automation_workflow=automation_workflow, )