diff --git a/ingestion/src/airflow_provider_openmetadata/lineage/status.py b/ingestion/src/airflow_provider_openmetadata/lineage/status.py index 81a7b96cecc..d235cf95548 100644 --- a/ingestion/src/airflow_provider_openmetadata/lineage/status.py +++ b/ingestion/src/airflow_provider_openmetadata/lineage/status.py @@ -113,7 +113,7 @@ def add_status( ] updated_status = PipelineStatus( - timestamp=datetime_to_ts(execution_date), + timestamp=execution_date, executionStatus=get_dag_status( all_tasks=dag.task_ids, task_status=updated_task_status, diff --git a/ingestion/tests/integration/airflow/test_lineage_runner.py b/ingestion/tests/integration/airflow/test_lineage_runner.py index fbe1ecd547e..6a7d3e5a924 100644 --- a/ingestion/tests/integration/airflow/test_lineage_runner.py +++ b/ingestion/tests/integration/airflow/test_lineage_runner.py @@ -17,6 +17,7 @@ from unittest.mock import patch from airflow import DAG from airflow.operators.bash import BashOperator +from integration.integration_base import int_admin_ometa from airflow_provider_openmetadata.lineage.runner import AirflowLineageRunner from metadata.generated.schema.api.data.createDatabase import CreateDatabaseRequest @@ -34,19 +35,12 @@ from metadata.generated.schema.entity.services.connections.database.common.basic from metadata.generated.schema.entity.services.connections.database.mysqlConnection import ( MysqlConnection, ) -from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import ( - OpenMetadataConnection, -) from metadata.generated.schema.entity.services.databaseService import ( DatabaseConnection, DatabaseService, DatabaseServiceType, ) from metadata.generated.schema.entity.services.pipelineService import PipelineService -from metadata.generated.schema.security.client.openMetadataJWTClientConfig import ( - OpenMetadataJWTClientConfig, -) -from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.source.pipeline.airflow.lineage_parser import ( OMEntity, get_xlets_from_dag, @@ -55,7 +49,6 @@ from metadata.ingestion.source.pipeline.airflow.lineage_parser import ( SLEEP = "sleep 1" PIPELINE_SERVICE_NAME = "test-lineage-runner" DB_SERVICE_NAME = "test-service-lineage-runner" -OM_JWT = "eyJraWQiOiJHYjM4OWEtOWY3Ni1nZGpzLWE5MmotMDI0MmJrOTQzNTYiLCJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJhZG1pbiIsImlzQm90IjpmYWxzZSwiaXNzIjoib3Blbi1tZXRhZGF0YS5vcmciLCJpYXQiOjE2NjM5Mzg0NjIsImVtYWlsIjoiYWRtaW5Ab3Blbm1ldGFkYXRhLm9yZyJ9.tS8um_5DKu7HgzGBzS1VTA5uUjKWOCU0B_j08WXBiEC0mr0zNREkqVfwFDD-d24HlNEbrqioLsBuFRiwIWKc1m_ZlVQbG7P36RUxhuv2vbSp80FKyNM-Tj93FDzq91jsyNmsQhyNv_fNr3TXfzzSPjHt8Go0FMMP66weoKMgW2PbXlhVKwEuXUHyakLLzewm9UMeQaEiRzhiTMU3UkLXcKbYEJJvfNFcLwSl9W8JCO_l0Yj3ud-qt_nQYEZwqW6u5nfdQllN133iikV4fM5QZsMCnm8Rq1mvLR0y9bmJiD7fwM1tmJ791TUWqmKaTnP49U493VanKpUAfzIiOiIbhg" class TestAirflowLineageRuner(TestCase): @@ -63,14 +56,7 @@ class TestAirflowLineageRuner(TestCase): Validate AirflowLineageRunner """ - server_config = OpenMetadataConnection( - hostPort="http://localhost:8585/api", - authProvider="openmetadata", - securityConfig=OpenMetadataJWTClientConfig(jwtToken=OM_JWT), - ) - metadata = OpenMetadata(server_config) - - assert metadata.health_check() + metadata = int_admin_ometa() service = CreateDatabaseServiceRequest( name=DB_SERVICE_NAME, diff --git a/ingestion/tests/integration/airflow/test_status_callback.py b/ingestion/tests/integration/airflow/test_status_callback.py new file mode 100644 index 00000000000..5e36032f16d --- /dev/null +++ b/ingestion/tests/integration/airflow/test_status_callback.py @@ -0,0 +1,159 @@ +# Copyright 2021 Collate +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test status callback +""" +from datetime import datetime, timezone +from unittest import TestCase + +from pydantic import BaseModel + +from airflow_provider_openmetadata.lineage.status import add_status, get_dag_status +from metadata.generated.schema.entity.data.pipeline import ( + Pipeline, + StatusType, + TaskStatus, +) +from metadata.generated.schema.entity.services.pipelineService import PipelineService + +from ..integration_base import ( + generate_name, + get_create_entity, + get_create_service, + get_test_dag, + int_admin_ometa, +) + + +class MockDagRun(BaseModel): + execution_date: datetime + + +class MockTaskInstance(BaseModel): + task_id: str + state: str + start_date: datetime + end_date: datetime + log_url: str + + +class TestStatusCallback(TestCase): + """ + Test Status Callback + """ + + metadata = int_admin_ometa() + service_name = generate_name() + pipeline_name = generate_name() + + @classmethod + def setUpClass(cls) -> None: + """ + Prepare ingredients: Pipeline Entity + """ + create_service = get_create_service( + entity=PipelineService, name=cls.service_name + ) + cls.metadata.create_or_update(create_service) + + create_pipeline = get_create_entity( + entity=Pipeline, name=cls.pipeline_name, reference=cls.service_name.__root__ + ) + cls.pipeline: Pipeline = cls.metadata.create_or_update(create_pipeline) + + @classmethod + def tearDownClass(cls) -> None: + """ + Clean up + """ + + service_id = str( + cls.metadata.get_by_name( + entity=PipelineService, fqn=cls.service_name.__root__ + ).id.__root__ + ) + + cls.metadata.delete( + entity=PipelineService, + entity_id=service_id, + recursive=True, + hard_delete=True, + ) + + def test_get_dag_status(self): + """Check the logic when passing DAG status""" + + # If we need more task status, the DAG is marked as pending + all_tasks = ["task1", "task2"] + task_status = [TaskStatus(name="task1", executionStatus=StatusType.Successful)] + self.assertEqual(StatusType.Pending, get_dag_status(all_tasks, task_status)) + + # If a task is failed, DAG is flagged as failed + all_tasks = ["task1", "task2"] + task_status = [ + TaskStatus(name="task1", executionStatus=StatusType.Successful), + TaskStatus(name="task2", executionStatus=StatusType.Failed), + ] + self.assertEqual(StatusType.Failed, get_dag_status(all_tasks, task_status)) + + # If all tasks are successful, DAG is marked as successful + all_tasks = ["task1", "task2"] + task_status = [ + TaskStatus(name="task1", executionStatus=StatusType.Successful), + TaskStatus(name="task2", executionStatus=StatusType.Successful), + ] + self.assertEqual(StatusType.Successful, get_dag_status(all_tasks, task_status)) + + def test_add_status(self): + """Status gets properly added to the Pipeline Entity""" + + now = datetime.now(timezone.utc) + + dag = get_test_dag(self.pipeline_name.__root__) + + # Use the first tasks as operator we are processing in the callback + operator = dag.tasks[0] + + # Patching a DagRun since we only pick up the execution_date + dag_run = MockDagRun(execution_date=now) + + # Patching a TaskInstance + task_instance = MockTaskInstance( + task_id=operator.task_id, + state="success", + start_date=now, + end_date=now, + log_url="https://example.com", + ) + + context = {"dag": dag, "dag_run": dag_run, "task_instance": task_instance} + + add_status( + operator=operator, + pipeline=self.pipeline, + metadata=self.metadata, + context=context, + ) + + updated_pipeline: Pipeline = self.metadata.get_by_name( + entity=Pipeline, + fqn=self.pipeline.fullyQualifiedName, + fields=["pipelineStatus"], + ) + + # DAG status is Pending since we only have the status of a single task + self.assertEqual( + StatusType.Pending, updated_pipeline.pipelineStatus.executionStatus + ) + self.assertEqual( + StatusType.Successful, + updated_pipeline.pipelineStatus.taskStatus[0].executionStatus, + ) diff --git a/ingestion/tests/integration/integration_base.py b/ingestion/tests/integration/integration_base.py new file mode 100644 index 00000000000..15e861f79cf --- /dev/null +++ b/ingestion/tests/integration/integration_base.py @@ -0,0 +1,149 @@ +# Copyright 2021 Collate +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +OpenMetadata base class for tests +""" +import uuid +from datetime import datetime +from typing import Any, Optional, Type + +from airflow import DAG +from airflow.operators.bash import BashOperator + +from metadata.generated.schema.api.data.createPipeline import CreatePipelineRequest +from metadata.generated.schema.api.services.createPipelineService import ( + CreatePipelineServiceRequest, +) +from metadata.generated.schema.entity.data.pipeline import Pipeline, Task +from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import ( + AuthProvider, + OpenMetadataConnection, +) +from metadata.generated.schema.entity.services.connections.pipeline.airflowConnection import ( + AirflowConnection, +) +from metadata.generated.schema.entity.services.connections.pipeline.backendConnection import ( + BackendConnection, +) +from metadata.generated.schema.entity.services.pipelineService import ( + PipelineConnection, + PipelineService, + PipelineServiceType, +) +from metadata.generated.schema.security.client.openMetadataJWTClientConfig import ( + OpenMetadataJWTClientConfig, +) +from metadata.generated.schema.type.basic import EntityName, FullyQualifiedEntityName +from metadata.ingestion.models.custom_pydantic import CustomSecretStr +from metadata.ingestion.ometa.ometa_api import C, OpenMetadata, T +from metadata.utils.dispatch import class_register + +OM_JWT = "eyJraWQiOiJHYjM4OWEtOWY3Ni1nZGpzLWE5MmotMDI0MmJrOTQzNTYiLCJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJhZG1pbiIsImlzQm90IjpmYWxzZSwiaXNzIjoib3Blbi1tZXRhZGF0YS5vcmciLCJpYXQiOjE2NjM5Mzg0NjIsImVtYWlsIjoiYWRtaW5Ab3Blbm1ldGFkYXRhLm9yZyJ9.tS8um_5DKu7HgzGBzS1VTA5uUjKWOCU0B_j08WXBiEC0mr0zNREkqVfwFDD-d24HlNEbrqioLsBuFRiwIWKc1m_ZlVQbG7P36RUxhuv2vbSp80FKyNM-Tj93FDzq91jsyNmsQhyNv_fNr3TXfzzSPjHt8Go0FMMP66weoKMgW2PbXlhVKwEuXUHyakLLzewm9UMeQaEiRzhiTMU3UkLXcKbYEJJvfNFcLwSl9W8JCO_l0Yj3ud-qt_nQYEZwqW6u5nfdQllN133iikV4fM5QZsMCnm8Rq1mvLR0y9bmJiD7fwM1tmJ791TUWqmKaTnP49U493VanKpUAfzIiOiIbhg" + + +def int_admin_ometa(url: str = "http://localhost:8585/api") -> OpenMetadata: + """Initialize the ometa connection with default admin:admin creds""" + server_config = OpenMetadataConnection( + hostPort=url, + authProvider=AuthProvider.openmetadata, + securityConfig=OpenMetadataJWTClientConfig(jwtToken=CustomSecretStr(OM_JWT)), + ) + metadata = OpenMetadata(server_config) + assert metadata.health_check() + + return metadata + + +def generate_name() -> EntityName: + """Generate a random for the asset""" + return EntityName(__root__=str(uuid.uuid4())) + + +create_service_registry = class_register() + + +def get_create_service(entity: Type[T], name: Optional[EntityName] = None) -> C: + """Create a vanilla service based on the input type""" + func = create_service_registry.registry.get(entity.__name__) + if not func: + raise ValueError( + f"Create Service for type {entity.__name__} has not yet been implemented. Add it on `integration_base.py`" + ) + + if not name: + name = generate_name() + + return func(name) + + +@create_service_registry.add(PipelineService) +def _(name: EntityName) -> C: + """Prepare a Create service request""" + return CreatePipelineServiceRequest( + name=name, + serviceType=PipelineServiceType.Airflow, + connection=PipelineConnection( + config=AirflowConnection( + hostPort="http://localhost:8080", + connection=BackendConnection(), + ), + ), + ) + + +create_entity_registry = class_register() + + +def get_create_entity( + entity: Type[T], reference: Any, name: Optional[EntityName] = None +) -> C: + """Create a vanilla entity based on the input type""" + func = create_entity_registry.registry.get(entity.__name__) + if not func: + raise ValueError( + f"Create Service for type {entity.__name__} has not yet been implemented. Add it on `integration_base.py`" + ) + + if not name: + name = generate_name() + + return func(reference, name) + + +@create_entity_registry.add(Pipeline) +def _(reference: FullyQualifiedEntityName, name: EntityName) -> C: + return CreatePipelineRequest( + name=name, + service=reference, + tasks=[ + Task(name="task1"), + Task(name="task2", downstreamTasks=["task1"]), + Task(name="task3", downstreamTasks=["task2"]), + Task(name="task4", downstreamTasks=["task2"]), + ], + ) + + +def get_test_dag(name: str) -> DAG: + """Get a DAG with the tasks created in the CreatePipelineRequest""" + with DAG(name, start_date=datetime(2021, 1, 1)) as dag: + + tasks = [ + BashOperator( + task_id=task_id, + bash_command="date", + ) + for task_id in ("task1", "task2", "task3", "task4") + ] + + tasks[0] >> tasks[1] >> [tasks[2], tasks[3]] + + return dag diff --git a/ingestion/tests/integration/ometa/test_ometa_table_api.py b/ingestion/tests/integration/ometa/test_ometa_table_api.py index 34b18e9db8b..802d2336d38 100644 --- a/ingestion/tests/integration/ometa/test_ometa_table_api.py +++ b/ingestion/tests/integration/ometa/test_ometa_table_api.py @@ -55,23 +55,18 @@ from metadata.generated.schema.entity.services.connections.database.common.basic from metadata.generated.schema.entity.services.connections.database.mysqlConnection import ( MysqlConnection, ) -from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import ( - OpenMetadataConnection, -) from metadata.generated.schema.entity.services.databaseService import ( DatabaseConnection, DatabaseService, DatabaseServiceType, ) from metadata.generated.schema.entity.teams.user import User -from metadata.generated.schema.security.client.openMetadataJWTClientConfig import ( - OpenMetadataJWTClientConfig, -) from metadata.generated.schema.type.basic import FullyQualifiedEntityName, SqlQuery from metadata.generated.schema.type.entityReference import EntityReference from metadata.generated.schema.type.usageRequest import UsageRequest from metadata.ingestion.ometa.client import REST -from metadata.ingestion.ometa.ometa_api import OpenMetadata + +from ..integration_base import int_admin_ometa BAD_RESPONSE = { "data": [ @@ -128,16 +123,7 @@ class OMetaTableTest(TestCase): service_entity_id = None - server_config = OpenMetadataConnection( - hostPort="http://localhost:8585/api", - authProvider="openmetadata", - securityConfig=OpenMetadataJWTClientConfig( - jwtToken="eyJraWQiOiJHYjM4OWEtOWY3Ni1nZGpzLWE5MmotMDI0MmJrOTQzNTYiLCJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJhZG1pbiIsImlzQm90IjpmYWxzZSwiaXNzIjoib3Blbi1tZXRhZGF0YS5vcmciLCJpYXQiOjE2NjM5Mzg0NjIsImVtYWlsIjoiYWRtaW5Ab3Blbm1ldGFkYXRhLm9yZyJ9.tS8um_5DKu7HgzGBzS1VTA5uUjKWOCU0B_j08WXBiEC0mr0zNREkqVfwFDD-d24HlNEbrqioLsBuFRiwIWKc1m_ZlVQbG7P36RUxhuv2vbSp80FKyNM-Tj93FDzq91jsyNmsQhyNv_fNr3TXfzzSPjHt8Go0FMMP66weoKMgW2PbXlhVKwEuXUHyakLLzewm9UMeQaEiRzhiTMU3UkLXcKbYEJJvfNFcLwSl9W8JCO_l0Yj3ud-qt_nQYEZwqW6u5nfdQllN133iikV4fM5QZsMCnm8Rq1mvLR0y9bmJiD7fwM1tmJ791TUWqmKaTnP49U493VanKpUAfzIiOiIbhg" - ), - ) - metadata = OpenMetadata(server_config) - - assert metadata.health_check() + metadata = int_admin_ometa() user: User = metadata.create_or_update( data=CreateUserRequest(name="random-user", email="random@user.com"),