Clean dag_id if ingestion pipeline name has weird characters (#6028)

This commit is contained in:
Pere Miquel Brull 2022-07-13 14:43:35 +02:00 committed by GitHub
parent 048801067a
commit dc4579c564
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 70 additions and 38 deletions

View File

@ -32,6 +32,7 @@ from openmetadata.api.config import (
) )
from openmetadata.api.response import ApiResponse from openmetadata.api.response import ApiResponse
from openmetadata.api.utils import jwt_token_secure from openmetadata.api.utils import jwt_token_secure
from openmetadata.helpers import clean_dag_id
from openmetadata.operations.delete import delete_dag_id from openmetadata.operations.delete import delete_dag_id
from openmetadata.operations.deploy import DagDeployer from openmetadata.operations.deploy import DagDeployer
from openmetadata.operations.last_dag_logs import last_dag_logs from openmetadata.operations.last_dag_logs import last_dag_logs
@ -222,10 +223,12 @@ class REST_API(AppBuilderBaseView):
""" """
request_json = request.get_json() request_json = request.get_json()
dag_id = request_json.get("workflow_name") raw_dag_id = request_json.get("workflow_name")
if not dag_id: if not raw_dag_id:
return ApiResponse.bad_request("workflow_name should be informed") return ApiResponse.bad_request("workflow_name should be informed")
dag_id = clean_dag_id(raw_dag_id)
try: try:
run_id = request_json.get("run_id") run_id = request_json.get("run_id")
response = trigger(dag_id, run_id) response = trigger(dag_id, run_id)
@ -243,13 +246,12 @@ class REST_API(AppBuilderBaseView):
""" """
Check the status of a DAG runs Check the status of a DAG runs
""" """
dag_id: str = self.get_request_arg(request, "dag_id") raw_dag_id: str = self.get_request_arg(request, "dag_id")
if not dag_id: if not raw_dag_id:
return ApiResponse.error( return ApiResponse.bad_request("Missing dag_id argument in the request")
status=ApiResponse.STATUS_BAD_REQUEST,
error=f"Missing dag_id argument in the request", dag_id = clean_dag_id(raw_dag_id)
)
try: try:
return status(dag_id) return status(dag_id)
@ -270,11 +272,13 @@ class REST_API(AppBuilderBaseView):
"workflow_name": "my_ingestion_pipeline3" "workflow_name": "my_ingestion_pipeline3"
} }
""" """
dag_id: str = self.get_request_arg(request, "dag_id") raw_dag_id: str = self.get_request_arg(request, "dag_id")
if not dag_id: if not raw_dag_id:
return ApiResponse.bad_request("workflow_name should be informed") return ApiResponse.bad_request("workflow_name should be informed")
dag_id = clean_dag_id(raw_dag_id)
try: try:
return delete_dag_id(dag_id) return delete_dag_id(dag_id)
@ -289,13 +293,12 @@ class REST_API(AppBuilderBaseView):
""" """
Retrieve all logs from the task instances of a last DAG run Retrieve all logs from the task instances of a last DAG run
""" """
dag_id: str = self.get_request_arg(request, "dag_id") raw_dag_id: str = self.get_request_arg(request, "dag_id")
if not dag_id: if not raw_dag_id:
return ApiResponse.error( ApiResponse.bad_request("Missing dag_id parameter in the request")
status=ApiResponse.STATUS_BAD_REQUEST,
error=f"Missing dag_id argument in the request", dag_id = clean_dag_id(raw_dag_id)
)
try: try:
return last_dag_logs(dag_id) return last_dag_logs(dag_id)
@ -314,10 +317,12 @@ class REST_API(AppBuilderBaseView):
""" """
request_json = request.get_json() request_json = request.get_json()
dag_id = request_json.get("dag_id") raw_dag_id = request_json.get("dag_id")
if not dag_id: if not raw_dag_id:
return ApiResponse.bad_request(f"Missing dag_id argument in the request") return ApiResponse.bad_request(f"Missing dag_id argument in the request")
dag_id = clean_dag_id(raw_dag_id)
try: try:
return enable_dag(dag_id) return enable_dag(dag_id)
@ -335,10 +340,12 @@ class REST_API(AppBuilderBaseView):
""" """
request_json = request.get_json() request_json = request.get_json()
dag_id = request_json.get("dag_id") raw_dag_id = request_json.get("dag_id")
if not dag_id: if not raw_dag_id:
return ApiResponse.bad_request(f"Missing dag_id argument in the request") return ApiResponse.bad_request(f"Missing dag_id argument in the request")
dag_id = clean_dag_id(raw_dag_id)
try: try:
return disable_dag(dag_id) return disable_dag(dag_id)

View File

@ -0,0 +1,23 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Helper functions
"""
import re
def clean_dag_id(raw_dag_id: str) -> str:
"""
Given a string we want to use as a dag_id, we should
give it a cleanup as Airflow does not support anything
that is not alphanumeric for the name
"""
return re.sub("[^0-9a-zA-Z-_]+", "_", raw_dag_id)

View File

@ -26,6 +26,7 @@ from openmetadata.api.config import (
) )
from openmetadata.api.response import ApiResponse from openmetadata.api.response import ApiResponse
from openmetadata.api.utils import import_path from openmetadata.api.utils import import_path
from openmetadata.helpers import clean_dag_id
from metadata.generated.schema.entity.services.ingestionPipelines.ingestionPipeline import ( from metadata.generated.schema.entity.services.ingestionPipelines.ingestionPipeline import (
IngestionPipeline, IngestionPipeline,
@ -53,6 +54,7 @@ class DagDeployer:
self.ingestion_pipeline = ingestion_pipeline self.ingestion_pipeline = ingestion_pipeline
self.dag_bag = dag_bag self.dag_bag = dag_bag
self.dag_id = clean_dag_id(self.ingestion_pipeline.name.__root__)
def store_airflow_pipeline_config( def store_airflow_pipeline_config(
self, dag_config_file_path: Path self, dag_config_file_path: Path
@ -74,9 +76,7 @@ class DagDeployer:
the rendered strings the rendered strings
""" """
dag_py_file = ( dag_py_file = Path(AIRFLOW_DAGS_FOLDER) / f"{self.dag_id}.py"
Path(AIRFLOW_DAGS_FOLDER) / f"{self.ingestion_pipeline.name.__root__}.py"
)
# Open the template and render # Open the template and render
with open(DAG_RUNNER_TEMPLATE, "r") as f: with open(DAG_RUNNER_TEMPLATE, "r") as f:
@ -113,28 +113,22 @@ class DagDeployer:
logging.info("dagbag size {}".format(self.dag_bag.size())) logging.info("dagbag size {}".format(self.dag_bag.size()))
found_dags = self.dag_bag.process_file(dag_py_file) found_dags = self.dag_bag.process_file(dag_py_file)
logging.info("processed dags {}".format(found_dags)) logging.info("processed dags {}".format(found_dags))
dag = self.dag_bag.get_dag( dag = self.dag_bag.get_dag(self.dag_id, session=session)
self.ingestion_pipeline.name.__root__, session=session
)
SerializedDagModel.write_dag(dag) SerializedDagModel.write_dag(dag)
dag.sync_to_db(session=session) dag.sync_to_db(session=session)
dag_model = ( dag_model = (
session.query(DagModel) session.query(DagModel).filter(DagModel.dag_id == self.dag_id).first()
.filter(DagModel.dag_id == self.ingestion_pipeline.name.__root__)
.first()
) )
logging.info("dag_model:" + str(dag_model)) logging.info("dag_model:" + str(dag_model))
return ApiResponse.success( return ApiResponse.success(
{ {"message": f"Workflow [{self.dag_id}] has been created"}
"message": f"Workflow [{self.ingestion_pipeline.name.__root__}] has been created"
}
) )
except Exception as exc: except Exception as exc:
logging.info(f"Failed to serialize the dag {exc}") logging.info(f"Failed to serialize the dag {exc}")
return ApiResponse.server_error( return ApiResponse.server_error(
{ {
"message": f"Workflow [{self.ingestion_pipeline.name.__root__}] failed to refresh due to [{exc}] " "message": f"Workflow [{self.dag_id}] failed to refresh due to [{exc}] "
+ f"- {traceback.format_exc()}" + f"- {traceback.format_exc()}"
} }
) )
@ -143,10 +137,7 @@ class DagDeployer:
""" """
Run all methods to deploy the DAG Run all methods to deploy the DAG
""" """
dag_config_file_path = ( dag_config_file_path = Path(DAG_GENERATED_CONFIGS) / f"{self.dag_id}.json"
Path(DAG_GENERATED_CONFIGS)
/ f"{self.ingestion_pipeline.name.__root__}.json"
)
logging.info(f"Config file under {dag_config_file_path}") logging.info(f"Config file under {dag_config_file_path}")
dag_runner_config = self.store_airflow_pipeline_config(dag_config_file_path) dag_runner_config = self.store_airflow_pipeline_config(dag_config_file_path)

View File

@ -17,6 +17,7 @@ from typing import Callable, Optional, Union
import airflow import airflow
from airflow import DAG from airflow import DAG
from openmetadata.helpers import clean_dag_id
from metadata.generated.schema.entity.services.dashboardService import DashboardService from metadata.generated.schema.entity.services.dashboardService import DashboardService
from metadata.generated.schema.entity.services.databaseService import DatabaseService from metadata.generated.schema.entity.services.databaseService import DatabaseService
@ -170,7 +171,7 @@ def build_dag_configs(ingestion_pipeline: IngestionPipeline) -> dict:
:return: dict to use as kwargs :return: dict to use as kwargs
""" """
return { return {
"dag_id": ingestion_pipeline.name.__root__, "dag_id": clean_dag_id(ingestion_pipeline.name.__root__),
"description": ingestion_pipeline.description, "description": ingestion_pipeline.description,
"start_date": ingestion_pipeline.airflowConfig.startDate.__root__ "start_date": ingestion_pipeline.airflowConfig.startDate.__root__
if ingestion_pipeline.airflowConfig.startDate if ingestion_pipeline.airflowConfig.startDate

View File

@ -16,6 +16,7 @@ import json
import uuid import uuid
from unittest import TestCase from unittest import TestCase
from openmetadata.helpers import clean_dag_id
from openmetadata.workflows.ingestion.metadata import build_metadata_workflow_config from openmetadata.workflows.ingestion.metadata import build_metadata_workflow_config
from openmetadata.workflows.ingestion.profiler import build_profiler_workflow_config from openmetadata.workflows.ingestion.profiler import build_profiler_workflow_config
from openmetadata.workflows.ingestion.usage import build_usage_workflow_config from openmetadata.workflows.ingestion.usage import build_usage_workflow_config
@ -123,6 +124,15 @@ class OMetaServiceTest(TestCase):
hard_delete=True, hard_delete=True,
) )
def test_clean_dag_id(self):
"""
Validate dag_id clean
"""
self.assertEqual(clean_dag_id("hello"), "hello")
self.assertEqual(clean_dag_id("hello(world)"), "hello_world_")
self.assertEqual(clean_dag_id("hello-world"), "hello-world")
self.assertEqual(clean_dag_id("%%&^++hello__"), "_hello__")
def test_ingestion_workflow(self): def test_ingestion_workflow(self):
""" """
Validate that the ingestionPipeline can be parsed Validate that the ingestionPipeline can be parsed