From c1a8553e07cf5defdfd8cdd62979a391a3b9aa57 Mon Sep 17 00:00:00 2001 From: Milan Bariya <52292922+MilanBariya@users.noreply.github.com> Date: Fri, 24 Feb 2023 11:43:11 +0530 Subject: [PATCH] Airflow API change for test connection (#10182) * Airflow API change for test connection * Improve logic * Pydantic change * Improve logic --- .../ingestion/connections/test_connections.py | 42 ++++++++++++------- .../source/database/athena/connection.py | 4 +- .../source/database/clickhouse/connection.py | 17 ++++++-- .../source/database/databricks/connection.py | 4 +- .../source/database/hive/connection.py | 4 +- .../source/database/mssql/connection.py | 4 +- .../source/database/mysql/connection.py | 18 ++++++-- .../source/database/postgres/connection.py | 4 +- .../source/database/redshift/connection.py | 4 +- .../operations/test_connection.py | 10 +++++ 10 files changed, 77 insertions(+), 34 deletions(-) diff --git a/ingestion/src/metadata/ingestion/connections/test_connections.py b/ingestion/src/metadata/ingestion/connections/test_connections.py index aac3385e1fc..2e58be085ce 100644 --- a/ingestion/src/metadata/ingestion/connections/test_connections.py +++ b/ingestion/src/metadata/ingestion/connections/test_connections.py @@ -55,31 +55,39 @@ class TestConnectionStep(BaseModel): mandatory: bool = True -def test_connection_steps(steps: List[TestConnectionStep]) -> None: +class TestConnectionResult(BaseModel): + failed: List[str] = [] + success: List[str] = [] + warning: List[str] = [] + + +def test_connection_steps(steps: List[TestConnectionStep]) -> str: """ Run all the function steps and raise any errors """ - errors = {} + + test_connection_result = TestConnectionResult() for step in steps: try: step.function() - except Exception as exc: - msg = f"Faild to {step.name}, {exc}" - if step.mandatory: - errors[ - step.name - ] = f"{msg} This is a mandatory step and we won't be able to extract necessary metadata" - else: - errors[ - step.name - ] = f"{msg} This is a optional. The ingestion will continue to work as expected" + test_connection_result.success.append(f"'{step.name}': Pass") - if errors: - raise SourceConnectionException(errors) + except Exception: + if step.mandatory: + test_connection_result.failed.append( + f"'{step.name}': This is a mandatory step and we won't be able to extract necessary metadata" + ) + + else: + test_connection_result.warning.append( + f"'{step.name}': This is a optional and the ingestion will continue to work as expected" + ) + + return test_connection_result @timeout(seconds=120) -def test_connection_db_common(connection: Engine, steps=None) -> None: +def test_connection_db_common(connection: Engine, steps=None) -> str: """ Default implementation is the engine to test. @@ -91,7 +99,7 @@ def test_connection_db_common(connection: Engine, steps=None) -> None: with connection.connect() as conn: conn.execute(ConnTestFn()) if steps: - test_connection_steps(steps) + return test_connection_steps(steps) except SourceConnectionException as exc: raise exc except OperationalError as err: @@ -100,3 +108,5 @@ def test_connection_db_common(connection: Engine, steps=None) -> None: except Exception as exc: msg = f"Unknown error connecting with {connection}: {exc}." raise SourceConnectionException(msg) from exc + + return None diff --git a/ingestion/src/metadata/ingestion/source/database/athena/connection.py b/ingestion/src/metadata/ingestion/source/database/athena/connection.py index ac0eb80290b..d4c1344bfff 100644 --- a/ingestion/src/metadata/ingestion/source/database/athena/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/athena/connection.py @@ -60,7 +60,7 @@ def get_connection(connection: AthenaConnection) -> Engine: ) -def test_connection(engine: Engine) -> None: +def test_connection(engine: Engine) -> str: """ Test connection """ @@ -80,4 +80,4 @@ def test_connection(engine: Engine) -> None: mandatory=False, ), ] - test_connection_db_common(engine, steps) + return test_connection_db_common(engine, steps) diff --git a/ingestion/src/metadata/ingestion/source/database/clickhouse/connection.py b/ingestion/src/metadata/ingestion/source/database/clickhouse/connection.py index 95a923f235a..8fcc33df302 100644 --- a/ingestion/src/metadata/ingestion/source/database/clickhouse/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/clickhouse/connection.py @@ -45,7 +45,7 @@ def get_connection(connection: ClickhouseConnection) -> Engine: ) -def test_connection(engine: Engine) -> None: +def test_connection(engine: Engine) -> str: """ Test Clickhouse connection """ @@ -55,13 +55,24 @@ def test_connection(engine: Engine) -> None: return list(cursor.all()) inspector = inspect(engine) + + def custom_executor_for_tables(): + schema_name = inspector.get_schema_names() + + if schema_name: + for schema in schema_name: + if schema not in ("INFORMATION_SCHEMA", "system"): + table_name = inspector.get_table_names(schema) + return table_name + return None + steps = [ TestConnectionStep( function=inspector.get_schema_names, name="Get Schemas", ), TestConnectionStep( - function=inspector.get_table_names, + function=partial(custom_executor_for_tables), name="Get Tables", ), TestConnectionStep( @@ -80,4 +91,4 @@ def test_connection(engine: Engine) -> None: ), ] - test_connection_db_common(engine, steps) + return test_connection_db_common(engine, steps) diff --git a/ingestion/src/metadata/ingestion/source/database/databricks/connection.py b/ingestion/src/metadata/ingestion/source/database/databricks/connection.py index 7bab6f7dfe3..ab833481c40 100644 --- a/ingestion/src/metadata/ingestion/source/database/databricks/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/databricks/connection.py @@ -52,7 +52,7 @@ def get_connection(connection: DatabricksConnection) -> Engine: ) -def test_connection(engine: Engine) -> None: +def test_connection(engine: Engine) -> str: """ Test connection """ @@ -90,4 +90,4 @@ def test_connection(engine: Engine) -> None: ), ] - test_connection_db_common(engine, steps) + return test_connection_db_common(engine, steps) diff --git a/ingestion/src/metadata/ingestion/source/database/hive/connection.py b/ingestion/src/metadata/ingestion/source/database/hive/connection.py index 72eae83f322..881de56fa48 100644 --- a/ingestion/src/metadata/ingestion/source/database/hive/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/hive/connection.py @@ -80,7 +80,7 @@ def get_connection(connection: HiveConnection) -> Engine: ) -def test_connection(engine: Engine) -> None: +def test_connection(engine: Engine) -> str: """ Test connection """ @@ -100,4 +100,4 @@ def test_connection(engine: Engine) -> None: mandatory=False, ), ] - test_connection_db_common(engine, steps) + return test_connection_db_common(engine, steps) diff --git a/ingestion/src/metadata/ingestion/source/database/mssql/connection.py b/ingestion/src/metadata/ingestion/source/database/mssql/connection.py index d0487da9411..e1452d047e8 100644 --- a/ingestion/src/metadata/ingestion/source/database/mssql/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/mssql/connection.py @@ -52,7 +52,7 @@ def get_connection(connection: MssqlConnection) -> Engine: ) -def test_connection(engine: MssqlConnection) -> None: +def test_connection(engine: MssqlConnection) -> str: """ Test connection """ @@ -95,4 +95,4 @@ def test_connection(engine: MssqlConnection) -> None: ), ] - test_connection_db_common(engine, steps) + return test_connection_db_common(engine, steps) diff --git a/ingestion/src/metadata/ingestion/source/database/mysql/connection.py b/ingestion/src/metadata/ingestion/source/database/mysql/connection.py index e471c23ea1e..01c08501bb8 100644 --- a/ingestion/src/metadata/ingestion/source/database/mysql/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/mysql/connection.py @@ -12,6 +12,8 @@ """ Source connection handler """ +from functools import partial + from sqlalchemy.engine import Engine from sqlalchemy.inspection import inspect @@ -40,18 +42,28 @@ def get_connection(connection: MysqlConnection) -> Engine: ) -def test_connection(engine: Engine) -> None: +def test_connection(engine: Engine) -> str: """ Test connection """ inspector = inspect(engine) + + def custom_executor(): + schema_name = inspector.get_schema_names() + if schema_name: + for schema in schema_name: + if schema not in ("information_schema", "performance_schema"): + table_name = inspector.get_table_names(schema) + return table_name + return None + steps = [ TestConnectionStep( function=inspector.get_schema_names, name="Get Schemas", ), TestConnectionStep( - function=inspector.get_table_names, + function=partial(custom_executor), name="Get Tables", ), TestConnectionStep( @@ -60,4 +72,4 @@ def test_connection(engine: Engine) -> None: mandatory=False, ), ] - test_connection_db_common(engine, steps) + return test_connection_db_common(engine, steps) diff --git a/ingestion/src/metadata/ingestion/source/database/postgres/connection.py b/ingestion/src/metadata/ingestion/source/database/postgres/connection.py index f0541627254..6d896d1ebee 100644 --- a/ingestion/src/metadata/ingestion/source/database/postgres/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/postgres/connection.py @@ -47,7 +47,7 @@ def get_connection(connection: PostgresConnection) -> Engine: ) -def test_connection(engine: Engine) -> None: +def test_connection(engine: Engine) -> str: """ Test connection """ @@ -99,4 +99,4 @@ def test_connection(engine: Engine) -> None: ), ] - test_connection_db_common(engine, steps) + return test_connection_db_common(engine, steps) diff --git a/ingestion/src/metadata/ingestion/source/database/redshift/connection.py b/ingestion/src/metadata/ingestion/source/database/redshift/connection.py index e0a697c33c0..135cc31faa8 100644 --- a/ingestion/src/metadata/ingestion/source/database/redshift/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/redshift/connection.py @@ -48,7 +48,7 @@ def get_connection(connection: RedshiftConnection) -> Engine: ) -def test_connection(engine: Engine) -> None: +def test_connection(engine: Engine) -> str: """ Test connection """ @@ -109,4 +109,4 @@ def test_connection(engine: Engine) -> None: ), ] - test_connection_db_common(engine, steps) + return test_connection_db_common(engine, steps) diff --git a/openmetadata-airflow-apis/openmetadata_managed_apis/operations/test_connection.py b/openmetadata-airflow-apis/openmetadata_managed_apis/operations/test_connection.py index 2e35a772e20..8cc022a5a2a 100644 --- a/openmetadata-airflow-apis/openmetadata_managed_apis/operations/test_connection.py +++ b/openmetadata-airflow-apis/openmetadata_managed_apis/operations/test_connection.py @@ -54,6 +54,16 @@ def test_source_connection( ) test_connection_fn(connection) + if test_connection_fn(connection): + msg = test_connection_fn(connection) + if msg.failed: + return ApiResponse.error( + status=ApiResponse.STATUS_SERVER_ERROR, + error=msg.json(), + ) + elif msg.success: + return ApiResponse.success({"message": msg.json()}) + except SourceConnectionException as exc: msg = f"Connection error from [{connection}]: {exc}" logger.debug(traceback.format_exc())