Add exception handling (#7217)

This commit is contained in:
Pere Miquel Brull 2022-09-05 18:50:22 +02:00 committed by GitHub
parent e44c8dacfe
commit e08bd1b1d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -15,10 +15,12 @@ import traceback
from datetime import datetime from datetime import datetime
from typing import Any, Iterable, List, Optional, cast from typing import Any, Iterable, List, Optional, cast
import sqlalchemy
from airflow.models import BaseOperator, DagRun, TaskInstance from airflow.models import BaseOperator, DagRun, TaskInstance
from airflow.models.serialized_dag import SerializedDagModel from airflow.models.serialized_dag import SerializedDagModel
from airflow.serialization.serialized_objects import SerializedDAG from airflow.serialization.serialized_objects import SerializedDAG
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from metadata.generated.schema.api.data.createPipeline import CreatePipelineRequest from metadata.generated.schema.api.data.createPipeline import CreatePipelineRequest
@ -56,8 +58,6 @@ STATUS_MAP = {
"queued": StatusType.Pending.value, "queued": StatusType.Pending.value,
} }
IGNORE_DAG_RUN_COL = {"log_template_id"}
class OMSerializedDagDetails(BaseModel): class OMSerializedDagDetails(BaseModel):
""" """
@ -126,7 +126,12 @@ class AirflowSource(PipelineServiceSource):
dag_run_list = ( dag_run_list = (
self.session.query( self.session.query(
*[c for c in DagRun.__table__.c if c.name not in IGNORE_DAG_RUN_COL] DagRun.dag_id,
DagRun.run_id,
DagRun.queued_at,
DagRun.execution_date,
DagRun.start_date,
DagRun.state,
) )
.filter(DagRun.dag_id == dag_id) .filter(DagRun.dag_id == dag_id)
.order_by(DagRun.execution_date.desc()) .order_by(DagRun.execution_date.desc())
@ -150,7 +155,9 @@ class AirflowSource(PipelineServiceSource):
for elem in dag_run_dict for elem in dag_run_dict
] ]
def get_task_instances(self, dag_id: str, run_id: str) -> List[OMTaskInstance]: def get_task_instances(
self, dag_id: str, run_id: str, execution_date: datetime
) -> List[OMTaskInstance]:
""" """
We are building our own scoped TaskInstance We are building our own scoped TaskInstance
class to only focus on core properties required class to only focus on core properties required
@ -159,19 +166,31 @@ class AirflowSource(PipelineServiceSource):
This makes the versioning more flexible on which Airflow This makes the versioning more flexible on which Airflow
sources we support. sources we support.
""" """
task_instance_list = ( task_instance_list = None
self.session.query(
TaskInstance.task_id,
TaskInstance.state,
TaskInstance.start_date,
TaskInstance.end_date,
TaskInstance.run_id,
)
.filter(TaskInstance.dag_id == dag_id, TaskInstance.run_id == run_id)
.all()
)
task_instance_dict = [dict(elem) for elem in task_instance_list] try:
task_instance_list = (
self.session.query(
TaskInstance.task_id,
TaskInstance.state,
TaskInstance.start_date,
TaskInstance.end_date,
TaskInstance.run_id,
)
.filter(TaskInstance.dag_id == dag_id, TaskInstance.run_id == run_id)
.all()
)
except Exception as exc: # pylint: disable=broad-except
# Using a broad Exception here as the backend can come in many flavours (pymysql, pyodbc...)
# And we don't want to force all imports
logger.debug(traceback.format_exc())
logger.warning(
f"Tried to get TaskInstances with run_id. It might not be available in older Airflow versions - {exc}."
)
task_instance_dict = (
[dict(elem) for elem in task_instance_list] if task_instance_list else []
)
return [ return [
OMTaskInstance( OMTaskInstance(
@ -186,36 +205,46 @@ class AirflowSource(PipelineServiceSource):
def yield_pipeline_status( def yield_pipeline_status(
self, pipeline_details: SerializedDAG self, pipeline_details: SerializedDAG
) -> OMetaPipelineStatus: ) -> OMetaPipelineStatus:
dag_run_list = self.get_pipeline_status(pipeline_details.dag_id) try:
dag_run_list = self.get_pipeline_status(pipeline_details.dag_id)
for dag_run in dag_run_list: for dag_run in dag_run_list:
tasks = self.get_task_instances( tasks = self.get_task_instances(
dag_id=dag_run.dag_id, dag_id=dag_run.dag_id,
run_id=dag_run.run_id, run_id=dag_run.run_id,
) execution_date=dag_run.execution_date, # Used for Airflow 2.1.4 query fallback
)
task_statuses = [ task_statuses = [
TaskStatus( TaskStatus(
name=task.task_id, name=task.task_id,
executionStatus=STATUS_MAP.get(
task.state, StatusType.Pending.value
),
startTime=datetime_to_ts(task.start_date),
endTime=datetime_to_ts(
task.end_date
), # Might be None for running tasks
) # Log link might not be present in all Airflow versions
for task in tasks
]
pipeline_status = PipelineStatus(
taskStatus=task_statuses,
executionStatus=STATUS_MAP.get( executionStatus=STATUS_MAP.get(
task.state, StatusType.Pending.value dag_run.state, StatusType.Pending.value
), ),
startTime=datetime_to_ts(task.start_date), timestamp=dag_run.execution_date.timestamp(),
endTime=datetime_to_ts( )
task.end_date yield OMetaPipelineStatus(
), # Might be None for running tasks pipeline_fqn=self.context.pipeline.fullyQualifiedName.__root__,
) # Log link might not be present in all Airflow versions pipeline_status=pipeline_status,
for task in tasks )
] except Exception as exc:
logger.debug(traceback.format_exc())
pipeline_status = PipelineStatus( logger.warning(
taskStatus=task_statuses, f"Wild error trying to extract status from DAG {pipeline_details.dag_id} - {exc}."
executionStatus=STATUS_MAP.get(dag_run.state, StatusType.Pending.value), " Skipping status ingestion."
timestamp=dag_run.execution_date.timestamp(),
)
yield OMetaPipelineStatus(
pipeline_fqn=self.context.pipeline.fullyQualifiedName.__root__,
pipeline_status=pipeline_status,
) )
def get_pipelines_list(self) -> Iterable[OMSerializedDagDetails]: def get_pipelines_list(self) -> Iterable[OMSerializedDagDetails]:
@ -225,9 +254,16 @@ class AirflowSource(PipelineServiceSource):
We are using the SerializedDagModel as it helps We are using the SerializedDagModel as it helps
us retrieve all the task and inlets/outlets information us retrieve all the task and inlets/outlets information
""" """
json_data_column = (
SerializedDagModel._data # For 2.3.0 onwards
if hasattr(SerializedDagModel, "_data")
else SerializedDagModel.data # For 2.2.5 and 2.1.4
)
for serialized_dag in self.session.query( for serialized_dag in self.session.query(
SerializedDagModel.dag_id, SerializedDagModel.dag_id,
SerializedDagModel._data, json_data_column,
SerializedDagModel.fileloc, SerializedDagModel.fileloc,
).all(): ).all():