mirror of
				https://github.com/open-metadata/OpenMetadata.git
				synced 2025-10-25 07:42:40 +00:00 
			
		
		
		
	 a3bff293a4
			
		
	
	
		a3bff293a4
		
			
		
	
	
	
	
		
			
			* Bump Airflow version * Support Airflow 2.6 * Lint * Bump airflow version * Fix airflow 2.6 lineage * Fix airflow 2.6 lineage
		
			
				
	
	
		
			165 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			165 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #  Copyright 2021 Collate
 | |
| #  Licensed under the Apache License, Version 2.0 (the "License");
 | |
| #  you may not use this file except in compliance with the License.
 | |
| #  You may obtain a copy of the License at
 | |
| #  http://www.apache.org/licenses/LICENSE-2.0
 | |
| #  Unless required by applicable law or agreed to in writing, software
 | |
| #  distributed under the License is distributed on an "AS IS" BASIS,
 | |
| #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| #  See the License for the specific language governing permissions and
 | |
| #  limitations under the License.
 | |
| 
 | |
| import importlib
 | |
| import os
 | |
| import re
 | |
| import sys
 | |
| import traceback
 | |
| from multiprocessing import Process
 | |
| from typing import Optional
 | |
| 
 | |
| from airflow import settings
 | |
| from airflow.models import DagBag
 | |
| from airflow.version import version as airflow_version
 | |
| from flask import request
 | |
| from openmetadata_managed_apis.utils.logger import api_logger
 | |
| 
 | |
| logger = api_logger()
 | |
| 
 | |
| 
 | |
| class MissingArgException(Exception):
 | |
|     """
 | |
|     Raised when we cannot properly validate the incoming data
 | |
|     """
 | |
| 
 | |
| 
 | |
| def import_path(path):
 | |
|     module_name = os.path.basename(path).replace("-", "_")
 | |
|     spec = importlib.util.spec_from_loader(
 | |
|         module_name, importlib.machinery.SourceFileLoader(module_name, path)
 | |
|     )
 | |
|     module = importlib.util.module_from_spec(spec)
 | |
|     spec.loader.exec_module(module)
 | |
|     sys.modules[module_name] = module
 | |
|     return module
 | |
| 
 | |
| 
 | |
| def clean_dag_id(raw_dag_id: Optional[str]) -> Optional[str]:
 | |
|     """
 | |
|     Given a string we want to use as a dag_id, we should
 | |
|     give it a cleanup as Airflow does not support anything
 | |
|     that is not alphanumeric for the name
 | |
|     """
 | |
|     return re.sub("[^0-9a-zA-Z-_]+", "_", raw_dag_id) if raw_dag_id else None
 | |
| 
 | |
| 
 | |
| def get_request_arg(req, arg, raise_missing: bool = True) -> Optional[str]:
 | |
|     """
 | |
|     Pick up the `arg` from the flask `req`.
 | |
|     E.g., GET api/v1/endpoint?key=value
 | |
| 
 | |
|     If raise_missing, throw an exception if the argument is
 | |
|     not present in the request
 | |
|     """
 | |
|     request_argument = req.args.get(arg) or req.form.get(arg)
 | |
| 
 | |
|     if not request_argument and raise_missing:
 | |
|         raise MissingArgException(f"Missing {arg} from request {req} argument")
 | |
| 
 | |
|     return request_argument
 | |
| 
 | |
| 
 | |
| def get_arg_dag_id() -> Optional[str]:
 | |
|     """
 | |
|     Try to fetch the dag_id from the args
 | |
|     and clean it
 | |
|     """
 | |
|     raw_dag_id = get_request_arg(request, "dag_id")
 | |
| 
 | |
|     return clean_dag_id(raw_dag_id)
 | |
| 
 | |
| 
 | |
| def get_arg_only_queued() -> Optional[str]:
 | |
|     """
 | |
|     Try to fetch the only_queued from the args
 | |
|     """
 | |
|     return get_request_arg(request, "only_queued", raise_missing=False)
 | |
| 
 | |
| 
 | |
| def get_request_dag_id() -> Optional[str]:
 | |
|     """
 | |
|     Try to fetch the dag_id from the JSON request
 | |
|     and clean it
 | |
|     """
 | |
|     raw_dag_id = request.get_json().get("dag_id")
 | |
| 
 | |
|     if not raw_dag_id:
 | |
|         raise MissingArgException("Missing dag_id from request JSON")
 | |
| 
 | |
|     return clean_dag_id(raw_dag_id)
 | |
| 
 | |
| 
 | |
| def get_dagbag():
 | |
|     """
 | |
|     Load the dagbag from Airflow settings
 | |
|     """
 | |
|     dagbag = DagBag(dag_folder=settings.DAGS_FOLDER, read_dags_from_db=True)
 | |
|     dagbag.collect_dags()
 | |
|     dagbag.collect_dags_from_db()
 | |
|     return dagbag
 | |
| 
 | |
| 
 | |
| class ScanDagsTask(Process):
 | |
|     def run(self):
 | |
| 
 | |
|         if airflow_version >= "2.6":
 | |
|             scheduler_job = self._run_new_scheduler_job()
 | |
|         else:
 | |
|             scheduler_job = self._run_old_scheduler_job()
 | |
|         try:
 | |
|             scheduler_job.kill()
 | |
|         except Exception as exc:
 | |
|             logger.debug(traceback.format_exc())
 | |
|             logger.info(f"Rescan Complete: Killed Job: {exc}")
 | |
| 
 | |
|     @staticmethod
 | |
|     def _run_new_scheduler_job() -> "Job":
 | |
|         """
 | |
|         Run the new scheduler job from Airflow 2.6
 | |
|         """
 | |
|         from airflow.jobs.job import Job, run_job
 | |
|         from airflow.jobs.scheduler_job_runner import SchedulerJobRunner
 | |
| 
 | |
|         scheduler_job = Job()
 | |
|         job_runner = SchedulerJobRunner(
 | |
|             job=scheduler_job,
 | |
|             num_runs=1,
 | |
|         )
 | |
|         scheduler_job.heartrate = 0
 | |
| 
 | |
|         # pylint: disable=protected-access
 | |
|         run_job(scheduler_job, execute_callable=job_runner._execute)
 | |
| 
 | |
|         return scheduler_job
 | |
| 
 | |
|     @staticmethod
 | |
|     def _run_old_scheduler_job() -> "SchedulerJob":
 | |
|         """
 | |
|         Run the old scheduler job before 2.6
 | |
|         """
 | |
|         from airflow.jobs.scheduler_job import SchedulerJob
 | |
| 
 | |
|         scheduler_job = SchedulerJob(num_times_parse_dags=1)
 | |
|         scheduler_job.heartrate = 0
 | |
|         scheduler_job.run()
 | |
| 
 | |
|         return scheduler_job
 | |
| 
 | |
| 
 | |
| def scan_dags_job_background():
 | |
|     """
 | |
|     Runs the scheduler scan in another thread
 | |
|     to not block the API call
 | |
|     """
 | |
|     process = ScanDagsTask()
 | |
|     process.start()
 |