OpenMetadata/ingestion/tests/integration/airflow/test_airflow_lineage.py
2025-04-03 10:39:47 +05:30

307 lines
11 KiB
Python

# Copyright 2025 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Test airflow lineage operator and hook.
This test is coupled with the example DAG `lineage_tutorial_operator`.
With the `docker compose up` setup, you can debug the progress
by setting breakpoints in this file.
"""
import time
from typing import Optional
from unittest import TestCase
import requests
from metadata.generated.schema.api.data.createDatabase import CreateDatabaseRequest
from metadata.generated.schema.api.data.createDatabaseSchema import (
CreateDatabaseSchemaRequest,
)
from metadata.generated.schema.api.data.createTable import CreateTableRequest
from metadata.generated.schema.api.services.createDatabaseService import (
CreateDatabaseServiceRequest,
)
from metadata.generated.schema.entity.data.pipeline import Pipeline, StatusType
from metadata.generated.schema.entity.data.table import Column, DataType, Table
from metadata.generated.schema.entity.services.connections.database.common.basicAuth import (
BasicAuth,
)
from metadata.generated.schema.entity.services.connections.database.mysqlConnection import (
MysqlConnection,
)
from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import (
OpenMetadataConnection,
)
from metadata.generated.schema.entity.services.databaseService import (
DatabaseConnection,
DatabaseService,
DatabaseServiceType,
)
from metadata.generated.schema.entity.services.pipelineService import PipelineService
from metadata.generated.schema.security.client.openMetadataJWTClientConfig import (
OpenMetadataJWTClientConfig,
)
from metadata.ingestion.ometa.ometa_api import OpenMetadata
# These variables are just here to validate elements in the local deployment
OM_HOST_PORT = "http://localhost:8585/api"
OM_JWT = "eyJraWQiOiJHYjM4OWEtOWY3Ni1nZGpzLWE5MmotMDI0MmJrOTQzNTYiLCJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJhZG1pbiIsImlzQm90IjpmYWxzZSwiaXNzIjoib3Blbi1tZXRhZGF0YS5vcmciLCJpYXQiOjE2NjM5Mzg0NjIsImVtYWlsIjoiYWRtaW5Ab3Blbm1ldGFkYXRhLm9yZyJ9.tS8um_5DKu7HgzGBzS1VTA5uUjKWOCU0B_j08WXBiEC0mr0zNREkqVfwFDD-d24HlNEbrqioLsBuFRiwIWKc1m_ZlVQbG7P36RUxhuv2vbSp80FKyNM-Tj93FDzq91jsyNmsQhyNv_fNr3TXfzzSPjHt8Go0FMMP66weoKMgW2PbXlhVKwEuXUHyakLLzewm9UMeQaEiRzhiTMU3UkLXcKbYEJJvfNFcLwSl9W8JCO_l0Yj3ud-qt_nQYEZwqW6u5nfdQllN133iikV4fM5QZsMCnm8Rq1mvLR0y9bmJiD7fwM1tmJ791TUWqmKaTnP49U493VanKpUAfzIiOiIbhg"
AIRFLOW_HOST_API_ROOT = "http://localhost:8080/api/v1/"
DEFAULT_OM_AIRFLOW_CONNECTION = "openmetadata_conn_id"
DEFAULT_AIRFLOW_HEADERS = {
"Content-Type": "application/json",
"Authorization": "Basic YWRtaW46YWRtaW4=",
}
OM_LINEAGE_DAG_NAME = "lineage_tutorial_operator"
PIPELINE_SERVICE_NAME = "airflow_lineage_op_service"
def get_task_status_type_by_name(pipeline: Pipeline, name: str) -> Optional[StatusType]:
"""
Given a pipeline, get its status by name
"""
return next(
(
status.executionStatus
for status in pipeline.pipelineStatus.taskStatus
if status.name == name
),
None,
)
class AirflowLineageTest(TestCase):
"""
This test will trigger an Airflow DAG and validate that the
OpenMetadata Lineage Operator can properly handle the
metadata ingestion and processes inlets and outlets.
"""
server_config = OpenMetadataConnection(
hostPort=OM_HOST_PORT,
authProvider="openmetadata",
securityConfig=OpenMetadataJWTClientConfig(jwtToken=OM_JWT),
)
metadata = OpenMetadata(server_config)
assert metadata.health_check()
service = CreateDatabaseServiceRequest(
name="test-service-table-lineage",
serviceType=DatabaseServiceType.Mysql,
connection=DatabaseConnection(
config=MysqlConnection(
username="username",
authType=BasicAuth(password="password"),
hostPort="http://localhost:1234",
)
),
)
service_type = "databaseService"
@classmethod
def setUpClass(cls) -> None:
"""
Prepare ingredients: Table Entity + DAG
"""
service_entity = cls.metadata.create_or_update(data=cls.service)
create_db = CreateDatabaseRequest(
name="test-db",
service=service_entity.fullyQualifiedName,
)
create_db_entity = cls.metadata.create_or_update(data=create_db)
create_schema = CreateDatabaseSchemaRequest(
name="test-schema",
database=create_db_entity.fullyQualifiedName,
)
create_schema_entity = cls.metadata.create_or_update(data=create_schema)
create_inlet = CreateTableRequest(
name="lineage-test-inlet",
databaseSchema=create_schema_entity.fullyQualifiedName,
columns=[Column(name="id", dataType=DataType.BIGINT)],
)
create_inlet_2 = CreateTableRequest(
name="lineage-test-inlet2",
databaseSchema=create_schema_entity.fullyQualifiedName,
columns=[Column(name="id", dataType=DataType.BIGINT)],
)
create_outlet = CreateTableRequest(
name="lineage-test-outlet",
databaseSchema=create_schema_entity.fullyQualifiedName,
columns=[Column(name="id", dataType=DataType.BIGINT)],
)
cls.metadata.create_or_update(data=create_inlet)
cls.metadata.create_or_update(data=create_inlet_2)
cls.table_outlet = cls.metadata.create_or_update(data=create_outlet)
@classmethod
def tearDownClass(cls) -> None:
"""
Clean up
"""
service_id = str(
cls.metadata.get_by_name(
entity=DatabaseService, fqn="test-service-table-lineage"
).id.root
)
cls.metadata.delete(
entity=DatabaseService,
entity_id=service_id,
recursive=True,
hard_delete=True,
)
# Service ID created from the Airflow Lineage Operator in the
# example DAG
pipeline_service_id = str(
cls.metadata.get_by_name(
entity=PipelineService, fqn=PIPELINE_SERVICE_NAME
).id.root
)
cls.metadata.delete(
entity=PipelineService,
entity_id=pipeline_service_id,
recursive=True,
hard_delete=True,
)
def test_dag_runs(self) -> None:
"""
Trigger the Airflow DAG and wait until it runs.
Note that the DAG definition is in examples/airflow_lineage_operator.py
and it is expected to fail. This will allow us to validate
the task status afterward.
"""
# 1. Validate that the OpenMetadata connection exists
res = requests.get(
AIRFLOW_HOST_API_ROOT + f"connections/{DEFAULT_OM_AIRFLOW_CONNECTION}",
headers=DEFAULT_AIRFLOW_HEADERS,
)
if res.status_code != 200:
raise RuntimeError(
f"Could not fetch {DEFAULT_OM_AIRFLOW_CONNECTION} connection"
)
# 2. Enable the DAG
res = requests.patch(
AIRFLOW_HOST_API_ROOT + f"dags/{OM_LINEAGE_DAG_NAME}",
json={"is_paused": False},
headers=DEFAULT_AIRFLOW_HEADERS,
)
if res.status_code != 200:
raise RuntimeError(f"Could not enable {OM_LINEAGE_DAG_NAME} DAG")
# 3. Trigger the DAG
res = requests.post(
AIRFLOW_HOST_API_ROOT + f"dags/{OM_LINEAGE_DAG_NAME}/dagRuns",
json={},
headers=DEFAULT_AIRFLOW_HEADERS,
)
if res.status_code != 200:
raise RuntimeError(f"Could not trigger {OM_LINEAGE_DAG_NAME} DAG")
dag_run_id = res.json()["dag_run_id"]
# 4. Wait until the DAG is flagged as `successful`
state = "queued"
tries = 0
while state != "success" and tries <= 5:
tries += 1
time.sleep(5)
res = requests.get(
AIRFLOW_HOST_API_ROOT
+ f"dags/{OM_LINEAGE_DAG_NAME}/dagRuns/{dag_run_id}",
headers=DEFAULT_AIRFLOW_HEADERS,
)
state = res.json().get("state")
if state != "success":
raise RuntimeError(f"DAG {OM_LINEAGE_DAG_NAME} has not finished on time.")
def test_pipeline_created(self) -> None:
"""
Validate that the pipeline has been created
"""
pipeline_service: PipelineService = self.metadata.get_by_name(
entity=PipelineService, fqn=PIPELINE_SERVICE_NAME
)
self.assertIsNotNone(pipeline_service)
pipeline: Pipeline = self.metadata.get_by_name(
entity=Pipeline,
fqn=f"{PIPELINE_SERVICE_NAME}.{OM_LINEAGE_DAG_NAME}",
fields=["tasks", "pipelineStatus"],
)
self.assertIsNotNone(pipeline)
expected_task_names = set((task.name for task in pipeline.tasks))
self.assertEqual(
expected_task_names, {"print_date", "sleep", "templated", "lineage_op"}
)
self.assertEqual(pipeline.description.root, "A simple tutorial DAG")
# Validate status
self.assertEqual(
get_task_status_type_by_name(pipeline, "print_date"), StatusType.Successful
)
self.assertEqual(
get_task_status_type_by_name(pipeline, "sleep"), StatusType.Successful
)
self.assertEqual(
get_task_status_type_by_name(pipeline, "templated"), StatusType.Successful
)
def test_pipeline_lineage(self) -> None:
"""
Validate that the pipeline has proper lineage
"""
root_name = "test-service-table-lineage.test-db.test-schema"
# Check that both inlets have the same outlet
for inlet_table in [
f"{root_name}.lineage-test-inlet",
f"{root_name}.lineage-test-inlet2",
]:
lineage = self.metadata.get_lineage_by_name(
entity=Table,
fqn=inlet_table,
)
node_names = set((node["name"] for node in lineage.get("nodes") or []))
self.assertEqual(node_names, {"lineage-test-outlet"})
self.assertEqual(len(lineage.get("downstreamEdges")), 1)
self.assertEqual(
lineage["downstreamEdges"][0]["toEntity"],
str(self.table_outlet.id.root),
)
self.assertEqual(
lineage["downstreamEdges"][0]["lineageDetails"]["pipeline"][
"fullyQualifiedName"
],
f"{PIPELINE_SERVICE_NAME}.{OM_LINEAGE_DAG_NAME}",
)