mirror of
				https://github.com/open-metadata/OpenMetadata.git
				synced 2025-11-03 20:19:31 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			115 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			115 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
#  Copyright 2022 Collate
 | 
						|
#  Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
#  you may not use this file except in compliance with the License.
 | 
						|
#  You may obtain a copy of the License at
 | 
						|
#  http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#  Unless required by applicable law or agreed to in writing, software
 | 
						|
#  distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
#  See the License for the specific language governing permissions and
 | 
						|
#  limitations under the License.
 | 
						|
"""
 | 
						|
Module containing the logic to retrieve all logs from the tasks of a last DAG run
 | 
						|
"""
 | 
						|
from functools import partial
 | 
						|
from io import StringIO
 | 
						|
from typing import List, Optional
 | 
						|
 | 
						|
from airflow.models import DagModel, TaskInstance
 | 
						|
from airflow.utils.log.log_reader import TaskLogReader
 | 
						|
from flask import Response
 | 
						|
from openmetadata_managed_apis.api.response import ApiResponse
 | 
						|
 | 
						|
LOG_METADATA = {
 | 
						|
    "download_logs": False,
 | 
						|
}
 | 
						|
# Make chunks of 2M characters
 | 
						|
CHUNK_SIZE = 2_000_000
 | 
						|
 | 
						|
 | 
						|
def last_dag_logs(dag_id: str, task_id: str, after: Optional[int] = None) -> Response:
 | 
						|
    """Validate that the DAG is registered by Airflow and have at least one Run.
 | 
						|
 | 
						|
    If exists, returns all logs for each task instance of the last DAG run.
 | 
						|
 | 
						|
    Args:
 | 
						|
        dag_id (str): DAG to look for
 | 
						|
        task_id (str): Task to fetch logs from
 | 
						|
        after (int): log stream cursor
 | 
						|
 | 
						|
    Return:
 | 
						|
        Response with log and pagination
 | 
						|
    """
 | 
						|
 | 
						|
    dag_model = DagModel.get_dagmodel(dag_id=dag_id)
 | 
						|
 | 
						|
    if not dag_model:
 | 
						|
        return ApiResponse.not_found(f"DAG {dag_id} not found.")
 | 
						|
 | 
						|
    last_dag_run = dag_model.get_last_dagrun(include_externally_triggered=True)
 | 
						|
 | 
						|
    if not last_dag_run:
 | 
						|
        return ApiResponse.not_found(f"No DAG run found for {dag_id}.")
 | 
						|
 | 
						|
    task_instances: List[TaskInstance] = last_dag_run.get_task_instances()
 | 
						|
 | 
						|
    if not task_instances:
 | 
						|
        return ApiResponse.not_found(
 | 
						|
            f"Cannot find any task instance for the last DagRun of {dag_id}."
 | 
						|
        )
 | 
						|
 | 
						|
    raw_logs_str = None
 | 
						|
 | 
						|
    for task_instance in task_instances:
 | 
						|
        # Only fetch the required logs
 | 
						|
        if task_instance.task_id == task_id:
 | 
						|
            # Pick up the _try_number, otherwise they are adding 1
 | 
						|
            try_number = task_instance._try_number  # pylint: disable=protected-access
 | 
						|
 | 
						|
            task_log_reader = TaskLogReader()
 | 
						|
            if not task_log_reader.supports_read:
 | 
						|
                return ApiResponse.server_error(
 | 
						|
                    "Task Log Reader does not support read logs."
 | 
						|
                )
 | 
						|
 | 
						|
            # Even when generating a ton of logs, we just get a single element.
 | 
						|
            # Same happens when trying to call task_log_reader.read_log_chunks
 | 
						|
            # We'll create our own chunk size and paginate based on that
 | 
						|
            raw_logs_str = "".join(
 | 
						|
                list(
 | 
						|
                    task_log_reader.read_log_stream(
 | 
						|
                        ti=task_instance,
 | 
						|
                        try_number=try_number,
 | 
						|
                        metadata=LOG_METADATA,
 | 
						|
                    )
 | 
						|
                )
 | 
						|
            )
 | 
						|
 | 
						|
    if not raw_logs_str:
 | 
						|
        return ApiResponse.bad_request(
 | 
						|
            f"Can't fetch logs for DAG {dag_id} and Task {task_id}."
 | 
						|
        )
 | 
						|
 | 
						|
    # Split the string in chunks of size without
 | 
						|
    # having to know the full length beforehand
 | 
						|
    log_chunks = [
 | 
						|
        chunk for chunk in iter(partial(StringIO(raw_logs_str).read, CHUNK_SIZE), "")
 | 
						|
    ]
 | 
						|
 | 
						|
    total = len(log_chunks)
 | 
						|
    after_idx = int(after) if after is not None else 0
 | 
						|
 | 
						|
    if after_idx >= total:
 | 
						|
        return ApiResponse.bad_request(
 | 
						|
            f"After index {after} is out of bounds. Total pagination is {total} for DAG {dag_id} and Task {task_id}."
 | 
						|
        )
 | 
						|
 | 
						|
    return ApiResponse.success(
 | 
						|
        {
 | 
						|
            task_id: log_chunks[after_idx],
 | 
						|
            "total": len(log_chunks),
 | 
						|
            # Only add the after if there are more pages
 | 
						|
            **({"after": after_idx + 1} if after_idx < total - 1 else {}),
 | 
						|
        }
 | 
						|
    )
 |