mirror of
				https://github.com/open-metadata/OpenMetadata.git
				synced 2025-10-30 18:17:53 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			160 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			160 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #  Copyright 2021 Collate
 | |
| #  Licensed under the Apache License, Version 2.0 (the "License");
 | |
| #  you may not use this file except in compliance with the License.
 | |
| #  You may obtain a copy of the License at
 | |
| #  http://www.apache.org/licenses/LICENSE-2.0
 | |
| #  Unless required by applicable law or agreed to in writing, software
 | |
| #  distributed under the License is distributed on an "AS IS" BASIS,
 | |
| #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| #  See the License for the specific language governing permissions and
 | |
| #  limitations under the License.
 | |
| """
 | |
| Test status callback
 | |
| """
 | |
| from datetime import datetime, timezone
 | |
| from unittest import TestCase
 | |
| 
 | |
| from pydantic import BaseModel
 | |
| 
 | |
| from airflow_provider_openmetadata.lineage.status import add_status, get_dag_status
 | |
| from metadata.generated.schema.entity.data.pipeline import (
 | |
|     Pipeline,
 | |
|     StatusType,
 | |
|     TaskStatus,
 | |
| )
 | |
| from metadata.generated.schema.entity.services.pipelineService import PipelineService
 | |
| 
 | |
| from ..integration_base import (
 | |
|     generate_name,
 | |
|     get_create_entity,
 | |
|     get_create_service,
 | |
|     get_test_dag,
 | |
|     int_admin_ometa,
 | |
| )
 | |
| 
 | |
| 
 | |
| class MockDagRun(BaseModel):
 | |
|     execution_date: datetime
 | |
| 
 | |
| 
 | |
| class MockTaskInstance(BaseModel):
 | |
|     task_id: str
 | |
|     state: str
 | |
|     start_date: datetime
 | |
|     end_date: datetime
 | |
|     log_url: str
 | |
| 
 | |
| 
 | |
| class TestStatusCallback(TestCase):
 | |
|     """
 | |
|     Test Status Callback
 | |
|     """
 | |
| 
 | |
|     metadata = int_admin_ometa()
 | |
|     service_name = generate_name()
 | |
|     pipeline_name = generate_name()
 | |
| 
 | |
|     @classmethod
 | |
|     def setUpClass(cls) -> None:
 | |
|         """
 | |
|         Prepare ingredients: Pipeline Entity
 | |
|         """
 | |
|         create_service = get_create_service(
 | |
|             entity=PipelineService, name=cls.service_name
 | |
|         )
 | |
|         cls.metadata.create_or_update(create_service)
 | |
| 
 | |
|         create_pipeline = get_create_entity(
 | |
|             entity=Pipeline, name=cls.pipeline_name, reference=cls.service_name.__root__
 | |
|         )
 | |
|         cls.pipeline: Pipeline = cls.metadata.create_or_update(create_pipeline)
 | |
| 
 | |
|     @classmethod
 | |
|     def tearDownClass(cls) -> None:
 | |
|         """
 | |
|         Clean up
 | |
|         """
 | |
| 
 | |
|         service_id = str(
 | |
|             cls.metadata.get_by_name(
 | |
|                 entity=PipelineService, fqn=cls.service_name.__root__
 | |
|             ).id.__root__
 | |
|         )
 | |
| 
 | |
|         cls.metadata.delete(
 | |
|             entity=PipelineService,
 | |
|             entity_id=service_id,
 | |
|             recursive=True,
 | |
|             hard_delete=True,
 | |
|         )
 | |
| 
 | |
|     def test_get_dag_status(self):
 | |
|         """Check the logic when passing DAG status"""
 | |
| 
 | |
|         # If we need more task status, the DAG is marked as pending
 | |
|         all_tasks = ["task1", "task2"]
 | |
|         task_status = [TaskStatus(name="task1", executionStatus=StatusType.Successful)]
 | |
|         self.assertEqual(StatusType.Pending, get_dag_status(all_tasks, task_status))
 | |
| 
 | |
|         # If a task is failed, DAG is flagged as failed
 | |
|         all_tasks = ["task1", "task2"]
 | |
|         task_status = [
 | |
|             TaskStatus(name="task1", executionStatus=StatusType.Successful),
 | |
|             TaskStatus(name="task2", executionStatus=StatusType.Failed),
 | |
|         ]
 | |
|         self.assertEqual(StatusType.Failed, get_dag_status(all_tasks, task_status))
 | |
| 
 | |
|         # If all tasks are successful, DAG is marked as successful
 | |
|         all_tasks = ["task1", "task2"]
 | |
|         task_status = [
 | |
|             TaskStatus(name="task1", executionStatus=StatusType.Successful),
 | |
|             TaskStatus(name="task2", executionStatus=StatusType.Successful),
 | |
|         ]
 | |
|         self.assertEqual(StatusType.Successful, get_dag_status(all_tasks, task_status))
 | |
| 
 | |
|     def test_add_status(self):
 | |
|         """Status gets properly added to the Pipeline Entity"""
 | |
| 
 | |
|         now = datetime.now(timezone.utc)
 | |
| 
 | |
|         dag = get_test_dag(self.pipeline_name.__root__)
 | |
| 
 | |
|         # Use the first tasks as operator we are processing in the callback
 | |
|         operator = dag.tasks[0]
 | |
| 
 | |
|         # Patching a DagRun since we only pick up the execution_date
 | |
|         dag_run = MockDagRun(execution_date=now)
 | |
| 
 | |
|         # Patching a TaskInstance
 | |
|         task_instance = MockTaskInstance(
 | |
|             task_id=operator.task_id,
 | |
|             state="success",
 | |
|             start_date=now,
 | |
|             end_date=now,
 | |
|             log_url="https://example.com",
 | |
|         )
 | |
| 
 | |
|         context = {"dag": dag, "dag_run": dag_run, "task_instance": task_instance}
 | |
| 
 | |
|         add_status(
 | |
|             operator=operator,
 | |
|             pipeline=self.pipeline,
 | |
|             metadata=self.metadata,
 | |
|             context=context,
 | |
|         )
 | |
| 
 | |
|         updated_pipeline: Pipeline = self.metadata.get_by_name(
 | |
|             entity=Pipeline,
 | |
|             fqn=self.pipeline.fullyQualifiedName,
 | |
|             fields=["pipelineStatus"],
 | |
|         )
 | |
| 
 | |
|         # DAG status is Pending since we only have the status of a single task
 | |
|         self.assertEqual(
 | |
|             StatusType.Pending, updated_pipeline.pipelineStatus.executionStatus
 | |
|         )
 | |
|         self.assertEqual(
 | |
|             StatusType.Successful,
 | |
|             updated_pipeline.pipelineStatus.taskStatus[0].executionStatus,
 | |
|         )
 | 
