mirror of
				https://github.com/open-metadata/OpenMetadata.git
				synced 2025-11-04 04:29:13 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			160 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			160 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
#  Copyright 2025 Collate
 | 
						|
#  Licensed under the Collate Community License, Version 1.0 (the "License");
 | 
						|
#  you may not use this file except in compliance with the License.
 | 
						|
#  You may obtain a copy of the License at
 | 
						|
#  https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
 | 
						|
#  Unless required by applicable law or agreed to in writing, software
 | 
						|
#  distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
#  See the License for the specific language governing permissions and
 | 
						|
#  limitations under the License.
 | 
						|
"""
 | 
						|
Test status callback
 | 
						|
"""
 | 
						|
from datetime import datetime, timezone
 | 
						|
from unittest import TestCase
 | 
						|
 | 
						|
from pydantic import BaseModel
 | 
						|
 | 
						|
from _openmetadata_testutils.ometa import int_admin_ometa
 | 
						|
from airflow_provider_openmetadata.lineage.status import add_status, get_dag_status
 | 
						|
from metadata.generated.schema.entity.data.pipeline import (
 | 
						|
    Pipeline,
 | 
						|
    StatusType,
 | 
						|
    TaskStatus,
 | 
						|
)
 | 
						|
from metadata.generated.schema.entity.services.pipelineService import PipelineService
 | 
						|
 | 
						|
from ..integration_base import (
 | 
						|
    generate_name,
 | 
						|
    get_create_entity,
 | 
						|
    get_create_service,
 | 
						|
    get_test_dag,
 | 
						|
)
 | 
						|
 | 
						|
 | 
						|
class MockDagRun(BaseModel):
 | 
						|
    execution_date: datetime
 | 
						|
 | 
						|
 | 
						|
class MockTaskInstance(BaseModel):
 | 
						|
    task_id: str
 | 
						|
    state: str
 | 
						|
    start_date: datetime
 | 
						|
    end_date: datetime
 | 
						|
    log_url: str
 | 
						|
 | 
						|
 | 
						|
class TestStatusCallback(TestCase):
 | 
						|
    """
 | 
						|
    Test Status Callback
 | 
						|
    """
 | 
						|
 | 
						|
    metadata = int_admin_ometa()
 | 
						|
    service_name = generate_name()
 | 
						|
    pipeline_name = generate_name()
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def setUpClass(cls) -> None:
 | 
						|
        """
 | 
						|
        Prepare ingredients: Pipeline Entity
 | 
						|
        """
 | 
						|
        create_service = get_create_service(
 | 
						|
            entity=PipelineService, name=cls.service_name
 | 
						|
        )
 | 
						|
        cls.metadata.create_or_update(create_service)
 | 
						|
 | 
						|
        create_pipeline = get_create_entity(
 | 
						|
            entity=Pipeline, name=cls.pipeline_name, reference=cls.service_name.root
 | 
						|
        )
 | 
						|
        cls.pipeline: Pipeline = cls.metadata.create_or_update(create_pipeline)
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def tearDownClass(cls) -> None:
 | 
						|
        """
 | 
						|
        Clean up
 | 
						|
        """
 | 
						|
 | 
						|
        service_id = str(
 | 
						|
            cls.metadata.get_by_name(
 | 
						|
                entity=PipelineService, fqn=cls.service_name.root
 | 
						|
            ).id.root
 | 
						|
        )
 | 
						|
 | 
						|
        cls.metadata.delete(
 | 
						|
            entity=PipelineService,
 | 
						|
            entity_id=service_id,
 | 
						|
            recursive=True,
 | 
						|
            hard_delete=True,
 | 
						|
        )
 | 
						|
 | 
						|
    def test_get_dag_status(self):
 | 
						|
        """Check the logic when passing DAG status"""
 | 
						|
 | 
						|
        # If we need more task status, the DAG is marked as pending
 | 
						|
        all_tasks = ["task1", "task2"]
 | 
						|
        task_status = [TaskStatus(name="task1", executionStatus=StatusType.Successful)]
 | 
						|
        self.assertEqual(StatusType.Pending, get_dag_status(all_tasks, task_status))
 | 
						|
 | 
						|
        # If a task is failed, DAG is flagged as failed
 | 
						|
        all_tasks = ["task1", "task2"]
 | 
						|
        task_status = [
 | 
						|
            TaskStatus(name="task1", executionStatus=StatusType.Successful),
 | 
						|
            TaskStatus(name="task2", executionStatus=StatusType.Failed),
 | 
						|
        ]
 | 
						|
        self.assertEqual(StatusType.Failed, get_dag_status(all_tasks, task_status))
 | 
						|
 | 
						|
        # If all tasks are successful, DAG is marked as successful
 | 
						|
        all_tasks = ["task1", "task2"]
 | 
						|
        task_status = [
 | 
						|
            TaskStatus(name="task1", executionStatus=StatusType.Successful),
 | 
						|
            TaskStatus(name="task2", executionStatus=StatusType.Successful),
 | 
						|
        ]
 | 
						|
        self.assertEqual(StatusType.Successful, get_dag_status(all_tasks, task_status))
 | 
						|
 | 
						|
    def test_add_status(self):
 | 
						|
        """Status gets properly added to the Pipeline Entity"""
 | 
						|
 | 
						|
        now = datetime.now(timezone.utc)
 | 
						|
 | 
						|
        dag = get_test_dag(self.pipeline_name.root)
 | 
						|
 | 
						|
        # Use the first tasks as operator we are processing in the callback
 | 
						|
        operator = dag.tasks[0]
 | 
						|
 | 
						|
        # Patching a DagRun since we only pick up the execution_date
 | 
						|
        dag_run = MockDagRun(execution_date=now)
 | 
						|
 | 
						|
        # Patching a TaskInstance
 | 
						|
        task_instance = MockTaskInstance(
 | 
						|
            task_id=operator.task_id,
 | 
						|
            state="success",
 | 
						|
            start_date=now,
 | 
						|
            end_date=now,
 | 
						|
            log_url="https://example.com",
 | 
						|
        )
 | 
						|
 | 
						|
        context = {"dag": dag, "dag_run": dag_run, "task_instance": task_instance}
 | 
						|
 | 
						|
        add_status(
 | 
						|
            operator=operator,
 | 
						|
            pipeline=self.pipeline,
 | 
						|
            metadata=self.metadata,
 | 
						|
            context=context,
 | 
						|
        )
 | 
						|
 | 
						|
        updated_pipeline: Pipeline = self.metadata.get_by_name(
 | 
						|
            entity=Pipeline,
 | 
						|
            fqn=self.pipeline.fullyQualifiedName,
 | 
						|
            fields=["pipelineStatus"],
 | 
						|
        )
 | 
						|
 | 
						|
        # DAG status is Pending since we only have the status of a single task
 | 
						|
        self.assertEqual(
 | 
						|
            StatusType.Pending, updated_pipeline.pipelineStatus.executionStatus
 | 
						|
        )
 | 
						|
        self.assertEqual(
 | 
						|
            StatusType.Successful,
 | 
						|
            updated_pipeline.pipelineStatus.taskStatus[0].executionStatus,
 | 
						|
        )
 |