Pere Miquel Brull ad95bdb9c6
Fix #6735 - Improve source validation for inner connections (#6738)
* Add inner connection validations and throw better error

* Use custom classes when retrieving metadata from Airflow to allow more versions
2022-08-16 18:47:50 +02:00

402 lines
14 KiB
Python

# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Airflow source to extract metadata from OM UI
"""
import traceback
from datetime import datetime
from typing import Any, Iterable, List, Optional, cast
from airflow.models import BaseOperator, DagRun, TaskInstance
from airflow.models.serialized_dag import SerializedDagModel
from airflow.serialization.serialized_objects import SerializedDAG
from pydantic import BaseModel
from sqlalchemy.orm import Session
from metadata.generated.schema.api.data.createPipeline import CreatePipelineRequest
from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest
from metadata.generated.schema.entity.data.pipeline import (
PipelineStatus,
StatusType,
Task,
TaskStatus,
)
from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import (
OpenMetadataConnection,
)
from metadata.generated.schema.entity.services.connections.pipeline.airflowConnection import (
AirflowConnection,
)
from metadata.generated.schema.metadataIngestion.workflow import (
Source as WorkflowSource,
)
from metadata.generated.schema.type.entityLineage import EntitiesEdge
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.api.source import InvalidSourceException
from metadata.ingestion.models.pipeline_status import OMetaPipelineStatus
from metadata.ingestion.source.pipeline.pipeline_service import PipelineServiceSource
from metadata.utils.connections import create_and_bind_session, test_connection
from metadata.utils.helpers import datetime_to_ts
from metadata.utils.logger import ingestion_logger
logger = ingestion_logger()
STATUS_MAP = {
"success": StatusType.Successful.value,
"failed": StatusType.Failed.value,
"queued": StatusType.Pending.value,
}
IGNORE_DAG_RUN_COL = {"log_template_id"}
class OMSerializedDagDetails(BaseModel):
"""
Custom model we get from the Airflow db
as a scoped SELECT from SerializedDagModel
"""
dag_id: str
data: Any
fileloc: str
# We don't have a validator for SerializedDag
class Config:
arbitrary_types_allowed = True
class OMTaskInstance(BaseModel):
"""
Custom model we get from the Airflow db
as a scoped SELECT from TaskInstance
"""
task_id: str
state: str
start_date: Optional[datetime]
end_date: Optional[datetime]
class AirflowSource(PipelineServiceSource):
"""
Implements the necessary methods ot extract
Pipeline metadata from Airflow's metadata db
"""
config: WorkflowSource
def __init__(
self,
config: WorkflowSource,
metadata_config: OpenMetadataConnection,
):
super().__init__(config, metadata_config)
self._session = None
@classmethod
def create(cls, config_dict, metadata_config: OpenMetadataConnection):
config: WorkflowSource = WorkflowSource.parse_obj(config_dict)
connection: AirflowConnection = config.serviceConnection.__root__.config
if not isinstance(connection, AirflowConnection):
raise InvalidSourceException(
f"Expected AirflowConnection, but got {connection}"
)
return cls(config, metadata_config)
@property
def session(self) -> Session:
"""
Return the SQLAlchemy session from the engine
"""
if not self._session:
self._session = create_and_bind_session(self.connection)
return self._session
def get_pipeline_status(self, dag_id: str) -> List[DagRun]:
dag_run_list = (
self.session.query(
*[c for c in DagRun.__table__.c if c.name not in IGNORE_DAG_RUN_COL]
)
.filter(DagRun.dag_id == dag_id)
.order_by(DagRun.execution_date.desc())
.limit(self.config.serviceConnection.__root__.config.numberOfStatus)
.all()
)
dag_run_dict = [dict(elem) for elem in dag_run_list]
# Build DagRun manually to not fall into new/old columns from
# different Airflow versions
return [
DagRun(
dag_id=elem.get("dag_id"),
run_id=elem.get("run_id"),
queued_at=elem.get("queued_at"),
execution_date=elem.get("execution_date"),
start_date=elem.get("start_date"),
state=elem.get("state"),
)
for elem in dag_run_dict
]
def get_task_instances(self, dag_id: str, run_id: str) -> List[OMTaskInstance]:
"""
We are building our own scoped TaskInstance
class to only focus on core properties required
by the metadata ingestion.
This makes the versioning more flexible on which Airflow
sources we support.
"""
task_instance_list = (
self.session.query(
TaskInstance.task_id,
TaskInstance.state,
TaskInstance.start_date,
TaskInstance.end_date,
TaskInstance.run_id,
)
.filter(TaskInstance.dag_id == dag_id, TaskInstance.run_id == run_id)
.all()
)
task_instance_dict = [dict(elem) for elem in task_instance_list]
return [
OMTaskInstance(
task_id=elem.get("task_id"),
state=elem.get("state"),
start_date=elem.get("start_date"),
end_date=elem.get("end_date"),
)
for elem in task_instance_dict
]
def yield_pipeline_status(
self, pipeline_details: SerializedDAG
) -> OMetaPipelineStatus:
dag_run_list = self.get_pipeline_status(pipeline_details.dag_id)
for dag_run in dag_run_list:
tasks = self.get_task_instances(
dag_id=dag_run.dag_id,
run_id=dag_run.run_id,
)
task_statuses = [
TaskStatus(
name=task.task_id,
executionStatus=STATUS_MAP.get(
task.state, StatusType.Pending.value
),
startTime=datetime_to_ts(task.start_date),
endTime=datetime_to_ts(
task.end_date
), # Might be None for running tasks
) # Log link might not be present in all Airflow versions
for task in tasks
]
pipeline_status = PipelineStatus(
taskStatus=task_statuses,
executionStatus=STATUS_MAP.get(dag_run.state, StatusType.Pending.value),
timestamp=dag_run.execution_date.timestamp(),
)
yield OMetaPipelineStatus(
pipeline_fqn=self.context.pipeline.fullyQualifiedName.__root__,
pipeline_status=pipeline_status,
)
def get_pipelines_list(self) -> Iterable[OMSerializedDagDetails]:
"""
List all DAGs from the metadata db.
We are using the SerializedDagModel as it helps
us retrieve all the task and inlets/outlets information
"""
for serialized_dag in self.session.query(
SerializedDagModel.dag_id,
SerializedDagModel._data,
SerializedDagModel.fileloc,
).all():
yield OMSerializedDagDetails(
dag_id=serialized_dag[0],
data=serialized_dag[1],
fileloc=serialized_dag[2],
)
def get_pipeline_name(self, pipeline_details: SerializedDAG) -> str:
"""
Get Pipeline Name
"""
return pipeline_details.dag_id
@staticmethod
def get_tasks_from_dag(dag: SerializedDAG) -> List[Task]:
"""
Obtain the tasks from a SerializedDAG
:param dag: SerializedDAG
:return: List of tasks
"""
return [
Task(
name=task.task_id,
description=task.doc_md,
# Just the suffix
taskUrl=f"/taskinstance/list/?flt1_dag_id_equals={dag.dag_id}&_flt_3_task_id={task.task_id}",
downstreamTasks=list(task.downstream_task_ids),
taskType=task.task_type,
startDate=task.start_date.isoformat() if task.start_date else None,
endDate=task.end_date.isoformat() if task.end_date else None,
)
for task in cast(Iterable[BaseOperator], dag.tasks)
]
@staticmethod
def _build_dag(data: Any) -> SerializedDAG:
"""
Use the queried data to fetch the DAG
:param data: from SQA query
:return: SerializedDAG
"""
if isinstance(data, dict):
return SerializedDAG.from_dict(data)
return SerializedDAG.from_json(data)
def yield_pipeline(
self, pipeline_details: OMSerializedDagDetails
) -> Iterable[CreatePipelineRequest]:
"""
Convert a DAG into a Pipeline Entity
:param pipeline_details: SerializedDAG from airflow metadata DB
:return: Create Pipeline request with tasks
"""
dag: SerializedDAG = self._build_dag(pipeline_details.data)
yield CreatePipelineRequest(
name=pipeline_details.dag_id,
description=dag.description,
pipelineUrl=f"/tree?dag_id={dag.dag_id}", # Just the suffix
concurrency=dag.concurrency,
pipelineLocation=pipeline_details.fileloc,
startDate=dag.start_date.isoformat() if dag.start_date else None,
tasks=self.get_tasks_from_dag(dag),
service=EntityReference(
id=self.context.pipeline_service.id.__root__, type="pipelineService"
),
)
@staticmethod
def parse_xlets(xlet: List[Any]) -> Optional[List[str]]:
"""
Parse airflow xlets for 2.1.4. E.g.,
[{'__var': {'tables': ['sample_data.ecommerce_db.shopify.fact_order']},
'__type': 'dict'}]
:param xlet: airflow v2 xlet dict
:return: table FQN list or None
"""
if len(xlet) and isinstance(xlet[0], dict):
tables = xlet[0].get("__var").get("tables")
if tables and isinstance(tables, list):
return tables
return None
def get_inlets(self, task: BaseOperator) -> Optional[List[str]]:
"""
Get inlets from serialised operator
:param task: SerializedBaseOperator
:return: maybe an inlet list
"""
inlets = task.get_inlet_defs()
try:
return self.parse_xlets(inlets)
except Exception as err:
logger.debug(traceback.format_exc())
logger.warn(f"Error trying to parse inlets - {err}")
return None
def get_outlets(self, task: BaseOperator) -> Optional[List[str]]:
"""
Get outlets from serialised operator
:param task: SerializedBaseOperator
:return: maybe an inlet list
"""
outlets = task.get_outlet_defs()
try:
return self.parse_xlets(outlets)
except Exception as err:
logger.debug(traceback.format_exc())
logger.warn(f"Error trying to parse outlets - {err}")
return None
def yield_pipeline_lineage_details(
self, pipeline_details: OMSerializedDagDetails
) -> Optional[Iterable[AddLineageRequest]]:
"""
Parse xlets and add lineage between Pipelines and Tables
:param pipeline_details: SerializedDAG from airflow metadata DB
:return: Lineage from inlets and outlets
"""
dag: SerializedDAG = self._build_dag(pipeline_details.data)
for task in dag.tasks:
for table_fqn in self.get_inlets(task) or []:
table_entity: Table = self.metadata.get_by_name(
entity=Table, fqn=table_fqn
)
if table_entity:
yield AddLineageRequest(
edge=EntitiesEdge(
fromEntity=EntityReference(
id=table_entity.id, type="table"
),
toEntity=EntityReference(
id=self.context.pipeline.id.__root__, type="pipeline"
),
)
)
else:
logger.warn(
f"Could not find Table [{table_fqn}] from "
f"[{self.context.pipeline.fullyQualifiedName.__root__}] inlets"
)
for table_fqn in self.get_outlets(task) or []:
table_entity: Table = self.metadata.get_by_name(
entity=Table, fqn=table_fqn
)
if table_entity:
yield AddLineageRequest(
edge=EntitiesEdge(
fromEntity=EntityReference(
id=self.context.pipeline.id.__root__, type="pipeline"
),
toEntity=EntityReference(id=table_entity.id, type="table"),
)
)
else:
logger.warn(
f"Could not find Table [{table_fqn}] from "
f"[{self.context.pipeline.fullyQualifiedName.__root__}] outlets"
)
def close(self):
self.session.close()