From dc4579c5642d4dd1bd5fc84dc066816ab3ba6ba8 Mon Sep 17 00:00:00 2001 From: Pere Miquel Brull Date: Wed, 13 Jul 2022 14:43:35 +0200 Subject: [PATCH] Clean dag_id if ingestion pipeline name has weird characters (#6028) --- .../src/openmetadata/api/rest_api.py | 47 +++++++++++-------- .../src/openmetadata/helpers.py | 23 +++++++++ .../src/openmetadata/operations/deploy.py | 25 ++++------ .../workflows/ingestion/common.py | 3 +- .../test_workflow_creation.py | 10 ++++ 5 files changed, 70 insertions(+), 38 deletions(-) create mode 100644 openmetadata-airflow-apis/src/openmetadata/helpers.py diff --git a/openmetadata-airflow-apis/src/openmetadata/api/rest_api.py b/openmetadata-airflow-apis/src/openmetadata/api/rest_api.py index 625fe068412..306d951e1e0 100644 --- a/openmetadata-airflow-apis/src/openmetadata/api/rest_api.py +++ b/openmetadata-airflow-apis/src/openmetadata/api/rest_api.py @@ -32,6 +32,7 @@ from openmetadata.api.config import ( ) from openmetadata.api.response import ApiResponse from openmetadata.api.utils import jwt_token_secure +from openmetadata.helpers import clean_dag_id from openmetadata.operations.delete import delete_dag_id from openmetadata.operations.deploy import DagDeployer from openmetadata.operations.last_dag_logs import last_dag_logs @@ -222,10 +223,12 @@ class REST_API(AppBuilderBaseView): """ request_json = request.get_json() - dag_id = request_json.get("workflow_name") - if not dag_id: + raw_dag_id = request_json.get("workflow_name") + if not raw_dag_id: return ApiResponse.bad_request("workflow_name should be informed") + dag_id = clean_dag_id(raw_dag_id) + try: run_id = request_json.get("run_id") response = trigger(dag_id, run_id) @@ -243,13 +246,12 @@ class REST_API(AppBuilderBaseView): """ Check the status of a DAG runs """ - dag_id: str = self.get_request_arg(request, "dag_id") + raw_dag_id: str = self.get_request_arg(request, "dag_id") - if not dag_id: - return ApiResponse.error( - status=ApiResponse.STATUS_BAD_REQUEST, - error=f"Missing dag_id argument in the request", - ) + if not raw_dag_id: + return ApiResponse.bad_request("Missing dag_id argument in the request") + + dag_id = clean_dag_id(raw_dag_id) try: return status(dag_id) @@ -270,11 +272,13 @@ class REST_API(AppBuilderBaseView): "workflow_name": "my_ingestion_pipeline3" } """ - dag_id: str = self.get_request_arg(request, "dag_id") + raw_dag_id: str = self.get_request_arg(request, "dag_id") - if not dag_id: + if not raw_dag_id: return ApiResponse.bad_request("workflow_name should be informed") + dag_id = clean_dag_id(raw_dag_id) + try: return delete_dag_id(dag_id) @@ -289,13 +293,12 @@ class REST_API(AppBuilderBaseView): """ Retrieve all logs from the task instances of a last DAG run """ - dag_id: str = self.get_request_arg(request, "dag_id") + raw_dag_id: str = self.get_request_arg(request, "dag_id") - if not dag_id: - return ApiResponse.error( - status=ApiResponse.STATUS_BAD_REQUEST, - error=f"Missing dag_id argument in the request", - ) + if not raw_dag_id: + ApiResponse.bad_request("Missing dag_id parameter in the request") + + dag_id = clean_dag_id(raw_dag_id) try: return last_dag_logs(dag_id) @@ -314,10 +317,12 @@ class REST_API(AppBuilderBaseView): """ request_json = request.get_json() - dag_id = request_json.get("dag_id") - if not dag_id: + raw_dag_id = request_json.get("dag_id") + if not raw_dag_id: return ApiResponse.bad_request(f"Missing dag_id argument in the request") + dag_id = clean_dag_id(raw_dag_id) + try: return enable_dag(dag_id) @@ -335,10 +340,12 @@ class REST_API(AppBuilderBaseView): """ request_json = request.get_json() - dag_id = request_json.get("dag_id") - if not dag_id: + raw_dag_id = request_json.get("dag_id") + if not raw_dag_id: return ApiResponse.bad_request(f"Missing dag_id argument in the request") + dag_id = clean_dag_id(raw_dag_id) + try: return disable_dag(dag_id) diff --git a/openmetadata-airflow-apis/src/openmetadata/helpers.py b/openmetadata-airflow-apis/src/openmetadata/helpers.py new file mode 100644 index 00000000000..76cb3325e22 --- /dev/null +++ b/openmetadata-airflow-apis/src/openmetadata/helpers.py @@ -0,0 +1,23 @@ +# Copyright 2021 Collate +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Helper functions +""" +import re + + +def clean_dag_id(raw_dag_id: str) -> str: + """ + Given a string we want to use as a dag_id, we should + give it a cleanup as Airflow does not support anything + that is not alphanumeric for the name + """ + return re.sub("[^0-9a-zA-Z-_]+", "_", raw_dag_id) diff --git a/openmetadata-airflow-apis/src/openmetadata/operations/deploy.py b/openmetadata-airflow-apis/src/openmetadata/operations/deploy.py index 66564f37927..83792181c4e 100644 --- a/openmetadata-airflow-apis/src/openmetadata/operations/deploy.py +++ b/openmetadata-airflow-apis/src/openmetadata/operations/deploy.py @@ -26,6 +26,7 @@ from openmetadata.api.config import ( ) from openmetadata.api.response import ApiResponse from openmetadata.api.utils import import_path +from openmetadata.helpers import clean_dag_id from metadata.generated.schema.entity.services.ingestionPipelines.ingestionPipeline import ( IngestionPipeline, @@ -53,6 +54,7 @@ class DagDeployer: self.ingestion_pipeline = ingestion_pipeline self.dag_bag = dag_bag + self.dag_id = clean_dag_id(self.ingestion_pipeline.name.__root__) def store_airflow_pipeline_config( self, dag_config_file_path: Path @@ -74,9 +76,7 @@ class DagDeployer: the rendered strings """ - dag_py_file = ( - Path(AIRFLOW_DAGS_FOLDER) / f"{self.ingestion_pipeline.name.__root__}.py" - ) + dag_py_file = Path(AIRFLOW_DAGS_FOLDER) / f"{self.dag_id}.py" # Open the template and render with open(DAG_RUNNER_TEMPLATE, "r") as f: @@ -113,28 +113,22 @@ class DagDeployer: logging.info("dagbag size {}".format(self.dag_bag.size())) found_dags = self.dag_bag.process_file(dag_py_file) logging.info("processed dags {}".format(found_dags)) - dag = self.dag_bag.get_dag( - self.ingestion_pipeline.name.__root__, session=session - ) + dag = self.dag_bag.get_dag(self.dag_id, session=session) SerializedDagModel.write_dag(dag) dag.sync_to_db(session=session) dag_model = ( - session.query(DagModel) - .filter(DagModel.dag_id == self.ingestion_pipeline.name.__root__) - .first() + session.query(DagModel).filter(DagModel.dag_id == self.dag_id).first() ) logging.info("dag_model:" + str(dag_model)) return ApiResponse.success( - { - "message": f"Workflow [{self.ingestion_pipeline.name.__root__}] has been created" - } + {"message": f"Workflow [{self.dag_id}] has been created"} ) except Exception as exc: logging.info(f"Failed to serialize the dag {exc}") return ApiResponse.server_error( { - "message": f"Workflow [{self.ingestion_pipeline.name.__root__}] failed to refresh due to [{exc}] " + "message": f"Workflow [{self.dag_id}] failed to refresh due to [{exc}] " + f"- {traceback.format_exc()}" } ) @@ -143,10 +137,7 @@ class DagDeployer: """ Run all methods to deploy the DAG """ - dag_config_file_path = ( - Path(DAG_GENERATED_CONFIGS) - / f"{self.ingestion_pipeline.name.__root__}.json" - ) + dag_config_file_path = Path(DAG_GENERATED_CONFIGS) / f"{self.dag_id}.json" logging.info(f"Config file under {dag_config_file_path}") dag_runner_config = self.store_airflow_pipeline_config(dag_config_file_path) diff --git a/openmetadata-airflow-apis/src/openmetadata/workflows/ingestion/common.py b/openmetadata-airflow-apis/src/openmetadata/workflows/ingestion/common.py index 2b8f78bb7b2..a2a06f7b524 100644 --- a/openmetadata-airflow-apis/src/openmetadata/workflows/ingestion/common.py +++ b/openmetadata-airflow-apis/src/openmetadata/workflows/ingestion/common.py @@ -17,6 +17,7 @@ from typing import Callable, Optional, Union import airflow from airflow import DAG +from openmetadata.helpers import clean_dag_id from metadata.generated.schema.entity.services.dashboardService import DashboardService from metadata.generated.schema.entity.services.databaseService import DatabaseService @@ -170,7 +171,7 @@ def build_dag_configs(ingestion_pipeline: IngestionPipeline) -> dict: :return: dict to use as kwargs """ return { - "dag_id": ingestion_pipeline.name.__root__, + "dag_id": clean_dag_id(ingestion_pipeline.name.__root__), "description": ingestion_pipeline.description, "start_date": ingestion_pipeline.airflowConfig.startDate.__root__ if ingestion_pipeline.airflowConfig.startDate diff --git a/openmetadata-airflow-apis/tests/unit/ingestion_pipeline/test_workflow_creation.py b/openmetadata-airflow-apis/tests/unit/ingestion_pipeline/test_workflow_creation.py index c2f1fb53d2f..751637b31dd 100644 --- a/openmetadata-airflow-apis/tests/unit/ingestion_pipeline/test_workflow_creation.py +++ b/openmetadata-airflow-apis/tests/unit/ingestion_pipeline/test_workflow_creation.py @@ -16,6 +16,7 @@ import json import uuid from unittest import TestCase +from openmetadata.helpers import clean_dag_id from openmetadata.workflows.ingestion.metadata import build_metadata_workflow_config from openmetadata.workflows.ingestion.profiler import build_profiler_workflow_config from openmetadata.workflows.ingestion.usage import build_usage_workflow_config @@ -123,6 +124,15 @@ class OMetaServiceTest(TestCase): hard_delete=True, ) + def test_clean_dag_id(self): + """ + Validate dag_id clean + """ + self.assertEqual(clean_dag_id("hello"), "hello") + self.assertEqual(clean_dag_id("hello(world)"), "hello_world_") + self.assertEqual(clean_dag_id("hello-world"), "hello-world") + self.assertEqual(clean_dag_id("%%&^++hello__"), "_hello__") + def test_ingestion_workflow(self): """ Validate that the ingestionPipeline can be parsed