Fix #4039 - Airflow REST Trigger & Test Connection (#4072)

Fix #4039 - Airflow REST Trigger & Test Connection (#4072)
This commit is contained in:
Pere Miquel Brull 2022-04-12 17:06:49 +02:00 committed by GitHub
parent 32004021c8
commit 87e5854f25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 252 additions and 75 deletions

View File

@ -35,7 +35,8 @@ def start_docker(docker, start_time, file_path, skip_sample_data):
logger.info("Ran docker compose for OpenMetadata successfully.")
if not skip_sample_data:
logger.info("Waiting for ingestion to complete..")
ingest_sample_data(docker)
wait_for_containers(docker)
ingest_sample_data()
metadata_config = OpenMetadataServerConfig(
hostPort="http://localhost:8585/api", authProvider="no-auth"
)
@ -208,28 +209,43 @@ def reset_db_om(docker):
click.secho("OpenMetadata Instance is not up and running", fg="yellow")
def ingest_sample_data(docker):
if docker.container.inspect("openmetadata_server").state.running:
base_url = "http://localhost:8080/api"
dags = ["sample_data", "sample_usage", "index_metadata"]
client_config = ClientConfig(
base_url=base_url,
auth_header="Authorization",
auth_token_mode="Basic",
access_token=to_native_string(
b64encode(b":".join(("admin".encode(), "admin".encode()))).strip()
),
def wait_for_containers(docker) -> None:
"""
Wait until docker containers are running
"""
while True:
running = (
docker.container.inspect("openmetadata_server").state.running
and docker.container.inspect("openmetadata_ingestion").state.running
)
client = REST(client_config)
if running:
break
else:
sys.stdout.write(".")
sys.stdout.flush()
time.sleep(5)
for dag in dags:
json_sample_data = {
"dag_run_id": "{}_{}".format(dag, datetime.now()),
}
client.post(
"/dags/{}/dagRuns".format(dag), data=json.dumps(json_sample_data)
)
else:
click.secho("OpenMetadata Instance is not up and running", fg="yellow")
def ingest_sample_data() -> None:
"""
Trigger sample data DAGs
"""
base_url = "http://localhost:8080/api"
dags = ["sample_data", "sample_usage", "index_metadata"]
client_config = ClientConfig(
base_url=base_url,
auth_header="Authorization",
auth_token_mode="Basic",
access_token=to_native_string(
b64encode(b":".join(("admin".encode(), "admin".encode()))).strip()
),
)
client = REST(client_config)
for dag in dags:
json_sample_data = {
"dag_run_id": "{}_{}".format(dag, datetime.now()),
}
client.post("/dags/{}/dagRuns".format(dag), data=json.dumps(json_sample_data))

View File

@ -71,7 +71,7 @@ class PostgresSource(SQLSource):
logger.info(f"Ingesting from database: {row[0]}")
self.config.database = row[0]
self.engine = get_engine(self.config)
self.engine = get_engine(self.config.serviceConnection)
self.connection = self.engine.connect()
yield inspect(self.engine)

View File

@ -82,7 +82,7 @@ class SnowflakeSource(SQLSource):
self.connection.execute(use_db_query)
logger.info(f"Ingesting from database: {row[1]}")
self.config.serviceConnection.__root__.config.database = row[1]
self.engine = get_engine(self.config)
self.engine = get_engine(self.config.serviceConnection)
yield inspect(self.engine)
def fetch_sample_data(self, schema: str, table: str) -> Optional[TableData]:

View File

@ -129,7 +129,7 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
self.service = get_database_service_or_create(config, metadata_config)
self.metadata = OpenMetadata(metadata_config)
self.status = SQLSourceStatus()
self.engine = get_engine(workflow_source=self.config)
self.engine = get_engine(service_connection=self.config.serviceConnection)
self._session = None # We will instantiate this just if needed
self._connection = None # Lazy init as well
self.data_profiler = None

View File

@ -16,29 +16,37 @@ import logging
from sqlalchemy import create_engine
from sqlalchemy.engine.base import Engine
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.session import Session
from metadata.generated.schema.entity.services.connections.connectionBasicType import (
ConnectionOptions,
)
from metadata.generated.schema.metadataIngestion.workflow import (
Source as WorkflowSource,
from metadata.generated.schema.entity.services.connections.serviceConnection import (
ServiceConnection,
)
from metadata.utils.source_connections import get_connection_args, get_connection_url
logger = logging.getLogger("Utils")
def get_engine(workflow_source: WorkflowSource, verbose: bool = False) -> Engine:
class SourceConnectionException(Exception):
"""
Raised when we cannot connect to the source
"""
def get_engine(service_connection: ServiceConnection, verbose: bool = False) -> Engine:
"""
Given an SQL configuration, build the SQLAlchemy Engine
"""
logger.info(f"Building Engine for {workflow_source.serviceName}...")
service_connection_config = workflow_source.serviceConnection.__root__.config
service_connection_config = service_connection.__root__.config
options = service_connection_config.connectionOptions
if not options:
options = ConnectionOptions()
engine = create_engine(
get_connection_url(service_connection_config),
**options.dict(),
@ -57,3 +65,22 @@ def create_and_bind_session(engine: Engine) -> Session:
session = sessionmaker()
session.configure(bind=engine)
return session()
def test_connection(engine: Engine) -> None:
"""
Test that we can connect to the source using the given engine
:param engine: Engine to test
:return: None or raise an exception if we cannot connect
"""
try:
with engine.connect() as _:
pass
except OperationalError as err:
raise SourceConnectionException(
f"Connection error for {engine} - {err}. Check the connection details."
)
except Exception as err:
raise SourceConnectionException(
f"Unknown error connecting with {engine} - {err}."
)

View File

@ -18,10 +18,7 @@
"hostPort": "localhost:3306",
"database": null,
"connectionOptions": null,
"connectionArguments": null,
"supportedPipelineTypes": [
"Metadata"
]
"connectionArguments": null
}
},
"sourceConfig": {

View File

@ -0,0 +1,10 @@
{
"serviceConnection": {
"config": {
"type": "MySQL",
"username": "openmetadata_user",
"password": "openmetadata_password",
"hostPort": "localhost:3306"
}
}
}

View File

@ -0,0 +1,3 @@
{
"workflow_name": "my_pipeline"
}

View File

@ -10,6 +10,7 @@
# limitations under the License.
from typing import Any, Dict, Optional
# TODO DELETE, STATUS (pick it up from airflow directly), LOG (just link v1), ENABLE DAG, DISABLE DAG (play pause)
APIS_METADATA = [
{
"name": "deploy_dag",
@ -20,7 +21,7 @@ APIS_METADATA = [
"post_arguments": [
{
"name": "workflow_config",
"description": "Workflow config to deploy",
"description": "Workflow config to deploy as IngestionPipeline",
"form_input_type": "file",
"required": True,
},
@ -39,6 +40,19 @@ APIS_METADATA = [
},
],
},
{
"name": "test_connection",
"description": "Test a connection",
"http_method": "POST",
"arguments": [],
"post_arguments": [
{
"name": "service_connection",
"description": "ServiceConnectionModel config to test",
"required": True,
},
],
},
]

View File

@ -14,17 +14,17 @@ Airflow REST API definition
import logging
import traceback
from typing import Optional
from airflow import settings
from airflow.api.common.experimental.trigger_dag import trigger_dag
from airflow.models import DagBag, DagModel
from airflow.utils import timezone
from airflow.www.app import csrf
from flask import request
from flask import Response, request
from flask_admin import expose as admin_expose
from flask_appbuilder import BaseView as AppBuilderBaseView
from flask_appbuilder import expose as app_builder_expose
from openmetadata.airflow.deploy import DagDeployer
from openmetadata.api.apis_metadata import APIS_METADATA, get_metadata_api
from openmetadata.api.config import (
AIRFLOW_VERSION,
@ -34,8 +34,14 @@ from openmetadata.api.config import (
)
from openmetadata.api.response import ApiResponse
from openmetadata.api.utils import jwt_token_secure
from openmetadata.operations.deploy import DagDeployer
from openmetadata.operations.test_connection import test_source_connection
from openmetadata.operations.trigger import trigger
from pydantic.error_wrappers import ValidationError
from metadata.generated.schema.entity.services.connections.serviceConnection import (
ServiceConnectionModel,
)
from metadata.generated.schema.entity.services.ingestionPipelines.ingestionPipeline import (
IngestionPipeline,
)
@ -53,7 +59,7 @@ class REST_API(AppBuilderBaseView):
return dagbag
@staticmethod
def get_request_arg(req, arg):
def get_request_arg(req, arg) -> Optional[str]:
return req.args.get(arg) or req.form.get(arg)
# '/' Endpoint where the Admin page is which allows you to view the APIs available and trigger them
@ -97,16 +103,16 @@ class REST_API(AppBuilderBaseView):
# Validate that the API is provided
if not api:
logging.warning("api argument not provided")
logging.warning("api argument not provided or empty")
return ApiResponse.bad_request("API should be provided")
api = api.strip().lower()
logging.info("REST_API.api() called (api: " + str(api) + ")")
logging.info(f"REST_API.api() called (api: {api})")
api_metadata = get_metadata_api(api)
if api_metadata is None:
logging.info("api '" + str(api) + "' not supported")
return ApiResponse.bad_request("API '" + str(api) + "' was not found")
logging.info(f"api [{api}] not supported")
return ApiResponse.bad_request(f"API [{api}] was not found")
# Deciding which function to use based off the API object that was requested.
# Some functions are custom and need to be manually routed to.
@ -114,14 +120,14 @@ class REST_API(AppBuilderBaseView):
return self.deploy_dag()
if api == "trigger_dag":
return self.trigger_dag()
# TODO DELETE, STATUS (pick it up from airflow directly), LOG (just link v1), ENABLE DAG, DISABLE DAG (play pause)
if api == "test_connection":
return self.test_connection()
raise ValueError(
f"Invalid api param {api}. Expected deploy_dag or trigger_dag."
)
def deploy_dag(self):
def deploy_dag(self) -> Response:
"""Custom Function for the deploy_dag API
Creates workflow dag based on workflow dag file and refreshes
the session
@ -140,42 +146,64 @@ class REST_API(AppBuilderBaseView):
return response
except ValidationError as err:
msg = f"Request Validation Error parsing payload {json_request} - {err}"
return ApiResponse.error(status=ApiResponse.STATUS_BAD_REQUEST, error=msg)
return ApiResponse.error(
status=ApiResponse.STATUS_BAD_REQUEST,
error=f"Request Validation Error parsing payload {json_request}. IngestionPipeline expected - {err}",
)
except Exception as err:
msg = f"Internal error deploying {json_request} - {err} - {traceback.format_exc()}"
return ApiResponse.error(status=ApiResponse.STATUS_SERVER_ERROR, error=msg)
return ApiResponse.error(
status=ApiResponse.STATUS_SERVER_ERROR,
error=f"Internal error deploying {json_request} - {err} - {traceback.format_exc()}",
)
@staticmethod
def trigger_dag():
def test_connection() -> Response:
"""
Given a WorkflowSource Schema, create the engine
and test the connection
"""
json_request = request.get_json()
try:
service_connection_model = ServiceConnectionModel(**json_request)
response = test_source_connection(service_connection_model)
return response
except ValidationError as err:
return ApiResponse.error(
status=ApiResponse.STATUS_BAD_REQUEST,
error=f"Request Validation Error parsing payload {json_request}. (Workflow)Source expected - {err}",
)
except Exception as err:
return ApiResponse.error(
status=ApiResponse.STATUS_SERVER_ERROR,
error=f"Internal error testing connection {json_request} - {err} - {traceback.format_exc()}",
)
@staticmethod
def trigger_dag() -> Response:
"""
Trigger a dag run
"""
logging.info("Running run_dag method")
request_json = request.get_json()
dag_id = request_json.get("workflow_name")
if not dag_id:
return ApiResponse.bad_request("workflow_name should be informed")
try:
request_json = request.get_json()
dag_id = request_json["workflow_name"]
run_id = request_json["run_id"] if "run_id" in request_json.keys() else None
dag_run = trigger_dag(
dag_id=dag_id,
run_id=run_id,
conf=None,
execution_date=timezone.utcnow(),
)
return ApiResponse.success(
{
"message": "Workflow [{}] has been triggered {}".format(
dag_id, dag_run
)
}
)
except Exception as e:
run_id = request_json.get("run_id")
response = trigger(dag_id, run_id)
return response
except Exception as exc:
logging.info(f"Failed to trigger dag {dag_id}")
return ApiResponse.error(
{
"message": "Workflow {} has filed to trigger due to {}".format(
dag_id, e
)
}
status=ApiResponse.STATUS_SERVER_ERROR,
error=f"Workflow {dag_id} has filed to trigger due to {exc} - {traceback.format_exc()}",
)

View File

@ -0,0 +1,47 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Module containing the logic to test a connection
from a WorkflowSource
"""
from flask import Response
from openmetadata.api.response import ApiResponse
from metadata.generated.schema.entity.services.connections.serviceConnection import (
ServiceConnectionModel,
)
from metadata.utils.engines import (
SourceConnectionException,
get_engine,
test_connection,
)
def test_source_connection(
service_connection_model: ServiceConnectionModel,
) -> Response:
"""
Create the engine and test the connection
:param workflow_source: Source to test
:return: None or exception
"""
engine = get_engine(service_connection_model.serviceConnection)
try:
test_connection(engine)
except SourceConnectionException as err:
return ApiResponse.error(
status=ApiResponse.STATUS_SERVER_ERROR,
error=f"Connection error from {engine} - {err}",
)
return ApiResponse.success({"message": f"Connection with {engine} successful!"})

View File

@ -0,0 +1,33 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Module containing the logic to trigger a DAG
"""
from typing import Optional
from airflow.api.common.experimental.trigger_dag import trigger_dag
from airflow.models import DagBag, DagModel
from airflow.utils import timezone
from flask import Response
from openmetadata.api.response import ApiResponse
def trigger(dag_id: str, run_id: Optional[str]) -> Response:
dag_run = trigger_dag(
dag_id=dag_id,
run_id=run_id,
conf=None,
execution_date=timezone.utcnow(),
)
return ApiResponse.success(
{"message": f"Workflow [{dag_id}] has been triggered {dag_run}"}
)

View File

@ -19,7 +19,9 @@ from collections import namedtuple
from openmetadata.workflows.ingestion.metadata import build_metadata_dag
from openmetadata.workflows.ingestion.usage import build_usage_dag
from metadata.generated.schema.operations.pipelines.airflowPipeline import PipelineType
from metadata.generated.schema.entity.services.ingestionPipelines.ingestionPipeline import (
PipelineType,
)
def register():
@ -43,4 +45,4 @@ def register():
build_registry = register()
build_registry.add(PipelineType.metadata.value)(build_metadata_dag)
build_registry.add(PipelineType.queryUsage.value)(build_usage_dag)
build_registry.add(PipelineType.usage.value)(build_usage_dag)