mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-08-31 04:14:34 +00:00
Airflow API change for test connection (#10182)
* Airflow API change for test connection * Improve logic * Pydantic change * Improve logic
This commit is contained in:
parent
b2bd2e1463
commit
c1a8553e07
@ -55,31 +55,39 @@ class TestConnectionStep(BaseModel):
|
||||
mandatory: bool = True
|
||||
|
||||
|
||||
def test_connection_steps(steps: List[TestConnectionStep]) -> None:
|
||||
class TestConnectionResult(BaseModel):
|
||||
failed: List[str] = []
|
||||
success: List[str] = []
|
||||
warning: List[str] = []
|
||||
|
||||
|
||||
def test_connection_steps(steps: List[TestConnectionStep]) -> str:
|
||||
"""
|
||||
Run all the function steps and raise any errors
|
||||
"""
|
||||
errors = {}
|
||||
|
||||
test_connection_result = TestConnectionResult()
|
||||
for step in steps:
|
||||
try:
|
||||
step.function()
|
||||
except Exception as exc:
|
||||
msg = f"Faild to {step.name}, {exc}"
|
||||
if step.mandatory:
|
||||
errors[
|
||||
step.name
|
||||
] = f"{msg} This is a mandatory step and we won't be able to extract necessary metadata"
|
||||
else:
|
||||
errors[
|
||||
step.name
|
||||
] = f"{msg} This is a optional. The ingestion will continue to work as expected"
|
||||
test_connection_result.success.append(f"'{step.name}': Pass")
|
||||
|
||||
if errors:
|
||||
raise SourceConnectionException(errors)
|
||||
except Exception:
|
||||
if step.mandatory:
|
||||
test_connection_result.failed.append(
|
||||
f"'{step.name}': This is a mandatory step and we won't be able to extract necessary metadata"
|
||||
)
|
||||
|
||||
else:
|
||||
test_connection_result.warning.append(
|
||||
f"'{step.name}': This is a optional and the ingestion will continue to work as expected"
|
||||
)
|
||||
|
||||
return test_connection_result
|
||||
|
||||
|
||||
@timeout(seconds=120)
|
||||
def test_connection_db_common(connection: Engine, steps=None) -> None:
|
||||
def test_connection_db_common(connection: Engine, steps=None) -> str:
|
||||
"""
|
||||
Default implementation is the engine to test.
|
||||
|
||||
@ -91,7 +99,7 @@ def test_connection_db_common(connection: Engine, steps=None) -> None:
|
||||
with connection.connect() as conn:
|
||||
conn.execute(ConnTestFn())
|
||||
if steps:
|
||||
test_connection_steps(steps)
|
||||
return test_connection_steps(steps)
|
||||
except SourceConnectionException as exc:
|
||||
raise exc
|
||||
except OperationalError as err:
|
||||
@ -100,3 +108,5 @@ def test_connection_db_common(connection: Engine, steps=None) -> None:
|
||||
except Exception as exc:
|
||||
msg = f"Unknown error connecting with {connection}: {exc}."
|
||||
raise SourceConnectionException(msg) from exc
|
||||
|
||||
return None
|
||||
|
@ -60,7 +60,7 @@ def get_connection(connection: AthenaConnection) -> Engine:
|
||||
)
|
||||
|
||||
|
||||
def test_connection(engine: Engine) -> None:
|
||||
def test_connection(engine: Engine) -> str:
|
||||
"""
|
||||
Test connection
|
||||
"""
|
||||
@ -80,4 +80,4 @@ def test_connection(engine: Engine) -> None:
|
||||
mandatory=False,
|
||||
),
|
||||
]
|
||||
test_connection_db_common(engine, steps)
|
||||
return test_connection_db_common(engine, steps)
|
||||
|
@ -45,7 +45,7 @@ def get_connection(connection: ClickhouseConnection) -> Engine:
|
||||
)
|
||||
|
||||
|
||||
def test_connection(engine: Engine) -> None:
|
||||
def test_connection(engine: Engine) -> str:
|
||||
"""
|
||||
Test Clickhouse connection
|
||||
"""
|
||||
@ -55,13 +55,24 @@ def test_connection(engine: Engine) -> None:
|
||||
return list(cursor.all())
|
||||
|
||||
inspector = inspect(engine)
|
||||
|
||||
def custom_executor_for_tables():
|
||||
schema_name = inspector.get_schema_names()
|
||||
|
||||
if schema_name:
|
||||
for schema in schema_name:
|
||||
if schema not in ("INFORMATION_SCHEMA", "system"):
|
||||
table_name = inspector.get_table_names(schema)
|
||||
return table_name
|
||||
return None
|
||||
|
||||
steps = [
|
||||
TestConnectionStep(
|
||||
function=inspector.get_schema_names,
|
||||
name="Get Schemas",
|
||||
),
|
||||
TestConnectionStep(
|
||||
function=inspector.get_table_names,
|
||||
function=partial(custom_executor_for_tables),
|
||||
name="Get Tables",
|
||||
),
|
||||
TestConnectionStep(
|
||||
@ -80,4 +91,4 @@ def test_connection(engine: Engine) -> None:
|
||||
),
|
||||
]
|
||||
|
||||
test_connection_db_common(engine, steps)
|
||||
return test_connection_db_common(engine, steps)
|
||||
|
@ -52,7 +52,7 @@ def get_connection(connection: DatabricksConnection) -> Engine:
|
||||
)
|
||||
|
||||
|
||||
def test_connection(engine: Engine) -> None:
|
||||
def test_connection(engine: Engine) -> str:
|
||||
"""
|
||||
Test connection
|
||||
"""
|
||||
@ -90,4 +90,4 @@ def test_connection(engine: Engine) -> None:
|
||||
),
|
||||
]
|
||||
|
||||
test_connection_db_common(engine, steps)
|
||||
return test_connection_db_common(engine, steps)
|
||||
|
@ -80,7 +80,7 @@ def get_connection(connection: HiveConnection) -> Engine:
|
||||
)
|
||||
|
||||
|
||||
def test_connection(engine: Engine) -> None:
|
||||
def test_connection(engine: Engine) -> str:
|
||||
"""
|
||||
Test connection
|
||||
"""
|
||||
@ -100,4 +100,4 @@ def test_connection(engine: Engine) -> None:
|
||||
mandatory=False,
|
||||
),
|
||||
]
|
||||
test_connection_db_common(engine, steps)
|
||||
return test_connection_db_common(engine, steps)
|
||||
|
@ -52,7 +52,7 @@ def get_connection(connection: MssqlConnection) -> Engine:
|
||||
)
|
||||
|
||||
|
||||
def test_connection(engine: MssqlConnection) -> None:
|
||||
def test_connection(engine: MssqlConnection) -> str:
|
||||
"""
|
||||
Test connection
|
||||
"""
|
||||
@ -95,4 +95,4 @@ def test_connection(engine: MssqlConnection) -> None:
|
||||
),
|
||||
]
|
||||
|
||||
test_connection_db_common(engine, steps)
|
||||
return test_connection_db_common(engine, steps)
|
||||
|
@ -12,6 +12,8 @@
|
||||
"""
|
||||
Source connection handler
|
||||
"""
|
||||
from functools import partial
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.inspection import inspect
|
||||
|
||||
@ -40,18 +42,28 @@ def get_connection(connection: MysqlConnection) -> Engine:
|
||||
)
|
||||
|
||||
|
||||
def test_connection(engine: Engine) -> None:
|
||||
def test_connection(engine: Engine) -> str:
|
||||
"""
|
||||
Test connection
|
||||
"""
|
||||
inspector = inspect(engine)
|
||||
|
||||
def custom_executor():
|
||||
schema_name = inspector.get_schema_names()
|
||||
if schema_name:
|
||||
for schema in schema_name:
|
||||
if schema not in ("information_schema", "performance_schema"):
|
||||
table_name = inspector.get_table_names(schema)
|
||||
return table_name
|
||||
return None
|
||||
|
||||
steps = [
|
||||
TestConnectionStep(
|
||||
function=inspector.get_schema_names,
|
||||
name="Get Schemas",
|
||||
),
|
||||
TestConnectionStep(
|
||||
function=inspector.get_table_names,
|
||||
function=partial(custom_executor),
|
||||
name="Get Tables",
|
||||
),
|
||||
TestConnectionStep(
|
||||
@ -60,4 +72,4 @@ def test_connection(engine: Engine) -> None:
|
||||
mandatory=False,
|
||||
),
|
||||
]
|
||||
test_connection_db_common(engine, steps)
|
||||
return test_connection_db_common(engine, steps)
|
||||
|
@ -47,7 +47,7 @@ def get_connection(connection: PostgresConnection) -> Engine:
|
||||
)
|
||||
|
||||
|
||||
def test_connection(engine: Engine) -> None:
|
||||
def test_connection(engine: Engine) -> str:
|
||||
"""
|
||||
Test connection
|
||||
"""
|
||||
@ -99,4 +99,4 @@ def test_connection(engine: Engine) -> None:
|
||||
),
|
||||
]
|
||||
|
||||
test_connection_db_common(engine, steps)
|
||||
return test_connection_db_common(engine, steps)
|
||||
|
@ -48,7 +48,7 @@ def get_connection(connection: RedshiftConnection) -> Engine:
|
||||
)
|
||||
|
||||
|
||||
def test_connection(engine: Engine) -> None:
|
||||
def test_connection(engine: Engine) -> str:
|
||||
"""
|
||||
Test connection
|
||||
"""
|
||||
@ -109,4 +109,4 @@ def test_connection(engine: Engine) -> None:
|
||||
),
|
||||
]
|
||||
|
||||
test_connection_db_common(engine, steps)
|
||||
return test_connection_db_common(engine, steps)
|
||||
|
@ -54,6 +54,16 @@ def test_source_connection(
|
||||
)
|
||||
test_connection_fn(connection)
|
||||
|
||||
if test_connection_fn(connection):
|
||||
msg = test_connection_fn(connection)
|
||||
if msg.failed:
|
||||
return ApiResponse.error(
|
||||
status=ApiResponse.STATUS_SERVER_ERROR,
|
||||
error=msg.json(),
|
||||
)
|
||||
elif msg.success:
|
||||
return ApiResponse.success({"message": msg.json()})
|
||||
|
||||
except SourceConnectionException as exc:
|
||||
msg = f"Connection error from [{connection}]: {exc}"
|
||||
logger.debug(traceback.format_exc())
|
||||
|
Loading…
x
Reference in New Issue
Block a user