# Copyright 2021 Collate # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pkgutil import traceback from pathlib import Path from typing import Dict from airflow import DAG, settings from airflow.models import DagModel from jinja2 import Template from openmetadata_managed_apis.api.config import ( AIRFLOW_DAGS_FOLDER, DAG_GENERATED_CONFIGS, PLUGIN_NAME, ) from openmetadata_managed_apis.api.response import ApiResponse from openmetadata_managed_apis.api.utils import ( clean_dag_id, get_dagbag, import_path, scan_dags_job_background, ) from openmetadata_managed_apis.utils.logger import operations_logger from metadata.generated.schema.entity.services.ingestionPipelines.ingestionPipeline import ( IngestionPipeline, ) from metadata.ingestion.models.encoders import show_secrets_encoder logger = operations_logger() class DeployDagException(Exception): """ Error when deploying the DAG """ class DagDeployer: """ Helper class to store DAG config and deploy it to Airflow """ def __init__(self, ingestion_pipeline: IngestionPipeline): logger.info( f"Received the following Airflow Configuration: {ingestion_pipeline.airflowConfig}" ) self.ingestion_pipeline = ingestion_pipeline self.dag_id = clean_dag_id(self.ingestion_pipeline.name.__root__) def store_airflow_pipeline_config( self, dag_config_file_path: Path ) -> Dict[str, str]: """ Store the airflow pipeline config in a JSON file and return the path for the Jinja rendering. """ logger.info(f"Saving file to {dag_config_file_path}") with open(dag_config_file_path, "w") as outfile: outfile.write(self.ingestion_pipeline.json(encoder=show_secrets_encoder)) return {"workflow_config_file": str(dag_config_file_path)} def store_and_validate_dag_file(self, dag_runner_config: Dict[str, str]) -> str: """ Stores the Python file generating the DAG and returns the rendered strings """ dag_py_file = Path(AIRFLOW_DAGS_FOLDER) / f"{self.dag_id}.py" # Open the template and render raw_template = pkgutil.get_data(PLUGIN_NAME, "resources/dag_runner.j2").decode() template = Template(raw_template) rendered_dag = template.render(dag_runner_config) # Create the DAGs path if it does not exist if not dag_py_file.parent.is_dir(): dag_py_file.parent.mkdir(parents=True, exist_ok=True) with open(dag_py_file, "w") as f: f.write(rendered_dag) try: dag_file = import_path(str(dag_py_file)) except Exception as exc: logger.debug(traceback.format_exc()) logger.error(f"Failed to import dag_file [{dag_py_file}]: {exc}") raise exc if dag_file is None: raise DeployDagException(f"Failed to import dag_file [{dag_py_file}]") return str(dag_py_file) def refresh_session_dag(self, dag_py_file: str): """ Get the stored python DAG file and update the Airflow DagBag and sync it to the db. In Airflow 2.3.3, we also need to add a call to the Scheduler job, to make sure that all the pieces are being properly picked up. """ # Refresh dag into session with settings.Session() as session: try: dag_bag = get_dagbag() logger.info("dagbag size {}".format(dag_bag.size())) found_dags = dag_bag.process_file(dag_py_file) logger.info("processed dags {}".format(found_dags)) dag: DAG = dag_bag.get_dag(self.dag_id, session=session) # Sync to DB dag.sync_to_db(session=session) dag_model = ( session.query(DagModel) .filter(DagModel.dag_id == self.dag_id) .first() ) logger.info("dag_model:" + str(dag_model)) # Scheduler Job to scan dags scan_dags_job_background() return ApiResponse.success( {"message": f"Workflow [{self.dag_id}] has been created"} ) except Exception as exc: msg = f"Workflow [{self.dag_id}] failed to refresh due to [{exc}]" logger.debug(traceback.format_exc()) logger.error(msg) return ApiResponse.server_error({f"message": msg}) def deploy(self): """ Run all methods to deploy the DAG """ dag_config_file_path = Path(DAG_GENERATED_CONFIGS) / f"{self.dag_id}.json" logger.info(f"Config file under {dag_config_file_path}") dag_runner_config = self.store_airflow_pipeline_config(dag_config_file_path) dag_py_file = self.store_and_validate_dag_file(dag_runner_config) response = self.refresh_session_dag(dag_py_file) return response