Fix #4039 - Airflow REST Trigger & Test Connection (#4072)

Fix #4039 - Airflow REST Trigger & Test Connection (#4072)
This commit is contained in:
Pere Miquel Brull 2022-04-12 17:06:49 +02:00 committed by GitHub
parent 32004021c8
commit 87e5854f25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 252 additions and 75 deletions

View File

@ -35,7 +35,8 @@ def start_docker(docker, start_time, file_path, skip_sample_data):
logger.info("Ran docker compose for OpenMetadata successfully.") logger.info("Ran docker compose for OpenMetadata successfully.")
if not skip_sample_data: if not skip_sample_data:
logger.info("Waiting for ingestion to complete..") logger.info("Waiting for ingestion to complete..")
ingest_sample_data(docker) wait_for_containers(docker)
ingest_sample_data()
metadata_config = OpenMetadataServerConfig( metadata_config = OpenMetadataServerConfig(
hostPort="http://localhost:8585/api", authProvider="no-auth" hostPort="http://localhost:8585/api", authProvider="no-auth"
) )
@ -208,28 +209,43 @@ def reset_db_om(docker):
click.secho("OpenMetadata Instance is not up and running", fg="yellow") click.secho("OpenMetadata Instance is not up and running", fg="yellow")
def ingest_sample_data(docker): def wait_for_containers(docker) -> None:
if docker.container.inspect("openmetadata_server").state.running: """
base_url = "http://localhost:8080/api" Wait until docker containers are running
dags = ["sample_data", "sample_usage", "index_metadata"] """
while True:
client_config = ClientConfig( running = (
base_url=base_url, docker.container.inspect("openmetadata_server").state.running
auth_header="Authorization", and docker.container.inspect("openmetadata_ingestion").state.running
auth_token_mode="Basic",
access_token=to_native_string(
b64encode(b":".join(("admin".encode(), "admin".encode()))).strip()
),
) )
client = REST(client_config) if running:
break
else:
sys.stdout.write(".")
sys.stdout.flush()
time.sleep(5)
for dag in dags:
json_sample_data = {
"dag_run_id": "{}_{}".format(dag, datetime.now()),
}
client.post(
"/dags/{}/dagRuns".format(dag), data=json.dumps(json_sample_data)
)
else: def ingest_sample_data() -> None:
click.secho("OpenMetadata Instance is not up and running", fg="yellow") """
Trigger sample data DAGs
"""
base_url = "http://localhost:8080/api"
dags = ["sample_data", "sample_usage", "index_metadata"]
client_config = ClientConfig(
base_url=base_url,
auth_header="Authorization",
auth_token_mode="Basic",
access_token=to_native_string(
b64encode(b":".join(("admin".encode(), "admin".encode()))).strip()
),
)
client = REST(client_config)
for dag in dags:
json_sample_data = {
"dag_run_id": "{}_{}".format(dag, datetime.now()),
}
client.post("/dags/{}/dagRuns".format(dag), data=json.dumps(json_sample_data))

View File

@ -71,7 +71,7 @@ class PostgresSource(SQLSource):
logger.info(f"Ingesting from database: {row[0]}") logger.info(f"Ingesting from database: {row[0]}")
self.config.database = row[0] self.config.database = row[0]
self.engine = get_engine(self.config) self.engine = get_engine(self.config.serviceConnection)
self.connection = self.engine.connect() self.connection = self.engine.connect()
yield inspect(self.engine) yield inspect(self.engine)

View File

@ -82,7 +82,7 @@ class SnowflakeSource(SQLSource):
self.connection.execute(use_db_query) self.connection.execute(use_db_query)
logger.info(f"Ingesting from database: {row[1]}") logger.info(f"Ingesting from database: {row[1]}")
self.config.serviceConnection.__root__.config.database = row[1] self.config.serviceConnection.__root__.config.database = row[1]
self.engine = get_engine(self.config) self.engine = get_engine(self.config.serviceConnection)
yield inspect(self.engine) yield inspect(self.engine)
def fetch_sample_data(self, schema: str, table: str) -> Optional[TableData]: def fetch_sample_data(self, schema: str, table: str) -> Optional[TableData]:

View File

@ -129,7 +129,7 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
self.service = get_database_service_or_create(config, metadata_config) self.service = get_database_service_or_create(config, metadata_config)
self.metadata = OpenMetadata(metadata_config) self.metadata = OpenMetadata(metadata_config)
self.status = SQLSourceStatus() self.status = SQLSourceStatus()
self.engine = get_engine(workflow_source=self.config) self.engine = get_engine(service_connection=self.config.serviceConnection)
self._session = None # We will instantiate this just if needed self._session = None # We will instantiate this just if needed
self._connection = None # Lazy init as well self._connection = None # Lazy init as well
self.data_profiler = None self.data_profiler = None

View File

@ -16,29 +16,37 @@ import logging
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.engine.base import Engine from sqlalchemy.engine.base import Engine
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from metadata.generated.schema.entity.services.connections.connectionBasicType import ( from metadata.generated.schema.entity.services.connections.connectionBasicType import (
ConnectionOptions, ConnectionOptions,
) )
from metadata.generated.schema.metadataIngestion.workflow import ( from metadata.generated.schema.entity.services.connections.serviceConnection import (
Source as WorkflowSource, ServiceConnection,
) )
from metadata.utils.source_connections import get_connection_args, get_connection_url from metadata.utils.source_connections import get_connection_args, get_connection_url
logger = logging.getLogger("Utils") logger = logging.getLogger("Utils")
def get_engine(workflow_source: WorkflowSource, verbose: bool = False) -> Engine: class SourceConnectionException(Exception):
"""
Raised when we cannot connect to the source
"""
def get_engine(service_connection: ServiceConnection, verbose: bool = False) -> Engine:
""" """
Given an SQL configuration, build the SQLAlchemy Engine Given an SQL configuration, build the SQLAlchemy Engine
""" """
logger.info(f"Building Engine for {workflow_source.serviceName}...") service_connection_config = service_connection.__root__.config
service_connection_config = workflow_source.serviceConnection.__root__.config
options = service_connection_config.connectionOptions options = service_connection_config.connectionOptions
if not options: if not options:
options = ConnectionOptions() options = ConnectionOptions()
engine = create_engine( engine = create_engine(
get_connection_url(service_connection_config), get_connection_url(service_connection_config),
**options.dict(), **options.dict(),
@ -57,3 +65,22 @@ def create_and_bind_session(engine: Engine) -> Session:
session = sessionmaker() session = sessionmaker()
session.configure(bind=engine) session.configure(bind=engine)
return session() return session()
def test_connection(engine: Engine) -> None:
"""
Test that we can connect to the source using the given engine
:param engine: Engine to test
:return: None or raise an exception if we cannot connect
"""
try:
with engine.connect() as _:
pass
except OperationalError as err:
raise SourceConnectionException(
f"Connection error for {engine} - {err}. Check the connection details."
)
except Exception as err:
raise SourceConnectionException(
f"Unknown error connecting with {engine} - {err}."
)

View File

@ -18,10 +18,7 @@
"hostPort": "localhost:3306", "hostPort": "localhost:3306",
"database": null, "database": null,
"connectionOptions": null, "connectionOptions": null,
"connectionArguments": null, "connectionArguments": null
"supportedPipelineTypes": [
"Metadata"
]
} }
}, },
"sourceConfig": { "sourceConfig": {

View File

@ -0,0 +1,10 @@
{
"serviceConnection": {
"config": {
"type": "MySQL",
"username": "openmetadata_user",
"password": "openmetadata_password",
"hostPort": "localhost:3306"
}
}
}

View File

@ -0,0 +1,3 @@
{
"workflow_name": "my_pipeline"
}

View File

@ -10,6 +10,7 @@
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
# TODO DELETE, STATUS (pick it up from airflow directly), LOG (just link v1), ENABLE DAG, DISABLE DAG (play pause)
APIS_METADATA = [ APIS_METADATA = [
{ {
"name": "deploy_dag", "name": "deploy_dag",
@ -20,7 +21,7 @@ APIS_METADATA = [
"post_arguments": [ "post_arguments": [
{ {
"name": "workflow_config", "name": "workflow_config",
"description": "Workflow config to deploy", "description": "Workflow config to deploy as IngestionPipeline",
"form_input_type": "file", "form_input_type": "file",
"required": True, "required": True,
}, },
@ -39,6 +40,19 @@ APIS_METADATA = [
}, },
], ],
}, },
{
"name": "test_connection",
"description": "Test a connection",
"http_method": "POST",
"arguments": [],
"post_arguments": [
{
"name": "service_connection",
"description": "ServiceConnectionModel config to test",
"required": True,
},
],
},
] ]

View File

@ -14,17 +14,17 @@ Airflow REST API definition
import logging import logging
import traceback import traceback
from typing import Optional
from airflow import settings from airflow import settings
from airflow.api.common.experimental.trigger_dag import trigger_dag from airflow.api.common.experimental.trigger_dag import trigger_dag
from airflow.models import DagBag, DagModel from airflow.models import DagBag, DagModel
from airflow.utils import timezone from airflow.utils import timezone
from airflow.www.app import csrf from airflow.www.app import csrf
from flask import request from flask import Response, request
from flask_admin import expose as admin_expose from flask_admin import expose as admin_expose
from flask_appbuilder import BaseView as AppBuilderBaseView from flask_appbuilder import BaseView as AppBuilderBaseView
from flask_appbuilder import expose as app_builder_expose from flask_appbuilder import expose as app_builder_expose
from openmetadata.airflow.deploy import DagDeployer
from openmetadata.api.apis_metadata import APIS_METADATA, get_metadata_api from openmetadata.api.apis_metadata import APIS_METADATA, get_metadata_api
from openmetadata.api.config import ( from openmetadata.api.config import (
AIRFLOW_VERSION, AIRFLOW_VERSION,
@ -34,8 +34,14 @@ from openmetadata.api.config import (
) )
from openmetadata.api.response import ApiResponse from openmetadata.api.response import ApiResponse
from openmetadata.api.utils import jwt_token_secure from openmetadata.api.utils import jwt_token_secure
from openmetadata.operations.deploy import DagDeployer
from openmetadata.operations.test_connection import test_source_connection
from openmetadata.operations.trigger import trigger
from pydantic.error_wrappers import ValidationError from pydantic.error_wrappers import ValidationError
from metadata.generated.schema.entity.services.connections.serviceConnection import (
ServiceConnectionModel,
)
from metadata.generated.schema.entity.services.ingestionPipelines.ingestionPipeline import ( from metadata.generated.schema.entity.services.ingestionPipelines.ingestionPipeline import (
IngestionPipeline, IngestionPipeline,
) )
@ -53,7 +59,7 @@ class REST_API(AppBuilderBaseView):
return dagbag return dagbag
@staticmethod @staticmethod
def get_request_arg(req, arg): def get_request_arg(req, arg) -> Optional[str]:
return req.args.get(arg) or req.form.get(arg) return req.args.get(arg) or req.form.get(arg)
# '/' Endpoint where the Admin page is which allows you to view the APIs available and trigger them # '/' Endpoint where the Admin page is which allows you to view the APIs available and trigger them
@ -97,16 +103,16 @@ class REST_API(AppBuilderBaseView):
# Validate that the API is provided # Validate that the API is provided
if not api: if not api:
logging.warning("api argument not provided") logging.warning("api argument not provided or empty")
return ApiResponse.bad_request("API should be provided") return ApiResponse.bad_request("API should be provided")
api = api.strip().lower() api = api.strip().lower()
logging.info("REST_API.api() called (api: " + str(api) + ")") logging.info(f"REST_API.api() called (api: {api})")
api_metadata = get_metadata_api(api) api_metadata = get_metadata_api(api)
if api_metadata is None: if api_metadata is None:
logging.info("api '" + str(api) + "' not supported") logging.info(f"api [{api}] not supported")
return ApiResponse.bad_request("API '" + str(api) + "' was not found") return ApiResponse.bad_request(f"API [{api}] was not found")
# Deciding which function to use based off the API object that was requested. # Deciding which function to use based off the API object that was requested.
# Some functions are custom and need to be manually routed to. # Some functions are custom and need to be manually routed to.
@ -114,14 +120,14 @@ class REST_API(AppBuilderBaseView):
return self.deploy_dag() return self.deploy_dag()
if api == "trigger_dag": if api == "trigger_dag":
return self.trigger_dag() return self.trigger_dag()
if api == "test_connection":
# TODO DELETE, STATUS (pick it up from airflow directly), LOG (just link v1), ENABLE DAG, DISABLE DAG (play pause) return self.test_connection()
raise ValueError( raise ValueError(
f"Invalid api param {api}. Expected deploy_dag or trigger_dag." f"Invalid api param {api}. Expected deploy_dag or trigger_dag."
) )
def deploy_dag(self): def deploy_dag(self) -> Response:
"""Custom Function for the deploy_dag API """Custom Function for the deploy_dag API
Creates workflow dag based on workflow dag file and refreshes Creates workflow dag based on workflow dag file and refreshes
the session the session
@ -140,42 +146,64 @@ class REST_API(AppBuilderBaseView):
return response return response
except ValidationError as err: except ValidationError as err:
msg = f"Request Validation Error parsing payload {json_request} - {err}" return ApiResponse.error(
return ApiResponse.error(status=ApiResponse.STATUS_BAD_REQUEST, error=msg) status=ApiResponse.STATUS_BAD_REQUEST,
error=f"Request Validation Error parsing payload {json_request}. IngestionPipeline expected - {err}",
)
except Exception as err: except Exception as err:
msg = f"Internal error deploying {json_request} - {err} - {traceback.format_exc()}" return ApiResponse.error(
return ApiResponse.error(status=ApiResponse.STATUS_SERVER_ERROR, error=msg) status=ApiResponse.STATUS_SERVER_ERROR,
error=f"Internal error deploying {json_request} - {err} - {traceback.format_exc()}",
)
@staticmethod @staticmethod
def trigger_dag(): def test_connection() -> Response:
"""
Given a WorkflowSource Schema, create the engine
and test the connection
"""
json_request = request.get_json()
try:
service_connection_model = ServiceConnectionModel(**json_request)
response = test_source_connection(service_connection_model)
return response
except ValidationError as err:
return ApiResponse.error(
status=ApiResponse.STATUS_BAD_REQUEST,
error=f"Request Validation Error parsing payload {json_request}. (Workflow)Source expected - {err}",
)
except Exception as err:
return ApiResponse.error(
status=ApiResponse.STATUS_SERVER_ERROR,
error=f"Internal error testing connection {json_request} - {err} - {traceback.format_exc()}",
)
@staticmethod
def trigger_dag() -> Response:
""" """
Trigger a dag run Trigger a dag run
""" """
logging.info("Running run_dag method") logging.info("Running run_dag method")
request_json = request.get_json()
dag_id = request_json.get("workflow_name")
if not dag_id:
return ApiResponse.bad_request("workflow_name should be informed")
try: try:
request_json = request.get_json() run_id = request_json.get("run_id")
dag_id = request_json["workflow_name"] response = trigger(dag_id, run_id)
run_id = request_json["run_id"] if "run_id" in request_json.keys() else None
dag_run = trigger_dag( return response
dag_id=dag_id,
run_id=run_id, except Exception as exc:
conf=None,
execution_date=timezone.utcnow(),
)
return ApiResponse.success(
{
"message": "Workflow [{}] has been triggered {}".format(
dag_id, dag_run
)
}
)
except Exception as e:
logging.info(f"Failed to trigger dag {dag_id}") logging.info(f"Failed to trigger dag {dag_id}")
return ApiResponse.error( return ApiResponse.error(
{ status=ApiResponse.STATUS_SERVER_ERROR,
"message": "Workflow {} has filed to trigger due to {}".format( error=f"Workflow {dag_id} has filed to trigger due to {exc} - {traceback.format_exc()}",
dag_id, e
)
}
) )

View File

@ -0,0 +1,47 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Module containing the logic to test a connection
from a WorkflowSource
"""
from flask import Response
from openmetadata.api.response import ApiResponse
from metadata.generated.schema.entity.services.connections.serviceConnection import (
ServiceConnectionModel,
)
from metadata.utils.engines import (
SourceConnectionException,
get_engine,
test_connection,
)
def test_source_connection(
service_connection_model: ServiceConnectionModel,
) -> Response:
"""
Create the engine and test the connection
:param workflow_source: Source to test
:return: None or exception
"""
engine = get_engine(service_connection_model.serviceConnection)
try:
test_connection(engine)
except SourceConnectionException as err:
return ApiResponse.error(
status=ApiResponse.STATUS_SERVER_ERROR,
error=f"Connection error from {engine} - {err}",
)
return ApiResponse.success({"message": f"Connection with {engine} successful!"})

View File

@ -0,0 +1,33 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Module containing the logic to trigger a DAG
"""
from typing import Optional
from airflow.api.common.experimental.trigger_dag import trigger_dag
from airflow.models import DagBag, DagModel
from airflow.utils import timezone
from flask import Response
from openmetadata.api.response import ApiResponse
def trigger(dag_id: str, run_id: Optional[str]) -> Response:
dag_run = trigger_dag(
dag_id=dag_id,
run_id=run_id,
conf=None,
execution_date=timezone.utcnow(),
)
return ApiResponse.success(
{"message": f"Workflow [{dag_id}] has been triggered {dag_run}"}
)

View File

@ -19,7 +19,9 @@ from collections import namedtuple
from openmetadata.workflows.ingestion.metadata import build_metadata_dag from openmetadata.workflows.ingestion.metadata import build_metadata_dag
from openmetadata.workflows.ingestion.usage import build_usage_dag from openmetadata.workflows.ingestion.usage import build_usage_dag
from metadata.generated.schema.operations.pipelines.airflowPipeline import PipelineType from metadata.generated.schema.entity.services.ingestionPipelines.ingestionPipeline import (
PipelineType,
)
def register(): def register():
@ -43,4 +45,4 @@ def register():
build_registry = register() build_registry = register()
build_registry.add(PipelineType.metadata.value)(build_metadata_dag) build_registry.add(PipelineType.metadata.value)(build_metadata_dag)
build_registry.add(PipelineType.queryUsage.value)(build_usage_dag) build_registry.add(PipelineType.usage.value)(build_usage_dag)