mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-08-18 14:06:59 +00:00

* Add inner connection validations and throw better error * Use custom classes when retrieving metadata from Airflow to allow more versions
402 lines
14 KiB
Python
402 lines
14 KiB
Python
# Copyright 2021 Collate
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Airflow source to extract metadata from OM UI
|
|
"""
|
|
import traceback
|
|
from datetime import datetime
|
|
from typing import Any, Iterable, List, Optional, cast
|
|
|
|
from airflow.models import BaseOperator, DagRun, TaskInstance
|
|
from airflow.models.serialized_dag import SerializedDagModel
|
|
from airflow.serialization.serialized_objects import SerializedDAG
|
|
from pydantic import BaseModel
|
|
from sqlalchemy.orm import Session
|
|
|
|
from metadata.generated.schema.api.data.createPipeline import CreatePipelineRequest
|
|
from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest
|
|
from metadata.generated.schema.entity.data.pipeline import (
|
|
PipelineStatus,
|
|
StatusType,
|
|
Task,
|
|
TaskStatus,
|
|
)
|
|
from metadata.generated.schema.entity.data.table import Table
|
|
from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import (
|
|
OpenMetadataConnection,
|
|
)
|
|
from metadata.generated.schema.entity.services.connections.pipeline.airflowConnection import (
|
|
AirflowConnection,
|
|
)
|
|
from metadata.generated.schema.metadataIngestion.workflow import (
|
|
Source as WorkflowSource,
|
|
)
|
|
from metadata.generated.schema.type.entityLineage import EntitiesEdge
|
|
from metadata.generated.schema.type.entityReference import EntityReference
|
|
from metadata.ingestion.api.source import InvalidSourceException
|
|
from metadata.ingestion.models.pipeline_status import OMetaPipelineStatus
|
|
from metadata.ingestion.source.pipeline.pipeline_service import PipelineServiceSource
|
|
from metadata.utils.connections import create_and_bind_session, test_connection
|
|
from metadata.utils.helpers import datetime_to_ts
|
|
from metadata.utils.logger import ingestion_logger
|
|
|
|
logger = ingestion_logger()
|
|
|
|
STATUS_MAP = {
|
|
"success": StatusType.Successful.value,
|
|
"failed": StatusType.Failed.value,
|
|
"queued": StatusType.Pending.value,
|
|
}
|
|
|
|
IGNORE_DAG_RUN_COL = {"log_template_id"}
|
|
|
|
|
|
class OMSerializedDagDetails(BaseModel):
|
|
"""
|
|
Custom model we get from the Airflow db
|
|
as a scoped SELECT from SerializedDagModel
|
|
"""
|
|
|
|
dag_id: str
|
|
data: Any
|
|
fileloc: str
|
|
|
|
# We don't have a validator for SerializedDag
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
|
|
class OMTaskInstance(BaseModel):
|
|
"""
|
|
Custom model we get from the Airflow db
|
|
as a scoped SELECT from TaskInstance
|
|
"""
|
|
|
|
task_id: str
|
|
state: str
|
|
start_date: Optional[datetime]
|
|
end_date: Optional[datetime]
|
|
|
|
|
|
class AirflowSource(PipelineServiceSource):
|
|
"""
|
|
Implements the necessary methods ot extract
|
|
Pipeline metadata from Airflow's metadata db
|
|
"""
|
|
|
|
config: WorkflowSource
|
|
|
|
def __init__(
|
|
self,
|
|
config: WorkflowSource,
|
|
metadata_config: OpenMetadataConnection,
|
|
):
|
|
super().__init__(config, metadata_config)
|
|
self._session = None
|
|
|
|
@classmethod
|
|
def create(cls, config_dict, metadata_config: OpenMetadataConnection):
|
|
config: WorkflowSource = WorkflowSource.parse_obj(config_dict)
|
|
connection: AirflowConnection = config.serviceConnection.__root__.config
|
|
if not isinstance(connection, AirflowConnection):
|
|
raise InvalidSourceException(
|
|
f"Expected AirflowConnection, but got {connection}"
|
|
)
|
|
return cls(config, metadata_config)
|
|
|
|
@property
|
|
def session(self) -> Session:
|
|
"""
|
|
Return the SQLAlchemy session from the engine
|
|
"""
|
|
if not self._session:
|
|
self._session = create_and_bind_session(self.connection)
|
|
|
|
return self._session
|
|
|
|
def get_pipeline_status(self, dag_id: str) -> List[DagRun]:
|
|
|
|
dag_run_list = (
|
|
self.session.query(
|
|
*[c for c in DagRun.__table__.c if c.name not in IGNORE_DAG_RUN_COL]
|
|
)
|
|
.filter(DagRun.dag_id == dag_id)
|
|
.order_by(DagRun.execution_date.desc())
|
|
.limit(self.config.serviceConnection.__root__.config.numberOfStatus)
|
|
.all()
|
|
)
|
|
|
|
dag_run_dict = [dict(elem) for elem in dag_run_list]
|
|
|
|
# Build DagRun manually to not fall into new/old columns from
|
|
# different Airflow versions
|
|
return [
|
|
DagRun(
|
|
dag_id=elem.get("dag_id"),
|
|
run_id=elem.get("run_id"),
|
|
queued_at=elem.get("queued_at"),
|
|
execution_date=elem.get("execution_date"),
|
|
start_date=elem.get("start_date"),
|
|
state=elem.get("state"),
|
|
)
|
|
for elem in dag_run_dict
|
|
]
|
|
|
|
def get_task_instances(self, dag_id: str, run_id: str) -> List[OMTaskInstance]:
|
|
"""
|
|
We are building our own scoped TaskInstance
|
|
class to only focus on core properties required
|
|
by the metadata ingestion.
|
|
|
|
This makes the versioning more flexible on which Airflow
|
|
sources we support.
|
|
"""
|
|
task_instance_list = (
|
|
self.session.query(
|
|
TaskInstance.task_id,
|
|
TaskInstance.state,
|
|
TaskInstance.start_date,
|
|
TaskInstance.end_date,
|
|
TaskInstance.run_id,
|
|
)
|
|
.filter(TaskInstance.dag_id == dag_id, TaskInstance.run_id == run_id)
|
|
.all()
|
|
)
|
|
|
|
task_instance_dict = [dict(elem) for elem in task_instance_list]
|
|
|
|
return [
|
|
OMTaskInstance(
|
|
task_id=elem.get("task_id"),
|
|
state=elem.get("state"),
|
|
start_date=elem.get("start_date"),
|
|
end_date=elem.get("end_date"),
|
|
)
|
|
for elem in task_instance_dict
|
|
]
|
|
|
|
def yield_pipeline_status(
|
|
self, pipeline_details: SerializedDAG
|
|
) -> OMetaPipelineStatus:
|
|
dag_run_list = self.get_pipeline_status(pipeline_details.dag_id)
|
|
|
|
for dag_run in dag_run_list:
|
|
tasks = self.get_task_instances(
|
|
dag_id=dag_run.dag_id,
|
|
run_id=dag_run.run_id,
|
|
)
|
|
|
|
task_statuses = [
|
|
TaskStatus(
|
|
name=task.task_id,
|
|
executionStatus=STATUS_MAP.get(
|
|
task.state, StatusType.Pending.value
|
|
),
|
|
startTime=datetime_to_ts(task.start_date),
|
|
endTime=datetime_to_ts(
|
|
task.end_date
|
|
), # Might be None for running tasks
|
|
) # Log link might not be present in all Airflow versions
|
|
for task in tasks
|
|
]
|
|
|
|
pipeline_status = PipelineStatus(
|
|
taskStatus=task_statuses,
|
|
executionStatus=STATUS_MAP.get(dag_run.state, StatusType.Pending.value),
|
|
timestamp=dag_run.execution_date.timestamp(),
|
|
)
|
|
yield OMetaPipelineStatus(
|
|
pipeline_fqn=self.context.pipeline.fullyQualifiedName.__root__,
|
|
pipeline_status=pipeline_status,
|
|
)
|
|
|
|
def get_pipelines_list(self) -> Iterable[OMSerializedDagDetails]:
|
|
"""
|
|
List all DAGs from the metadata db.
|
|
|
|
We are using the SerializedDagModel as it helps
|
|
us retrieve all the task and inlets/outlets information
|
|
"""
|
|
for serialized_dag in self.session.query(
|
|
SerializedDagModel.dag_id,
|
|
SerializedDagModel._data,
|
|
SerializedDagModel.fileloc,
|
|
).all():
|
|
|
|
yield OMSerializedDagDetails(
|
|
dag_id=serialized_dag[0],
|
|
data=serialized_dag[1],
|
|
fileloc=serialized_dag[2],
|
|
)
|
|
|
|
def get_pipeline_name(self, pipeline_details: SerializedDAG) -> str:
|
|
"""
|
|
Get Pipeline Name
|
|
"""
|
|
return pipeline_details.dag_id
|
|
|
|
@staticmethod
|
|
def get_tasks_from_dag(dag: SerializedDAG) -> List[Task]:
|
|
"""
|
|
Obtain the tasks from a SerializedDAG
|
|
:param dag: SerializedDAG
|
|
:return: List of tasks
|
|
"""
|
|
return [
|
|
Task(
|
|
name=task.task_id,
|
|
description=task.doc_md,
|
|
# Just the suffix
|
|
taskUrl=f"/taskinstance/list/?flt1_dag_id_equals={dag.dag_id}&_flt_3_task_id={task.task_id}",
|
|
downstreamTasks=list(task.downstream_task_ids),
|
|
taskType=task.task_type,
|
|
startDate=task.start_date.isoformat() if task.start_date else None,
|
|
endDate=task.end_date.isoformat() if task.end_date else None,
|
|
)
|
|
for task in cast(Iterable[BaseOperator], dag.tasks)
|
|
]
|
|
|
|
@staticmethod
|
|
def _build_dag(data: Any) -> SerializedDAG:
|
|
"""
|
|
Use the queried data to fetch the DAG
|
|
:param data: from SQA query
|
|
:return: SerializedDAG
|
|
"""
|
|
|
|
if isinstance(data, dict):
|
|
return SerializedDAG.from_dict(data)
|
|
|
|
return SerializedDAG.from_json(data)
|
|
|
|
def yield_pipeline(
|
|
self, pipeline_details: OMSerializedDagDetails
|
|
) -> Iterable[CreatePipelineRequest]:
|
|
"""
|
|
Convert a DAG into a Pipeline Entity
|
|
:param pipeline_details: SerializedDAG from airflow metadata DB
|
|
:return: Create Pipeline request with tasks
|
|
"""
|
|
|
|
dag: SerializedDAG = self._build_dag(pipeline_details.data)
|
|
yield CreatePipelineRequest(
|
|
name=pipeline_details.dag_id,
|
|
description=dag.description,
|
|
pipelineUrl=f"/tree?dag_id={dag.dag_id}", # Just the suffix
|
|
concurrency=dag.concurrency,
|
|
pipelineLocation=pipeline_details.fileloc,
|
|
startDate=dag.start_date.isoformat() if dag.start_date else None,
|
|
tasks=self.get_tasks_from_dag(dag),
|
|
service=EntityReference(
|
|
id=self.context.pipeline_service.id.__root__, type="pipelineService"
|
|
),
|
|
)
|
|
|
|
@staticmethod
|
|
def parse_xlets(xlet: List[Any]) -> Optional[List[str]]:
|
|
"""
|
|
Parse airflow xlets for 2.1.4. E.g.,
|
|
|
|
[{'__var': {'tables': ['sample_data.ecommerce_db.shopify.fact_order']},
|
|
'__type': 'dict'}]
|
|
|
|
:param xlet: airflow v2 xlet dict
|
|
:return: table FQN list or None
|
|
"""
|
|
if len(xlet) and isinstance(xlet[0], dict):
|
|
tables = xlet[0].get("__var").get("tables")
|
|
if tables and isinstance(tables, list):
|
|
return tables
|
|
|
|
return None
|
|
|
|
def get_inlets(self, task: BaseOperator) -> Optional[List[str]]:
|
|
"""
|
|
Get inlets from serialised operator
|
|
:param task: SerializedBaseOperator
|
|
:return: maybe an inlet list
|
|
"""
|
|
inlets = task.get_inlet_defs()
|
|
try:
|
|
return self.parse_xlets(inlets)
|
|
except Exception as err:
|
|
logger.debug(traceback.format_exc())
|
|
logger.warn(f"Error trying to parse inlets - {err}")
|
|
return None
|
|
|
|
def get_outlets(self, task: BaseOperator) -> Optional[List[str]]:
|
|
"""
|
|
Get outlets from serialised operator
|
|
:param task: SerializedBaseOperator
|
|
:return: maybe an inlet list
|
|
"""
|
|
outlets = task.get_outlet_defs()
|
|
try:
|
|
return self.parse_xlets(outlets)
|
|
except Exception as err:
|
|
logger.debug(traceback.format_exc())
|
|
logger.warn(f"Error trying to parse outlets - {err}")
|
|
return None
|
|
|
|
def yield_pipeline_lineage_details(
|
|
self, pipeline_details: OMSerializedDagDetails
|
|
) -> Optional[Iterable[AddLineageRequest]]:
|
|
"""
|
|
Parse xlets and add lineage between Pipelines and Tables
|
|
:param pipeline_details: SerializedDAG from airflow metadata DB
|
|
:return: Lineage from inlets and outlets
|
|
"""
|
|
dag: SerializedDAG = self._build_dag(pipeline_details.data)
|
|
|
|
for task in dag.tasks:
|
|
for table_fqn in self.get_inlets(task) or []:
|
|
table_entity: Table = self.metadata.get_by_name(
|
|
entity=Table, fqn=table_fqn
|
|
)
|
|
if table_entity:
|
|
yield AddLineageRequest(
|
|
edge=EntitiesEdge(
|
|
fromEntity=EntityReference(
|
|
id=table_entity.id, type="table"
|
|
),
|
|
toEntity=EntityReference(
|
|
id=self.context.pipeline.id.__root__, type="pipeline"
|
|
),
|
|
)
|
|
)
|
|
else:
|
|
logger.warn(
|
|
f"Could not find Table [{table_fqn}] from "
|
|
f"[{self.context.pipeline.fullyQualifiedName.__root__}] inlets"
|
|
)
|
|
|
|
for table_fqn in self.get_outlets(task) or []:
|
|
table_entity: Table = self.metadata.get_by_name(
|
|
entity=Table, fqn=table_fqn
|
|
)
|
|
if table_entity:
|
|
yield AddLineageRequest(
|
|
edge=EntitiesEdge(
|
|
fromEntity=EntityReference(
|
|
id=self.context.pipeline.id.__root__, type="pipeline"
|
|
),
|
|
toEntity=EntityReference(id=table_entity.id, type="table"),
|
|
)
|
|
)
|
|
else:
|
|
logger.warn(
|
|
f"Could not find Table [{table_fqn}] from "
|
|
f"[{self.context.pipeline.fullyQualifiedName.__root__}] outlets"
|
|
)
|
|
|
|
def close(self):
|
|
self.session.close()
|