Fix #2735 - Simplify Airflow properties extraction (#2749)

* Simplify lineage properties extraction

* Add network name

* Format
This commit is contained in:
Pere Miquel Brull 2022-02-14 16:53:42 +01:00 committed by GitHub
parent 562d6b39ef
commit 76f4ccd590
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 38 additions and 228 deletions

View File

@ -91,6 +91,7 @@ services:
networks:
local_app_net:
name: ometa_network
ipam:
driver: default
config:

View File

@ -14,7 +14,7 @@ OpenMetadata Airflow Lineage Backend
"""
import traceback
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from airflow.configuration import conf
@ -39,43 +39,14 @@ from metadata.generated.schema.entity.services.pipelineService import (
from metadata.generated.schema.type.entityLineage import EntitiesEdge
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.utils.helpers import convert_epoch_to_iso
from metadata.utils.helpers import datetime_to_ts
if TYPE_CHECKING:
from airflow import DAG
from airflow.models.baseoperator import BaseOperator
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
_ALLOWED_TASK_KEYS = {
"_downstream_task_ids",
"_inlets",
"_outlets",
"_task_type",
"_task_module",
"depends_on_past",
"email",
"label",
"execution_timeout",
"end_date",
"start_date",
"sla",
"sql",
"task_id",
"trigger_rule",
"wait_for_downstream",
}
_ALLOWED_FLOW_KEYS = {
"_access_control",
"_concurrency",
"_default_view",
"catchup",
"fileloc",
"is_paused_upon_creation",
"start_date",
"tags",
"timezone",
"_task_group", # We can get children information from here
}
_STATUS_MAP = {
"running": StatusType.Pending,
@ -99,24 +70,6 @@ def is_airflow_version_1() -> bool:
return True
def get_properties(
obj: Union["DAG", "BaseOperator"], serializer: Callable, allowed_keys: Set[str]
) -> Dict[str, str]:
"""
Given either a DAG or a BaseOperator, obtain its allowed properties
:param obj: DAG or BaseOperator object
:return: properties dict
"""
props: Dict[str, str] = dict(serializer(obj).items())
for key in obj.get_serialized_fields():
if key not in props:
props[key] = getattr(obj, key)
return {key: value for (key, value) in props.items() if key in allowed_keys}
def get_xlets(
operator: "BaseOperator", xlet_mode: str = "_inlets"
) -> Union[Optional[List[str]], Any]:
@ -144,50 +97,8 @@ def get_xlets(
return None
# pylint: disable=too-many-arguments
def iso_dag_start_date(props: Dict[str, Any]) -> Optional[str]:
"""
Given a properties dict, return the start_date
as an iso string if start_date is informed
:param props: properties dict
:return: iso start_date or None
"""
# DAG start date comes as `float`
if props.get("start_date"):
return convert_epoch_to_iso(int(float(props["start_date"])))
return None
def iso_task_start_end_date(
props: Dict[str, Any]
) -> Tuple[Optional[str], Optional[str]]:
"""
Given the attributes of a Task Instance, return
the task start date and task end date as
ISO format
:param props: task instance attributes
:return: task start and end date
"""
task_start_date = (
convert_epoch_to_iso(int(props["start_date"].timestamp()))
if props.get("start_date")
else None
)
task_end_date = (
convert_epoch_to_iso(int(props["end_date"].timestamp()))
if props.get("end_date")
else None
)
return task_start_date, task_end_date
def create_or_update_pipeline( # pylint: disable=too-many-locals
dag_properties: Dict[str, Any],
task_properties: Dict[str, Any],
task_instance: "TaskInstance",
operator: "BaseOperator",
dag: "DAG",
airflow_service_entity: PipelineService,
@ -201,8 +112,8 @@ def create_or_update_pipeline( # pylint: disable=too-many-locals
- Append the task being processed
- Clean deleted tasks based on the DAG children information
:param dag_properties: attributes of the dag object
:param task_properties: attributes of the task object
:param task_instance: task run being processed
:param dag_run: DAG run being processed
:param operator: task being examined by lineage
:param dag: airflow dag
:param airflow_service_entity: PipelineService
@ -215,20 +126,25 @@ def create_or_update_pipeline( # pylint: disable=too-many-locals
f"{pipeline_service_url}/taskinstance/list/"
+ f"?flt1_dag_id_equals={dag.dag_id}&_flt_3_task_id={operator.task_id}"
)
dag_start_date = iso_dag_start_date(dag_properties)
task_start_date, task_end_date = iso_task_start_end_date(task_properties)
dag_start_date = dag.start_date.isoformat() if dag.start_date else None
task_start_date = (
task_instance.start_date.isoformat() if task_instance.start_date else None
)
task_end_date = (
task_instance.end_date.isoformat() if task_instance.end_date else None
)
downstream_tasks = []
if task_properties.get("_downstream_task_ids"):
downstream_tasks = task_properties["_downstream_task_ids"]
if getattr(operator, "downstream_task_ids", None):
downstream_tasks = operator.downstream_task_ids
operator.log.info(f"downstream tasks {downstream_tasks}")
task = Task(
name=task_properties["task_id"],
displayName=task_properties.get("label"), # v1.10.15 does not have label
name=operator.task_id,
displayName=operator.task_id,
taskUrl=task_url,
taskType=task_properties["_task_type"],
taskType=operator.task_type,
startDate=task_start_date,
endDate=task_end_date,
downstreamTasks=downstream_tasks,
@ -268,8 +184,7 @@ def create_or_update_pipeline( # pylint: disable=too-many-locals
# Clean pipeline
try:
operator.log.info("Cleaning pipeline tasks...")
children = dag_properties.get("_task_group").get("children")
dag_tasks = [Task(name=name) for name in children.keys()]
dag_tasks = [Task(name=name) for name in dag.task_group.children.keys()]
updated_pipeline = client.clean_pipeline_tasks(updated_pipeline, dag_tasks)
except Exception as exc: # pylint: disable=broad-except
operator.log.warning(f"Error cleaning pipeline tasks {exc}")
@ -277,34 +192,7 @@ def create_or_update_pipeline( # pylint: disable=too-many-locals
return updated_pipeline
def get_context_properties(
operator: "BaseOperator", dag: "DAG"
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
Prepare DAG and Task properties based on attributes
and serializers
"""
# Move this import to avoid circular import error when airflow parses the config
# pylint: disable=import-outside-toplevel
from airflow.serialization.serialized_objects import (
SerializedBaseOperator,
SerializedDAG,
)
dag_properties = get_properties(
dag, SerializedDAG.serialize_dag, _ALLOWED_FLOW_KEYS
)
task_properties = get_properties(
operator, SerializedBaseOperator.serialize_operator, _ALLOWED_TASK_KEYS
)
operator.log.info(f"Task Properties {task_properties}")
operator.log.info(f"DAG properties {dag_properties}")
return dag_properties, task_properties
def get_dag_status(dag_properties: Dict[str, Any], task_status: List[TaskStatus]):
def get_dag_status(all_tasks: List[str], task_status: List[TaskStatus]):
"""
Based on the task information and the total DAG tasks, cook the
DAG status.
@ -312,20 +200,18 @@ def get_dag_status(dag_properties: Dict[str, Any], task_status: List[TaskStatus]
gets flagged as "running" during the callbacks.
"""
children = dag_properties.get("_task_group").get("children")
if len(children) < len(task_status):
if len(all_tasks) < len(task_status):
raise ValueError(
"We have more status than children:"
+ f"children {children} vs. status {task_status}"
+ f"children {all_tasks} vs. status {task_status}"
)
# We are still processing tasks...
if len(children) > len(task_status):
if len(all_tasks) > len(task_status):
return StatusType.Pending
# Check for any failure if all tasks have been processed
if len(children) == len(task_status) and StatusType.Failed in {
if len(all_tasks) == len(task_status) and StatusType.Failed in {
task.executionStatus for task in task_status
}:
return StatusType.Failed
@ -344,10 +230,11 @@ def add_status(
"""
dag: "DAG" = context["dag"]
dag_properties, task_properties = get_context_properties(operator, dag)
dag_run: "DagRun" = context["dag_run"]
task_instance: "TaskInstance" = context["task_instance"]
# Let this fail if we cannot properly extract & cast the start_date
execution_date = int(dag_properties.get("start_date"))
execution_date = datetime_to_ts(dag_run.execution_date)
operator.log.info(f"Logging pipeline status for execution {execution_date}")
# Check if we already have a pipelineStatus for
@ -363,14 +250,14 @@ def add_status(
task_status = [
task
for task in pipeline_status[0].taskStatus
if task.name != task_properties["task_id"]
if task.name != task_instance.task_id
]
# Prepare the new task status information based on the tasks already
# visited and the current task
updated_task_status = [
TaskStatus(
name=task_properties["task_id"],
name=task_instance.task_id,
executionStatus=_STATUS_MAP.get(context["task_instance"].state),
),
*task_status,
@ -379,7 +266,8 @@ def add_status(
updated_status = PipelineStatus(
executionDate=execution_date,
executionStatus=get_dag_status(
dag_properties=dag_properties, task_status=updated_task_status
all_tasks=list(dag.task_group.children.keys()),
task_status=updated_task_status,
),
taskStatus=updated_task_status,
)
@ -412,7 +300,7 @@ def parse_lineage(
operator.log.info("Parsing Lineage for OpenMetadata")
dag: "DAG" = context["dag"]
dag_properties, task_properties = get_context_properties(operator, dag)
task_instance: "TaskInstance" = context["task_instance"]
try:
@ -420,8 +308,7 @@ def parse_lineage(
operator, client, config
)
pipeline = create_or_update_pipeline(
dag_properties=dag_properties,
task_properties=task_properties,
task_instance=task_instance,
operator=operator,
dag=dag,
airflow_service_entity=airflow_service_entity,

View File

@ -12,8 +12,6 @@
from datetime import datetime, timedelta
from typing import List
from pydantic import SecretStr
from metadata.generated.schema.api.services.createDashboardService import (
CreateDashboardServiceRequest,
)
@ -175,12 +173,6 @@ def get_database_service_or_create_v2(service_json, metadata_config) -> Database
return created_service
def convert_epoch_to_iso(seconds_since_epoch):
dt = datetime.utcfromtimestamp(seconds_since_epoch)
iso_format = dt.isoformat() + "Z"
return iso_format
def datetime_to_ts(date: datetime) -> int:
"""
Convert a given date to a timestamp as an Int

View File

@ -19,21 +19,12 @@ from unittest import TestCase
# The DAG object; we'll need this to instantiate a DAG
from airflow import DAG
from airflow.operators.bash import BashOperator
from airflow.serialization.serialized_objects import (
SerializedBaseOperator,
SerializedDAG,
)
from airflow_provider_openmetadata.lineage.openmetadata import (
OpenMetadataLineageBackend,
)
from airflow_provider_openmetadata.lineage.utils import (
_ALLOWED_FLOW_KEYS,
_ALLOWED_TASK_KEYS,
get_properties,
get_xlets,
iso_dag_start_date,
iso_task_start_end_date,
)
from metadata.generated.schema.api.data.createDatabase import CreateDatabaseRequest
from metadata.generated.schema.api.data.createTable import CreateTableRequest
@ -41,7 +32,7 @@ from metadata.generated.schema.api.services.createDatabaseService import (
CreateDatabaseServiceRequest,
)
from metadata.generated.schema.entity.data.pipeline import Pipeline
from metadata.generated.schema.entity.data.table import Column, DataType, Table
from metadata.generated.schema.entity.data.table import Column, DataType
from metadata.generated.schema.entity.services.databaseService import (
DatabaseConnection,
DatabaseServiceType,
@ -143,67 +134,6 @@ class AirflowLineageTest(TestCase):
self.assertIsNone(get_xlets(self.dag.get_task("task3"), "_inlets"))
self.assertIsNone(get_xlets(self.dag.get_task("task3"), "_outlets"))
def test_properties(self):
"""
check props
"""
dag_props = get_properties(
self.dag, SerializedDAG.serialize_dag, _ALLOWED_FLOW_KEYS
)
self.assertTrue(set(dag_props.keys()).issubset(_ALLOWED_FLOW_KEYS))
task1_props = get_properties(
self.dag.get_task("task1"),
SerializedBaseOperator.serialize_operator,
_ALLOWED_TASK_KEYS,
)
self.assertTrue(set(task1_props.keys()).issubset(_ALLOWED_TASK_KEYS))
task2_props = get_properties(
self.dag.get_task("task2"),
SerializedBaseOperator.serialize_operator,
_ALLOWED_TASK_KEYS,
)
self.assertTrue(set(task2_props.keys()).issubset(_ALLOWED_TASK_KEYS))
task3_props = get_properties(
self.dag.get_task("task3"),
SerializedBaseOperator.serialize_operator,
_ALLOWED_TASK_KEYS,
)
self.assertTrue(set(task3_props.keys()).issubset(_ALLOWED_TASK_KEYS))
def test_times(self):
"""
Check the ISO date extraction for DAG and Tasks instances
"""
dag_props = get_properties(
self.dag, SerializedDAG.serialize_dag, _ALLOWED_FLOW_KEYS
)
dag_date = iso_dag_start_date(dag_props)
self.assertEqual("2021-01-01T00:00:00Z", dag_date)
# Remove the start_time
dag_props.pop("start_date")
dag_none_date = iso_dag_start_date(dag_props)
self.assertIsNone(dag_none_date)
# By default we'll get the start_date for the task,
# so we can check its value, but the end date
# might not come as in this case.
# Check that we can have those values as None
task1_props = get_properties(
self.dag.get_task("task1"),
SerializedBaseOperator.serialize_operator,
_ALLOWED_TASK_KEYS,
)
task_start_date, task_end_date = iso_task_start_end_date(task1_props)
self.assertEqual("2021-01-01T00:00:00Z", task_start_date)
self.assertIsNone(task_end_date)
def test_lineage(self):
"""
Test end to end
@ -219,11 +149,11 @@ class AirflowLineageTest(TestCase):
)
self.assertIsNotNone(
self.metadata.get_by_name(entity=Pipeline, fqdn="airflow.lineage")
self.metadata.get_by_name(entity=Pipeline, fqdn="local_airflow_3.lineage")
)
lineage = self.metadata.get_lineage_by_name(
entity=Pipeline, fqdn="airflow.lineage"
entity=Pipeline, fqdn="local_airflow_3.lineage"
)
nodes = {node["id"] for node in lineage["nodes"]}