Fix #3437 & #3186 - Airflow lineage Task Group & Tests (#3460)

This commit is contained in:
Pere Miquel Brull 2022-03-21 18:29:49 +01:00 committed by GitHub
parent b865d85d85
commit 548a0ab722
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 91 additions and 10 deletions

View File

@ -151,6 +151,9 @@ def create_or_update_pipeline( # pylint: disable=too-many-locals
)
# Check if the pipeline already exists
operator.log.info(
f"Checking if the pipeline {airflow_service_entity.name}.{dag.dag_id} exists. If not, we will create it."
)
current_pipeline: Pipeline = client.get_by_name(
entity=Pipeline,
fqdn=f"{airflow_service_entity.name}.{dag.dag_id}",
@ -184,8 +187,7 @@ def create_or_update_pipeline( # pylint: disable=too-many-locals
# Clean pipeline
try:
operator.log.info("Cleaning pipeline tasks...")
dag_tasks = [Task(name=name) for name in dag.task_group.children.keys()]
updated_pipeline = client.clean_pipeline_tasks(updated_pipeline, dag_tasks)
updated_pipeline = client.clean_pipeline_tasks(updated_pipeline, dag.task_ids)
except Exception as exc: # pylint: disable=broad-except
operator.log.warning(f"Error cleaning pipeline tasks {exc}")
@ -266,7 +268,7 @@ def add_status(
updated_status = PipelineStatus(
executionDate=execution_date,
executionStatus=get_dag_status(
all_tasks=list(dag.task_group.children.keys()),
all_tasks=dag.task_ids,
task_status=updated_task_status,
),
taskStatus=updated_task_status,

View File

@ -23,6 +23,7 @@ from metadata.generated.schema.entity.data.pipeline import (
Task,
)
from metadata.ingestion.ometa.client import REST
from metadata.utils.constants import DOT
logger = logging.getLogger(__name__)
@ -93,7 +94,7 @@ class OMetaPipelineMixin:
return self.create_or_update(updated_pipeline)
def clean_pipeline_tasks(self, pipeline: Pipeline, tasks: List[Task]) -> Pipeline:
def clean_pipeline_tasks(self, pipeline: Pipeline, task_ids: List[str]) -> Pipeline:
"""
Given a list of tasks, remove from the
Pipeline Entity those that are not received
@ -104,8 +105,6 @@ class OMetaPipelineMixin:
remove the task B from the entity
"""
names = {task.name for task in tasks}
updated_pipeline = CreatePipelineRequest(
name=pipeline.name,
displayName=pipeline.displayName,
@ -115,7 +114,11 @@ class OMetaPipelineMixin:
pipelineLocation=pipeline.pipelineLocation,
startDate=pipeline.startDate,
service=pipeline.service,
tasks=[task for task in pipeline.tasks if task.name in names],
tasks=[
task
for task in pipeline.tasks
if task.name.replace(DOT, ".") in task_ids
],
owner=pipeline.owner,
tags=pipeline.tags,
)

View File

@ -0,0 +1,16 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Define constants useful for the metadata ingestion
"""
DOT = "_DOT_"

View File

@ -18,7 +18,10 @@ from unittest import TestCase
# The DAG object; we'll need this to instantiate a DAG
from airflow import DAG
from airflow.models import TaskInstance
from airflow.operators.bash import BashOperator
from airflow.operators.dummy import DummyOperator
from airflow.utils.task_group import TaskGroup
from airflow_provider_openmetadata.lineage.openmetadata import (
OpenMetadataLineageBackend,
@ -57,6 +60,8 @@ class AirflowLineageTest(TestCase):
)
service_type = "databaseService"
backend = OpenMetadataLineageBackend()
@classmethod
def setUpClass(cls) -> None:
"""
@ -137,12 +142,19 @@ class AirflowLineageTest(TestCase):
Test end to end
"""
backend = OpenMetadataLineageBackend()
backend.send_lineage(
self.backend.send_lineage(
operator=self.dag.get_task("task1"),
context={
"dag": self.dag,
"task": self.dag.get_task("task1"),
"task_instance": TaskInstance(
task=self.dag.get_task("task1"),
execution_date=datetime.strptime(
"2022-03-15T08:13:45", "%Y-%m-%dT%H:%M:%S"
),
run_id="scheduled__2022-03-15T08:13:45.967068+00:00",
state="running",
),
},
)
@ -156,3 +168,51 @@ class AirflowLineageTest(TestCase):
nodes = {node["id"] for node in lineage["nodes"]}
self.assertIn(str(self.table.id.__root__), nodes)
def test_lineage_task_group(self):
"""
Test end to end for task groups
"""
with DAG(
"task_group_lineage",
description="A lineage test DAG",
schedule_interval=timedelta(days=1),
start_date=datetime(2021, 1, 1),
) as dag:
t0 = DummyOperator(task_id="start")
# Start Task Group definition
with TaskGroup(group_id="group1") as tg1:
t1 = DummyOperator(task_id="task1")
t2 = DummyOperator(task_id="task2")
t1 >> t2
# End Task Group definition
t3 = DummyOperator(task_id="end")
# Set Task Group's (tg1) dependencies
t0 >> tg1 >> t3
self.backend.send_lineage(
operator=dag.get_task("group1.task1"),
context={
"dag": dag,
"task": dag.get_task("group1.task1"),
"task_instance": TaskInstance(
task=dag.get_task("group1.task1"),
execution_date=datetime.strptime(
"2022-03-15T08:13:45", "%Y-%m-%dT%H:%M:%S"
),
run_id="scheduled__2022-03-15T08:13:45.967068+00:00",
state="running",
),
},
)
pipeline = self.metadata.get_by_name(
entity=Pipeline, fqdn="local_airflow_3.task_group_lineage", fields=["tasks"]
)
self.assertIsNotNone(pipeline)
self.assertIn("group1_DOT_task1", {task.name for task in pipeline.tasks})

View File

@ -342,7 +342,7 @@ class OMetaPipelineTest(TestCase):
pipeline = self.metadata.create_or_update(data=create_pipeline)
updated_pipeline = self.metadata.clean_pipeline_tasks(
pipeline=pipeline, tasks=[Task(name="task3"), Task(name="task4")]
pipeline=pipeline, task_ids=["task3", "task4"]
)
assert len(updated_pipeline.tasks) == 2