Airflow API change for test connection (#10182)

* Airflow API change for test connection

* Improve logic

* Pydantic change

* Improve logic
This commit is contained in:
Milan Bariya 2023-02-24 11:43:11 +05:30 committed by GitHub
parent b2bd2e1463
commit c1a8553e07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 77 additions and 34 deletions

View File

@ -55,31 +55,39 @@ class TestConnectionStep(BaseModel):
mandatory: bool = True
def test_connection_steps(steps: List[TestConnectionStep]) -> None:
class TestConnectionResult(BaseModel):
failed: List[str] = []
success: List[str] = []
warning: List[str] = []
def test_connection_steps(steps: List[TestConnectionStep]) -> str:
"""
Run all the function steps and raise any errors
"""
errors = {}
test_connection_result = TestConnectionResult()
for step in steps:
try:
step.function()
except Exception as exc:
msg = f"Faild to {step.name}, {exc}"
if step.mandatory:
errors[
step.name
] = f"{msg} This is a mandatory step and we won't be able to extract necessary metadata"
else:
errors[
step.name
] = f"{msg} This is a optional. The ingestion will continue to work as expected"
test_connection_result.success.append(f"'{step.name}': Pass")
if errors:
raise SourceConnectionException(errors)
except Exception:
if step.mandatory:
test_connection_result.failed.append(
f"'{step.name}': This is a mandatory step and we won't be able to extract necessary metadata"
)
else:
test_connection_result.warning.append(
f"'{step.name}': This is a optional and the ingestion will continue to work as expected"
)
return test_connection_result
@timeout(seconds=120)
def test_connection_db_common(connection: Engine, steps=None) -> None:
def test_connection_db_common(connection: Engine, steps=None) -> str:
"""
Default implementation is the engine to test.
@ -91,7 +99,7 @@ def test_connection_db_common(connection: Engine, steps=None) -> None:
with connection.connect() as conn:
conn.execute(ConnTestFn())
if steps:
test_connection_steps(steps)
return test_connection_steps(steps)
except SourceConnectionException as exc:
raise exc
except OperationalError as err:
@ -100,3 +108,5 @@ def test_connection_db_common(connection: Engine, steps=None) -> None:
except Exception as exc:
msg = f"Unknown error connecting with {connection}: {exc}."
raise SourceConnectionException(msg) from exc
return None

View File

@ -60,7 +60,7 @@ def get_connection(connection: AthenaConnection) -> Engine:
)
def test_connection(engine: Engine) -> None:
def test_connection(engine: Engine) -> str:
"""
Test connection
"""
@ -80,4 +80,4 @@ def test_connection(engine: Engine) -> None:
mandatory=False,
),
]
test_connection_db_common(engine, steps)
return test_connection_db_common(engine, steps)

View File

@ -45,7 +45,7 @@ def get_connection(connection: ClickhouseConnection) -> Engine:
)
def test_connection(engine: Engine) -> None:
def test_connection(engine: Engine) -> str:
"""
Test Clickhouse connection
"""
@ -55,13 +55,24 @@ def test_connection(engine: Engine) -> None:
return list(cursor.all())
inspector = inspect(engine)
def custom_executor_for_tables():
schema_name = inspector.get_schema_names()
if schema_name:
for schema in schema_name:
if schema not in ("INFORMATION_SCHEMA", "system"):
table_name = inspector.get_table_names(schema)
return table_name
return None
steps = [
TestConnectionStep(
function=inspector.get_schema_names,
name="Get Schemas",
),
TestConnectionStep(
function=inspector.get_table_names,
function=partial(custom_executor_for_tables),
name="Get Tables",
),
TestConnectionStep(
@ -80,4 +91,4 @@ def test_connection(engine: Engine) -> None:
),
]
test_connection_db_common(engine, steps)
return test_connection_db_common(engine, steps)

View File

@ -52,7 +52,7 @@ def get_connection(connection: DatabricksConnection) -> Engine:
)
def test_connection(engine: Engine) -> None:
def test_connection(engine: Engine) -> str:
"""
Test connection
"""
@ -90,4 +90,4 @@ def test_connection(engine: Engine) -> None:
),
]
test_connection_db_common(engine, steps)
return test_connection_db_common(engine, steps)

View File

@ -80,7 +80,7 @@ def get_connection(connection: HiveConnection) -> Engine:
)
def test_connection(engine: Engine) -> None:
def test_connection(engine: Engine) -> str:
"""
Test connection
"""
@ -100,4 +100,4 @@ def test_connection(engine: Engine) -> None:
mandatory=False,
),
]
test_connection_db_common(engine, steps)
return test_connection_db_common(engine, steps)

View File

@ -52,7 +52,7 @@ def get_connection(connection: MssqlConnection) -> Engine:
)
def test_connection(engine: MssqlConnection) -> None:
def test_connection(engine: MssqlConnection) -> str:
"""
Test connection
"""
@ -95,4 +95,4 @@ def test_connection(engine: MssqlConnection) -> None:
),
]
test_connection_db_common(engine, steps)
return test_connection_db_common(engine, steps)

View File

@ -12,6 +12,8 @@
"""
Source connection handler
"""
from functools import partial
from sqlalchemy.engine import Engine
from sqlalchemy.inspection import inspect
@ -40,18 +42,28 @@ def get_connection(connection: MysqlConnection) -> Engine:
)
def test_connection(engine: Engine) -> None:
def test_connection(engine: Engine) -> str:
"""
Test connection
"""
inspector = inspect(engine)
def custom_executor():
schema_name = inspector.get_schema_names()
if schema_name:
for schema in schema_name:
if schema not in ("information_schema", "performance_schema"):
table_name = inspector.get_table_names(schema)
return table_name
return None
steps = [
TestConnectionStep(
function=inspector.get_schema_names,
name="Get Schemas",
),
TestConnectionStep(
function=inspector.get_table_names,
function=partial(custom_executor),
name="Get Tables",
),
TestConnectionStep(
@ -60,4 +72,4 @@ def test_connection(engine: Engine) -> None:
mandatory=False,
),
]
test_connection_db_common(engine, steps)
return test_connection_db_common(engine, steps)

View File

@ -47,7 +47,7 @@ def get_connection(connection: PostgresConnection) -> Engine:
)
def test_connection(engine: Engine) -> None:
def test_connection(engine: Engine) -> str:
"""
Test connection
"""
@ -99,4 +99,4 @@ def test_connection(engine: Engine) -> None:
),
]
test_connection_db_common(engine, steps)
return test_connection_db_common(engine, steps)

View File

@ -48,7 +48,7 @@ def get_connection(connection: RedshiftConnection) -> Engine:
)
def test_connection(engine: Engine) -> None:
def test_connection(engine: Engine) -> str:
"""
Test connection
"""
@ -109,4 +109,4 @@ def test_connection(engine: Engine) -> None:
),
]
test_connection_db_common(engine, steps)
return test_connection_db_common(engine, steps)

View File

@ -54,6 +54,16 @@ def test_source_connection(
)
test_connection_fn(connection)
if test_connection_fn(connection):
msg = test_connection_fn(connection)
if msg.failed:
return ApiResponse.error(
status=ApiResponse.STATUS_SERVER_ERROR,
error=msg.json(),
)
elif msg.success:
return ApiResponse.success({"message": msg.json()})
except SourceConnectionException as exc:
msg = f"Connection error from [{connection}]: {exc}"
logger.debug(traceback.format_exc())