2025-04-03 10:39:47 +05:30
|
|
|
# Copyright 2025 Collate
|
|
|
|
# Licensed under the Collate Community License, Version 1.0 (the "License");
|
2023-12-22 15:43:41 +01:00
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
2025-04-03 10:39:47 +05:30
|
|
|
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
|
2023-12-22 15:43:41 +01:00
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
"""
|
|
|
|
Test status callback
|
|
|
|
"""
|
|
|
|
from datetime import datetime, timezone
|
|
|
|
from unittest import TestCase
|
|
|
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
2024-06-25 07:51:22 +02:00
|
|
|
from _openmetadata_testutils.ometa import int_admin_ometa
|
2023-12-22 15:43:41 +01:00
|
|
|
from airflow_provider_openmetadata.lineage.status import add_status, get_dag_status
|
|
|
|
from metadata.generated.schema.entity.data.pipeline import (
|
|
|
|
Pipeline,
|
|
|
|
StatusType,
|
|
|
|
TaskStatus,
|
|
|
|
)
|
|
|
|
from metadata.generated.schema.entity.services.pipelineService import PipelineService
|
|
|
|
|
|
|
|
from ..integration_base import (
|
|
|
|
generate_name,
|
|
|
|
get_create_entity,
|
|
|
|
get_create_service,
|
|
|
|
get_test_dag,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class MockDagRun(BaseModel):
|
|
|
|
execution_date: datetime
|
|
|
|
|
|
|
|
|
|
|
|
class MockTaskInstance(BaseModel):
|
|
|
|
task_id: str
|
|
|
|
state: str
|
|
|
|
start_date: datetime
|
|
|
|
end_date: datetime
|
|
|
|
log_url: str
|
|
|
|
|
|
|
|
|
|
|
|
class TestStatusCallback(TestCase):
|
|
|
|
"""
|
|
|
|
Test Status Callback
|
|
|
|
"""
|
|
|
|
|
|
|
|
metadata = int_admin_ometa()
|
|
|
|
service_name = generate_name()
|
|
|
|
pipeline_name = generate_name()
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls) -> None:
|
|
|
|
"""
|
|
|
|
Prepare ingredients: Pipeline Entity
|
|
|
|
"""
|
|
|
|
create_service = get_create_service(
|
|
|
|
entity=PipelineService, name=cls.service_name
|
|
|
|
)
|
|
|
|
cls.metadata.create_or_update(create_service)
|
|
|
|
|
|
|
|
create_pipeline = get_create_entity(
|
2024-06-05 21:18:37 +02:00
|
|
|
entity=Pipeline, name=cls.pipeline_name, reference=cls.service_name.root
|
2023-12-22 15:43:41 +01:00
|
|
|
)
|
|
|
|
cls.pipeline: Pipeline = cls.metadata.create_or_update(create_pipeline)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls) -> None:
|
|
|
|
"""
|
|
|
|
Clean up
|
|
|
|
"""
|
|
|
|
|
|
|
|
service_id = str(
|
|
|
|
cls.metadata.get_by_name(
|
2024-06-05 21:18:37 +02:00
|
|
|
entity=PipelineService, fqn=cls.service_name.root
|
|
|
|
).id.root
|
2023-12-22 15:43:41 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
cls.metadata.delete(
|
|
|
|
entity=PipelineService,
|
|
|
|
entity_id=service_id,
|
|
|
|
recursive=True,
|
|
|
|
hard_delete=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_get_dag_status(self):
|
|
|
|
"""Check the logic when passing DAG status"""
|
|
|
|
|
|
|
|
# If we need more task status, the DAG is marked as pending
|
|
|
|
all_tasks = ["task1", "task2"]
|
|
|
|
task_status = [TaskStatus(name="task1", executionStatus=StatusType.Successful)]
|
|
|
|
self.assertEqual(StatusType.Pending, get_dag_status(all_tasks, task_status))
|
|
|
|
|
|
|
|
# If a task is failed, DAG is flagged as failed
|
|
|
|
all_tasks = ["task1", "task2"]
|
|
|
|
task_status = [
|
|
|
|
TaskStatus(name="task1", executionStatus=StatusType.Successful),
|
|
|
|
TaskStatus(name="task2", executionStatus=StatusType.Failed),
|
|
|
|
]
|
|
|
|
self.assertEqual(StatusType.Failed, get_dag_status(all_tasks, task_status))
|
|
|
|
|
|
|
|
# If all tasks are successful, DAG is marked as successful
|
|
|
|
all_tasks = ["task1", "task2"]
|
|
|
|
task_status = [
|
|
|
|
TaskStatus(name="task1", executionStatus=StatusType.Successful),
|
|
|
|
TaskStatus(name="task2", executionStatus=StatusType.Successful),
|
|
|
|
]
|
|
|
|
self.assertEqual(StatusType.Successful, get_dag_status(all_tasks, task_status))
|
|
|
|
|
|
|
|
def test_add_status(self):
|
|
|
|
"""Status gets properly added to the Pipeline Entity"""
|
|
|
|
|
|
|
|
now = datetime.now(timezone.utc)
|
|
|
|
|
2024-06-05 21:18:37 +02:00
|
|
|
dag = get_test_dag(self.pipeline_name.root)
|
2023-12-22 15:43:41 +01:00
|
|
|
|
|
|
|
# Use the first tasks as operator we are processing in the callback
|
|
|
|
operator = dag.tasks[0]
|
|
|
|
|
|
|
|
# Patching a DagRun since we only pick up the execution_date
|
|
|
|
dag_run = MockDagRun(execution_date=now)
|
|
|
|
|
|
|
|
# Patching a TaskInstance
|
|
|
|
task_instance = MockTaskInstance(
|
|
|
|
task_id=operator.task_id,
|
|
|
|
state="success",
|
|
|
|
start_date=now,
|
|
|
|
end_date=now,
|
|
|
|
log_url="https://example.com",
|
|
|
|
)
|
|
|
|
|
|
|
|
context = {"dag": dag, "dag_run": dag_run, "task_instance": task_instance}
|
|
|
|
|
|
|
|
add_status(
|
|
|
|
operator=operator,
|
|
|
|
pipeline=self.pipeline,
|
|
|
|
metadata=self.metadata,
|
|
|
|
context=context,
|
|
|
|
)
|
|
|
|
|
|
|
|
updated_pipeline: Pipeline = self.metadata.get_by_name(
|
|
|
|
entity=Pipeline,
|
|
|
|
fqn=self.pipeline.fullyQualifiedName,
|
|
|
|
fields=["pipelineStatus"],
|
|
|
|
)
|
|
|
|
|
|
|
|
# DAG status is Pending since we only have the status of a single task
|
|
|
|
self.assertEqual(
|
|
|
|
StatusType.Pending, updated_pipeline.pipelineStatus.executionStatus
|
|
|
|
)
|
|
|
|
self.assertEqual(
|
|
|
|
StatusType.Successful,
|
|
|
|
updated_pipeline.pipelineStatus.taskStatus[0].executionStatus,
|
|
|
|
)
|