#14320 - FIx Airflow Callback datetime conversion (#14487)

* #14320 - don't cast twice

* #14320 - Fix Airflow Callback datetime conversion

* import
This commit is contained in:
Pere Miquel Brull 2023-12-22 15:43:41 +01:00 committed by GitHub
parent a8e53d2b5f
commit 7e8e4a7e68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 314 additions and 34 deletions

View File

@ -113,7 +113,7 @@ def add_status(
]
updated_status = PipelineStatus(
timestamp=datetime_to_ts(execution_date),
timestamp=execution_date,
executionStatus=get_dag_status(
all_tasks=dag.task_ids,
task_status=updated_task_status,

View File

@ -17,6 +17,7 @@ from unittest.mock import patch
from airflow import DAG
from airflow.operators.bash import BashOperator
from integration.integration_base import int_admin_ometa
from airflow_provider_openmetadata.lineage.runner import AirflowLineageRunner
from metadata.generated.schema.api.data.createDatabase import CreateDatabaseRequest
@ -34,19 +35,12 @@ from metadata.generated.schema.entity.services.connections.database.common.basic
from metadata.generated.schema.entity.services.connections.database.mysqlConnection import (
MysqlConnection,
)
from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import (
OpenMetadataConnection,
)
from metadata.generated.schema.entity.services.databaseService import (
DatabaseConnection,
DatabaseService,
DatabaseServiceType,
)
from metadata.generated.schema.entity.services.pipelineService import PipelineService
from metadata.generated.schema.security.client.openMetadataJWTClientConfig import (
OpenMetadataJWTClientConfig,
)
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.pipeline.airflow.lineage_parser import (
OMEntity,
get_xlets_from_dag,
@ -55,7 +49,6 @@ from metadata.ingestion.source.pipeline.airflow.lineage_parser import (
SLEEP = "sleep 1"
PIPELINE_SERVICE_NAME = "test-lineage-runner"
DB_SERVICE_NAME = "test-service-lineage-runner"
OM_JWT = "eyJraWQiOiJHYjM4OWEtOWY3Ni1nZGpzLWE5MmotMDI0MmJrOTQzNTYiLCJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJhZG1pbiIsImlzQm90IjpmYWxzZSwiaXNzIjoib3Blbi1tZXRhZGF0YS5vcmciLCJpYXQiOjE2NjM5Mzg0NjIsImVtYWlsIjoiYWRtaW5Ab3Blbm1ldGFkYXRhLm9yZyJ9.tS8um_5DKu7HgzGBzS1VTA5uUjKWOCU0B_j08WXBiEC0mr0zNREkqVfwFDD-d24HlNEbrqioLsBuFRiwIWKc1m_ZlVQbG7P36RUxhuv2vbSp80FKyNM-Tj93FDzq91jsyNmsQhyNv_fNr3TXfzzSPjHt8Go0FMMP66weoKMgW2PbXlhVKwEuXUHyakLLzewm9UMeQaEiRzhiTMU3UkLXcKbYEJJvfNFcLwSl9W8JCO_l0Yj3ud-qt_nQYEZwqW6u5nfdQllN133iikV4fM5QZsMCnm8Rq1mvLR0y9bmJiD7fwM1tmJ791TUWqmKaTnP49U493VanKpUAfzIiOiIbhg"
class TestAirflowLineageRuner(TestCase):
@ -63,14 +56,7 @@ class TestAirflowLineageRuner(TestCase):
Validate AirflowLineageRunner
"""
server_config = OpenMetadataConnection(
hostPort="http://localhost:8585/api",
authProvider="openmetadata",
securityConfig=OpenMetadataJWTClientConfig(jwtToken=OM_JWT),
)
metadata = OpenMetadata(server_config)
assert metadata.health_check()
metadata = int_admin_ometa()
service = CreateDatabaseServiceRequest(
name=DB_SERVICE_NAME,

View File

@ -0,0 +1,159 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Test status callback
"""
from datetime import datetime, timezone
from unittest import TestCase
from pydantic import BaseModel
from airflow_provider_openmetadata.lineage.status import add_status, get_dag_status
from metadata.generated.schema.entity.data.pipeline import (
Pipeline,
StatusType,
TaskStatus,
)
from metadata.generated.schema.entity.services.pipelineService import PipelineService
from ..integration_base import (
generate_name,
get_create_entity,
get_create_service,
get_test_dag,
int_admin_ometa,
)
class MockDagRun(BaseModel):
execution_date: datetime
class MockTaskInstance(BaseModel):
task_id: str
state: str
start_date: datetime
end_date: datetime
log_url: str
class TestStatusCallback(TestCase):
"""
Test Status Callback
"""
metadata = int_admin_ometa()
service_name = generate_name()
pipeline_name = generate_name()
@classmethod
def setUpClass(cls) -> None:
"""
Prepare ingredients: Pipeline Entity
"""
create_service = get_create_service(
entity=PipelineService, name=cls.service_name
)
cls.metadata.create_or_update(create_service)
create_pipeline = get_create_entity(
entity=Pipeline, name=cls.pipeline_name, reference=cls.service_name.__root__
)
cls.pipeline: Pipeline = cls.metadata.create_or_update(create_pipeline)
@classmethod
def tearDownClass(cls) -> None:
"""
Clean up
"""
service_id = str(
cls.metadata.get_by_name(
entity=PipelineService, fqn=cls.service_name.__root__
).id.__root__
)
cls.metadata.delete(
entity=PipelineService,
entity_id=service_id,
recursive=True,
hard_delete=True,
)
def test_get_dag_status(self):
"""Check the logic when passing DAG status"""
# If we need more task status, the DAG is marked as pending
all_tasks = ["task1", "task2"]
task_status = [TaskStatus(name="task1", executionStatus=StatusType.Successful)]
self.assertEqual(StatusType.Pending, get_dag_status(all_tasks, task_status))
# If a task is failed, DAG is flagged as failed
all_tasks = ["task1", "task2"]
task_status = [
TaskStatus(name="task1", executionStatus=StatusType.Successful),
TaskStatus(name="task2", executionStatus=StatusType.Failed),
]
self.assertEqual(StatusType.Failed, get_dag_status(all_tasks, task_status))
# If all tasks are successful, DAG is marked as successful
all_tasks = ["task1", "task2"]
task_status = [
TaskStatus(name="task1", executionStatus=StatusType.Successful),
TaskStatus(name="task2", executionStatus=StatusType.Successful),
]
self.assertEqual(StatusType.Successful, get_dag_status(all_tasks, task_status))
def test_add_status(self):
"""Status gets properly added to the Pipeline Entity"""
now = datetime.now(timezone.utc)
dag = get_test_dag(self.pipeline_name.__root__)
# Use the first tasks as operator we are processing in the callback
operator = dag.tasks[0]
# Patching a DagRun since we only pick up the execution_date
dag_run = MockDagRun(execution_date=now)
# Patching a TaskInstance
task_instance = MockTaskInstance(
task_id=operator.task_id,
state="success",
start_date=now,
end_date=now,
log_url="https://example.com",
)
context = {"dag": dag, "dag_run": dag_run, "task_instance": task_instance}
add_status(
operator=operator,
pipeline=self.pipeline,
metadata=self.metadata,
context=context,
)
updated_pipeline: Pipeline = self.metadata.get_by_name(
entity=Pipeline,
fqn=self.pipeline.fullyQualifiedName,
fields=["pipelineStatus"],
)
# DAG status is Pending since we only have the status of a single task
self.assertEqual(
StatusType.Pending, updated_pipeline.pipelineStatus.executionStatus
)
self.assertEqual(
StatusType.Successful,
updated_pipeline.pipelineStatus.taskStatus[0].executionStatus,
)

View File

@ -0,0 +1,149 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
OpenMetadata base class for tests
"""
import uuid
from datetime import datetime
from typing import Any, Optional, Type
from airflow import DAG
from airflow.operators.bash import BashOperator
from metadata.generated.schema.api.data.createPipeline import CreatePipelineRequest
from metadata.generated.schema.api.services.createPipelineService import (
CreatePipelineServiceRequest,
)
from metadata.generated.schema.entity.data.pipeline import Pipeline, Task
from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import (
AuthProvider,
OpenMetadataConnection,
)
from metadata.generated.schema.entity.services.connections.pipeline.airflowConnection import (
AirflowConnection,
)
from metadata.generated.schema.entity.services.connections.pipeline.backendConnection import (
BackendConnection,
)
from metadata.generated.schema.entity.services.pipelineService import (
PipelineConnection,
PipelineService,
PipelineServiceType,
)
from metadata.generated.schema.security.client.openMetadataJWTClientConfig import (
OpenMetadataJWTClientConfig,
)
from metadata.generated.schema.type.basic import EntityName, FullyQualifiedEntityName
from metadata.ingestion.models.custom_pydantic import CustomSecretStr
from metadata.ingestion.ometa.ometa_api import C, OpenMetadata, T
from metadata.utils.dispatch import class_register
OM_JWT = "eyJraWQiOiJHYjM4OWEtOWY3Ni1nZGpzLWE5MmotMDI0MmJrOTQzNTYiLCJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJhZG1pbiIsImlzQm90IjpmYWxzZSwiaXNzIjoib3Blbi1tZXRhZGF0YS5vcmciLCJpYXQiOjE2NjM5Mzg0NjIsImVtYWlsIjoiYWRtaW5Ab3Blbm1ldGFkYXRhLm9yZyJ9.tS8um_5DKu7HgzGBzS1VTA5uUjKWOCU0B_j08WXBiEC0mr0zNREkqVfwFDD-d24HlNEbrqioLsBuFRiwIWKc1m_ZlVQbG7P36RUxhuv2vbSp80FKyNM-Tj93FDzq91jsyNmsQhyNv_fNr3TXfzzSPjHt8Go0FMMP66weoKMgW2PbXlhVKwEuXUHyakLLzewm9UMeQaEiRzhiTMU3UkLXcKbYEJJvfNFcLwSl9W8JCO_l0Yj3ud-qt_nQYEZwqW6u5nfdQllN133iikV4fM5QZsMCnm8Rq1mvLR0y9bmJiD7fwM1tmJ791TUWqmKaTnP49U493VanKpUAfzIiOiIbhg"
def int_admin_ometa(url: str = "http://localhost:8585/api") -> OpenMetadata:
"""Initialize the ometa connection with default admin:admin creds"""
server_config = OpenMetadataConnection(
hostPort=url,
authProvider=AuthProvider.openmetadata,
securityConfig=OpenMetadataJWTClientConfig(jwtToken=CustomSecretStr(OM_JWT)),
)
metadata = OpenMetadata(server_config)
assert metadata.health_check()
return metadata
def generate_name() -> EntityName:
"""Generate a random for the asset"""
return EntityName(__root__=str(uuid.uuid4()))
create_service_registry = class_register()
def get_create_service(entity: Type[T], name: Optional[EntityName] = None) -> C:
"""Create a vanilla service based on the input type"""
func = create_service_registry.registry.get(entity.__name__)
if not func:
raise ValueError(
f"Create Service for type {entity.__name__} has not yet been implemented. Add it on `integration_base.py`"
)
if not name:
name = generate_name()
return func(name)
@create_service_registry.add(PipelineService)
def _(name: EntityName) -> C:
"""Prepare a Create service request"""
return CreatePipelineServiceRequest(
name=name,
serviceType=PipelineServiceType.Airflow,
connection=PipelineConnection(
config=AirflowConnection(
hostPort="http://localhost:8080",
connection=BackendConnection(),
),
),
)
create_entity_registry = class_register()
def get_create_entity(
entity: Type[T], reference: Any, name: Optional[EntityName] = None
) -> C:
"""Create a vanilla entity based on the input type"""
func = create_entity_registry.registry.get(entity.__name__)
if not func:
raise ValueError(
f"Create Service for type {entity.__name__} has not yet been implemented. Add it on `integration_base.py`"
)
if not name:
name = generate_name()
return func(reference, name)
@create_entity_registry.add(Pipeline)
def _(reference: FullyQualifiedEntityName, name: EntityName) -> C:
return CreatePipelineRequest(
name=name,
service=reference,
tasks=[
Task(name="task1"),
Task(name="task2", downstreamTasks=["task1"]),
Task(name="task3", downstreamTasks=["task2"]),
Task(name="task4", downstreamTasks=["task2"]),
],
)
def get_test_dag(name: str) -> DAG:
"""Get a DAG with the tasks created in the CreatePipelineRequest"""
with DAG(name, start_date=datetime(2021, 1, 1)) as dag:
tasks = [
BashOperator(
task_id=task_id,
bash_command="date",
)
for task_id in ("task1", "task2", "task3", "task4")
]
tasks[0] >> tasks[1] >> [tasks[2], tasks[3]]
return dag

View File

@ -55,23 +55,18 @@ from metadata.generated.schema.entity.services.connections.database.common.basic
from metadata.generated.schema.entity.services.connections.database.mysqlConnection import (
MysqlConnection,
)
from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import (
OpenMetadataConnection,
)
from metadata.generated.schema.entity.services.databaseService import (
DatabaseConnection,
DatabaseService,
DatabaseServiceType,
)
from metadata.generated.schema.entity.teams.user import User
from metadata.generated.schema.security.client.openMetadataJWTClientConfig import (
OpenMetadataJWTClientConfig,
)
from metadata.generated.schema.type.basic import FullyQualifiedEntityName, SqlQuery
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.generated.schema.type.usageRequest import UsageRequest
from metadata.ingestion.ometa.client import REST
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from ..integration_base import int_admin_ometa
BAD_RESPONSE = {
"data": [
@ -128,16 +123,7 @@ class OMetaTableTest(TestCase):
service_entity_id = None
server_config = OpenMetadataConnection(
hostPort="http://localhost:8585/api",
authProvider="openmetadata",
securityConfig=OpenMetadataJWTClientConfig(
jwtToken="eyJraWQiOiJHYjM4OWEtOWY3Ni1nZGpzLWE5MmotMDI0MmJrOTQzNTYiLCJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJhZG1pbiIsImlzQm90IjpmYWxzZSwiaXNzIjoib3Blbi1tZXRhZGF0YS5vcmciLCJpYXQiOjE2NjM5Mzg0NjIsImVtYWlsIjoiYWRtaW5Ab3Blbm1ldGFkYXRhLm9yZyJ9.tS8um_5DKu7HgzGBzS1VTA5uUjKWOCU0B_j08WXBiEC0mr0zNREkqVfwFDD-d24HlNEbrqioLsBuFRiwIWKc1m_ZlVQbG7P36RUxhuv2vbSp80FKyNM-Tj93FDzq91jsyNmsQhyNv_fNr3TXfzzSPjHt8Go0FMMP66weoKMgW2PbXlhVKwEuXUHyakLLzewm9UMeQaEiRzhiTMU3UkLXcKbYEJJvfNFcLwSl9W8JCO_l0Yj3ud-qt_nQYEZwqW6u5nfdQllN133iikV4fM5QZsMCnm8Rq1mvLR0y9bmJiD7fwM1tmJ791TUWqmKaTnP49U493VanKpUAfzIiOiIbhg"
),
)
metadata = OpenMetadata(server_config)
assert metadata.health_check()
metadata = int_admin_ometa()
user: User = metadata.create_or_update(
data=CreateUserRequest(name="random-user", email="random@user.com"),