mirror of
				https://github.com/open-metadata/OpenMetadata.git
				synced 2025-10-31 02:29:03 +00:00 
			
		
		
		
	
		
			
	
	
		
			160 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			160 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | #  Copyright 2021 Collate | ||
|  | #  Licensed under the Apache License, Version 2.0 (the "License"); | ||
|  | #  you may not use this file except in compliance with the License. | ||
|  | #  You may obtain a copy of the License at | ||
|  | #  http://www.apache.org/licenses/LICENSE-2.0 | ||
|  | #  Unless required by applicable law or agreed to in writing, software | ||
|  | #  distributed under the License is distributed on an "AS IS" BASIS, | ||
|  | #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
|  | #  See the License for the specific language governing permissions and | ||
|  | #  limitations under the License. | ||
|  | """
 | ||
|  | Test status callback | ||
|  | """
 | ||
|  | from datetime import datetime, timezone | ||
|  | from unittest import TestCase | ||
|  | 
 | ||
|  | from pydantic import BaseModel | ||
|  | 
 | ||
|  | from airflow_provider_openmetadata.lineage.status import add_status, get_dag_status | ||
|  | from metadata.generated.schema.entity.data.pipeline import ( | ||
|  |     Pipeline, | ||
|  |     StatusType, | ||
|  |     TaskStatus, | ||
|  | ) | ||
|  | from metadata.generated.schema.entity.services.pipelineService import PipelineService | ||
|  | 
 | ||
|  | from ..integration_base import ( | ||
|  |     generate_name, | ||
|  |     get_create_entity, | ||
|  |     get_create_service, | ||
|  |     get_test_dag, | ||
|  |     int_admin_ometa, | ||
|  | ) | ||
|  | 
 | ||
|  | 
 | ||
|  | class MockDagRun(BaseModel): | ||
|  |     execution_date: datetime | ||
|  | 
 | ||
|  | 
 | ||
|  | class MockTaskInstance(BaseModel): | ||
|  |     task_id: str | ||
|  |     state: str | ||
|  |     start_date: datetime | ||
|  |     end_date: datetime | ||
|  |     log_url: str | ||
|  | 
 | ||
|  | 
 | ||
|  | class TestStatusCallback(TestCase): | ||
|  |     """
 | ||
|  |     Test Status Callback | ||
|  |     """
 | ||
|  | 
 | ||
|  |     metadata = int_admin_ometa() | ||
|  |     service_name = generate_name() | ||
|  |     pipeline_name = generate_name() | ||
|  | 
 | ||
|  |     @classmethod | ||
|  |     def setUpClass(cls) -> None: | ||
|  |         """
 | ||
|  |         Prepare ingredients: Pipeline Entity | ||
|  |         """
 | ||
|  |         create_service = get_create_service( | ||
|  |             entity=PipelineService, name=cls.service_name | ||
|  |         ) | ||
|  |         cls.metadata.create_or_update(create_service) | ||
|  | 
 | ||
|  |         create_pipeline = get_create_entity( | ||
|  |             entity=Pipeline, name=cls.pipeline_name, reference=cls.service_name.__root__ | ||
|  |         ) | ||
|  |         cls.pipeline: Pipeline = cls.metadata.create_or_update(create_pipeline) | ||
|  | 
 | ||
|  |     @classmethod | ||
|  |     def tearDownClass(cls) -> None: | ||
|  |         """
 | ||
|  |         Clean up | ||
|  |         """
 | ||
|  | 
 | ||
|  |         service_id = str( | ||
|  |             cls.metadata.get_by_name( | ||
|  |                 entity=PipelineService, fqn=cls.service_name.__root__ | ||
|  |             ).id.__root__ | ||
|  |         ) | ||
|  | 
 | ||
|  |         cls.metadata.delete( | ||
|  |             entity=PipelineService, | ||
|  |             entity_id=service_id, | ||
|  |             recursive=True, | ||
|  |             hard_delete=True, | ||
|  |         ) | ||
|  | 
 | ||
|  |     def test_get_dag_status(self): | ||
|  |         """Check the logic when passing DAG status""" | ||
|  | 
 | ||
|  |         # If we need more task status, the DAG is marked as pending | ||
|  |         all_tasks = ["task1", "task2"] | ||
|  |         task_status = [TaskStatus(name="task1", executionStatus=StatusType.Successful)] | ||
|  |         self.assertEqual(StatusType.Pending, get_dag_status(all_tasks, task_status)) | ||
|  | 
 | ||
|  |         # If a task is failed, DAG is flagged as failed | ||
|  |         all_tasks = ["task1", "task2"] | ||
|  |         task_status = [ | ||
|  |             TaskStatus(name="task1", executionStatus=StatusType.Successful), | ||
|  |             TaskStatus(name="task2", executionStatus=StatusType.Failed), | ||
|  |         ] | ||
|  |         self.assertEqual(StatusType.Failed, get_dag_status(all_tasks, task_status)) | ||
|  | 
 | ||
|  |         # If all tasks are successful, DAG is marked as successful | ||
|  |         all_tasks = ["task1", "task2"] | ||
|  |         task_status = [ | ||
|  |             TaskStatus(name="task1", executionStatus=StatusType.Successful), | ||
|  |             TaskStatus(name="task2", executionStatus=StatusType.Successful), | ||
|  |         ] | ||
|  |         self.assertEqual(StatusType.Successful, get_dag_status(all_tasks, task_status)) | ||
|  | 
 | ||
|  |     def test_add_status(self): | ||
|  |         """Status gets properly added to the Pipeline Entity""" | ||
|  | 
 | ||
|  |         now = datetime.now(timezone.utc) | ||
|  | 
 | ||
|  |         dag = get_test_dag(self.pipeline_name.__root__) | ||
|  | 
 | ||
|  |         # Use the first tasks as operator we are processing in the callback | ||
|  |         operator = dag.tasks[0] | ||
|  | 
 | ||
|  |         # Patching a DagRun since we only pick up the execution_date | ||
|  |         dag_run = MockDagRun(execution_date=now) | ||
|  | 
 | ||
|  |         # Patching a TaskInstance | ||
|  |         task_instance = MockTaskInstance( | ||
|  |             task_id=operator.task_id, | ||
|  |             state="success", | ||
|  |             start_date=now, | ||
|  |             end_date=now, | ||
|  |             log_url="https://example.com", | ||
|  |         ) | ||
|  | 
 | ||
|  |         context = {"dag": dag, "dag_run": dag_run, "task_instance": task_instance} | ||
|  | 
 | ||
|  |         add_status( | ||
|  |             operator=operator, | ||
|  |             pipeline=self.pipeline, | ||
|  |             metadata=self.metadata, | ||
|  |             context=context, | ||
|  |         ) | ||
|  | 
 | ||
|  |         updated_pipeline: Pipeline = self.metadata.get_by_name( | ||
|  |             entity=Pipeline, | ||
|  |             fqn=self.pipeline.fullyQualifiedName, | ||
|  |             fields=["pipelineStatus"], | ||
|  |         ) | ||
|  | 
 | ||
|  |         # DAG status is Pending since we only have the status of a single task | ||
|  |         self.assertEqual( | ||
|  |             StatusType.Pending, updated_pipeline.pipelineStatus.executionStatus | ||
|  |         ) | ||
|  |         self.assertEqual( | ||
|  |             StatusType.Successful, | ||
|  |             updated_pipeline.pipelineStatus.taskStatus[0].executionStatus, | ||
|  |         ) |