Fix #2316 - Append Tasks & Add Status in Airflow (#2738)

* Add license

* Add date to timestamp helper

* Prepare pipeline status operation

* Update and clean tasks from client

* Update tasks from client

* Check if pipeline is empty

* Keep all current pipeline info

* Append and clean tasks

* format

* Add status information

* Test pipelineStatus update

* Update task on clear

* Log status on callback

* Update lineage and status docs

* Update lineage docs

* Format

* Logic to handle DAG status

* Lint and format

* Update lineage tests
This commit is contained in:
Pere Miquel Brull 2022-02-13 17:51:25 +01:00 committed by GitHub
parent 06ed718235
commit 1b2ea54d4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 642 additions and 117 deletions

View File

@ -122,9 +122,89 @@ In order to still get the metadata information from the workflow, we can configu
Import it with
```python
from airflow_provider_openmetadata.lineage.callback import lineage_callback
from airflow_provider_openmetadata.lineage.callback import failure_callback
```
and use it as an argument for `on_failure_callback` property.
This can be set both at DAG and Task level, giving us flexibility on how (and if) we want to handle lineage on failure.
## Pipeline Status
Another property that we can check from each Pipeline instance is `pipelineStatus`. You could check status and
the current tasks using a REST query such as:
```bash
http GET http://localhost:8585/api/v1/pipelines/name/<FQDN>\?fields\=tasks,pipelineStatus
```
The pipeline status property looks like:
```json
"pipelineStatus": [
{
"executionDate": 1609459200,
"executionStatus": "Failed",
"taskStatus": [
{
"executionStatus": "Successful",
"name": "sleep"
},
{
"executionStatus": "Failed",
"name": "explode"
},
{
"executionStatus": "Successful",
"name": "print_date"
}
]
},
...
]
```
Note that it is a list of all the statuses recorded for a specific Pipeline instance. This can help us keep track
of our executions and check our processes KPIs in terms of reliability.
To properly extract the status data we need to again play with the failure and success callbacks. This is because
during the Lineage Backend execution, the tasks are still flagged as `running`. It is not until we reach to a callback
that we can properly use the Task Instance information to operate with the statuses.
The `failure_callback` will both compute the lineage and status of failed tasks. For successful ones, we can import
```python
from airflow_provider_openmetadata.lineage.callback import success_callback
```
and pass it as the value for the `on_success_callback` property.
Note that:
- We will mark the DAG status as **successful** only if all the tasks of a given execution are successful.
- Clearing a task/DAG will update its previous `pipelineStatus` element of the specific `executionDate`.
## Best Practices
It is encouraged to use a set of default arguments for all our DAGs. In there we can set specific configurations
such as the `catchup`, `email` or `email_on_failure`, which are usually handled at project level.
Using this default arguments `dict` to configure globally the success and failure callbacks for status information
is the most comfortable way to make sure we won't be missing any information. E.g.,
```python
from airflow import DAG
from airflow_provider_openmetadata.lineage.callback import success_callback, failure_callback
default_args = {
"on_failure_callback": failure_callback,
"on_success_callback": success_callback,
}
with DAG(
...,
default_args=default_args,
) as dag:
...
```

View File

@ -20,16 +20,19 @@ from airflow_provider_openmetadata.lineage.config import (
get_metadata_config,
)
from airflow_provider_openmetadata.lineage.utils import (
add_status,
get_xlets,
parse_lineage_to_openmetadata,
parse_lineage,
)
from metadata.generated.schema.entity.data.pipeline import Pipeline
from metadata.generated.schema.entity.services.pipelineService import PipelineService
from metadata.ingestion.ometa.ometa_api import OpenMetadata
if TYPE_CHECKING:
from airflow.models.baseoperator import BaseOperator
def lineage_callback(context: Dict[str, str]) -> None:
def failure_callback(context: Dict[str, str]) -> None:
"""
Add this function to the args of your DAG or Task
as the value of `on_failure_callback` to track
@ -44,12 +47,60 @@ def lineage_callback(context: Dict[str, str]) -> None:
operator: "BaseOperator" = context["task"]
operator.log.info("Parsing lineage & pipeline status on failure...")
op_inlets = get_xlets(operator, "_inlets")
op_outlets = get_xlets(operator, "_outlets")
parse_lineage_to_openmetadata(
# Get the pipeline created or updated during the lineage
pipeline = parse_lineage(
config, context, operator, op_inlets, op_outlets, client
)
add_status(
operator=operator,
pipeline=pipeline,
client=client,
context=context,
)
except Exception as exc: # pylint: disable=broad-except
logging.error("Lineage Callback exception %s", exc)
def success_callback(context: Dict[str, str]) -> None:
"""
Add this function to the args of your DAG or Task
as the value of `on_success_callback` to track
task status on task success
:param context: Airflow runtime context
"""
try:
config = get_lineage_config()
metadata_config = get_metadata_config(config)
client = OpenMetadata(metadata_config)
operator: "BaseOperator" = context["task"]
dag: "DAG" = context["dag"]
operator.log.info("Updating pipeline status on success...")
airflow_service_entity = client.get_by_name(
entity=PipelineService, fqdn=config.airflow_service_name
)
pipeline: Pipeline = client.get_by_name(
entity=Pipeline,
fqdn=f"{airflow_service_entity.name}.{dag.dag_id}",
)
add_status(
operator=operator,
pipeline=pipeline,
client=client,
context=context,
)
except Exception as exc: # pylint: disable=broad-except
logging.error("Lineage Callback exception %s", exc)

View File

@ -22,65 +22,13 @@ from airflow_provider_openmetadata.lineage.config import (
get_lineage_config,
get_metadata_config,
)
from airflow_provider_openmetadata.lineage.utils import (
get_xlets,
parse_lineage_to_openmetadata,
)
from airflow_provider_openmetadata.lineage.utils import get_xlets, parse_lineage
from metadata.ingestion.ometa.ometa_api import OpenMetadata
if TYPE_CHECKING:
from airflow.models.baseoperator import BaseOperator
allowed_task_keys = [
"_downstream_task_ids",
"_inlets",
"_outlets",
"_task_type",
"_task_module",
"depends_on_past",
"email",
"label",
"execution_timeout",
"end_date",
"start_date",
"sla",
"sql",
"task_id",
"trigger_rule",
"wait_for_downstream",
]
allowed_flow_keys = [
"_access_control",
"_concurrency",
"_default_view",
"catchup",
"fileloc",
"is_paused_upon_creation",
"start_date",
"tags",
"timezone",
]
# pylint: disable=import-outside-toplevel, unused-import
def is_airflow_version_1() -> bool:
"""
Manage airflow submodule import based airflow version
Returns
bool
"""
try:
from airflow.hooks.base import BaseHook
return False
except ModuleNotFoundError:
from airflow.hooks.base_hook import BaseHook
return True
# pylint: disable=too-few-public-methods
class OpenMetadataLineageBackend(LineageBackend):
"""
@ -110,8 +58,8 @@ class OpenMetadataLineageBackend(LineageBackend):
_ = get_lineage_config()
# pylint: disable=protected-access
@staticmethod
def send_lineage(
self,
operator: "BaseOperator",
inlets: Optional[List] = None,
outlets: Optional[List] = None,
@ -137,9 +85,7 @@ class OpenMetadataLineageBackend(LineageBackend):
op_inlets = get_xlets(operator, "_inlets")
op_outlets = get_xlets(operator, "_outlets")
parse_lineage_to_openmetadata(
config, context, operator, op_inlets, op_outlets, client
)
parse_lineage(config, context, operator, op_inlets, op_outlets, client)
except Exception as exc: # pylint: disable=broad-except
operator.log.error(traceback.format_exc())
operator.log.error(exc)

View File

@ -24,7 +24,13 @@ from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest
from metadata.generated.schema.api.services.createPipelineService import (
CreatePipelineServiceRequest,
)
from metadata.generated.schema.entity.data.pipeline import Pipeline, Task
from metadata.generated.schema.entity.data.pipeline import (
Pipeline,
PipelineStatus,
StatusType,
Task,
TaskStatus,
)
from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.entity.services.pipelineService import (
PipelineService,
@ -39,7 +45,7 @@ if TYPE_CHECKING:
from airflow import DAG
from airflow.models.baseoperator import BaseOperator
ALLOWED_TASK_KEYS = {
_ALLOWED_TASK_KEYS = {
"_downstream_task_ids",
"_inlets",
"_outlets",
@ -58,7 +64,7 @@ ALLOWED_TASK_KEYS = {
"wait_for_downstream",
}
ALLOWED_FLOW_KEYS = {
_ALLOWED_FLOW_KEYS = {
"_access_control",
"_concurrency",
"_default_view",
@ -68,6 +74,13 @@ ALLOWED_FLOW_KEYS = {
"start_date",
"tags",
"timezone",
"_task_group", # We can get children information from here
}
_STATUS_MAP = {
"running": StatusType.Pending,
"success": StatusType.Successful,
"failed": StatusType.Failed,
}
@ -95,7 +108,7 @@ def get_properties(
:return: properties dict
"""
props: Dict[str, str] = {key: value for (key, value) in serializer(obj).items()}
props: Dict[str, str] = dict(serializer(obj).items())
for key in obj.get_serialized_fields():
if key not in props:
@ -172,16 +185,21 @@ def iso_task_start_end_date(
return task_start_date, task_end_date
def create_pipeline_entity(
dag_properties: Dict[str, str],
task_properties: Dict[str, str],
def create_or_update_pipeline( # pylint: disable=too-many-locals
dag_properties: Dict[str, Any],
task_properties: Dict[str, Any],
operator: "BaseOperator",
dag: "DAG",
airflow_service_entity: PipelineService,
client: OpenMetadata,
) -> Pipeline:
"""
Prepare the upsert the pipeline entity with the given task
Prepare the upsert of pipeline entity with the given task
We will:
- Create the pipeline Entity
- Append the task being processed
- Clean deleted tasks based on the DAG children information
:param dag_properties: attributes of the dag object
:param task_properties: attributes of the task object
@ -215,28 +233,170 @@ def create_pipeline_entity(
endDate=task_end_date,
downstreamTasks=downstream_tasks,
)
create_pipeline = CreatePipelineRequest(
# Check if the pipeline already exists
current_pipeline: Pipeline = client.get_by_name(
entity=Pipeline,
fqdn=f"{airflow_service_entity.name}.{dag.dag_id}",
fields=["tasks"],
)
# Create pipeline if not exists or update its properties
pipeline_request = CreatePipelineRequest(
name=dag.dag_id,
displayName=dag.dag_id,
description=dag.description,
pipelineUrl=dag_url,
concurrency=current_pipeline.concurrency if current_pipeline else None,
pipelineLocation=current_pipeline.pipelineLocation
if current_pipeline
else None,
startDate=dag_start_date,
tasks=[task], # TODO: should we GET + append?
tasks=current_pipeline.tasks
if current_pipeline
else None, # use the current tasks, if any
service=EntityReference(id=airflow_service_entity.id, type="pipelineService"),
owner=current_pipeline.owner if current_pipeline else None,
tags=current_pipeline.tags if current_pipeline else None,
)
pipeline = client.create_or_update(pipeline_request)
# Add the task we are processing in the lineage backend
operator.log.info("Adding tasks to pipeline...")
updated_pipeline = client.add_task_to_pipeline(pipeline, task)
# Clean pipeline
try:
operator.log.info("Cleaning pipeline tasks...")
children = dag_properties.get("_task_group").get("children")
dag_tasks = [Task(name=name) for name in children.keys()]
updated_pipeline = client.clean_pipeline_tasks(updated_pipeline, dag_tasks)
except Exception as exc: # pylint: disable=broad-except
operator.log.warning(f"Error cleaning pipeline tasks {exc}")
return updated_pipeline
def get_context_properties(
operator: "BaseOperator", dag: "DAG"
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
Prepare DAG and Task properties based on attributes
and serializers
"""
# Move this import to avoid circular import error when airflow parses the config
# pylint: disable=import-outside-toplevel
from airflow.serialization.serialized_objects import (
SerializedBaseOperator,
SerializedDAG,
)
return client.create_or_update(create_pipeline)
dag_properties = get_properties(
dag, SerializedDAG.serialize_dag, _ALLOWED_FLOW_KEYS
)
task_properties = get_properties(
operator, SerializedBaseOperator.serialize_operator, _ALLOWED_TASK_KEYS
)
operator.log.info(f"Task Properties {task_properties}")
operator.log.info(f"DAG properties {dag_properties}")
return dag_properties, task_properties
def get_dag_status(dag_properties: Dict[str, Any], task_status: List[TaskStatus]):
"""
Based on the task information and the total DAG tasks, cook the
DAG status.
We are not directly using `context["dag_run"]._state` as it always
gets flagged as "running" during the callbacks.
"""
children = dag_properties.get("_task_group").get("children")
if len(children) < len(task_status):
raise ValueError(
"We have more status than children:"
+ f"children {children} vs. status {task_status}"
)
# We are still processing tasks...
if len(children) > len(task_status):
return StatusType.Pending
# Check for any failure if all tasks have been processed
if len(children) == len(task_status) and StatusType.Failed in {
task.executionStatus for task in task_status
}:
return StatusType.Failed
return StatusType.Successful
def add_status(
operator: "BaseOperator",
pipeline: Pipeline,
client: OpenMetadata,
context: Dict,
) -> None:
"""
Add status information for this execution date
"""
dag: "DAG" = context["dag"]
dag_properties, task_properties = get_context_properties(operator, dag)
# Let this fail if we cannot properly extract & cast the start_date
execution_date = int(dag_properties.get("start_date"))
operator.log.info(f"Logging pipeline status for execution {execution_date}")
# Check if we already have a pipelineStatus for
# our execution_date that we should update
pipeline_status: List[PipelineStatus] = client.get_by_id(
entity=Pipeline, entity_id=pipeline.id, fields=["pipelineStatus"]
).pipelineStatus
task_status = []
# We will append based on the current registered status
if pipeline_status and pipeline_status[0].executionDate.__root__ == execution_date:
# If we are clearing a task, use the status of the new execution
task_status = [
task
for task in pipeline_status[0].taskStatus
if task.name != task_properties["task_id"]
]
# Prepare the new task status information based on the tasks already
# visited and the current task
updated_task_status = [
TaskStatus(
name=task_properties["task_id"],
executionStatus=_STATUS_MAP.get(context["task_instance"].state),
),
*task_status,
]
updated_status = PipelineStatus(
executionDate=execution_date,
executionStatus=get_dag_status(
dag_properties=dag_properties, task_status=updated_task_status
),
taskStatus=updated_task_status,
)
operator.log.info(f"Added status to DAG {updated_status}")
client.add_pipeline_status(pipeline=pipeline, status=updated_status)
# pylint: disable=too-many-arguments,too-many-locals
def parse_lineage_to_openmetadata(
def parse_lineage(
config: OpenMetadataLineageConfig,
context: Dict,
operator: "BaseOperator",
inlets: List,
outlets: List,
client: OpenMetadata,
) -> None:
) -> Optional[Pipeline]:
"""
Main logic to extract properties from DAG and the
triggered operator to ingest lineage data into
@ -249,36 +409,23 @@ def parse_lineage_to_openmetadata(
:param outlets: list of downstream tables
:param client: OpenMetadata client
"""
# Move this import to avoid circular import error when airflow parses the config
# pylint: disable=import-outside-toplevel
from airflow.serialization.serialized_objects import (
SerializedBaseOperator,
SerializedDAG,
)
operator.log.info("Parsing Lineage for OpenMetadata")
dag: "DAG" = context["dag"]
dag_properties = get_properties(dag, SerializedDAG.serialize_dag, ALLOWED_FLOW_KEYS)
task_properties = get_properties(
operator, SerializedBaseOperator.serialize_operator, ALLOWED_TASK_KEYS
)
operator.log.info(f"Task Properties {task_properties}")
operator.log.info(f"DAG properties {dag_properties}")
dag_properties, task_properties = get_context_properties(operator, dag)
try:
airflow_service_entity = get_or_create_pipeline_service(
operator, client, config
)
pipeline = create_pipeline_entity(
dag_properties,
task_properties,
operator,
dag,
airflow_service_entity,
client,
pipeline = create_or_update_pipeline(
dag_properties=dag_properties,
task_properties=task_properties,
operator=operator,
dag=dag,
airflow_service_entity=airflow_service_entity,
client=client,
)
operator.log.info("Parsing Lineage")
@ -291,27 +438,31 @@ def parse_lineage_to_openmetadata(
toEntity=EntityReference(id=pipeline.id, type="pipeline"),
)
)
operator.log.debug(f"from lineage {lineage}")
operator.log.debug(f"From lineage {lineage}")
client.add_lineage(lineage)
for table in outlets if outlets else []:
table_entity = client.get_by_name(entity=Table, fqdn=table)
operator.log.debug(f"to entity {table_entity}")
operator.log.debug(f"To entity {table_entity}")
lineage = AddLineageRequest(
edge=EntitiesEdge(
fromEntity=EntityReference(id=pipeline.id, type="pipeline"),
toEntity=EntityReference(id=table_entity.id, type="table"),
)
)
operator.log.debug(f"to lineage {lineage}")
operator.log.debug(f"To lineage {lineage}")
client.add_lineage(lineage)
return pipeline
except Exception as exc: # pylint: disable=broad-except
operator.log.error(
f"Failed to parse Airflow DAG task and publish to OpenMetadata due to {exc}"
)
operator.log.error(traceback.format_exc())
return None
def get_or_create_pipeline_service(
operator: "BaseOperator", client: OpenMetadata, config: OpenMetadataLineageConfig
@ -320,7 +471,7 @@ def get_or_create_pipeline_service(
Check if we already have the airflow instance as a PipelineService,
otherwise create it.
:param operator: task from which we extract the lienage
:param operator: task from which we extract the lineage
:param client: OpenMetadata API wrapper
:param config: lineage config
:return: PipelineService

View File

@ -1,3 +1,13 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Mixin class containing Lineage specific methods

View File

@ -0,0 +1,123 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Mixin class containing Pipeline specific methods
To be used by OpenMetadata class
"""
import logging
from typing import List
from metadata.generated.schema.api.data.createPipeline import CreatePipelineRequest
from metadata.generated.schema.entity.data.pipeline import (
Pipeline,
PipelineStatus,
Task,
)
from metadata.ingestion.ometa.client import REST
logger = logging.getLogger(__name__)
class OMetaPipelineMixin:
"""
OpenMetadata API methods related to the Pipeline Entity
To be inherited by OpenMetadata
"""
client: REST
def add_pipeline_status(
self, pipeline: Pipeline, status: PipelineStatus
) -> Pipeline:
"""
Given a pipeline and a PipelineStatus, send it
to the Pipeline Entity
"""
resp = self.client.put(
f"{self.get_suffix(Pipeline)}/{pipeline.id.__root__}/status",
data=status.json(),
)
return Pipeline(**resp)
def add_task_to_pipeline(self, pipeline: Pipeline, *tasks: Task) -> Pipeline:
"""
The background logic for this method is that during
Airflow backend lineage, we compute one task at
a time.
Let's generalise a bit the approach by preparing
a method capable of updating a tuple of tasks
from the client.
Latest changes leave all the task management
to the client. Therefore, a Pipeline will only contain
the tasks sent in each PUT from the client.
"""
# Get the names of all incoming tasks
updated_tasks_names = {task.name for task in tasks}
# Check which tasks are currently in the pipeline but not being updated
not_updated_tasks = []
if pipeline.tasks:
not_updated_tasks = [
task for task in pipeline.tasks if task.name not in updated_tasks_names
]
# All tasks are the union of the incoming tasks & the not updated tasks
all_tasks = [*tasks, *not_updated_tasks]
updated_pipeline = CreatePipelineRequest(
name=pipeline.name,
displayName=pipeline.displayName,
description=pipeline.description,
pipelineUrl=pipeline.pipelineUrl,
concurrency=pipeline.concurrency,
pipelineLocation=pipeline.pipelineLocation,
startDate=pipeline.startDate,
service=pipeline.service,
tasks=all_tasks,
owner=pipeline.owner,
tags=pipeline.tags,
)
return self.create_or_update(updated_pipeline)
def clean_pipeline_tasks(self, pipeline: Pipeline, tasks: List[Task]) -> Pipeline:
"""
Given a list of tasks, remove from the
Pipeline Entity those that are not received
as an input.
e.g., if a Pipeline has tasks A, B, C,
but we only receive A & C, we will
remove the task B from the entity
"""
names = {task.name for task in tasks}
updated_pipeline = CreatePipelineRequest(
name=pipeline.name,
displayName=pipeline.displayName,
description=pipeline.description,
pipelineUrl=pipeline.pipelineUrl,
concurrency=pipeline.concurrency,
pipelineLocation=pipeline.pipelineLocation,
startDate=pipeline.startDate,
service=pipeline.service,
tasks=[task for task in pipeline.tasks if task.name in names],
owner=pipeline.owner,
tags=pipeline.tags,
)
return self.create_or_update(updated_pipeline)

View File

@ -48,6 +48,7 @@ from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.ometa.auth_provider import AuthenticationProvider
from metadata.ingestion.ometa.client import REST, APIError, ClientConfig
from metadata.ingestion.ometa.mixins.mlmodel_mixin import OMetaMlModelMixin
from metadata.ingestion.ometa.mixins.pipeline_mixin import OMetaPipelineMixin
from metadata.ingestion.ometa.mixins.table_mixin import OMetaTableMixin
from metadata.ingestion.ometa.mixins.tag_mixin import OMetaTagMixin
from metadata.ingestion.ometa.mixins.version_mixin import OMetaVersionMixin
@ -97,7 +98,12 @@ class EntityList(Generic[T], BaseModel):
class OpenMetadata(
OMetaMlModelMixin, OMetaTableMixin, OMetaVersionMixin, OMetaTagMixin, Generic[T, C]
OMetaPipelineMixin,
OMetaMlModelMixin,
OMetaTableMixin,
OMetaVersionMixin,
OMetaTagMixin,
Generic[T, C],
):
"""
Generic interface to the OpenMetadata API

View File

@ -179,3 +179,10 @@ def convert_epoch_to_iso(seconds_since_epoch):
dt = datetime.utcfromtimestamp(seconds_since_epoch)
iso_format = dt.isoformat() + "Z"
return iso_format
def datetime_to_ts(date: datetime) -> int:
"""
Convert a given date to a timestamp as an Int
"""
return int(date.timestamp())

View File

@ -25,13 +25,13 @@ from airflow.serialization.serialized_objects import (
)
from airflow_provider_openmetadata.lineage.openmetadata import (
ALLOWED_FLOW_KEYS,
ALLOWED_TASK_KEYS,
OpenMetadataLineageBackend,
get_properties,
get_xlets,
)
from airflow_provider_openmetadata.lineage.utils import (
_ALLOWED_FLOW_KEYS,
_ALLOWED_TASK_KEYS,
get_properties,
get_xlets,
iso_dag_start_date,
iso_task_start_end_date,
)
@ -83,7 +83,9 @@ class AirflowLineageTest(TestCase):
cls.create_db_entity = cls.metadata.create_or_update(data=cls.create_db)
cls.db_reference = EntityReference(id=cls.create_db_entity.id, name="test-db", type="database")
cls.db_reference = EntityReference(
id=cls.create_db_entity.id, name="test-db", type="database"
)
cls.create = CreateTableRequest(
name="lineage-test",
@ -147,37 +149,37 @@ class AirflowLineageTest(TestCase):
"""
dag_props = get_properties(
self.dag, SerializedDAG.serialize_dag, ALLOWED_FLOW_KEYS
self.dag, SerializedDAG.serialize_dag, _ALLOWED_FLOW_KEYS
)
self.assertTrue(set(dag_props.keys()).issubset(ALLOWED_FLOW_KEYS))
self.assertTrue(set(dag_props.keys()).issubset(_ALLOWED_FLOW_KEYS))
task1_props = get_properties(
self.dag.get_task("task1"),
SerializedBaseOperator.serialize_operator,
ALLOWED_TASK_KEYS,
_ALLOWED_TASK_KEYS,
)
self.assertTrue(set(task1_props.keys()).issubset(ALLOWED_TASK_KEYS))
self.assertTrue(set(task1_props.keys()).issubset(_ALLOWED_TASK_KEYS))
task2_props = get_properties(
self.dag.get_task("task2"),
SerializedBaseOperator.serialize_operator,
ALLOWED_TASK_KEYS,
_ALLOWED_TASK_KEYS,
)
self.assertTrue(set(task2_props.keys()).issubset(ALLOWED_TASK_KEYS))
self.assertTrue(set(task2_props.keys()).issubset(_ALLOWED_TASK_KEYS))
task3_props = get_properties(
self.dag.get_task("task3"),
SerializedBaseOperator.serialize_operator,
ALLOWED_TASK_KEYS,
_ALLOWED_TASK_KEYS,
)
self.assertTrue(set(task3_props.keys()).issubset(ALLOWED_TASK_KEYS))
self.assertTrue(set(task3_props.keys()).issubset(_ALLOWED_TASK_KEYS))
def test_times(self):
"""
Check the ISO date extraction for DAG and Tasks instances
"""
dag_props = get_properties(
self.dag, SerializedDAG.serialize_dag, ALLOWED_FLOW_KEYS
self.dag, SerializedDAG.serialize_dag, _ALLOWED_FLOW_KEYS
)
dag_date = iso_dag_start_date(dag_props)
@ -195,7 +197,7 @@ class AirflowLineageTest(TestCase):
task1_props = get_properties(
self.dag.get_task("task1"),
SerializedBaseOperator.serialize_operator,
ALLOWED_TASK_KEYS,
_ALLOWED_TASK_KEYS,
)
task_start_date, task_end_date = iso_task_start_end_date(task1_props)

View File

@ -14,13 +14,14 @@ OpenMetadata high-level API Pipeline test
"""
import uuid
from unittest import TestCase
from datetime import datetime
from metadata.generated.schema.api.data.createPipeline import CreatePipelineRequest
from metadata.generated.schema.api.services.createPipelineService import (
CreatePipelineServiceRequest,
)
from metadata.generated.schema.api.teams.createUser import CreateUserRequest
from metadata.generated.schema.entity.data.pipeline import Pipeline
from metadata.generated.schema.entity.data.pipeline import Pipeline, PipelineStatus, StatusType, Task, TaskStatus
from metadata.generated.schema.entity.services.pipelineService import (
PipelineService,
PipelineServiceType,
@ -28,6 +29,7 @@ from metadata.generated.schema.entity.services.pipelineService import (
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
from metadata.utils.helpers import datetime_to_ts
class OMetaPipelineTest(TestCase):
@ -196,6 +198,153 @@ class OMetaPipelineTest(TestCase):
None,
)
def test_add_status(self):
"""
We can add status data
"""
create_pipeline = CreatePipelineRequest(
name="pipeline-test",
service=EntityReference(id=self.service_entity.id, type=self.service_type),
tasks=[
Task(name="task1"),
Task(name="task2"),
]
)
pipeline = self.metadata.create_or_update(data=create_pipeline)
execution_ts = datetime_to_ts(datetime.strptime("2021-03-07", "%Y-%m-%d"))
updated = self.metadata.add_pipeline_status(
pipeline=pipeline,
status=PipelineStatus(
executionDate=execution_ts,
executionStatus=StatusType.Successful,
taskStatus=[
TaskStatus(name="task1", executionStatus=StatusType.Successful),
]
)
)
# We get a list of status
assert updated.pipelineStatus[0].executionDate.__root__ == execution_ts
assert len(updated.pipelineStatus[0].taskStatus) == 1
# Check that we can update a given status properly
updated = self.metadata.add_pipeline_status(
pipeline=pipeline,
status=PipelineStatus(
executionDate=execution_ts,
executionStatus=StatusType.Successful,
taskStatus=[
TaskStatus(name="task1", executionStatus=StatusType.Successful),
TaskStatus(name="task2", executionStatus=StatusType.Successful),
]
)
)
assert updated.pipelineStatus[0].executionDate.__root__ == execution_ts
assert len(updated.pipelineStatus[0].taskStatus) == 2
# Cleanup
self.metadata.delete(entity=Pipeline, entity_id=pipeline.id)
def test_add_tasks(self):
"""
Check the add task logic
"""
create_pipeline = CreatePipelineRequest(
name="pipeline-test",
service=EntityReference(id=self.service_entity.id, type=self.service_type),
tasks=[
Task(name="task1"),
Task(name="task2"),
]
)
pipeline = self.metadata.create_or_update(data=create_pipeline)
# Add new tasks
updated_pipeline = self.metadata.add_task_to_pipeline(
pipeline, Task(name="task3"),
)
assert len(updated_pipeline.tasks) == 3
# Update a task already added
updated_pipeline = self.metadata.add_task_to_pipeline(
pipeline, Task(name="task3", displayName="TaskDisplay"),
)
assert len(updated_pipeline.tasks) == 3
assert next(
iter(
task for task in updated_pipeline.tasks
if task.displayName == "TaskDisplay"
)
)
# Add more than one task at a time
new_tasks = [
Task(name="task3"),
Task(name="task4"),
]
updated_pipeline = self.metadata.add_task_to_pipeline(
pipeline, *new_tasks
)
assert len(updated_pipeline.tasks) == 4
# Cleanup
self.metadata.delete(entity=Pipeline, entity_id=pipeline.id)
def test_add_tasks_to_empty_pipeline(self):
"""
We can add tasks to a pipeline without tasks
"""
pipeline = self.metadata.create_or_update(data=self.create)
updated_pipeline = self.metadata.add_task_to_pipeline(
pipeline, Task(name="task", displayName="TaskDisplay"),
)
assert len(updated_pipeline.tasks) == 1
def test_clean_tasks(self):
"""
Check that we can remove Pipeline tasks
if they are not part of the list arg
"""
create_pipeline = CreatePipelineRequest(
name="pipeline-test",
service=EntityReference(id=self.service_entity.id, type=self.service_type),
tasks=[
Task(name="task1"),
Task(name="task2"),
Task(name="task3"),
Task(name="task4"),
]
)
pipeline = self.metadata.create_or_update(data=create_pipeline)
updated_pipeline = self.metadata.clean_pipeline_tasks(
pipeline=pipeline,
tasks=[
Task(name="task3"),
Task(name="task4")
]
)
assert len(updated_pipeline.tasks) == 2
assert {task.name for task in updated_pipeline.tasks} == {"task3", "task4"}
# Cleanup
self.metadata.delete(entity=Pipeline, entity_id=pipeline.id)
def test_list_versions(self):
"""
test list pipeline entity versions