From 1b2ea54d4fb0c1dd89121612c63b09da828c2933 Mon Sep 17 00:00:00 2001 From: Pere Miquel Brull Date: Sun, 13 Feb 2022 17:51:25 +0100 Subject: [PATCH] Fix #2316 - Append Tasks & Add Status in Airflow (#2738) * Add license * Add date to timestamp helper * Prepare pipeline status operation * Update and clean tasks from client * Update tasks from client * Check if pipeline is empty * Keep all current pipeline info * Append and clean tasks * format * Add status information * Test pipelineStatus update * Update task on clear * Log status on callback * Update lineage and status docs * Update lineage docs * Format * Logic to handle DAG status * Lint and format * Update lineage tests --- docs/install/lineage/airflow-lineage.md | 82 ++++++- .../lineage/callback.py | 57 ++++- .../lineage/openmetadata.py | 60 +---- .../lineage/utils.py | 229 +++++++++++++++--- .../ingestion/ometa/mixins/mlmodel_mixin.py | 10 + .../ingestion/ometa/mixins/pipeline_mixin.py | 123 ++++++++++ .../src/metadata/ingestion/ometa/ometa_api.py | 8 +- ingestion/src/metadata/utils/helpers.py | 7 + .../lineage/airflow/test_airflow_lineage.py | 32 +-- .../ometa/test_ometa_pipeline_api.py | 151 +++++++++++- 10 files changed, 642 insertions(+), 117 deletions(-) create mode 100644 ingestion/src/metadata/ingestion/ometa/mixins/pipeline_mixin.py diff --git a/docs/install/lineage/airflow-lineage.md b/docs/install/lineage/airflow-lineage.md index 32c111b8bb7..e2e65616fcd 100644 --- a/docs/install/lineage/airflow-lineage.md +++ b/docs/install/lineage/airflow-lineage.md @@ -122,9 +122,89 @@ In order to still get the metadata information from the workflow, we can configu Import it with ```python -from airflow_provider_openmetadata.lineage.callback import lineage_callback +from airflow_provider_openmetadata.lineage.callback import failure_callback ``` and use it as an argument for `on_failure_callback` property. This can be set both at DAG and Task level, giving us flexibility on how (and if) we want to handle lineage on failure. + +## Pipeline Status + +Another property that we can check from each Pipeline instance is `pipelineStatus`. You could check status and +the current tasks using a REST query such as: + +```bash +http GET http://localhost:8585/api/v1/pipelines/name/\?fields\=tasks,pipelineStatus +``` + +The pipeline status property looks like: + +```json +"pipelineStatus": [ + { + "executionDate": 1609459200, + "executionStatus": "Failed", + "taskStatus": [ + { + "executionStatus": "Successful", + "name": "sleep" + }, + { + "executionStatus": "Failed", + "name": "explode" + }, + { + "executionStatus": "Successful", + "name": "print_date" + } + ] + }, + ... +] +``` + +Note that it is a list of all the statuses recorded for a specific Pipeline instance. This can help us keep track +of our executions and check our processes KPIs in terms of reliability. + +To properly extract the status data we need to again play with the failure and success callbacks. This is because +during the Lineage Backend execution, the tasks are still flagged as `running`. It is not until we reach to a callback +that we can properly use the Task Instance information to operate with the statuses. + +The `failure_callback` will both compute the lineage and status of failed tasks. For successful ones, we can import + +```python +from airflow_provider_openmetadata.lineage.callback import success_callback +``` + +and pass it as the value for the `on_success_callback` property. + +Note that: + +- We will mark the DAG status as **successful** only if all the tasks of a given execution are successful. +- Clearing a task/DAG will update its previous `pipelineStatus` element of the specific `executionDate`. + +## Best Practices + +It is encouraged to use a set of default arguments for all our DAGs. In there we can set specific configurations +such as the `catchup`, `email` or `email_on_failure`, which are usually handled at project level. + +Using this default arguments `dict` to configure globally the success and failure callbacks for status information +is the most comfortable way to make sure we won't be missing any information. E.g., + +```python +from airflow import DAG + +from airflow_provider_openmetadata.lineage.callback import success_callback, failure_callback + +default_args = { + "on_failure_callback": failure_callback, + "on_success_callback": success_callback, +} + +with DAG( + ..., + default_args=default_args, +) as dag: + ... +``` diff --git a/ingestion/src/airflow_provider_openmetadata/lineage/callback.py b/ingestion/src/airflow_provider_openmetadata/lineage/callback.py index 6788382322a..9c5122dfac0 100644 --- a/ingestion/src/airflow_provider_openmetadata/lineage/callback.py +++ b/ingestion/src/airflow_provider_openmetadata/lineage/callback.py @@ -20,16 +20,19 @@ from airflow_provider_openmetadata.lineage.config import ( get_metadata_config, ) from airflow_provider_openmetadata.lineage.utils import ( + add_status, get_xlets, - parse_lineage_to_openmetadata, + parse_lineage, ) +from metadata.generated.schema.entity.data.pipeline import Pipeline +from metadata.generated.schema.entity.services.pipelineService import PipelineService from metadata.ingestion.ometa.ometa_api import OpenMetadata if TYPE_CHECKING: from airflow.models.baseoperator import BaseOperator -def lineage_callback(context: Dict[str, str]) -> None: +def failure_callback(context: Dict[str, str]) -> None: """ Add this function to the args of your DAG or Task as the value of `on_failure_callback` to track @@ -44,12 +47,60 @@ def lineage_callback(context: Dict[str, str]) -> None: operator: "BaseOperator" = context["task"] + operator.log.info("Parsing lineage & pipeline status on failure...") + op_inlets = get_xlets(operator, "_inlets") op_outlets = get_xlets(operator, "_outlets") - parse_lineage_to_openmetadata( + # Get the pipeline created or updated during the lineage + pipeline = parse_lineage( config, context, operator, op_inlets, op_outlets, client ) + add_status( + operator=operator, + pipeline=pipeline, + client=client, + context=context, + ) + + except Exception as exc: # pylint: disable=broad-except + logging.error("Lineage Callback exception %s", exc) + + +def success_callback(context: Dict[str, str]) -> None: + """ + Add this function to the args of your DAG or Task + as the value of `on_success_callback` to track + task status on task success + + :param context: Airflow runtime context + """ + try: + + config = get_lineage_config() + metadata_config = get_metadata_config(config) + client = OpenMetadata(metadata_config) + + operator: "BaseOperator" = context["task"] + dag: "DAG" = context["dag"] + + operator.log.info("Updating pipeline status on success...") + + airflow_service_entity = client.get_by_name( + entity=PipelineService, fqdn=config.airflow_service_name + ) + pipeline: Pipeline = client.get_by_name( + entity=Pipeline, + fqdn=f"{airflow_service_entity.name}.{dag.dag_id}", + ) + + add_status( + operator=operator, + pipeline=pipeline, + client=client, + context=context, + ) + except Exception as exc: # pylint: disable=broad-except logging.error("Lineage Callback exception %s", exc) diff --git a/ingestion/src/airflow_provider_openmetadata/lineage/openmetadata.py b/ingestion/src/airflow_provider_openmetadata/lineage/openmetadata.py index eec5c3c9828..db12a90bfa1 100644 --- a/ingestion/src/airflow_provider_openmetadata/lineage/openmetadata.py +++ b/ingestion/src/airflow_provider_openmetadata/lineage/openmetadata.py @@ -22,65 +22,13 @@ from airflow_provider_openmetadata.lineage.config import ( get_lineage_config, get_metadata_config, ) -from airflow_provider_openmetadata.lineage.utils import ( - get_xlets, - parse_lineage_to_openmetadata, -) +from airflow_provider_openmetadata.lineage.utils import get_xlets, parse_lineage from metadata.ingestion.ometa.ometa_api import OpenMetadata if TYPE_CHECKING: from airflow.models.baseoperator import BaseOperator -allowed_task_keys = [ - "_downstream_task_ids", - "_inlets", - "_outlets", - "_task_type", - "_task_module", - "depends_on_past", - "email", - "label", - "execution_timeout", - "end_date", - "start_date", - "sla", - "sql", - "task_id", - "trigger_rule", - "wait_for_downstream", -] -allowed_flow_keys = [ - "_access_control", - "_concurrency", - "_default_view", - "catchup", - "fileloc", - "is_paused_upon_creation", - "start_date", - "tags", - "timezone", -] - - -# pylint: disable=import-outside-toplevel, unused-import -def is_airflow_version_1() -> bool: - """ - Manage airflow submodule import based airflow version - - Returns - bool - """ - try: - from airflow.hooks.base import BaseHook - - return False - except ModuleNotFoundError: - from airflow.hooks.base_hook import BaseHook - - return True - - # pylint: disable=too-few-public-methods class OpenMetadataLineageBackend(LineageBackend): """ @@ -110,8 +58,8 @@ class OpenMetadataLineageBackend(LineageBackend): _ = get_lineage_config() # pylint: disable=protected-access - @staticmethod def send_lineage( + self, operator: "BaseOperator", inlets: Optional[List] = None, outlets: Optional[List] = None, @@ -137,9 +85,7 @@ class OpenMetadataLineageBackend(LineageBackend): op_inlets = get_xlets(operator, "_inlets") op_outlets = get_xlets(operator, "_outlets") - parse_lineage_to_openmetadata( - config, context, operator, op_inlets, op_outlets, client - ) + parse_lineage(config, context, operator, op_inlets, op_outlets, client) except Exception as exc: # pylint: disable=broad-except operator.log.error(traceback.format_exc()) operator.log.error(exc) diff --git a/ingestion/src/airflow_provider_openmetadata/lineage/utils.py b/ingestion/src/airflow_provider_openmetadata/lineage/utils.py index ab712aca7b4..1ca50037924 100644 --- a/ingestion/src/airflow_provider_openmetadata/lineage/utils.py +++ b/ingestion/src/airflow_provider_openmetadata/lineage/utils.py @@ -24,7 +24,13 @@ from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest from metadata.generated.schema.api.services.createPipelineService import ( CreatePipelineServiceRequest, ) -from metadata.generated.schema.entity.data.pipeline import Pipeline, Task +from metadata.generated.schema.entity.data.pipeline import ( + Pipeline, + PipelineStatus, + StatusType, + Task, + TaskStatus, +) from metadata.generated.schema.entity.data.table import Table from metadata.generated.schema.entity.services.pipelineService import ( PipelineService, @@ -39,7 +45,7 @@ if TYPE_CHECKING: from airflow import DAG from airflow.models.baseoperator import BaseOperator -ALLOWED_TASK_KEYS = { +_ALLOWED_TASK_KEYS = { "_downstream_task_ids", "_inlets", "_outlets", @@ -58,7 +64,7 @@ ALLOWED_TASK_KEYS = { "wait_for_downstream", } -ALLOWED_FLOW_KEYS = { +_ALLOWED_FLOW_KEYS = { "_access_control", "_concurrency", "_default_view", @@ -68,6 +74,13 @@ ALLOWED_FLOW_KEYS = { "start_date", "tags", "timezone", + "_task_group", # We can get children information from here +} + +_STATUS_MAP = { + "running": StatusType.Pending, + "success": StatusType.Successful, + "failed": StatusType.Failed, } @@ -95,7 +108,7 @@ def get_properties( :return: properties dict """ - props: Dict[str, str] = {key: value for (key, value) in serializer(obj).items()} + props: Dict[str, str] = dict(serializer(obj).items()) for key in obj.get_serialized_fields(): if key not in props: @@ -172,16 +185,21 @@ def iso_task_start_end_date( return task_start_date, task_end_date -def create_pipeline_entity( - dag_properties: Dict[str, str], - task_properties: Dict[str, str], +def create_or_update_pipeline( # pylint: disable=too-many-locals + dag_properties: Dict[str, Any], + task_properties: Dict[str, Any], operator: "BaseOperator", dag: "DAG", airflow_service_entity: PipelineService, client: OpenMetadata, ) -> Pipeline: """ - Prepare the upsert the pipeline entity with the given task + Prepare the upsert of pipeline entity with the given task + + We will: + - Create the pipeline Entity + - Append the task being processed + - Clean deleted tasks based on the DAG children information :param dag_properties: attributes of the dag object :param task_properties: attributes of the task object @@ -215,28 +233,170 @@ def create_pipeline_entity( endDate=task_end_date, downstreamTasks=downstream_tasks, ) - create_pipeline = CreatePipelineRequest( + + # Check if the pipeline already exists + current_pipeline: Pipeline = client.get_by_name( + entity=Pipeline, + fqdn=f"{airflow_service_entity.name}.{dag.dag_id}", + fields=["tasks"], + ) + + # Create pipeline if not exists or update its properties + pipeline_request = CreatePipelineRequest( name=dag.dag_id, displayName=dag.dag_id, description=dag.description, pipelineUrl=dag_url, + concurrency=current_pipeline.concurrency if current_pipeline else None, + pipelineLocation=current_pipeline.pipelineLocation + if current_pipeline + else None, startDate=dag_start_date, - tasks=[task], # TODO: should we GET + append? + tasks=current_pipeline.tasks + if current_pipeline + else None, # use the current tasks, if any service=EntityReference(id=airflow_service_entity.id, type="pipelineService"), + owner=current_pipeline.owner if current_pipeline else None, + tags=current_pipeline.tags if current_pipeline else None, + ) + pipeline = client.create_or_update(pipeline_request) + + # Add the task we are processing in the lineage backend + operator.log.info("Adding tasks to pipeline...") + updated_pipeline = client.add_task_to_pipeline(pipeline, task) + + # Clean pipeline + try: + operator.log.info("Cleaning pipeline tasks...") + children = dag_properties.get("_task_group").get("children") + dag_tasks = [Task(name=name) for name in children.keys()] + updated_pipeline = client.clean_pipeline_tasks(updated_pipeline, dag_tasks) + except Exception as exc: # pylint: disable=broad-except + operator.log.warning(f"Error cleaning pipeline tasks {exc}") + + return updated_pipeline + + +def get_context_properties( + operator: "BaseOperator", dag: "DAG" +) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + Prepare DAG and Task properties based on attributes + and serializers + """ + # Move this import to avoid circular import error when airflow parses the config + # pylint: disable=import-outside-toplevel + from airflow.serialization.serialized_objects import ( + SerializedBaseOperator, + SerializedDAG, ) - return client.create_or_update(create_pipeline) + dag_properties = get_properties( + dag, SerializedDAG.serialize_dag, _ALLOWED_FLOW_KEYS + ) + task_properties = get_properties( + operator, SerializedBaseOperator.serialize_operator, _ALLOWED_TASK_KEYS + ) + + operator.log.info(f"Task Properties {task_properties}") + operator.log.info(f"DAG properties {dag_properties}") + + return dag_properties, task_properties + + +def get_dag_status(dag_properties: Dict[str, Any], task_status: List[TaskStatus]): + """ + Based on the task information and the total DAG tasks, cook the + DAG status. + We are not directly using `context["dag_run"]._state` as it always + gets flagged as "running" during the callbacks. + """ + + children = dag_properties.get("_task_group").get("children") + + if len(children) < len(task_status): + raise ValueError( + "We have more status than children:" + + f"children {children} vs. status {task_status}" + ) + + # We are still processing tasks... + if len(children) > len(task_status): + return StatusType.Pending + + # Check for any failure if all tasks have been processed + if len(children) == len(task_status) and StatusType.Failed in { + task.executionStatus for task in task_status + }: + return StatusType.Failed + + return StatusType.Successful + + +def add_status( + operator: "BaseOperator", + pipeline: Pipeline, + client: OpenMetadata, + context: Dict, +) -> None: + """ + Add status information for this execution date + """ + + dag: "DAG" = context["dag"] + dag_properties, task_properties = get_context_properties(operator, dag) + + # Let this fail if we cannot properly extract & cast the start_date + execution_date = int(dag_properties.get("start_date")) + operator.log.info(f"Logging pipeline status for execution {execution_date}") + + # Check if we already have a pipelineStatus for + # our execution_date that we should update + pipeline_status: List[PipelineStatus] = client.get_by_id( + entity=Pipeline, entity_id=pipeline.id, fields=["pipelineStatus"] + ).pipelineStatus + + task_status = [] + # We will append based on the current registered status + if pipeline_status and pipeline_status[0].executionDate.__root__ == execution_date: + # If we are clearing a task, use the status of the new execution + task_status = [ + task + for task in pipeline_status[0].taskStatus + if task.name != task_properties["task_id"] + ] + + # Prepare the new task status information based on the tasks already + # visited and the current task + updated_task_status = [ + TaskStatus( + name=task_properties["task_id"], + executionStatus=_STATUS_MAP.get(context["task_instance"].state), + ), + *task_status, + ] + + updated_status = PipelineStatus( + executionDate=execution_date, + executionStatus=get_dag_status( + dag_properties=dag_properties, task_status=updated_task_status + ), + taskStatus=updated_task_status, + ) + + operator.log.info(f"Added status to DAG {updated_status}") + client.add_pipeline_status(pipeline=pipeline, status=updated_status) # pylint: disable=too-many-arguments,too-many-locals -def parse_lineage_to_openmetadata( +def parse_lineage( config: OpenMetadataLineageConfig, context: Dict, operator: "BaseOperator", inlets: List, outlets: List, client: OpenMetadata, -) -> None: +) -> Optional[Pipeline]: """ Main logic to extract properties from DAG and the triggered operator to ingest lineage data into @@ -249,36 +409,23 @@ def parse_lineage_to_openmetadata( :param outlets: list of downstream tables :param client: OpenMetadata client """ - # Move this import to avoid circular import error when airflow parses the config - # pylint: disable=import-outside-toplevel - from airflow.serialization.serialized_objects import ( - SerializedBaseOperator, - SerializedDAG, - ) - operator.log.info("Parsing Lineage for OpenMetadata") + dag: "DAG" = context["dag"] - - dag_properties = get_properties(dag, SerializedDAG.serialize_dag, ALLOWED_FLOW_KEYS) - task_properties = get_properties( - operator, SerializedBaseOperator.serialize_operator, ALLOWED_TASK_KEYS - ) - - operator.log.info(f"Task Properties {task_properties}") - operator.log.info(f"DAG properties {dag_properties}") + dag_properties, task_properties = get_context_properties(operator, dag) try: airflow_service_entity = get_or_create_pipeline_service( operator, client, config ) - pipeline = create_pipeline_entity( - dag_properties, - task_properties, - operator, - dag, - airflow_service_entity, - client, + pipeline = create_or_update_pipeline( + dag_properties=dag_properties, + task_properties=task_properties, + operator=operator, + dag=dag, + airflow_service_entity=airflow_service_entity, + client=client, ) operator.log.info("Parsing Lineage") @@ -291,27 +438,31 @@ def parse_lineage_to_openmetadata( toEntity=EntityReference(id=pipeline.id, type="pipeline"), ) ) - operator.log.debug(f"from lineage {lineage}") + operator.log.debug(f"From lineage {lineage}") client.add_lineage(lineage) for table in outlets if outlets else []: table_entity = client.get_by_name(entity=Table, fqdn=table) - operator.log.debug(f"to entity {table_entity}") + operator.log.debug(f"To entity {table_entity}") lineage = AddLineageRequest( edge=EntitiesEdge( fromEntity=EntityReference(id=pipeline.id, type="pipeline"), toEntity=EntityReference(id=table_entity.id, type="table"), ) ) - operator.log.debug(f"to lineage {lineage}") + operator.log.debug(f"To lineage {lineage}") client.add_lineage(lineage) + return pipeline + except Exception as exc: # pylint: disable=broad-except operator.log.error( f"Failed to parse Airflow DAG task and publish to OpenMetadata due to {exc}" ) operator.log.error(traceback.format_exc()) + return None + def get_or_create_pipeline_service( operator: "BaseOperator", client: OpenMetadata, config: OpenMetadataLineageConfig @@ -320,7 +471,7 @@ def get_or_create_pipeline_service( Check if we already have the airflow instance as a PipelineService, otherwise create it. - :param operator: task from which we extract the lienage + :param operator: task from which we extract the lineage :param client: OpenMetadata API wrapper :param config: lineage config :return: PipelineService diff --git a/ingestion/src/metadata/ingestion/ometa/mixins/mlmodel_mixin.py b/ingestion/src/metadata/ingestion/ometa/mixins/mlmodel_mixin.py index ae2e84c4faf..caee50870bb 100644 --- a/ingestion/src/metadata/ingestion/ometa/mixins/mlmodel_mixin.py +++ b/ingestion/src/metadata/ingestion/ometa/mixins/mlmodel_mixin.py @@ -1,3 +1,13 @@ +# Copyright 2021 Collate +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ Mixin class containing Lineage specific methods diff --git a/ingestion/src/metadata/ingestion/ometa/mixins/pipeline_mixin.py b/ingestion/src/metadata/ingestion/ometa/mixins/pipeline_mixin.py new file mode 100644 index 00000000000..e51a1e2efc6 --- /dev/null +++ b/ingestion/src/metadata/ingestion/ometa/mixins/pipeline_mixin.py @@ -0,0 +1,123 @@ +# Copyright 2021 Collate +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Mixin class containing Pipeline specific methods + +To be used by OpenMetadata class +""" +import logging +from typing import List + +from metadata.generated.schema.api.data.createPipeline import CreatePipelineRequest +from metadata.generated.schema.entity.data.pipeline import ( + Pipeline, + PipelineStatus, + Task, +) +from metadata.ingestion.ometa.client import REST + +logger = logging.getLogger(__name__) + + +class OMetaPipelineMixin: + """ + OpenMetadata API methods related to the Pipeline Entity + + To be inherited by OpenMetadata + """ + + client: REST + + def add_pipeline_status( + self, pipeline: Pipeline, status: PipelineStatus + ) -> Pipeline: + """ + Given a pipeline and a PipelineStatus, send it + to the Pipeline Entity + """ + resp = self.client.put( + f"{self.get_suffix(Pipeline)}/{pipeline.id.__root__}/status", + data=status.json(), + ) + return Pipeline(**resp) + + def add_task_to_pipeline(self, pipeline: Pipeline, *tasks: Task) -> Pipeline: + """ + The background logic for this method is that during + Airflow backend lineage, we compute one task at + a time. + + Let's generalise a bit the approach by preparing + a method capable of updating a tuple of tasks + from the client. + + Latest changes leave all the task management + to the client. Therefore, a Pipeline will only contain + the tasks sent in each PUT from the client. + """ + + # Get the names of all incoming tasks + updated_tasks_names = {task.name for task in tasks} + + # Check which tasks are currently in the pipeline but not being updated + not_updated_tasks = [] + if pipeline.tasks: + not_updated_tasks = [ + task for task in pipeline.tasks if task.name not in updated_tasks_names + ] + + # All tasks are the union of the incoming tasks & the not updated tasks + all_tasks = [*tasks, *not_updated_tasks] + + updated_pipeline = CreatePipelineRequest( + name=pipeline.name, + displayName=pipeline.displayName, + description=pipeline.description, + pipelineUrl=pipeline.pipelineUrl, + concurrency=pipeline.concurrency, + pipelineLocation=pipeline.pipelineLocation, + startDate=pipeline.startDate, + service=pipeline.service, + tasks=all_tasks, + owner=pipeline.owner, + tags=pipeline.tags, + ) + + return self.create_or_update(updated_pipeline) + + def clean_pipeline_tasks(self, pipeline: Pipeline, tasks: List[Task]) -> Pipeline: + """ + Given a list of tasks, remove from the + Pipeline Entity those that are not received + as an input. + + e.g., if a Pipeline has tasks A, B, C, + but we only receive A & C, we will + remove the task B from the entity + """ + + names = {task.name for task in tasks} + + updated_pipeline = CreatePipelineRequest( + name=pipeline.name, + displayName=pipeline.displayName, + description=pipeline.description, + pipelineUrl=pipeline.pipelineUrl, + concurrency=pipeline.concurrency, + pipelineLocation=pipeline.pipelineLocation, + startDate=pipeline.startDate, + service=pipeline.service, + tasks=[task for task in pipeline.tasks if task.name in names], + owner=pipeline.owner, + tags=pipeline.tags, + ) + + return self.create_or_update(updated_pipeline) diff --git a/ingestion/src/metadata/ingestion/ometa/ometa_api.py b/ingestion/src/metadata/ingestion/ometa/ometa_api.py index fae92fb018b..616c0284672 100644 --- a/ingestion/src/metadata/ingestion/ometa/ometa_api.py +++ b/ingestion/src/metadata/ingestion/ometa/ometa_api.py @@ -48,6 +48,7 @@ from metadata.generated.schema.type.entityReference import EntityReference from metadata.ingestion.ometa.auth_provider import AuthenticationProvider from metadata.ingestion.ometa.client import REST, APIError, ClientConfig from metadata.ingestion.ometa.mixins.mlmodel_mixin import OMetaMlModelMixin +from metadata.ingestion.ometa.mixins.pipeline_mixin import OMetaPipelineMixin from metadata.ingestion.ometa.mixins.table_mixin import OMetaTableMixin from metadata.ingestion.ometa.mixins.tag_mixin import OMetaTagMixin from metadata.ingestion.ometa.mixins.version_mixin import OMetaVersionMixin @@ -97,7 +98,12 @@ class EntityList(Generic[T], BaseModel): class OpenMetadata( - OMetaMlModelMixin, OMetaTableMixin, OMetaVersionMixin, OMetaTagMixin, Generic[T, C] + OMetaPipelineMixin, + OMetaMlModelMixin, + OMetaTableMixin, + OMetaVersionMixin, + OMetaTagMixin, + Generic[T, C], ): """ Generic interface to the OpenMetadata API diff --git a/ingestion/src/metadata/utils/helpers.py b/ingestion/src/metadata/utils/helpers.py index 4678a95c7d5..62a8ecdf3b3 100644 --- a/ingestion/src/metadata/utils/helpers.py +++ b/ingestion/src/metadata/utils/helpers.py @@ -179,3 +179,10 @@ def convert_epoch_to_iso(seconds_since_epoch): dt = datetime.utcfromtimestamp(seconds_since_epoch) iso_format = dt.isoformat() + "Z" return iso_format + + +def datetime_to_ts(date: datetime) -> int: + """ + Convert a given date to a timestamp as an Int + """ + return int(date.timestamp()) diff --git a/ingestion/tests/integration/lineage/airflow/test_airflow_lineage.py b/ingestion/tests/integration/lineage/airflow/test_airflow_lineage.py index 5c10f3cbe16..5b0966da739 100644 --- a/ingestion/tests/integration/lineage/airflow/test_airflow_lineage.py +++ b/ingestion/tests/integration/lineage/airflow/test_airflow_lineage.py @@ -25,13 +25,13 @@ from airflow.serialization.serialized_objects import ( ) from airflow_provider_openmetadata.lineage.openmetadata import ( - ALLOWED_FLOW_KEYS, - ALLOWED_TASK_KEYS, OpenMetadataLineageBackend, - get_properties, - get_xlets, ) from airflow_provider_openmetadata.lineage.utils import ( + _ALLOWED_FLOW_KEYS, + _ALLOWED_TASK_KEYS, + get_properties, + get_xlets, iso_dag_start_date, iso_task_start_end_date, ) @@ -83,7 +83,9 @@ class AirflowLineageTest(TestCase): cls.create_db_entity = cls.metadata.create_or_update(data=cls.create_db) - cls.db_reference = EntityReference(id=cls.create_db_entity.id, name="test-db", type="database") + cls.db_reference = EntityReference( + id=cls.create_db_entity.id, name="test-db", type="database" + ) cls.create = CreateTableRequest( name="lineage-test", @@ -147,37 +149,37 @@ class AirflowLineageTest(TestCase): """ dag_props = get_properties( - self.dag, SerializedDAG.serialize_dag, ALLOWED_FLOW_KEYS + self.dag, SerializedDAG.serialize_dag, _ALLOWED_FLOW_KEYS ) - self.assertTrue(set(dag_props.keys()).issubset(ALLOWED_FLOW_KEYS)) + self.assertTrue(set(dag_props.keys()).issubset(_ALLOWED_FLOW_KEYS)) task1_props = get_properties( self.dag.get_task("task1"), SerializedBaseOperator.serialize_operator, - ALLOWED_TASK_KEYS, + _ALLOWED_TASK_KEYS, ) - self.assertTrue(set(task1_props.keys()).issubset(ALLOWED_TASK_KEYS)) + self.assertTrue(set(task1_props.keys()).issubset(_ALLOWED_TASK_KEYS)) task2_props = get_properties( self.dag.get_task("task2"), SerializedBaseOperator.serialize_operator, - ALLOWED_TASK_KEYS, + _ALLOWED_TASK_KEYS, ) - self.assertTrue(set(task2_props.keys()).issubset(ALLOWED_TASK_KEYS)) + self.assertTrue(set(task2_props.keys()).issubset(_ALLOWED_TASK_KEYS)) task3_props = get_properties( self.dag.get_task("task3"), SerializedBaseOperator.serialize_operator, - ALLOWED_TASK_KEYS, + _ALLOWED_TASK_KEYS, ) - self.assertTrue(set(task3_props.keys()).issubset(ALLOWED_TASK_KEYS)) + self.assertTrue(set(task3_props.keys()).issubset(_ALLOWED_TASK_KEYS)) def test_times(self): """ Check the ISO date extraction for DAG and Tasks instances """ dag_props = get_properties( - self.dag, SerializedDAG.serialize_dag, ALLOWED_FLOW_KEYS + self.dag, SerializedDAG.serialize_dag, _ALLOWED_FLOW_KEYS ) dag_date = iso_dag_start_date(dag_props) @@ -195,7 +197,7 @@ class AirflowLineageTest(TestCase): task1_props = get_properties( self.dag.get_task("task1"), SerializedBaseOperator.serialize_operator, - ALLOWED_TASK_KEYS, + _ALLOWED_TASK_KEYS, ) task_start_date, task_end_date = iso_task_start_end_date(task1_props) diff --git a/ingestion/tests/integration/ometa/test_ometa_pipeline_api.py b/ingestion/tests/integration/ometa/test_ometa_pipeline_api.py index e95fa2f3e49..6a30dff6ecf 100644 --- a/ingestion/tests/integration/ometa/test_ometa_pipeline_api.py +++ b/ingestion/tests/integration/ometa/test_ometa_pipeline_api.py @@ -14,13 +14,14 @@ OpenMetadata high-level API Pipeline test """ import uuid from unittest import TestCase +from datetime import datetime from metadata.generated.schema.api.data.createPipeline import CreatePipelineRequest from metadata.generated.schema.api.services.createPipelineService import ( CreatePipelineServiceRequest, ) from metadata.generated.schema.api.teams.createUser import CreateUserRequest -from metadata.generated.schema.entity.data.pipeline import Pipeline +from metadata.generated.schema.entity.data.pipeline import Pipeline, PipelineStatus, StatusType, Task, TaskStatus from metadata.generated.schema.entity.services.pipelineService import ( PipelineService, PipelineServiceType, @@ -28,6 +29,7 @@ from metadata.generated.schema.entity.services.pipelineService import ( from metadata.generated.schema.type.entityReference import EntityReference from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig +from metadata.utils.helpers import datetime_to_ts class OMetaPipelineTest(TestCase): @@ -196,6 +198,153 @@ class OMetaPipelineTest(TestCase): None, ) + def test_add_status(self): + """ + We can add status data + """ + + create_pipeline = CreatePipelineRequest( + name="pipeline-test", + service=EntityReference(id=self.service_entity.id, type=self.service_type), + tasks=[ + Task(name="task1"), + Task(name="task2"), + ] + ) + + pipeline = self.metadata.create_or_update(data=create_pipeline) + execution_ts = datetime_to_ts(datetime.strptime("2021-03-07", "%Y-%m-%d")) + + updated = self.metadata.add_pipeline_status( + pipeline=pipeline, + status=PipelineStatus( + executionDate=execution_ts, + executionStatus=StatusType.Successful, + taskStatus=[ + TaskStatus(name="task1", executionStatus=StatusType.Successful), + ] + ) + ) + + # We get a list of status + assert updated.pipelineStatus[0].executionDate.__root__ == execution_ts + assert len(updated.pipelineStatus[0].taskStatus) == 1 + + # Check that we can update a given status properly + updated = self.metadata.add_pipeline_status( + pipeline=pipeline, + status=PipelineStatus( + executionDate=execution_ts, + executionStatus=StatusType.Successful, + taskStatus=[ + TaskStatus(name="task1", executionStatus=StatusType.Successful), + TaskStatus(name="task2", executionStatus=StatusType.Successful), + ] + ) + ) + + assert updated.pipelineStatus[0].executionDate.__root__ == execution_ts + assert len(updated.pipelineStatus[0].taskStatus) == 2 + + # Cleanup + self.metadata.delete(entity=Pipeline, entity_id=pipeline.id) + + def test_add_tasks(self): + """ + Check the add task logic + """ + + create_pipeline = CreatePipelineRequest( + name="pipeline-test", + service=EntityReference(id=self.service_entity.id, type=self.service_type), + tasks=[ + Task(name="task1"), + Task(name="task2"), + ] + ) + + pipeline = self.metadata.create_or_update(data=create_pipeline) + + # Add new tasks + updated_pipeline = self.metadata.add_task_to_pipeline( + pipeline, Task(name="task3"), + ) + + assert len(updated_pipeline.tasks) == 3 + + # Update a task already added + updated_pipeline = self.metadata.add_task_to_pipeline( + pipeline, Task(name="task3", displayName="TaskDisplay"), + ) + + assert len(updated_pipeline.tasks) == 3 + assert next( + iter( + task for task in updated_pipeline.tasks + if task.displayName == "TaskDisplay" + ) + ) + + # Add more than one task at a time + new_tasks = [ + Task(name="task3"), + Task(name="task4"), + ] + updated_pipeline = self.metadata.add_task_to_pipeline( + pipeline, *new_tasks + ) + + assert len(updated_pipeline.tasks) == 4 + + # Cleanup + self.metadata.delete(entity=Pipeline, entity_id=pipeline.id) + + def test_add_tasks_to_empty_pipeline(self): + """ + We can add tasks to a pipeline without tasks + """ + + pipeline = self.metadata.create_or_update(data=self.create) + + updated_pipeline = self.metadata.add_task_to_pipeline( + pipeline, Task(name="task", displayName="TaskDisplay"), + ) + + assert len(updated_pipeline.tasks) == 1 + + def test_clean_tasks(self): + """ + Check that we can remove Pipeline tasks + if they are not part of the list arg + """ + + create_pipeline = CreatePipelineRequest( + name="pipeline-test", + service=EntityReference(id=self.service_entity.id, type=self.service_type), + tasks=[ + Task(name="task1"), + Task(name="task2"), + Task(name="task3"), + Task(name="task4"), + ] + ) + + pipeline = self.metadata.create_or_update(data=create_pipeline) + + updated_pipeline = self.metadata.clean_pipeline_tasks( + pipeline=pipeline, + tasks=[ + Task(name="task3"), + Task(name="task4") + ] + ) + + assert len(updated_pipeline.tasks) == 2 + assert {task.name for task in updated_pipeline.tasks} == {"task3", "task4"} + + # Cleanup + self.metadata.delete(entity=Pipeline, entity_id=pipeline.id) + def test_list_versions(self): """ test list pipeline entity versions