Add exception handling (#7217)

This commit is contained in:
Pere Miquel Brull 2022-09-05 18:50:22 +02:00 committed by GitHub
parent e44c8dacfe
commit e08bd1b1d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -15,10 +15,12 @@ import traceback
from datetime import datetime
from typing import Any, Iterable, List, Optional, cast
import sqlalchemy
from airflow.models import BaseOperator, DagRun, TaskInstance
from airflow.models.serialized_dag import SerializedDagModel
from airflow.serialization.serialized_objects import SerializedDAG
from pydantic import BaseModel
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session
from metadata.generated.schema.api.data.createPipeline import CreatePipelineRequest
@ -56,8 +58,6 @@ STATUS_MAP = {
"queued": StatusType.Pending.value,
}
IGNORE_DAG_RUN_COL = {"log_template_id"}
class OMSerializedDagDetails(BaseModel):
"""
@ -126,7 +126,12 @@ class AirflowSource(PipelineServiceSource):
dag_run_list = (
self.session.query(
*[c for c in DagRun.__table__.c if c.name not in IGNORE_DAG_RUN_COL]
DagRun.dag_id,
DagRun.run_id,
DagRun.queued_at,
DagRun.execution_date,
DagRun.start_date,
DagRun.state,
)
.filter(DagRun.dag_id == dag_id)
.order_by(DagRun.execution_date.desc())
@ -150,7 +155,9 @@ class AirflowSource(PipelineServiceSource):
for elem in dag_run_dict
]
def get_task_instances(self, dag_id: str, run_id: str) -> List[OMTaskInstance]:
def get_task_instances(
self, dag_id: str, run_id: str, execution_date: datetime
) -> List[OMTaskInstance]:
"""
We are building our own scoped TaskInstance
class to only focus on core properties required
@ -159,19 +166,31 @@ class AirflowSource(PipelineServiceSource):
This makes the versioning more flexible on which Airflow
sources we support.
"""
task_instance_list = (
self.session.query(
TaskInstance.task_id,
TaskInstance.state,
TaskInstance.start_date,
TaskInstance.end_date,
TaskInstance.run_id,
)
.filter(TaskInstance.dag_id == dag_id, TaskInstance.run_id == run_id)
.all()
)
task_instance_list = None
task_instance_dict = [dict(elem) for elem in task_instance_list]
try:
task_instance_list = (
self.session.query(
TaskInstance.task_id,
TaskInstance.state,
TaskInstance.start_date,
TaskInstance.end_date,
TaskInstance.run_id,
)
.filter(TaskInstance.dag_id == dag_id, TaskInstance.run_id == run_id)
.all()
)
except Exception as exc: # pylint: disable=broad-except
# Using a broad Exception here as the backend can come in many flavours (pymysql, pyodbc...)
# And we don't want to force all imports
logger.debug(traceback.format_exc())
logger.warning(
f"Tried to get TaskInstances with run_id. It might not be available in older Airflow versions - {exc}."
)
task_instance_dict = (
[dict(elem) for elem in task_instance_list] if task_instance_list else []
)
return [
OMTaskInstance(
@ -186,36 +205,46 @@ class AirflowSource(PipelineServiceSource):
def yield_pipeline_status(
self, pipeline_details: SerializedDAG
) -> OMetaPipelineStatus:
dag_run_list = self.get_pipeline_status(pipeline_details.dag_id)
try:
dag_run_list = self.get_pipeline_status(pipeline_details.dag_id)
for dag_run in dag_run_list:
tasks = self.get_task_instances(
dag_id=dag_run.dag_id,
run_id=dag_run.run_id,
)
for dag_run in dag_run_list:
tasks = self.get_task_instances(
dag_id=dag_run.dag_id,
run_id=dag_run.run_id,
execution_date=dag_run.execution_date, # Used for Airflow 2.1.4 query fallback
)
task_statuses = [
TaskStatus(
name=task.task_id,
task_statuses = [
TaskStatus(
name=task.task_id,
executionStatus=STATUS_MAP.get(
task.state, StatusType.Pending.value
),
startTime=datetime_to_ts(task.start_date),
endTime=datetime_to_ts(
task.end_date
), # Might be None for running tasks
) # Log link might not be present in all Airflow versions
for task in tasks
]
pipeline_status = PipelineStatus(
taskStatus=task_statuses,
executionStatus=STATUS_MAP.get(
task.state, StatusType.Pending.value
dag_run.state, StatusType.Pending.value
),
startTime=datetime_to_ts(task.start_date),
endTime=datetime_to_ts(
task.end_date
), # Might be None for running tasks
) # Log link might not be present in all Airflow versions
for task in tasks
]
pipeline_status = PipelineStatus(
taskStatus=task_statuses,
executionStatus=STATUS_MAP.get(dag_run.state, StatusType.Pending.value),
timestamp=dag_run.execution_date.timestamp(),
)
yield OMetaPipelineStatus(
pipeline_fqn=self.context.pipeline.fullyQualifiedName.__root__,
pipeline_status=pipeline_status,
timestamp=dag_run.execution_date.timestamp(),
)
yield OMetaPipelineStatus(
pipeline_fqn=self.context.pipeline.fullyQualifiedName.__root__,
pipeline_status=pipeline_status,
)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Wild error trying to extract status from DAG {pipeline_details.dag_id} - {exc}."
" Skipping status ingestion."
)
def get_pipelines_list(self) -> Iterable[OMSerializedDagDetails]:
@ -225,9 +254,16 @@ class AirflowSource(PipelineServiceSource):
We are using the SerializedDagModel as it helps
us retrieve all the task and inlets/outlets information
"""
json_data_column = (
SerializedDagModel._data # For 2.3.0 onwards
if hasattr(SerializedDagModel, "_data")
else SerializedDagModel.data # For 2.2.5 and 2.1.4
)
for serialized_dag in self.session.query(
SerializedDagModel.dag_id,
SerializedDagModel._data,
json_data_column,
SerializedDagModel.fileloc,
).all():