ci: separate airflow build and test (#8688)

Co-authored-by: Harshal Sheth <hsheth2@gmail.com>
This commit is contained in:
Mayuri Nehate 2023-08-31 02:38:42 +05:30 committed by GitHub
parent 1282e5bf93
commit e867dbc3da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
52 changed files with 2037 additions and 1874 deletions

85
.github/workflows/airflow-plugin.yml vendored Normal file
View File

@ -0,0 +1,85 @@
name: Airflow Plugin
on:
push:
branches:
- master
paths:
- ".github/workflows/airflow-plugin.yml"
- "metadata-ingestion-modules/airflow-plugin/**"
- "metadata-ingestion/**"
- "metadata-models/**"
pull_request:
branches:
- master
paths:
- ".github/**"
- "metadata-ingestion-modules/airflow-plugin/**"
- "metadata-ingestion/**"
- "metadata-models/**"
release:
types: [published]
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
jobs:
airflow-plugin:
runs-on: ubuntu-latest
env:
SPARK_VERSION: 3.0.3
DATAHUB_TELEMETRY_ENABLED: false
strategy:
matrix:
include:
- python-version: "3.7"
extraPythonRequirement: "apache-airflow~=2.1.0"
- python-version: "3.7"
extraPythonRequirement: "apache-airflow~=2.2.0"
- python-version: "3.10"
extraPythonRequirement: "apache-airflow~=2.4.0"
- python-version: "3.10"
extraPythonRequirement: "apache-airflow~=2.6.0"
- python-version: "3.10"
extraPythonRequirement: "apache-airflow>2.6.0"
fail-fast: false
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: "pip"
- name: Install dependencies
run: ./metadata-ingestion/scripts/install_deps.sh
- name: Install airflow package and test (extras ${{ matrix.extraPythonRequirement }})
run: ./gradlew -Pextra_pip_requirements='${{ matrix.extraPythonRequirement }}' :metadata-ingestion-modules:airflow-plugin:lint :metadata-ingestion-modules:airflow-plugin:testQuick
- name: pip freeze show list installed
if: always()
run: source metadata-ingestion-modules/airflow-plugin/venv/bin/activate && pip freeze
- uses: actions/upload-artifact@v3
if: ${{ always() && matrix.python-version == '3.10' && matrix.extraPythonRequirement == 'apache-airflow>2.6.0' }}
with:
name: Test Results (Airflow Plugin ${{ matrix.python-version}})
path: |
**/build/reports/tests/test/**
**/build/test-results/test/**
**/junit.*.xml
- name: Upload coverage to Codecov
if: always()
uses: codecov/codecov-action@v3
with:
token: ${{ secrets.CODECOV_TOKEN }}
directory: .
fail_ci_if_error: false
flags: airflow-${{ matrix.python-version }}-${{ matrix.extraPythonRequirement }}
name: pytest-airflow
verbose: true
event-file:
runs-on: ubuntu-latest
steps:
- name: Upload
uses: actions/upload-artifact@v3
with:
name: Event File
path: ${{ github.event_path }}

View File

@ -42,9 +42,7 @@ jobs:
]
include:
- python-version: "3.7"
extraPythonRequirement: "sqlalchemy==1.3.24 apache-airflow~=2.2.0"
- python-version: "3.10"
extraPythonRequirement: "sqlalchemy~=1.4.0 apache-airflow>=2.4.0"
fail-fast: false
steps:
- uses: actions/checkout@v3
@ -56,8 +54,8 @@ jobs:
run: ./metadata-ingestion/scripts/install_deps.sh
- name: Install package
run: ./gradlew :metadata-ingestion:installPackageOnly
- name: Run metadata-ingestion tests (extras ${{ matrix.extraPythonRequirement }})
run: ./gradlew -Pextra_pip_requirements='${{ matrix.extraPythonRequirement }}' :metadata-ingestion:${{ matrix.command }}
- name: Run metadata-ingestion tests
run: ./gradlew :metadata-ingestion:${{ matrix.command }}
- name: pip freeze show list installed
if: always()
run: source metadata-ingestion/venv/bin/activate && pip freeze
@ -80,7 +78,6 @@ jobs:
name: pytest-${{ matrix.command }}
verbose: true
event-file:
runs-on: ubuntu-latest
steps:

View File

@ -2,7 +2,7 @@ name: Test Results
on:
workflow_run:
workflows: ["build & test", "metadata ingestion"]
workflows: ["build & test", "metadata ingestion", "Airflow Plugin"]
types:
- completed

View File

@ -65,7 +65,7 @@ lazy_load_plugins = False
| datahub.capture_executions | true | If true, we'll capture task runs in DataHub in addition to DAG definitions. |
| datahub.graceful_exceptions | true | If set to true, most runtime errors in the lineage backend will be suppressed and will not cause the overall task to fail. Note that configuration issues will still throw exceptions. |
5. Configure `inlets` and `outlets` for your Airflow operators. For reference, look at the sample DAG in [`lineage_backend_demo.py`](../../metadata-ingestion/src/datahub_provider/example_dags/lineage_backend_demo.py), or reference [`lineage_backend_taskflow_demo.py`](../../metadata-ingestion/src/datahub_provider/example_dags/lineage_backend_taskflow_demo.py) if you're using the [TaskFlow API](https://airflow.apache.org/docs/apache-airflow/stable/concepts/taskflow.html).
5. Configure `inlets` and `outlets` for your Airflow operators. For reference, look at the sample DAG in [`lineage_backend_demo.py`](../../metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/lineage_backend_demo.py), or reference [`lineage_backend_taskflow_demo.py`](../../metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/lineage_backend_taskflow_demo.py) if you're using the [TaskFlow API](https://airflow.apache.org/docs/apache-airflow/stable/concepts/taskflow.html).
6. [optional] Learn more about [Airflow lineage](https://airflow.apache.org/docs/apache-airflow/stable/lineage.html), including shorthand notation and some automation.
### How to validate installation
@ -160,14 +160,14 @@ pip install acryl-datahub[airflow,datahub-kafka]
- `capture_executions` (defaults to false): If true, it captures task runs as DataHub DataProcessInstances.
- `graceful_exceptions` (defaults to true): If set to true, most runtime errors in the lineage backend will be suppressed and will not cause the overall task to fail. Note that configuration issues will still throw exceptions.
4. Configure `inlets` and `outlets` for your Airflow operators. For reference, look at the sample DAG in [`lineage_backend_demo.py`](../../metadata-ingestion/src/datahub_provider/example_dags/lineage_backend_demo.py), or reference [`lineage_backend_taskflow_demo.py`](../../metadata-ingestion/src/datahub_provider/example_dags/lineage_backend_taskflow_demo.py) if you're using the [TaskFlow API](https://airflow.apache.org/docs/apache-airflow/stable/concepts/taskflow.html).
4. Configure `inlets` and `outlets` for your Airflow operators. For reference, look at the sample DAG in [`lineage_backend_demo.py`](../../metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/lineage_backend_demo.py), or reference [`lineage_backend_taskflow_demo.py`](../../metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/lineage_backend_taskflow_demo.py) if you're using the [TaskFlow API](https://airflow.apache.org/docs/apache-airflow/stable/concepts/taskflow.html).
5. [optional] Learn more about [Airflow lineage](https://airflow.apache.org/docs/apache-airflow/stable/lineage.html), including shorthand notation and some automation.
## Emitting lineage via a separate operator
Take a look at this sample DAG:
- [`lineage_emission_dag.py`](../../metadata-ingestion/src/datahub_provider/example_dags/lineage_emission_dag.py) - emits lineage using the DatahubEmitterOperator.
- [`lineage_emission_dag.py`](../../metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/lineage_emission_dag.py) - emits lineage using the DatahubEmitterOperator.
In order to use this example, you must first configure the Datahub hook. Like in ingestion, we support a Datahub REST hook and a Kafka-based hook. See step 1 above for details.

View File

@ -7,6 +7,10 @@ ext {
venv_name = 'venv'
}
if (!project.hasProperty("extra_pip_requirements")) {
ext.extra_pip_requirements = ""
}
def pip_install_command = "${venv_name}/bin/pip install -e ../../metadata-ingestion"
task checkPythonVersion(type: Exec) {
@ -14,30 +18,37 @@ task checkPythonVersion(type: Exec) {
}
task environmentSetup(type: Exec, dependsOn: checkPythonVersion) {
def sentinel_file = "${venv_name}/.venv_environment_sentinel"
inputs.file file('setup.py')
outputs.dir("${venv_name}")
commandLine 'bash', '-c', "${python_executable} -m venv ${venv_name} && ${venv_name}/bin/python -m pip install --upgrade pip wheel 'setuptools>=63.0.0'"
outputs.file(sentinel_file)
commandLine 'bash', '-c',
"${python_executable} -m venv ${venv_name} &&" +
"${venv_name}/bin/python -m pip install --upgrade pip wheel 'setuptools>=63.0.0' && " +
"touch ${sentinel_file}"
}
task installPackage(type: Exec, dependsOn: environmentSetup) {
task installPackage(type: Exec, dependsOn: [environmentSetup, ':metadata-ingestion:codegen']) {
def sentinel_file = "${venv_name}/.build_install_package_sentinel"
inputs.file file('setup.py')
outputs.dir("${venv_name}")
outputs.file(sentinel_file)
// Workaround for https://github.com/yaml/pyyaml/issues/601.
// See https://github.com/yaml/pyyaml/issues/601#issuecomment-1638509577.
// and https://github.com/datahub-project/datahub/pull/8435.
commandLine 'bash', '-x', '-c',
"${pip_install_command} install 'Cython<3.0' 'PyYAML<6' --no-build-isolation && " +
"${pip_install_command} -e ."
"${pip_install_command} -e . ${extra_pip_requirements} &&" +
"touch ${sentinel_file}"
}
task install(dependsOn: [installPackage])
task installDev(type: Exec, dependsOn: [install]) {
def sentinel_file = "${venv_name}/.build_install_dev_sentinel"
inputs.file file('setup.py')
outputs.dir("${venv_name}")
outputs.file("${venv_name}/.build_install_dev_sentinel")
outputs.file("${sentinel_file}")
commandLine 'bash', '-x', '-c',
"${pip_install_command} -e .[dev] && touch ${venv_name}/.build_install_dev_sentinel"
"${pip_install_command} -e .[dev] ${extra_pip_requirements} && " +
"touch ${sentinel_file}"
}
task lint(type: Exec, dependsOn: installDev) {
@ -45,9 +56,13 @@ task lint(type: Exec, dependsOn: installDev) {
The find/sed combo below is a temporary work-around for the following mypy issue with airflow 2.2.0:
"venv/lib/python3.8/site-packages/airflow/_vendor/connexion/spec.py:169: error: invalid syntax".
*/
commandLine 'bash', '-x', '-c',
commandLine 'bash', '-c',
"find ${venv_name}/lib -path *airflow/_vendor/connexion/spec.py -exec sed -i.bak -e '169,169s/ # type: List\\[str\\]//g' {} \\; && " +
"source ${venv_name}/bin/activate && black --check --diff src/ tests/ && isort --check --diff src/ tests/ && flake8 --count --statistics src/ tests/ && mypy src/ tests/"
"source ${venv_name}/bin/activate && set -x && " +
"black --check --diff src/ tests/ && " +
"isort --check --diff src/ tests/ && " +
"flake8 --count --statistics src/ tests/ && " +
"mypy --show-traceback --show-error-codes src/ tests/"
}
task lintFix(type: Exec, dependsOn: installDev) {
commandLine 'bash', '-x', '-c',
@ -58,21 +73,13 @@ task lintFix(type: Exec, dependsOn: installDev) {
"mypy src/ tests/ "
}
task testQuick(type: Exec, dependsOn: installDev) {
// We can't enforce the coverage requirements if we run a subset of the tests.
inputs.files(project.fileTree(dir: "src/", include: "**/*.py"))
inputs.files(project.fileTree(dir: "tests/"))
outputs.dir("${venv_name}")
commandLine 'bash', '-x', '-c',
"source ${venv_name}/bin/activate && pytest -vv --continue-on-collection-errors --junit-xml=junit.quick.xml"
}
task installDevTest(type: Exec, dependsOn: [installDev]) {
def sentinel_file = "${venv_name}/.build_install_dev_test_sentinel"
inputs.file file('setup.py')
outputs.dir("${venv_name}")
outputs.file("${venv_name}/.build_install_dev_test_sentinel")
outputs.file("${sentinel_file}")
commandLine 'bash', '-x', '-c',
"${pip_install_command} -e .[dev,integration-tests] && touch ${venv_name}/.build_install_dev_test_sentinel"
"${pip_install_command} -e .[dev,integration-tests] && touch ${sentinel_file}"
}
def testFile = hasProperty('testFile') ? testFile : 'unknown'
@ -89,6 +96,16 @@ task testSingle(dependsOn: [installDevTest]) {
}
}
task testQuick(type: Exec, dependsOn: installDevTest) {
// We can't enforce the coverage requirements if we run a subset of the tests.
inputs.files(project.fileTree(dir: "src/", include: "**/*.py"))
inputs.files(project.fileTree(dir: "tests/"))
outputs.dir("${venv_name}")
commandLine 'bash', '-x', '-c',
"source ${venv_name}/bin/activate && pytest -vv --continue-on-collection-errors --junit-xml=junit.quick.xml"
}
task testFull(type: Exec, dependsOn: [testQuick, installDevTest]) {
commandLine 'bash', '-x', '-c',
"source ${venv_name}/bin/activate && pytest -m 'not slow_integration' -vv --continue-on-collection-errors --junit-xml=junit.full.xml"

View File

@ -9,7 +9,6 @@ extend-exclude = '''
^/tmp
'''
include = '\.pyi?$'
target-version = ['py36', 'py37', 'py38']
[tool.isort]
indent = ' '

View File

@ -69,4 +69,6 @@ exclude_lines =
pragma: no cover
@abstract
if TYPE_CHECKING:
#omit =
omit =
# omit example dags
src/datahub_airflow_plugin/example_dags/*

View File

@ -13,16 +13,21 @@ def get_long_description():
return pathlib.Path(os.path.join(root, "README.md")).read_text()
rest_common = {"requests", "requests_file"}
base_requirements = {
# Compatibility.
"dataclasses>=0.6; python_version < '3.7'",
"typing_extensions>=3.10.0.2",
# Typing extension should be >=3.10.0.2 ideally but we can't restrict due to Airflow 2.0.2 dependency conflict
"typing_extensions>=3.7.4.3 ; python_version < '3.8'",
"typing_extensions>=3.10.0.2,<4.6.0 ; python_version >= '3.8'",
"mypy_extensions>=0.4.3",
# Actual dependencies.
"typing-inspect",
"pydantic>=1.5.1",
"apache-airflow >= 2.0.2",
f"acryl-datahub[airflow] == {package_metadata['__version__']}",
*rest_common,
f"acryl-datahub == {package_metadata['__version__']}",
}
@ -47,19 +52,18 @@ mypy_stubs = {
base_dev_requirements = {
*base_requirements,
*mypy_stubs,
"black>=21.12b0",
"black==22.12.0",
"coverage>=5.1",
"flake8>=3.8.3",
"flake8-tidy-imports>=4.3.0",
"isort>=5.7.0",
"mypy>=0.920",
"mypy>=1.4.0",
# pydantic 1.8.2 is incompatible with mypy 0.910.
# See https://github.com/samuelcolvin/pydantic/pull/3175#issuecomment-995382910.
"pydantic>=1.9.0",
"pydantic>=1.10",
"pytest>=6.2.2",
"pytest-asyncio>=0.16.0",
"pytest-cov>=2.8.1",
"pytest-docker>=0.10.3,<0.12",
"tox",
"deepdiff",
"requests-mock",
@ -127,5 +131,13 @@ setuptools.setup(
"datahub-kafka": [
f"acryl-datahub[datahub-kafka] == {package_metadata['__version__']}"
],
"integration-tests": [
f"acryl-datahub[datahub-kafka] == {package_metadata['__version__']}",
# Extra requirements for Airflow.
"apache-airflow[snowflake]>=2.0.2", # snowflake is used in example dags
# Because of https://github.com/snowflakedb/snowflake-sqlalchemy/issues/350 we need to restrict SQLAlchemy's max version.
"SQLAlchemy<1.4.42",
"virtualenv", # needed by PythonVirtualenvOperator
],
},
)

View File

@ -0,0 +1,12 @@
# This module must be imported before any Airflow imports in any of our files.
# The AIRFLOW_PATCHED just helps avoid flake8 errors.
from datahub.utilities._markupsafe_compat import MARKUPSAFE_PATCHED
assert MARKUPSAFE_PATCHED
AIRFLOW_PATCHED = True
__all__ = [
"AIRFLOW_PATCHED",
]

View File

@ -0,0 +1,29 @@
from airflow.models.baseoperator import BaseOperator
from datahub_airflow_plugin._airflow_compat import AIRFLOW_PATCHED
try:
from airflow.models.mappedoperator import MappedOperator
from airflow.models.operator import Operator
from airflow.operators.empty import EmptyOperator
except ModuleNotFoundError:
# Operator isn't a real class, but rather a type alias defined
# as the union of BaseOperator and MappedOperator.
# Since older versions of Airflow don't have MappedOperator, we can just use BaseOperator.
Operator = BaseOperator # type: ignore
MappedOperator = None # type: ignore
from airflow.operators.dummy import DummyOperator as EmptyOperator # type: ignore
try:
from airflow.sensors.external_task import ExternalTaskSensor
except ImportError:
from airflow.sensors.external_task_sensor import ExternalTaskSensor # type: ignore
assert AIRFLOW_PATCHED
__all__ = [
"Operator",
"MappedOperator",
"EmptyOperator",
"ExternalTaskSensor",
]

View File

@ -0,0 +1,115 @@
from datetime import datetime
from typing import TYPE_CHECKING, Dict, List
import datahub.emitter.mce_builder as builder
from datahub.api.entities.dataprocess.dataprocess_instance import InstanceRunResult
from datahub.configuration.common import ConfigModel
from datahub.utilities.urns.dataset_urn import DatasetUrn
from datahub_airflow_plugin.client.airflow_generator import AirflowGenerator
from datahub_airflow_plugin.entities import _Entity
if TYPE_CHECKING:
from airflow import DAG
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from datahub_airflow_plugin._airflow_shims import Operator
from datahub_airflow_plugin.hooks.datahub import DatahubGenericHook
def _entities_to_urn_list(iolets: List[_Entity]) -> List[DatasetUrn]:
return [DatasetUrn.create_from_string(let.urn) for let in iolets]
class DatahubBasicLineageConfig(ConfigModel):
enabled: bool = True
# DataHub hook connection ID.
datahub_conn_id: str
# Cluster to associate with the pipelines and tasks. Defaults to "prod".
cluster: str = builder.DEFAULT_FLOW_CLUSTER
# If true, the owners field of the DAG will be capture as a DataHub corpuser.
capture_ownership_info: bool = True
# If true, the tags field of the DAG will be captured as DataHub tags.
capture_tags_info: bool = True
capture_executions: bool = False
def make_emitter_hook(self) -> "DatahubGenericHook":
# This is necessary to avoid issues with circular imports.
from datahub_airflow_plugin.hooks.datahub import DatahubGenericHook
return DatahubGenericHook(self.datahub_conn_id)
def send_lineage_to_datahub(
config: DatahubBasicLineageConfig,
operator: "Operator",
inlets: List[_Entity],
outlets: List[_Entity],
context: Dict,
) -> None:
if not config.enabled:
return
dag: "DAG" = context["dag"]
task: "Operator" = context["task"]
ti: "TaskInstance" = context["task_instance"]
hook = config.make_emitter_hook()
emitter = hook.make_emitter()
dataflow = AirflowGenerator.generate_dataflow(
cluster=config.cluster,
dag=dag,
capture_tags=config.capture_tags_info,
capture_owner=config.capture_ownership_info,
)
dataflow.emit(emitter)
operator.log.info(f"Emitted from Lineage: {dataflow}")
datajob = AirflowGenerator.generate_datajob(
cluster=config.cluster,
task=task,
dag=dag,
capture_tags=config.capture_tags_info,
capture_owner=config.capture_ownership_info,
)
datajob.inlets.extend(_entities_to_urn_list(inlets))
datajob.outlets.extend(_entities_to_urn_list(outlets))
datajob.emit(emitter)
operator.log.info(f"Emitted from Lineage: {datajob}")
if config.capture_executions:
dag_run: "DagRun" = context["dag_run"]
dpi = AirflowGenerator.run_datajob(
emitter=emitter,
cluster=config.cluster,
ti=ti,
dag=dag,
dag_run=dag_run,
datajob=datajob,
emit_templates=False,
)
operator.log.info(f"Emitted from Lineage: {dpi}")
dpi = AirflowGenerator.complete_datajob(
emitter=emitter,
cluster=config.cluster,
ti=ti,
dag=dag,
dag_run=dag_run,
datajob=datajob,
result=InstanceRunResult.SUCCESS,
end_timestamp_millis=int(datetime.utcnow().timestamp() * 1000),
)
operator.log.info(f"Emitted from Lineage: {dpi}")
emitter.flush()

View File

@ -0,0 +1,512 @@
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast
from airflow.configuration import conf
from datahub.api.entities.datajob import DataFlow, DataJob
from datahub.api.entities.dataprocess.dataprocess_instance import (
DataProcessInstance,
InstanceRunResult,
)
from datahub.metadata.schema_classes import DataProcessTypeClass
from datahub.utilities.urns.data_flow_urn import DataFlowUrn
from datahub.utilities.urns.data_job_urn import DataJobUrn
from datahub_airflow_plugin._airflow_compat import AIRFLOW_PATCHED
assert AIRFLOW_PATCHED
if TYPE_CHECKING:
from airflow import DAG
from airflow.models import DagRun, TaskInstance
from datahub.emitter.kafka_emitter import DatahubKafkaEmitter
from datahub.emitter.rest_emitter import DatahubRestEmitter
from datahub_airflow_plugin._airflow_shims import Operator
def _task_downstream_task_ids(operator: "Operator") -> Set[str]:
if hasattr(operator, "downstream_task_ids"):
return operator.downstream_task_ids
return operator._downstream_task_id # type: ignore[attr-defined,union-attr]
class AirflowGenerator:
@staticmethod
def _get_dependencies(
task: "Operator", dag: "DAG", flow_urn: DataFlowUrn
) -> List[DataJobUrn]:
from datahub_airflow_plugin._airflow_shims import ExternalTaskSensor
# resolve URNs for upstream nodes in subdags upstream of the current task.
upstream_subdag_task_urns: List[DataJobUrn] = []
for upstream_task_id in task.upstream_task_ids:
upstream_task = dag.task_dict[upstream_task_id]
# if upstream task is not a subdag, then skip it
upstream_subdag = getattr(upstream_task, "subdag", None)
if upstream_subdag is None:
continue
# else, link the leaf tasks of the upstream subdag as upstream tasks
for upstream_subdag_task_id in upstream_subdag.task_dict:
upstream_subdag_task = upstream_subdag.task_dict[
upstream_subdag_task_id
]
upstream_subdag_task_urn = DataJobUrn.create_from_ids(
job_id=upstream_subdag_task_id, data_flow_urn=str(flow_urn)
)
# if subdag task is a leaf task, then link it as an upstream task
if len(_task_downstream_task_ids(upstream_subdag_task)) == 0:
upstream_subdag_task_urns.append(upstream_subdag_task_urn)
# resolve URNs for upstream nodes that trigger the subdag containing the current task.
# (if it is in a subdag at all)
upstream_subdag_triggers: List[DataJobUrn] = []
# subdags are always named with 'parent.child' style or Airflow won't run them
# add connection from subdag trigger(s) if subdag task has no upstreams
if (
dag.is_subdag
and dag.parent_dag is not None
and len(task.upstream_task_ids) == 0
):
# filter through the parent dag's tasks and find the subdag trigger(s)
subdags = [
x for x in dag.parent_dag.task_dict.values() if x.subdag is not None
]
matched_subdags = [
x for x in subdags if x.subdag and x.subdag.dag_id == dag.dag_id
]
# id of the task containing the subdag
subdag_task_id = matched_subdags[0].task_id
# iterate through the parent dag's tasks and find the ones that trigger the subdag
for upstream_task_id in dag.parent_dag.task_dict:
upstream_task = dag.parent_dag.task_dict[upstream_task_id]
upstream_task_urn = DataJobUrn.create_from_ids(
data_flow_urn=str(flow_urn), job_id=upstream_task_id
)
# if the task triggers the subdag, link it to this node in the subdag
if subdag_task_id in _task_downstream_task_ids(upstream_task):
upstream_subdag_triggers.append(upstream_task_urn)
# If the operator is an ExternalTaskSensor then we set the remote task as upstream.
# It is possible to tie an external sensor to DAG if external_task_id is omitted but currently we can't tie
# jobflow to anothet jobflow.
external_task_upstreams = []
if task.task_type == "ExternalTaskSensor":
task = cast(ExternalTaskSensor, task)
if hasattr(task, "external_task_id") and task.external_task_id is not None:
external_task_upstreams = [
DataJobUrn.create_from_ids(
job_id=task.external_task_id,
data_flow_urn=str(
DataFlowUrn.create_from_ids(
orchestrator=flow_urn.get_orchestrator_name(),
flow_id=task.external_dag_id,
env=flow_urn.get_env(),
)
),
)
]
# exclude subdag operator tasks since these are not emitted, resulting in empty metadata
upstream_tasks = (
[
DataJobUrn.create_from_ids(job_id=task_id, data_flow_urn=str(flow_urn))
for task_id in task.upstream_task_ids
if getattr(dag.task_dict[task_id], "subdag", None) is None
]
+ upstream_subdag_task_urns
+ upstream_subdag_triggers
+ external_task_upstreams
)
return upstream_tasks
@staticmethod
def generate_dataflow(
cluster: str,
dag: "DAG",
capture_owner: bool = True,
capture_tags: bool = True,
) -> DataFlow:
"""
Generates a Dataflow object from an Airflow DAG
:param cluster: str - name of the cluster
:param dag: DAG -
:param capture_tags:
:param capture_owner:
:return: DataFlow - Data generated dataflow
"""
id = dag.dag_id
orchestrator = "airflow"
description = f"{dag.description}\n\n{dag.doc_md or ''}"
data_flow = DataFlow(
env=cluster, id=id, orchestrator=orchestrator, description=description
)
flow_property_bag: Dict[str, str] = {}
allowed_flow_keys = [
"_access_control",
"_concurrency",
"_default_view",
"catchup",
"fileloc",
"is_paused_upon_creation",
"start_date",
"tags",
"timezone",
]
for key in allowed_flow_keys:
if hasattr(dag, key):
flow_property_bag[key] = repr(getattr(dag, key))
data_flow.properties = flow_property_bag
base_url = conf.get("webserver", "base_url")
data_flow.url = f"{base_url}/tree?dag_id={dag.dag_id}"
if capture_owner and dag.owner:
data_flow.owners.add(dag.owner)
if capture_tags and dag.tags:
data_flow.tags.update(dag.tags)
return data_flow
@staticmethod
def _get_description(task: "Operator") -> Optional[str]:
from airflow.models.baseoperator import BaseOperator
if not isinstance(task, BaseOperator):
# TODO: Get docs for mapped operators.
return None
if hasattr(task, "doc") and task.doc:
return task.doc
elif hasattr(task, "doc_md") and task.doc_md:
return task.doc_md
elif hasattr(task, "doc_json") and task.doc_json:
return task.doc_json
elif hasattr(task, "doc_yaml") and task.doc_yaml:
return task.doc_yaml
elif hasattr(task, "doc_rst") and task.doc_yaml:
return task.doc_yaml
return None
@staticmethod
def generate_datajob(
cluster: str,
task: "Operator",
dag: "DAG",
set_dependencies: bool = True,
capture_owner: bool = True,
capture_tags: bool = True,
) -> DataJob:
"""
:param cluster: str
:param task: TaskIntance
:param dag: DAG
:param set_dependencies: bool - whether to extract dependencies from airflow task
:param capture_owner: bool - whether to extract owner from airflow task
:param capture_tags: bool - whether to set tags automatically from airflow task
:return: DataJob - returns the generated DataJob object
"""
dataflow_urn = DataFlowUrn.create_from_ids(
orchestrator="airflow", env=cluster, flow_id=dag.dag_id
)
datajob = DataJob(id=task.task_id, flow_urn=dataflow_urn)
# TODO add support for MappedOperator
datajob.description = AirflowGenerator._get_description(task)
job_property_bag: Dict[str, str] = {}
allowed_task_keys = [
"_downstream_task_ids",
"_inlets",
"_outlets",
"_task_type",
"_task_module",
"depends_on_past",
"email",
"label",
"execution_timeout",
"sla",
"sql",
"task_id",
"trigger_rule",
"wait_for_downstream",
# In Airflow 2.3, _downstream_task_ids was renamed to downstream_task_ids
"downstream_task_ids",
# In Airflow 2.4, _inlets and _outlets were removed in favor of non-private versions.
"inlets",
"outlets",
]
for key in allowed_task_keys:
if hasattr(task, key):
job_property_bag[key] = repr(getattr(task, key))
datajob.properties = job_property_bag
base_url = conf.get("webserver", "base_url")
datajob.url = f"{base_url}/taskinstance/list/?flt1_dag_id_equals={datajob.flow_urn.get_flow_id()}&_flt_3_task_id={task.task_id}"
if capture_owner and dag.owner:
datajob.owners.add(dag.owner)
if capture_tags and dag.tags:
datajob.tags.update(dag.tags)
if set_dependencies:
datajob.upstream_urns.extend(
AirflowGenerator._get_dependencies(
task=task, dag=dag, flow_urn=datajob.flow_urn
)
)
return datajob
@staticmethod
def create_datajob_instance(
cluster: str,
task: "Operator",
dag: "DAG",
data_job: Optional[DataJob] = None,
) -> DataProcessInstance:
if data_job is None:
data_job = AirflowGenerator.generate_datajob(cluster, task=task, dag=dag)
dpi = DataProcessInstance.from_datajob(
datajob=data_job, id=task.task_id, clone_inlets=True, clone_outlets=True
)
return dpi
@staticmethod
def run_dataflow(
emitter: Union["DatahubRestEmitter", "DatahubKafkaEmitter"],
cluster: str,
dag_run: "DagRun",
start_timestamp_millis: Optional[int] = None,
dataflow: Optional[DataFlow] = None,
) -> None:
if dataflow is None:
assert dag_run.dag
dataflow = AirflowGenerator.generate_dataflow(cluster, dag_run.dag)
if start_timestamp_millis is None:
assert dag_run.execution_date
start_timestamp_millis = int(dag_run.execution_date.timestamp() * 1000)
assert dag_run.run_id
dpi = DataProcessInstance.from_dataflow(dataflow=dataflow, id=dag_run.run_id)
# This property only exists in Airflow2
if hasattr(dag_run, "run_type"):
from airflow.utils.types import DagRunType
if dag_run.run_type == DagRunType.SCHEDULED:
dpi.type = DataProcessTypeClass.BATCH_SCHEDULED
elif dag_run.run_type == DagRunType.MANUAL:
dpi.type = DataProcessTypeClass.BATCH_AD_HOC
else:
if dag_run.run_id.startswith("scheduled__"):
dpi.type = DataProcessTypeClass.BATCH_SCHEDULED
else:
dpi.type = DataProcessTypeClass.BATCH_AD_HOC
property_bag: Dict[str, str] = {}
property_bag["run_id"] = str(dag_run.run_id)
property_bag["execution_date"] = str(dag_run.execution_date)
property_bag["end_date"] = str(dag_run.end_date)
property_bag["start_date"] = str(dag_run.start_date)
property_bag["creating_job_id"] = str(dag_run.creating_job_id)
# These properties only exists in Airflow>=2.2.0
if hasattr(dag_run, "data_interval_start") and hasattr(
dag_run, "data_interval_end"
):
property_bag["data_interval_start"] = str(dag_run.data_interval_start)
property_bag["data_interval_end"] = str(dag_run.data_interval_end)
property_bag["external_trigger"] = str(dag_run.external_trigger)
dpi.properties.update(property_bag)
dpi.emit_process_start(
emitter=emitter, start_timestamp_millis=start_timestamp_millis
)
@staticmethod
def complete_dataflow(
emitter: Union["DatahubRestEmitter", "DatahubKafkaEmitter"],
cluster: str,
dag_run: "DagRun",
end_timestamp_millis: Optional[int] = None,
dataflow: Optional[DataFlow] = None,
) -> None:
"""
:param emitter: DatahubRestEmitter - the datahub rest emitter to emit the generated mcps
:param cluster: str - name of the cluster
:param dag_run: DagRun
:param end_timestamp_millis: Optional[int] - the completion time in milliseconds if not set the current time will be used.
:param dataflow: Optional[Dataflow]
"""
if dataflow is None:
assert dag_run.dag
dataflow = AirflowGenerator.generate_dataflow(cluster, dag_run.dag)
assert dag_run.run_id
dpi = DataProcessInstance.from_dataflow(dataflow=dataflow, id=dag_run.run_id)
if end_timestamp_millis is None:
if dag_run.end_date is None:
raise Exception(
f"Dag {dag_run.dag_id}_{dag_run.run_id} is still running and unable to get end_date..."
)
end_timestamp_millis = int(dag_run.end_date.timestamp() * 1000)
# We should use DagRunState but it is not available in Airflow 1
if dag_run.state == "success":
result = InstanceRunResult.SUCCESS
elif dag_run.state == "failed":
result = InstanceRunResult.FAILURE
else:
raise Exception(
f"Result should be either success or failure and it was {dag_run.state}"
)
dpi.emit_process_end(
emitter=emitter,
end_timestamp_millis=end_timestamp_millis,
result=result,
result_type="airflow",
)
@staticmethod
def run_datajob(
emitter: Union["DatahubRestEmitter", "DatahubKafkaEmitter"],
cluster: str,
ti: "TaskInstance",
dag: "DAG",
dag_run: "DagRun",
start_timestamp_millis: Optional[int] = None,
datajob: Optional[DataJob] = None,
attempt: Optional[int] = None,
emit_templates: bool = True,
) -> DataProcessInstance:
if datajob is None:
datajob = AirflowGenerator.generate_datajob(cluster, ti.task, dag)
assert dag_run.run_id
dpi = DataProcessInstance.from_datajob(
datajob=datajob,
id=f"{dag.dag_id}_{ti.task_id}_{dag_run.run_id}",
clone_inlets=True,
clone_outlets=True,
)
job_property_bag: Dict[str, str] = {}
job_property_bag["run_id"] = str(dag_run.run_id)
job_property_bag["duration"] = str(ti.duration)
job_property_bag["start_date"] = str(ti.start_date)
job_property_bag["end_date"] = str(ti.end_date)
job_property_bag["execution_date"] = str(ti.execution_date)
job_property_bag["try_number"] = str(ti.try_number - 1)
job_property_bag["hostname"] = str(ti.hostname)
job_property_bag["max_tries"] = str(ti.max_tries)
# Not compatible with Airflow 1
if hasattr(ti, "external_executor_id"):
job_property_bag["external_executor_id"] = str(ti.external_executor_id)
job_property_bag["pid"] = str(ti.pid)
job_property_bag["state"] = str(ti.state)
job_property_bag["operator"] = str(ti.operator)
job_property_bag["priority_weight"] = str(ti.priority_weight)
job_property_bag["unixname"] = str(ti.unixname)
job_property_bag["log_url"] = ti.log_url
dpi.properties.update(job_property_bag)
dpi.url = ti.log_url
# This property only exists in Airflow2
if hasattr(ti, "dag_run") and hasattr(ti.dag_run, "run_type"):
from airflow.utils.types import DagRunType
if ti.dag_run.run_type == DagRunType.SCHEDULED:
dpi.type = DataProcessTypeClass.BATCH_SCHEDULED
elif ti.dag_run.run_type == DagRunType.MANUAL:
dpi.type = DataProcessTypeClass.BATCH_AD_HOC
else:
if dag_run.run_id.startswith("scheduled__"):
dpi.type = DataProcessTypeClass.BATCH_SCHEDULED
else:
dpi.type = DataProcessTypeClass.BATCH_AD_HOC
if start_timestamp_millis is None:
assert ti.start_date
start_timestamp_millis = int(ti.start_date.timestamp() * 1000)
if attempt is None:
attempt = ti.try_number
dpi.emit_process_start(
emitter=emitter,
start_timestamp_millis=start_timestamp_millis,
attempt=attempt,
emit_template=emit_templates,
)
return dpi
@staticmethod
def complete_datajob(
emitter: Union["DatahubRestEmitter", "DatahubKafkaEmitter"],
cluster: str,
ti: "TaskInstance",
dag: "DAG",
dag_run: "DagRun",
end_timestamp_millis: Optional[int] = None,
result: Optional[InstanceRunResult] = None,
datajob: Optional[DataJob] = None,
) -> DataProcessInstance:
"""
:param emitter: DatahubRestEmitter
:param cluster: str
:param ti: TaskInstance
:param dag: DAG
:param dag_run: DagRun
:param end_timestamp_millis: Optional[int]
:param result: Optional[str] One of the result from datahub.metadata.schema_class.RunResultTypeClass
:param datajob: Optional[DataJob]
:return: DataProcessInstance
"""
if datajob is None:
datajob = AirflowGenerator.generate_datajob(cluster, ti.task, dag)
if end_timestamp_millis is None:
assert ti.end_date
end_timestamp_millis = int(ti.end_date.timestamp() * 1000)
if result is None:
# We should use TaskInstanceState but it is not available in Airflow 1
if ti.state == "success":
result = InstanceRunResult.SUCCESS
elif ti.state == "failed":
result = InstanceRunResult.FAILURE
else:
raise Exception(
f"Result should be either success or failure and it was {ti.state}"
)
dpi = DataProcessInstance.from_datajob(
datajob=datajob,
id=f"{dag.dag_id}_{ti.task_id}_{dag_run.run_id}",
clone_inlets=True,
clone_outlets=True,
)
dpi.emit_process_end(
emitter=emitter,
end_timestamp_millis=end_timestamp_millis,
result=result,
result_type="airflow",
)
return dpi

View File

@ -1,4 +1,367 @@
# This package serves as a shim, but the actual implementation lives in datahub_provider
# from the acryl-datahub package. We leave this shim here to avoid breaking existing
# Airflow installs.
from datahub_provider._plugin import DatahubPlugin # noqa: F401
import contextlib
import logging
import traceback
from typing import Any, Callable, Iterable, List, Optional, Union
from airflow.configuration import conf
from airflow.lineage import PIPELINE_OUTLETS
from airflow.models.baseoperator import BaseOperator
from airflow.plugins_manager import AirflowPlugin
from airflow.utils.module_loading import import_string
from cattr import structure
from datahub.api.entities.dataprocess.dataprocess_instance import InstanceRunResult
from datahub_airflow_plugin._airflow_compat import AIRFLOW_PATCHED
from datahub_airflow_plugin._airflow_shims import MappedOperator, Operator
from datahub_airflow_plugin.client.airflow_generator import AirflowGenerator
from datahub_airflow_plugin.hooks.datahub import DatahubGenericHook
from datahub_airflow_plugin.lineage.datahub import DatahubLineageConfig
assert AIRFLOW_PATCHED
logger = logging.getLogger(__name__)
TASK_ON_FAILURE_CALLBACK = "on_failure_callback"
TASK_ON_SUCCESS_CALLBACK = "on_success_callback"
def get_lineage_config() -> DatahubLineageConfig:
"""Load the lineage config from airflow.cfg."""
enabled = conf.get("datahub", "enabled", fallback=True)
datahub_conn_id = conf.get("datahub", "conn_id", fallback="datahub_rest_default")
cluster = conf.get("datahub", "cluster", fallback="prod")
graceful_exceptions = conf.get("datahub", "graceful_exceptions", fallback=True)
capture_tags_info = conf.get("datahub", "capture_tags_info", fallback=True)
capture_ownership_info = conf.get(
"datahub", "capture_ownership_info", fallback=True
)
capture_executions = conf.get("datahub", "capture_executions", fallback=True)
return DatahubLineageConfig(
enabled=enabled,
datahub_conn_id=datahub_conn_id,
cluster=cluster,
graceful_exceptions=graceful_exceptions,
capture_ownership_info=capture_ownership_info,
capture_tags_info=capture_tags_info,
capture_executions=capture_executions,
)
def _task_inlets(operator: "Operator") -> List:
# From Airflow 2.4 _inlets is dropped and inlets used consistently. Earlier it was not the case, so we have to stick there to _inlets
if hasattr(operator, "_inlets"):
return operator._inlets # type: ignore[attr-defined, union-attr]
return operator.inlets
def _task_outlets(operator: "Operator") -> List:
# From Airflow 2.4 _outlets is dropped and inlets used consistently. Earlier it was not the case, so we have to stick there to _outlets
# We have to use _outlets because outlets is empty in Airflow < 2.4.0
if hasattr(operator, "_outlets"):
return operator._outlets # type: ignore[attr-defined, union-attr]
return operator.outlets
def get_inlets_from_task(task: BaseOperator, context: Any) -> Iterable[Any]:
# TODO: Fix for https://github.com/apache/airflow/commit/1b1f3fabc5909a447a6277cafef3a0d4ef1f01ae
# in Airflow 2.4.
# TODO: ignore/handle airflow's dataset type in our lineage
inlets: List[Any] = []
task_inlets = _task_inlets(task)
# From Airflow 2.3 this should be AbstractOperator but due to compatibility reason lets use BaseOperator
if isinstance(task_inlets, (str, BaseOperator)):
inlets = [
task_inlets,
]
if task_inlets and isinstance(task_inlets, list):
inlets = []
task_ids = (
{o for o in task_inlets if isinstance(o, str)}
.union(op.task_id for op in task_inlets if isinstance(op, BaseOperator))
.intersection(task.get_flat_relative_ids(upstream=True))
)
from airflow.lineage import AUTO
# pick up unique direct upstream task_ids if AUTO is specified
if AUTO.upper() in task_inlets or AUTO.lower() in task_inlets:
print("Picking up unique direct upstream task_ids as AUTO is specified")
task_ids = task_ids.union(
task_ids.symmetric_difference(task.upstream_task_ids)
)
inlets = task.xcom_pull(
context, task_ids=list(task_ids), dag_id=task.dag_id, key=PIPELINE_OUTLETS
)
# re-instantiate the obtained inlets
inlets = [
structure(item["data"], import_string(item["type_name"]))
# _get_instance(structure(item, Metadata))
for sublist in inlets
if sublist
for item in sublist
]
for inlet in task_inlets:
if not isinstance(inlet, str):
inlets.append(inlet)
return inlets
def _make_emit_callback(
logger: logging.Logger,
) -> Callable[[Optional[Exception], str], None]:
def emit_callback(err: Optional[Exception], msg: str) -> None:
if err:
logger.error(f"Error sending metadata to datahub: {msg}", exc_info=err)
return emit_callback
def datahub_task_status_callback(context, status):
ti = context["ti"]
task: "BaseOperator" = ti.task
dag = context["dag"]
# This code is from the original airflow lineage code ->
# https://github.com/apache/airflow/blob/main/airflow/lineage/__init__.py
inlets = get_inlets_from_task(task, context)
emitter = (
DatahubGenericHook(context["_datahub_config"].datahub_conn_id)
.get_underlying_hook()
.make_emitter()
)
dataflow = AirflowGenerator.generate_dataflow(
cluster=context["_datahub_config"].cluster,
dag=dag,
capture_tags=context["_datahub_config"].capture_tags_info,
capture_owner=context["_datahub_config"].capture_ownership_info,
)
task.log.info(f"Emitting Datahub Dataflow: {dataflow}")
dataflow.emit(emitter, callback=_make_emit_callback(task.log))
datajob = AirflowGenerator.generate_datajob(
cluster=context["_datahub_config"].cluster,
task=task,
dag=dag,
capture_tags=context["_datahub_config"].capture_tags_info,
capture_owner=context["_datahub_config"].capture_ownership_info,
)
for inlet in inlets:
datajob.inlets.append(inlet.urn)
task_outlets = _task_outlets(task)
for outlet in task_outlets:
datajob.outlets.append(outlet.urn)
task.log.info(f"Emitting Datahub Datajob: {datajob}")
datajob.emit(emitter, callback=_make_emit_callback(task.log))
if context["_datahub_config"].capture_executions:
dpi = AirflowGenerator.run_datajob(
emitter=emitter,
cluster=context["_datahub_config"].cluster,
ti=context["ti"],
dag=dag,
dag_run=context["dag_run"],
datajob=datajob,
start_timestamp_millis=int(ti.start_date.timestamp() * 1000),
)
task.log.info(f"Emitted Start Datahub Dataprocess Instance: {dpi}")
dpi = AirflowGenerator.complete_datajob(
emitter=emitter,
cluster=context["_datahub_config"].cluster,
ti=context["ti"],
dag_run=context["dag_run"],
result=status,
dag=dag,
datajob=datajob,
end_timestamp_millis=int(ti.end_date.timestamp() * 1000),
)
task.log.info(f"Emitted Completed Data Process Instance: {dpi}")
emitter.flush()
def datahub_pre_execution(context):
ti = context["ti"]
task: "BaseOperator" = ti.task
dag = context["dag"]
task.log.info("Running Datahub pre_execute method")
emitter = (
DatahubGenericHook(context["_datahub_config"].datahub_conn_id)
.get_underlying_hook()
.make_emitter()
)
# This code is from the original airflow lineage code ->
# https://github.com/apache/airflow/blob/main/airflow/lineage/__init__.py
inlets = get_inlets_from_task(task, context)
datajob = AirflowGenerator.generate_datajob(
cluster=context["_datahub_config"].cluster,
task=context["ti"].task,
dag=dag,
capture_tags=context["_datahub_config"].capture_tags_info,
capture_owner=context["_datahub_config"].capture_ownership_info,
)
for inlet in inlets:
datajob.inlets.append(inlet.urn)
task_outlets = _task_outlets(task)
for outlet in task_outlets:
datajob.outlets.append(outlet.urn)
task.log.info(f"Emitting Datahub dataJob {datajob}")
datajob.emit(emitter, callback=_make_emit_callback(task.log))
if context["_datahub_config"].capture_executions:
dpi = AirflowGenerator.run_datajob(
emitter=emitter,
cluster=context["_datahub_config"].cluster,
ti=context["ti"],
dag=dag,
dag_run=context["dag_run"],
datajob=datajob,
start_timestamp_millis=int(ti.start_date.timestamp() * 1000),
)
task.log.info(f"Emitting Datahub Dataprocess Instance: {dpi}")
emitter.flush()
def _wrap_pre_execution(pre_execution):
def custom_pre_execution(context):
config = get_lineage_config()
if config.enabled:
context["_datahub_config"] = config
datahub_pre_execution(context)
# Call original policy
if pre_execution:
pre_execution(context)
return custom_pre_execution
def _wrap_on_failure_callback(on_failure_callback):
def custom_on_failure_callback(context):
config = get_lineage_config()
if config.enabled:
context["_datahub_config"] = config
try:
datahub_task_status_callback(context, status=InstanceRunResult.FAILURE)
except Exception as e:
if not config.graceful_exceptions:
raise e
else:
print(f"Exception: {traceback.format_exc()}")
# Call original policy
if on_failure_callback:
on_failure_callback(context)
return custom_on_failure_callback
def _wrap_on_success_callback(on_success_callback):
def custom_on_success_callback(context):
config = get_lineage_config()
if config.enabled:
context["_datahub_config"] = config
try:
datahub_task_status_callback(context, status=InstanceRunResult.SUCCESS)
except Exception as e:
if not config.graceful_exceptions:
raise e
else:
print(f"Exception: {traceback.format_exc()}")
# Call original policy
if on_success_callback:
on_success_callback(context)
return custom_on_success_callback
def task_policy(task: Union[BaseOperator, MappedOperator]) -> None:
task.log.debug(f"Setting task policy for Dag: {task.dag_id} Task: {task.task_id}")
# task.add_inlets(["auto"])
# task.pre_execute = _wrap_pre_execution(task.pre_execute)
# MappedOperator's callbacks don't have setters until Airflow 2.X.X
# https://github.com/apache/airflow/issues/24547
# We can bypass this by going through partial_kwargs for now
if MappedOperator and isinstance(task, MappedOperator): # type: ignore
on_failure_callback_prop: property = getattr(
MappedOperator, TASK_ON_FAILURE_CALLBACK
)
on_success_callback_prop: property = getattr(
MappedOperator, TASK_ON_SUCCESS_CALLBACK
)
if not on_failure_callback_prop.fset or not on_success_callback_prop.fset:
task.log.debug(
"Using MappedOperator's partial_kwargs instead of callback properties"
)
task.partial_kwargs[TASK_ON_FAILURE_CALLBACK] = _wrap_on_failure_callback(
task.on_failure_callback
)
task.partial_kwargs[TASK_ON_SUCCESS_CALLBACK] = _wrap_on_success_callback(
task.on_success_callback
)
return
task.on_failure_callback = _wrap_on_failure_callback(task.on_failure_callback) # type: ignore
task.on_success_callback = _wrap_on_success_callback(task.on_success_callback) # type: ignore
# task.pre_execute = _wrap_pre_execution(task.pre_execute)
def _wrap_task_policy(policy):
if policy and hasattr(policy, "_task_policy_patched_by"):
return policy
def custom_task_policy(task):
policy(task)
task_policy(task)
# Add a flag to the policy to indicate that we've patched it.
custom_task_policy._task_policy_patched_by = "datahub_plugin" # type: ignore[attr-defined]
return custom_task_policy
def _patch_policy(settings):
if hasattr(settings, "task_policy"):
datahub_task_policy = _wrap_task_policy(settings.task_policy)
settings.task_policy = datahub_task_policy
def _patch_datahub_policy():
with contextlib.suppress(ImportError):
import airflow_local_settings
_patch_policy(airflow_local_settings)
from airflow.models.dagbag import settings
_patch_policy(settings)
_patch_datahub_policy()
class DatahubPlugin(AirflowPlugin):
name = "datahub_plugin"

View File

@ -0,0 +1,47 @@
from abc import abstractmethod
from typing import Optional
import attr
import datahub.emitter.mce_builder as builder
from datahub.utilities.urns.urn import guess_entity_type
class _Entity:
@property
@abstractmethod
def urn(self) -> str:
pass
@attr.s(auto_attribs=True, str=True)
class Dataset(_Entity):
platform: str
name: str
env: str = builder.DEFAULT_ENV
platform_instance: Optional[str] = None
@property
def urn(self):
return builder.make_dataset_urn_with_platform_instance(
platform=self.platform,
name=self.name,
platform_instance=self.platform_instance,
env=self.env,
)
@attr.s(str=True)
class Urn(_Entity):
_urn: str = attr.ib()
@_urn.validator
def _validate_urn(self, attribute, value):
if not value.startswith("urn:"):
raise ValueError("invalid urn provided: urns must start with 'urn:'")
if guess_entity_type(value) != "dataset":
# This is because DataJobs only support Dataset lineage.
raise ValueError("Airflow lineage currently only supports datasets")
@property
def urn(self):
return self._urn

View File

@ -9,7 +9,6 @@ from datetime import timedelta
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.utils.dates import days_ago
from datahub.configuration.config_loader import load_config_file
from datahub.ingestion.run.pipeline import Pipeline
@ -41,6 +40,7 @@ with DAG(
schedule_interval=timedelta(days=1),
start_date=days_ago(2),
catchup=False,
default_view="tree",
) as dag:
ingest_task = PythonOperator(
task_id="ingest_using_recipe",

View File

@ -9,7 +9,7 @@ from airflow import DAG
from airflow.operators.bash import BashOperator
from airflow.utils.dates import days_ago
from datahub_provider.entities import Dataset, Urn
from datahub_airflow_plugin.entities import Dataset, Urn
default_args = {
"owner": "airflow",
@ -28,6 +28,7 @@ with DAG(
start_date=days_ago(2),
tags=["example_tag"],
catchup=False,
default_view="tree",
) as dag:
task1 = BashOperator(
task_id="run_data_task",

View File

@ -8,7 +8,7 @@ from datetime import timedelta
from airflow.decorators import dag, task
from airflow.utils.dates import days_ago
from datahub_provider.entities import Dataset, Urn
from datahub_airflow_plugin.entities import Dataset, Urn
default_args = {
"owner": "airflow",
@ -26,6 +26,7 @@ default_args = {
start_date=days_ago(2),
tags=["example_tag"],
catchup=False,
default_view="tree",
)
def datahub_lineage_backend_taskflow_demo():
@task(

View File

@ -5,12 +5,12 @@ This example demonstrates how to emit lineage to DataHub within an Airflow DAG.
from datetime import timedelta
import datahub.emitter.mce_builder as builder
from airflow import DAG
from airflow.providers.snowflake.operators.snowflake import SnowflakeOperator
from airflow.utils.dates import days_ago
import datahub.emitter.mce_builder as builder
from datahub_provider.operators.datahub import DatahubEmitterOperator
from datahub_airflow_plugin.operators.datahub import DatahubEmitterOperator
default_args = {
"owner": "airflow",
@ -31,6 +31,7 @@ with DAG(
schedule_interval=timedelta(days=1),
start_date=days_ago(2),
catchup=False,
default_view="tree",
) as dag:
# This example shows a SnowflakeOperator followed by a lineage emission. However, the
# same DatahubEmitterOperator can be used to emit lineage in any context.

View File

@ -47,6 +47,7 @@ with DAG(
start_date=datetime(2022, 1, 1),
schedule_interval=timedelta(days=1),
catchup=False,
default_view="tree",
) as dag:
# While it is also possible to use the PythonOperator, we recommend using
# the PythonVirtualenvOperator to ensure that there are no dependency

View File

@ -57,6 +57,7 @@ with DAG(
start_date=datetime(2022, 1, 1),
schedule_interval=timedelta(days=1),
catchup=False,
default_view="tree",
) as dag:
# This example pulls credentials from Airflow's connection store.
# For this to work, you must have previously configured these connections in Airflow.

View File

@ -0,0 +1,214 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from datahub.metadata.com.linkedin.pegasus2avro.mxe import (
MetadataChangeEvent,
MetadataChangeProposal,
)
if TYPE_CHECKING:
from airflow.models.connection import Connection
from datahub.emitter.kafka_emitter import DatahubKafkaEmitter
from datahub.emitter.rest_emitter import DatahubRestEmitter
from datahub.ingestion.sink.datahub_kafka import KafkaSinkConfig
class DatahubRestHook(BaseHook):
"""
Creates a DataHub Rest API connection used to send metadata to DataHub.
Takes the endpoint for your DataHub Rest API in the Server Endpoint(host) field.
URI example: ::
AIRFLOW_CONN_DATAHUB_REST_DEFAULT='datahub-rest://rest-endpoint'
:param datahub_rest_conn_id: Reference to the DataHub Rest connection.
:type datahub_rest_conn_id: str
"""
conn_name_attr = "datahub_rest_conn_id"
default_conn_name = "datahub_rest_default"
conn_type = "datahub_rest"
hook_name = "DataHub REST Server"
def __init__(self, datahub_rest_conn_id: str = default_conn_name) -> None:
super().__init__()
self.datahub_rest_conn_id = datahub_rest_conn_id
@staticmethod
def get_connection_form_widgets() -> Dict[str, Any]:
return {}
@staticmethod
def get_ui_field_behaviour() -> Dict:
"""Returns custom field behavior"""
return {
"hidden_fields": ["port", "schema", "login"],
"relabeling": {
"host": "Server Endpoint",
},
}
def _get_config(self) -> Tuple[str, Optional[str], Optional[int]]:
conn: "Connection" = self.get_connection(self.datahub_rest_conn_id)
host = conn.host
if not host:
raise AirflowException("host parameter is required")
if conn.port:
if ":" in host:
raise AirflowException(
"host parameter should not contain a port number if the port is specified separately"
)
host = f"{host}:{conn.port}"
password = conn.password
timeout_sec = conn.extra_dejson.get("timeout_sec")
return (host, password, timeout_sec)
def make_emitter(self) -> "DatahubRestEmitter":
import datahub.emitter.rest_emitter
return datahub.emitter.rest_emitter.DatahubRestEmitter(*self._get_config())
def emit_mces(self, mces: List[MetadataChangeEvent]) -> None:
emitter = self.make_emitter()
for mce in mces:
emitter.emit_mce(mce)
def emit_mcps(self, mcps: List[MetadataChangeProposal]) -> None:
emitter = self.make_emitter()
for mce in mcps:
emitter.emit_mcp(mce)
class DatahubKafkaHook(BaseHook):
"""
Creates a DataHub Kafka connection used to send metadata to DataHub.
Takes your kafka broker in the Kafka Broker(host) field.
URI example: ::
AIRFLOW_CONN_DATAHUB_KAFKA_DEFAULT='datahub-kafka://kafka-broker'
:param datahub_kafka_conn_id: Reference to the DataHub Kafka connection.
:type datahub_kafka_conn_id: str
"""
conn_name_attr = "datahub_kafka_conn_id"
default_conn_name = "datahub_kafka_default"
conn_type = "datahub_kafka"
hook_name = "DataHub Kafka Sink"
def __init__(self, datahub_kafka_conn_id: str = default_conn_name) -> None:
super().__init__()
self.datahub_kafka_conn_id = datahub_kafka_conn_id
@staticmethod
def get_connection_form_widgets() -> Dict[str, Any]:
return {}
@staticmethod
def get_ui_field_behaviour() -> Dict:
"""Returns custom field behavior"""
return {
"hidden_fields": ["port", "schema", "login", "password"],
"relabeling": {
"host": "Kafka Broker",
},
}
def _get_config(self) -> "KafkaSinkConfig":
import datahub.ingestion.sink.datahub_kafka
conn = self.get_connection(self.datahub_kafka_conn_id)
obj = conn.extra_dejson
obj.setdefault("connection", {})
if conn.host is not None:
if "bootstrap" in obj["connection"]:
raise AirflowException(
"Kafka broker specified twice (present in host and extra)"
)
obj["connection"]["bootstrap"] = ":".join(
map(str, filter(None, [conn.host, conn.port]))
)
config = datahub.ingestion.sink.datahub_kafka.KafkaSinkConfig.parse_obj(obj)
return config
def make_emitter(self) -> "DatahubKafkaEmitter":
import datahub.emitter.kafka_emitter
sink_config = self._get_config()
return datahub.emitter.kafka_emitter.DatahubKafkaEmitter(sink_config)
def emit_mces(self, mces: List[MetadataChangeEvent]) -> None:
emitter = self.make_emitter()
errors = []
def callback(exc, msg):
if exc:
errors.append(exc)
for mce in mces:
emitter.emit_mce_async(mce, callback)
emitter.flush()
if errors:
raise AirflowException(f"failed to push some MCEs: {errors}")
def emit_mcps(self, mcps: List[MetadataChangeProposal]) -> None:
emitter = self.make_emitter()
errors = []
def callback(exc, msg):
if exc:
errors.append(exc)
for mcp in mcps:
emitter.emit_mcp_async(mcp, callback)
emitter.flush()
if errors:
raise AirflowException(f"failed to push some MCPs: {errors}")
class DatahubGenericHook(BaseHook):
"""
Emits Metadata Change Events using either the DatahubRestHook or the
DatahubKafkaHook. Set up a DataHub Rest or Kafka connection to use.
:param datahub_conn_id: Reference to the DataHub connection.
:type datahub_conn_id: str
"""
def __init__(self, datahub_conn_id: str) -> None:
super().__init__()
self.datahub_conn_id = datahub_conn_id
def get_underlying_hook(self) -> Union[DatahubRestHook, DatahubKafkaHook]:
conn = self.get_connection(self.datahub_conn_id)
# We need to figure out the underlying hook type. First check the
# conn_type. If that fails, attempt to guess using the conn id name.
if conn.conn_type == DatahubRestHook.conn_type:
return DatahubRestHook(self.datahub_conn_id)
elif conn.conn_type == DatahubKafkaHook.conn_type:
return DatahubKafkaHook(self.datahub_conn_id)
elif "rest" in self.datahub_conn_id:
return DatahubRestHook(self.datahub_conn_id)
elif "kafka" in self.datahub_conn_id:
return DatahubKafkaHook(self.datahub_conn_id)
else:
raise AirflowException(
f"DataHub cannot handle conn_type {conn.conn_type} in {conn}"
)
def make_emitter(self) -> Union["DatahubRestEmitter", "DatahubKafkaEmitter"]:
return self.get_underlying_hook().make_emitter()
def emit_mces(self, mces: List[MetadataChangeEvent]) -> None:
return self.get_underlying_hook().emit_mces(mces)

View File

@ -0,0 +1,91 @@
import json
from typing import TYPE_CHECKING, Dict, List, Optional
from airflow.configuration import conf
from airflow.lineage.backend import LineageBackend
from datahub_airflow_plugin._lineage_core import (
DatahubBasicLineageConfig,
send_lineage_to_datahub,
)
if TYPE_CHECKING:
from airflow.models.baseoperator import BaseOperator
class DatahubLineageConfig(DatahubBasicLineageConfig):
# If set to true, most runtime errors in the lineage backend will be
# suppressed and will not cause the overall task to fail. Note that
# configuration issues will still throw exceptions.
graceful_exceptions: bool = True
def get_lineage_config() -> DatahubLineageConfig:
"""Load the lineage config from airflow.cfg."""
# The kwargs pattern is also used for secret backends.
kwargs_str = conf.get("lineage", "datahub_kwargs", fallback="{}")
kwargs = json.loads(kwargs_str)
# Continue to support top-level datahub_conn_id config.
datahub_conn_id = conf.get("lineage", "datahub_conn_id", fallback=None)
if datahub_conn_id:
kwargs["datahub_conn_id"] = datahub_conn_id
return DatahubLineageConfig.parse_obj(kwargs)
class DatahubLineageBackend(LineageBackend):
"""
Sends lineage data from tasks to DataHub.
Configurable via ``airflow.cfg`` as follows: ::
# For REST-based:
airflow connections add --conn-type 'datahub_rest' 'datahub_rest_default' --conn-host 'http://localhost:8080'
# For Kafka-based (standard Kafka sink config can be passed via extras):
airflow connections add --conn-type 'datahub_kafka' 'datahub_kafka_default' --conn-host 'broker:9092' --conn-extra '{}'
[lineage]
backend = datahub_provider.lineage.datahub.DatahubLineageBackend
datahub_kwargs = {
"datahub_conn_id": "datahub_rest_default",
"capture_ownership_info": true,
"capture_tags_info": true,
"graceful_exceptions": true }
# The above indentation is important!
"""
def __init__(self) -> None:
super().__init__()
# By attempting to get and parse the config, we can detect configuration errors
# ahead of time. The init method is only called in Airflow 2.x.
_ = get_lineage_config()
# With Airflow 2.0, this can be an instance method. However, with Airflow 1.10.x, this
# method is used statically, even though LineageBackend declares it as an instance variable.
@staticmethod
def send_lineage(
operator: "BaseOperator",
inlets: Optional[List] = None, # unused
outlets: Optional[List] = None, # unused
context: Optional[Dict] = None,
) -> None:
config = get_lineage_config()
if not config.enabled:
return
try:
context = context or {} # ensure not None to satisfy mypy
send_lineage_to_datahub(
config, operator, operator.inlets, operator.outlets, context
)
except Exception as e:
if config.graceful_exceptions:
operator.log.error(e)
operator.log.info(
"Suppressing error because graceful_exceptions is set"
)
else:
raise

View File

@ -0,0 +1,63 @@
from typing import List, Union
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeEvent
from datahub_airflow_plugin.hooks.datahub import (
DatahubGenericHook,
DatahubKafkaHook,
DatahubRestHook,
)
class DatahubBaseOperator(BaseOperator):
"""
The DatahubBaseOperator is used as a base operator all DataHub operators.
"""
ui_color = "#4398c8"
hook: Union[DatahubRestHook, DatahubKafkaHook]
# mypy is not a fan of this. Newer versions of Airflow support proper typing for the decorator
# using PEP 612. However, there is not yet a good way to inherit the types of the kwargs from
# the superclass.
@apply_defaults # type: ignore[misc]
def __init__( # type: ignore[no-untyped-def]
self,
*,
datahub_conn_id: str,
**kwargs,
):
super().__init__(**kwargs)
self.datahub_conn_id = datahub_conn_id
self.generic_hook = DatahubGenericHook(datahub_conn_id)
class DatahubEmitterOperator(DatahubBaseOperator):
"""
Emits a Metadata Change Event to DataHub using either a DataHub
Rest or Kafka connection.
:param datahub_conn_id: Reference to the DataHub Rest or Kafka Connection.
:type datahub_conn_id: str
"""
# See above for why these mypy type issues are ignored here.
@apply_defaults # type: ignore[misc]
def __init__( # type: ignore[no-untyped-def]
self,
mces: List[MetadataChangeEvent],
datahub_conn_id: str,
**kwargs,
):
super().__init__(
datahub_conn_id=datahub_conn_id,
**kwargs,
)
self.mces = mces
def execute(self, context):
self.generic_hook.get_underlying_hook().emit_mces(self.mces)

View File

@ -0,0 +1,78 @@
import datetime
from typing import Any, List, Optional, Sequence, Union
from airflow.models import BaseOperator
from datahub.api.circuit_breaker import (
AssertionCircuitBreaker,
AssertionCircuitBreakerConfig,
)
from datahub_airflow_plugin.hooks.datahub import DatahubRestHook
class DataHubAssertionOperator(BaseOperator):
r"""
DataHub Assertion Circuit Breaker Operator.
:param urn: The DataHub dataset unique identifier. (templated)
:param datahub_rest_conn_id: The REST datahub connection id to communicate with DataHub
which is set as Airflow connection.
:param check_last_assertion_time: If set it checks assertions after the last operation was set on the dataset.
By default it is True.
:param time_delta: If verify_after_last_update is False it checks for assertion within the time delta.
"""
template_fields: Sequence[str] = ("urn",)
circuit_breaker: AssertionCircuitBreaker
urn: Union[List[str], str]
def __init__( # type: ignore[no-untyped-def]
self,
*,
urn: Union[List[str], str],
datahub_rest_conn_id: Optional[str] = None,
check_last_assertion_time: bool = True,
time_delta: Optional[datetime.timedelta] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
hook: DatahubRestHook
if datahub_rest_conn_id is not None:
hook = DatahubRestHook(datahub_rest_conn_id=datahub_rest_conn_id)
else:
hook = DatahubRestHook()
host, password, timeout_sec = hook._get_config()
self.urn = urn
config: AssertionCircuitBreakerConfig = AssertionCircuitBreakerConfig(
datahub_host=host,
datahub_token=password,
timeout=timeout_sec,
verify_after_last_update=check_last_assertion_time,
time_delta=time_delta if time_delta else datetime.timedelta(days=1),
)
self.circuit_breaker = AssertionCircuitBreaker(config=config)
def execute(self, context: Any) -> bool:
if "datahub_silence_circuit_breakers" in context["dag_run"].conf:
self.log.info(
"Circuit breaker is silenced because datahub_silence_circuit_breakers config is set"
)
return True
self.log.info(f"Checking if dataset {self.urn} is ready to be consumed")
if isinstance(self.urn, str):
urns = [self.urn]
elif isinstance(self.urn, list):
urns = self.urn
else:
raise Exception(f"urn parameter has invalid type {type(self.urn)}")
for urn in urns:
self.log.info(f"Checking if dataset {self.urn} is ready to be consumed")
ret = self.circuit_breaker.is_circuit_breaker_active(urn=urn)
if ret:
raise Exception(f"Dataset {self.urn} is not in consumable state")
return True

View File

@ -0,0 +1,78 @@
import datetime
from typing import Any, List, Optional, Sequence, Union
from airflow.sensors.base import BaseSensorOperator
from datahub.api.circuit_breaker import (
AssertionCircuitBreaker,
AssertionCircuitBreakerConfig,
)
from datahub_airflow_plugin.hooks.datahub import DatahubRestHook
class DataHubAssertionSensor(BaseSensorOperator):
r"""
DataHub Assertion Circuit Breaker Sensor.
:param urn: The DataHub dataset unique identifier. (templated)
:param datahub_rest_conn_id: The REST datahub connection id to communicate with DataHub
which is set as Airflow connection.
:param check_last_assertion_time: If set it checks assertions after the last operation was set on the dataset.
By default it is True.
:param time_delta: If verify_after_last_update is False it checks for assertion within the time delta.
"""
template_fields: Sequence[str] = ("urn",)
circuit_breaker: AssertionCircuitBreaker
urn: Union[List[str], str]
def __init__( # type: ignore[no-untyped-def]
self,
*,
urn: Union[List[str], str],
datahub_rest_conn_id: Optional[str] = None,
check_last_assertion_time: bool = True,
time_delta: datetime.timedelta = datetime.timedelta(days=1),
**kwargs,
) -> None:
super().__init__(**kwargs)
hook: DatahubRestHook
if datahub_rest_conn_id is not None:
hook = DatahubRestHook(datahub_rest_conn_id=datahub_rest_conn_id)
else:
hook = DatahubRestHook()
host, password, timeout_sec = hook._get_config()
self.urn = urn
config: AssertionCircuitBreakerConfig = AssertionCircuitBreakerConfig(
datahub_host=host,
datahub_token=password,
timeout=timeout_sec,
verify_after_last_update=check_last_assertion_time,
time_delta=time_delta,
)
self.circuit_breaker = AssertionCircuitBreaker(config=config)
def poke(self, context: Any) -> bool:
if "datahub_silence_circuit_breakers" in context["dag_run"].conf:
self.log.info(
"Circuit breaker is silenced because datahub_silence_circuit_breakers config is set"
)
return True
self.log.info(f"Checking if dataset {self.urn} is ready to be consumed")
if isinstance(self.urn, str):
urns = [self.urn]
elif isinstance(self.urn, list):
urns = self.urn
else:
raise Exception(f"urn parameter has invalid type {type(self.urn)}")
for urn in urns:
self.log.info(f"Checking if dataset {self.urn} is ready to be consumed")
ret = self.circuit_breaker.is_circuit_breaker_active(urn=urn)
if ret:
self.log.info(f"Dataset {self.urn} is not in consumable state")
return False
return True

View File

@ -0,0 +1,97 @@
import datetime
from typing import Any, List, Optional, Sequence, Union
from airflow.sensors.base import BaseSensorOperator
from datahub.api.circuit_breaker import (
OperationCircuitBreaker,
OperationCircuitBreakerConfig,
)
from datahub_airflow_plugin.hooks.datahub import DatahubRestHook
class DataHubOperationCircuitBreakerOperator(BaseSensorOperator):
r"""
DataHub Operation Circuit Breaker Operator.
:param urn: The DataHub dataset unique identifier. (templated)
:param datahub_rest_conn_id: The REST datahub connection id to communicate with DataHub
which is set as Airflow connection.
:param partition: The partition to check the operation.
:param source_type: The partition to check the operation. :ref:`https://datahubproject.io/docs/graphql/enums#operationsourcetype`
"""
template_fields: Sequence[str] = (
"urn",
"partition",
"source_type",
"operation_type",
)
circuit_breaker: OperationCircuitBreaker
urn: Union[List[str], str]
partition: Optional[str]
source_type: Optional[str]
operation_type: Optional[str]
def __init__( # type: ignore[no-untyped-def]
self,
*,
urn: Union[List[str], str],
datahub_rest_conn_id: Optional[str] = None,
time_delta: Optional[datetime.timedelta] = datetime.timedelta(days=1),
partition: Optional[str] = None,
source_type: Optional[str] = None,
operation_type: Optional[str] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
hook: DatahubRestHook
if datahub_rest_conn_id is not None:
hook = DatahubRestHook(datahub_rest_conn_id=datahub_rest_conn_id)
else:
hook = DatahubRestHook()
host, password, timeout_sec = hook._get_config()
self.urn = urn
self.partition = partition
self.operation_type = operation_type
self.source_type = source_type
config: OperationCircuitBreakerConfig = OperationCircuitBreakerConfig(
datahub_host=host,
datahub_token=password,
timeout=timeout_sec,
time_delta=time_delta,
)
self.circuit_breaker = OperationCircuitBreaker(config=config)
def execute(self, context: Any) -> bool:
if "datahub_silence_circuit_breakers" in context["dag_run"].conf:
self.log.info(
"Circuit breaker is silenced because datahub_silence_circuit_breakers config is set"
)
return True
self.log.info(f"Checking if dataset {self.urn} is ready to be consumed")
if isinstance(self.urn, str):
urns = [self.urn]
elif isinstance(self.urn, list):
urns = self.urn
else:
raise Exception(f"urn parameter has invalid type {type(self.urn)}")
for urn in urns:
self.log.info(f"Checking if dataset {self.urn} is ready to be consumed")
ret = self.circuit_breaker.is_circuit_breaker_active(
urn=urn,
partition=self.partition,
operation_type=self.operation_type,
source_type=self.source_type,
)
if ret:
raise Exception(f"Dataset {self.urn} is not in consumable state")
return True

View File

@ -0,0 +1,100 @@
import datetime
from typing import Any, List, Optional, Sequence, Union
from airflow.sensors.base import BaseSensorOperator
from datahub.api.circuit_breaker import (
OperationCircuitBreaker,
OperationCircuitBreakerConfig,
)
from datahub_airflow_plugin.hooks.datahub import DatahubRestHook
class DataHubOperationCircuitBreakerSensor(BaseSensorOperator):
r"""
DataHub Operation Circuit Breaker Sensor.
:param urn: The DataHub dataset unique identifier. (templated)
:param datahub_rest_conn_id: The REST datahub connection id to communicate with DataHub
which is set as Airflow connection.
:param partition: The partition to check the operation.
:param source_type: The source type to filter on. If not set it will accept any source type.
See valid values at: https://datahubproject.io/docs/graphql/enums#operationsourcetype
:param operation_type: The operation type to filter on. If not set it will accept any source type.
See valid values at: https://datahubproject.io/docs/graphql/enums/#operationtype
"""
template_fields: Sequence[str] = (
"urn",
"partition",
"source_type",
"operation_type",
)
circuit_breaker: OperationCircuitBreaker
urn: Union[List[str], str]
partition: Optional[str]
source_type: Optional[str]
operation_type: Optional[str]
def __init__( # type: ignore[no-untyped-def]
self,
*,
urn: Union[List[str], str],
datahub_rest_conn_id: Optional[str] = None,
time_delta: Optional[datetime.timedelta] = datetime.timedelta(days=1),
partition: Optional[str] = None,
source_type: Optional[str] = None,
operation_type: Optional[str] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
hook: DatahubRestHook
if datahub_rest_conn_id is not None:
hook = DatahubRestHook(datahub_rest_conn_id=datahub_rest_conn_id)
else:
hook = DatahubRestHook()
host, password, timeout_sec = hook._get_config()
self.urn = urn
self.partition = partition
self.operation_type = operation_type
self.source_type = source_type
config: OperationCircuitBreakerConfig = OperationCircuitBreakerConfig(
datahub_host=host,
datahub_token=password,
timeout=timeout_sec,
time_delta=time_delta,
)
self.circuit_breaker = OperationCircuitBreaker(config=config)
def poke(self, context: Any) -> bool:
if "datahub_silence_circuit_breakers" in context["dag_run"].conf:
self.log.info(
"Circuit breaker is silenced because datahub_silence_circuit_breakers config is set"
)
return True
self.log.info(f"Checking if dataset {self.urn} is ready to be consumed")
if isinstance(self.urn, str):
urns = [self.urn]
elif isinstance(self.urn, list):
urns = self.urn
else:
raise Exception(f"urn parameter has invalid type {type(self.urn)}")
for urn in urns:
self.log.info(f"Checking if dataset {self.urn} is ready to be consumed")
ret = self.circuit_breaker.is_circuit_breaker_active(
urn=urn,
partition=self.partition,
operation_type=self.operation_type,
source_type=self.source_type,
)
if ret:
self.log.info(f"Dataset {self.urn} is not in consumable state")
return False
return True

View File

@ -9,12 +9,11 @@ from unittest.mock import Mock
import airflow.configuration
import airflow.version
import datahub.emitter.mce_builder as builder
import packaging.version
import pytest
from airflow.lineage import apply_lineage, prepare_lineage
from airflow.models import DAG, Connection, DagBag, DagRun, TaskInstance
import datahub.emitter.mce_builder as builder
from datahub_provider import get_provider_info
from datahub_provider._airflow_shims import AIRFLOW_PATCHED, EmptyOperator
from datahub_provider.entities import Dataset, Urn
@ -23,7 +22,7 @@ from datahub_provider.operators.datahub import DatahubEmitterOperator
assert AIRFLOW_PATCHED
pytestmark = pytest.mark.airflow
# TODO: Remove default_view="tree" arg. Figure out why is default_view being picked as "grid" and how to fix it ?
# Approach suggested by https://stackoverflow.com/a/11887885/5004662.
AIRFLOW_VERSION = packaging.version.parse(airflow.version.version)
@ -75,7 +74,7 @@ def test_airflow_provider_info():
@pytest.mark.filterwarnings("ignore:.*is deprecated.*")
def test_dags_load_with_no_errors(pytestconfig: pytest.Config) -> None:
airflow_examples_folder = (
pytestconfig.rootpath / "src/datahub_provider/example_dags"
pytestconfig.rootpath / "src/datahub_airflow_plugin/example_dags"
)
# Note: the .airflowignore file skips the snowflake DAG.
@ -233,7 +232,11 @@ def test_lineage_backend(mock_emit, inlets, outlets, capture_executions):
func = mock.Mock()
func.__name__ = "foo"
dag = DAG(dag_id="test_lineage_is_sent_to_backend", start_date=DEFAULT_DATE)
dag = DAG(
dag_id="test_lineage_is_sent_to_backend",
start_date=DEFAULT_DATE,
default_view="tree",
)
with dag:
op1 = EmptyOperator(
@ -252,6 +255,7 @@ def test_lineage_backend(mock_emit, inlets, outlets, capture_executions):
# versions do not require it, but will attempt to find the associated
# run_id in the database if execution_date is provided. As such, we
# must fake the run_id parameter for newer Airflow versions.
# We need to add type:ignore in else to suppress mypy error in Airflow < 2.2
if AIRFLOW_VERSION < packaging.version.parse("2.2.0"):
ti = TaskInstance(task=op2, execution_date=DEFAULT_DATE)
# Ignoring type here because DagRun state is just a sring at Airflow 1
@ -259,7 +263,7 @@ def test_lineage_backend(mock_emit, inlets, outlets, capture_executions):
else:
from airflow.utils.state import DagRunState
ti = TaskInstance(task=op2, run_id=f"test_airflow-{DEFAULT_DATE}")
ti = TaskInstance(task=op2, run_id=f"test_airflow-{DEFAULT_DATE}") # type: ignore[call-arg]
dag_run = DagRun(
state=DagRunState.SUCCESS,
run_id=f"scheduled_{DEFAULT_DATE.isoformat()}",

View File

@ -26,6 +26,16 @@ source venv/bin/activate
datahub version # should print "DataHub CLI version: unavailable (installed in develop mode)"
```
### (Optional) Set up your Python environment for developing on Airflow Plugin
From the repository root:
```shell
cd metadata-ingestion-modules/airflow-plugin
../../gradlew :metadata-ingestion-modules:airflow-plugin:installDev
source venv/bin/activate
datahub version # should print "DataHub CLI version: unavailable (installed in develop mode)"
```
### Common setup issues
Common issues (click to expand):
@ -183,7 +193,7 @@ pytest -m 'slow_integration'
../gradlew :metadata-ingestion:testFull
../gradlew :metadata-ingestion:check
# Run all tests in a single file
../gradlew :metadata-ingestion:testSingle -PtestFile=tests/unit/test_airflow.py
../gradlew :metadata-ingestion:testSingle -PtestFile=tests/unit/test_bigquery_source.py
# Run all tests under tests/unit
../gradlew :metadata-ingestion:testSingle -PtestFile=tests/unit
```

View File

@ -4,9 +4,9 @@ If you are using Apache Airflow for your scheduling then you might want to also
We've provided a few examples of how to configure your DAG:
- [`mysql_sample_dag`](../src/datahub_provider/example_dags/mysql_sample_dag.py) embeds the full MySQL ingestion configuration inside the DAG.
- [`mysql_sample_dag`](../../metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/mysql_sample_dag.py) embeds the full MySQL ingestion configuration inside the DAG.
- [`snowflake_sample_dag`](../src/datahub_provider/example_dags/snowflake_sample_dag.py) avoids embedding credentials inside the recipe, and instead fetches them from Airflow's [Connections](https://airflow.apache.org/docs/apache-airflow/stable/howto/connection/index.html) feature. You must configure your connections in Airflow to use this approach.
- [`snowflake_sample_dag`](../../metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/snowflake_sample_dag.py) avoids embedding credentials inside the recipe, and instead fetches them from Airflow's [Connections](https://airflow.apache.org/docs/apache-airflow/stable/howto/connection/index.html) feature. You must configure your connections in Airflow to use this approach.
:::tip
@ -37,6 +37,6 @@ In more advanced cases, you might want to store your ingestion recipe in a file
- Create a DAG task to read your DataHub ingestion recipe file and run it. See the example below for reference.
- Deploy the DAG file into airflow for scheduling. Typically this involves checking in the DAG file into your dags folder which is accessible to your Airflow instance.
Example: [`generic_recipe_sample_dag`](../src/datahub_provider/example_dags/generic_recipe_sample_dag.py)
Example: [`generic_recipe_sample_dag`](../../metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/generic_recipe_sample_dag.py)
</details>

View File

@ -75,7 +75,6 @@ disallow_untyped_defs = yes
asyncio_mode = auto
addopts = --cov=src --cov-report= --cov-config setup.cfg --strict-markers
markers =
airflow: marks tests related to airflow (deselect with '-m not airflow')
slow_unit: marks tests to only run slow unit tests (deselect with '-m not slow_unit')
integration: marks tests to only run in integration (deselect with '-m "not integration"')
integration_batch_1: mark tests to only run in batch 1 of integration tests. This is done mainly for parallelisation (deselect with '-m not integration_batch_1')
@ -112,5 +111,3 @@ exclude_lines =
omit =
# omit codegen
src/datahub/metadata/*
# omit example dags
src/datahub_provider/example_dags/*

View File

@ -283,8 +283,7 @@ plugins: Dict[str, Set[str]] = {
},
# Integrations.
"airflow": {
"apache-airflow >= 2.0.2",
*rest_common,
f"acryl-datahub-airflow-plugin == {package_metadata['__version__']}",
},
"circuit-breaker": {
"gql>=3.3.0",
@ -508,8 +507,8 @@ base_dev_requirements = {
"salesforce",
"unity-catalog",
"nifi",
"vertica"
# airflow is added below
"vertica",
"mode",
]
if plugin
for dependency in plugins[plugin]
@ -518,9 +517,6 @@ base_dev_requirements = {
dev_requirements = {
*base_dev_requirements,
# Extra requirements for Airflow.
"apache-airflow[snowflake]>=2.0.2", # snowflake is used in example dags
"virtualenv", # needed by PythonVirtualenvOperator
}
full_test_dev_requirements = {

View File

@ -1,28 +1 @@
import datahub
# This is needed to allow Airflow to pick up specific metadata fields it needs for
# certain features. We recognize it's a bit unclean to define these in multiple places,
# but at this point it's the only workaround if you'd like your custom conn type to
# show up in the Airflow UI.
def get_provider_info():
return {
"name": "DataHub",
"description": "`DataHub <https://datahubproject.io/>`__\n",
"connection-types": [
{
"hook-class-name": "datahub_provider.hooks.datahub.DatahubRestHook",
"connection-type": "datahub_rest",
},
{
"hook-class-name": "datahub_provider.hooks.datahub.DatahubKafkaHook",
"connection-type": "datahub_kafka",
},
],
"hook-class-names": [
"datahub_provider.hooks.datahub.DatahubRestHook",
"datahub_provider.hooks.datahub.DatahubKafkaHook",
],
"package-name": datahub.__package_name__,
"versions": [datahub.__version__],
}
from datahub_airflow_plugin import get_provider_info

View File

@ -1,12 +1,3 @@
# This module must be imported before any Airflow imports in any of our files.
# The AIRFLOW_PATCHED just helps avoid flake8 errors.
from datahub_airflow_plugin._airflow_compat import AIRFLOW_PATCHED
from datahub.utilities._markupsafe_compat import MARKUPSAFE_PATCHED
assert MARKUPSAFE_PATCHED
AIRFLOW_PATCHED = True
__all__ = [
"AIRFLOW_PATCHED",
]
__all__ = ["AIRFLOW_PATCHED"]

View File

@ -1,29 +1,15 @@
from datahub_provider._airflow_compat import AIRFLOW_PATCHED
from airflow.models.baseoperator import BaseOperator
try:
from airflow.models.mappedoperator import MappedOperator
from airflow.models.operator import Operator
from airflow.operators.empty import EmptyOperator
except ModuleNotFoundError:
# Operator isn't a real class, but rather a type alias defined
# as the union of BaseOperator and MappedOperator.
# Since older versions of Airflow don't have MappedOperator, we can just use BaseOperator.
Operator = BaseOperator # type: ignore
MappedOperator = None # type: ignore
from airflow.operators.dummy import DummyOperator as EmptyOperator # type: ignore
try:
from airflow.sensors.external_task import ExternalTaskSensor
except ImportError:
from airflow.sensors.external_task_sensor import ExternalTaskSensor # type: ignore
assert AIRFLOW_PATCHED
from datahub_airflow_plugin._airflow_shims import (
AIRFLOW_PATCHED,
EmptyOperator,
ExternalTaskSensor,
MappedOperator,
Operator,
)
__all__ = [
"Operator",
"MappedOperator",
"AIRFLOW_PATCHED",
"EmptyOperator",
"ExternalTaskSensor",
"Operator",
"MappedOperator",
]

View File

@ -1,114 +1,3 @@
from datetime import datetime
from typing import TYPE_CHECKING, Dict, List
from datahub_airflow_plugin._lineage_core import DatahubBasicLineageConfig
import datahub.emitter.mce_builder as builder
from datahub.api.entities.dataprocess.dataprocess_instance import InstanceRunResult
from datahub.configuration.common import ConfigModel
from datahub.utilities.urns.dataset_urn import DatasetUrn
from datahub_provider.client.airflow_generator import AirflowGenerator
from datahub_provider.entities import _Entity
if TYPE_CHECKING:
from airflow import DAG
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from datahub_provider._airflow_shims import Operator
from datahub_provider.hooks.datahub import DatahubGenericHook
def _entities_to_urn_list(iolets: List[_Entity]) -> List[DatasetUrn]:
return [DatasetUrn.create_from_string(let.urn) for let in iolets]
class DatahubBasicLineageConfig(ConfigModel):
enabled: bool = True
# DataHub hook connection ID.
datahub_conn_id: str
# Cluster to associate with the pipelines and tasks. Defaults to "prod".
cluster: str = builder.DEFAULT_FLOW_CLUSTER
# If true, the owners field of the DAG will be capture as a DataHub corpuser.
capture_ownership_info: bool = True
# If true, the tags field of the DAG will be captured as DataHub tags.
capture_tags_info: bool = True
capture_executions: bool = False
def make_emitter_hook(self) -> "DatahubGenericHook":
# This is necessary to avoid issues with circular imports.
from datahub_provider.hooks.datahub import DatahubGenericHook
return DatahubGenericHook(self.datahub_conn_id)
def send_lineage_to_datahub(
config: DatahubBasicLineageConfig,
operator: "Operator",
inlets: List[_Entity],
outlets: List[_Entity],
context: Dict,
) -> None:
if not config.enabled:
return
dag: "DAG" = context["dag"]
task: "Operator" = context["task"]
ti: "TaskInstance" = context["task_instance"]
hook = config.make_emitter_hook()
emitter = hook.make_emitter()
dataflow = AirflowGenerator.generate_dataflow(
cluster=config.cluster,
dag=dag,
capture_tags=config.capture_tags_info,
capture_owner=config.capture_ownership_info,
)
dataflow.emit(emitter)
operator.log.info(f"Emitted from Lineage: {dataflow}")
datajob = AirflowGenerator.generate_datajob(
cluster=config.cluster,
task=task,
dag=dag,
capture_tags=config.capture_tags_info,
capture_owner=config.capture_ownership_info,
)
datajob.inlets.extend(_entities_to_urn_list(inlets))
datajob.outlets.extend(_entities_to_urn_list(outlets))
datajob.emit(emitter)
operator.log.info(f"Emitted from Lineage: {datajob}")
if config.capture_executions:
dag_run: "DagRun" = context["dag_run"]
dpi = AirflowGenerator.run_datajob(
emitter=emitter,
cluster=config.cluster,
ti=ti,
dag=dag,
dag_run=dag_run,
datajob=datajob,
emit_templates=False,
)
operator.log.info(f"Emitted from Lineage: {dpi}")
dpi = AirflowGenerator.complete_datajob(
emitter=emitter,
cluster=config.cluster,
ti=ti,
dag=dag,
dag_run=dag_run,
datajob=datajob,
result=InstanceRunResult.SUCCESS,
end_timestamp_millis=int(datetime.utcnow().timestamp() * 1000),
)
operator.log.info(f"Emitted from Lineage: {dpi}")
emitter.flush()
__all__ = ["DatahubBasicLineageConfig"]

View File

@ -1,368 +1,3 @@
from datahub_provider._airflow_compat import AIRFLOW_PATCHED
from datahub_airflow_plugin.datahub_plugin import DatahubPlugin
import contextlib
import logging
import traceback
from typing import Any, Callable, Iterable, List, Optional, Union
from airflow.configuration import conf
from airflow.lineage import PIPELINE_OUTLETS
from airflow.models.baseoperator import BaseOperator
from airflow.plugins_manager import AirflowPlugin
from airflow.utils.module_loading import import_string
from cattr import structure
from datahub.api.entities.dataprocess.dataprocess_instance import InstanceRunResult
from datahub_provider._airflow_shims import MappedOperator, Operator
from datahub_provider.client.airflow_generator import AirflowGenerator
from datahub_provider.hooks.datahub import DatahubGenericHook
from datahub_provider.lineage.datahub import DatahubLineageConfig
assert AIRFLOW_PATCHED
logger = logging.getLogger(__name__)
TASK_ON_FAILURE_CALLBACK = "on_failure_callback"
TASK_ON_SUCCESS_CALLBACK = "on_success_callback"
def get_lineage_config() -> DatahubLineageConfig:
"""Load the lineage config from airflow.cfg."""
enabled = conf.get("datahub", "enabled", fallback=True)
datahub_conn_id = conf.get("datahub", "conn_id", fallback="datahub_rest_default")
cluster = conf.get("datahub", "cluster", fallback="prod")
graceful_exceptions = conf.get("datahub", "graceful_exceptions", fallback=True)
capture_tags_info = conf.get("datahub", "capture_tags_info", fallback=True)
capture_ownership_info = conf.get(
"datahub", "capture_ownership_info", fallback=True
)
capture_executions = conf.get("datahub", "capture_executions", fallback=True)
return DatahubLineageConfig(
enabled=enabled,
datahub_conn_id=datahub_conn_id,
cluster=cluster,
graceful_exceptions=graceful_exceptions,
capture_ownership_info=capture_ownership_info,
capture_tags_info=capture_tags_info,
capture_executions=capture_executions,
)
def _task_inlets(operator: "Operator") -> List:
# From Airflow 2.4 _inlets is dropped and inlets used consistently. Earlier it was not the case, so we have to stick there to _inlets
if hasattr(operator, "_inlets"):
return operator._inlets # type: ignore[attr-defined, union-attr]
return operator.inlets
def _task_outlets(operator: "Operator") -> List:
# From Airflow 2.4 _outlets is dropped and inlets used consistently. Earlier it was not the case, so we have to stick there to _outlets
# We have to use _outlets because outlets is empty in Airflow < 2.4.0
if hasattr(operator, "_outlets"):
return operator._outlets # type: ignore[attr-defined, union-attr]
return operator.outlets
def get_inlets_from_task(task: BaseOperator, context: Any) -> Iterable[Any]:
# TODO: Fix for https://github.com/apache/airflow/commit/1b1f3fabc5909a447a6277cafef3a0d4ef1f01ae
# in Airflow 2.4.
# TODO: ignore/handle airflow's dataset type in our lineage
inlets: List[Any] = []
task_inlets = _task_inlets(task)
# From Airflow 2.3 this should be AbstractOperator but due to compatibility reason lets use BaseOperator
if isinstance(task_inlets, (str, BaseOperator)):
inlets = [
task_inlets,
]
if task_inlets and isinstance(task_inlets, list):
inlets = []
task_ids = (
{o for o in task_inlets if isinstance(o, str)}
.union(op.task_id for op in task_inlets if isinstance(op, BaseOperator))
.intersection(task.get_flat_relative_ids(upstream=True))
)
from airflow.lineage import AUTO
# pick up unique direct upstream task_ids if AUTO is specified
if AUTO.upper() in task_inlets or AUTO.lower() in task_inlets:
print("Picking up unique direct upstream task_ids as AUTO is specified")
task_ids = task_ids.union(
task_ids.symmetric_difference(task.upstream_task_ids)
)
inlets = task.xcom_pull(
context, task_ids=list(task_ids), dag_id=task.dag_id, key=PIPELINE_OUTLETS
)
# re-instantiate the obtained inlets
inlets = [
structure(item["data"], import_string(item["type_name"]))
# _get_instance(structure(item, Metadata))
for sublist in inlets
if sublist
for item in sublist
]
for inlet in task_inlets:
if not isinstance(inlet, str):
inlets.append(inlet)
return inlets
def _make_emit_callback(
logger: logging.Logger,
) -> Callable[[Optional[Exception], str], None]:
def emit_callback(err: Optional[Exception], msg: str) -> None:
if err:
logger.error(f"Error sending metadata to datahub: {msg}", exc_info=err)
return emit_callback
def datahub_task_status_callback(context, status):
ti = context["ti"]
task: "BaseOperator" = ti.task
dag = context["dag"]
# This code is from the original airflow lineage code ->
# https://github.com/apache/airflow/blob/main/airflow/lineage/__init__.py
inlets = get_inlets_from_task(task, context)
emitter = (
DatahubGenericHook(context["_datahub_config"].datahub_conn_id)
.get_underlying_hook()
.make_emitter()
)
dataflow = AirflowGenerator.generate_dataflow(
cluster=context["_datahub_config"].cluster,
dag=dag,
capture_tags=context["_datahub_config"].capture_tags_info,
capture_owner=context["_datahub_config"].capture_ownership_info,
)
task.log.info(f"Emitting Datahub Dataflow: {dataflow}")
dataflow.emit(emitter, callback=_make_emit_callback(task.log))
datajob = AirflowGenerator.generate_datajob(
cluster=context["_datahub_config"].cluster,
task=task,
dag=dag,
capture_tags=context["_datahub_config"].capture_tags_info,
capture_owner=context["_datahub_config"].capture_ownership_info,
)
for inlet in inlets:
datajob.inlets.append(inlet.urn)
task_outlets = _task_outlets(task)
for outlet in task_outlets:
datajob.outlets.append(outlet.urn)
task.log.info(f"Emitting Datahub Datajob: {datajob}")
datajob.emit(emitter, callback=_make_emit_callback(task.log))
if context["_datahub_config"].capture_executions:
dpi = AirflowGenerator.run_datajob(
emitter=emitter,
cluster=context["_datahub_config"].cluster,
ti=context["ti"],
dag=dag,
dag_run=context["dag_run"],
datajob=datajob,
start_timestamp_millis=int(ti.start_date.timestamp() * 1000),
)
task.log.info(f"Emitted Start Datahub Dataprocess Instance: {dpi}")
dpi = AirflowGenerator.complete_datajob(
emitter=emitter,
cluster=context["_datahub_config"].cluster,
ti=context["ti"],
dag_run=context["dag_run"],
result=status,
dag=dag,
datajob=datajob,
end_timestamp_millis=int(ti.end_date.timestamp() * 1000),
)
task.log.info(f"Emitted Completed Data Process Instance: {dpi}")
emitter.flush()
def datahub_pre_execution(context):
ti = context["ti"]
task: "BaseOperator" = ti.task
dag = context["dag"]
task.log.info("Running Datahub pre_execute method")
emitter = (
DatahubGenericHook(context["_datahub_config"].datahub_conn_id)
.get_underlying_hook()
.make_emitter()
)
# This code is from the original airflow lineage code ->
# https://github.com/apache/airflow/blob/main/airflow/lineage/__init__.py
inlets = get_inlets_from_task(task, context)
datajob = AirflowGenerator.generate_datajob(
cluster=context["_datahub_config"].cluster,
task=context["ti"].task,
dag=dag,
capture_tags=context["_datahub_config"].capture_tags_info,
capture_owner=context["_datahub_config"].capture_ownership_info,
)
for inlet in inlets:
datajob.inlets.append(inlet.urn)
task_outlets = _task_outlets(task)
for outlet in task_outlets:
datajob.outlets.append(outlet.urn)
task.log.info(f"Emitting Datahub dataJob {datajob}")
datajob.emit(emitter, callback=_make_emit_callback(task.log))
if context["_datahub_config"].capture_executions:
dpi = AirflowGenerator.run_datajob(
emitter=emitter,
cluster=context["_datahub_config"].cluster,
ti=context["ti"],
dag=dag,
dag_run=context["dag_run"],
datajob=datajob,
start_timestamp_millis=int(ti.start_date.timestamp() * 1000),
)
task.log.info(f"Emitting Datahub Dataprocess Instance: {dpi}")
emitter.flush()
def _wrap_pre_execution(pre_execution):
def custom_pre_execution(context):
config = get_lineage_config()
if config.enabled:
context["_datahub_config"] = config
datahub_pre_execution(context)
# Call original policy
if pre_execution:
pre_execution(context)
return custom_pre_execution
def _wrap_on_failure_callback(on_failure_callback):
def custom_on_failure_callback(context):
config = get_lineage_config()
if config.enabled:
context["_datahub_config"] = config
try:
datahub_task_status_callback(context, status=InstanceRunResult.FAILURE)
except Exception as e:
if not config.graceful_exceptions:
raise e
else:
print(f"Exception: {traceback.format_exc()}")
# Call original policy
if on_failure_callback:
on_failure_callback(context)
return custom_on_failure_callback
def _wrap_on_success_callback(on_success_callback):
def custom_on_success_callback(context):
config = get_lineage_config()
if config.enabled:
context["_datahub_config"] = config
try:
datahub_task_status_callback(context, status=InstanceRunResult.SUCCESS)
except Exception as e:
if not config.graceful_exceptions:
raise e
else:
print(f"Exception: {traceback.format_exc()}")
# Call original policy
if on_success_callback:
on_success_callback(context)
return custom_on_success_callback
def task_policy(task: Union[BaseOperator, MappedOperator]) -> None:
task.log.debug(f"Setting task policy for Dag: {task.dag_id} Task: {task.task_id}")
# task.add_inlets(["auto"])
# task.pre_execute = _wrap_pre_execution(task.pre_execute)
# MappedOperator's callbacks don't have setters until Airflow 2.X.X
# https://github.com/apache/airflow/issues/24547
# We can bypass this by going through partial_kwargs for now
if MappedOperator and isinstance(task, MappedOperator): # type: ignore
on_failure_callback_prop: property = getattr(
MappedOperator, TASK_ON_FAILURE_CALLBACK
)
on_success_callback_prop: property = getattr(
MappedOperator, TASK_ON_SUCCESS_CALLBACK
)
if not on_failure_callback_prop.fset or not on_success_callback_prop.fset:
task.log.debug(
"Using MappedOperator's partial_kwargs instead of callback properties"
)
task.partial_kwargs[TASK_ON_FAILURE_CALLBACK] = _wrap_on_failure_callback(
task.on_failure_callback
)
task.partial_kwargs[TASK_ON_SUCCESS_CALLBACK] = _wrap_on_success_callback(
task.on_success_callback
)
return
task.on_failure_callback = _wrap_on_failure_callback(task.on_failure_callback) # type: ignore
task.on_success_callback = _wrap_on_success_callback(task.on_success_callback) # type: ignore
# task.pre_execute = _wrap_pre_execution(task.pre_execute)
def _wrap_task_policy(policy):
if policy and hasattr(policy, "_task_policy_patched_by"):
return policy
def custom_task_policy(task):
policy(task)
task_policy(task)
# Add a flag to the policy to indicate that we've patched it.
custom_task_policy._task_policy_patched_by = "datahub_plugin" # type: ignore[attr-defined]
return custom_task_policy
def _patch_policy(settings):
if hasattr(settings, "task_policy"):
datahub_task_policy = _wrap_task_policy(settings.task_policy)
settings.task_policy = datahub_task_policy
def _patch_datahub_policy():
with contextlib.suppress(ImportError):
import airflow_local_settings
_patch_policy(airflow_local_settings)
from airflow.models.dagbag import settings
_patch_policy(settings)
_patch_datahub_policy()
class DatahubPlugin(AirflowPlugin):
name = "datahub_plugin"
__all__ = ["DatahubPlugin"]

View File

@ -1,509 +1,3 @@
from datahub_provider._airflow_compat import AIRFLOW_PATCHED
from datahub_airflow_plugin.client.airflow_generator import AirflowGenerator
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast
from airflow.configuration import conf
from datahub.api.entities.datajob import DataFlow, DataJob
from datahub.api.entities.dataprocess.dataprocess_instance import (
DataProcessInstance,
InstanceRunResult,
)
from datahub.metadata.schema_classes import DataProcessTypeClass
from datahub.utilities.urns.data_flow_urn import DataFlowUrn
from datahub.utilities.urns.data_job_urn import DataJobUrn
assert AIRFLOW_PATCHED
if TYPE_CHECKING:
from airflow import DAG
from airflow.models import DagRun, TaskInstance
from datahub.emitter.kafka_emitter import DatahubKafkaEmitter
from datahub.emitter.rest_emitter import DatahubRestEmitter
from datahub_provider._airflow_shims import Operator
def _task_downstream_task_ids(operator: "Operator") -> Set[str]:
if hasattr(operator, "downstream_task_ids"):
return operator.downstream_task_ids
return operator._downstream_task_id # type: ignore[attr-defined,union-attr]
class AirflowGenerator:
@staticmethod
def _get_dependencies(
task: "Operator", dag: "DAG", flow_urn: DataFlowUrn
) -> List[DataJobUrn]:
from datahub_provider._airflow_shims import ExternalTaskSensor
# resolve URNs for upstream nodes in subdags upstream of the current task.
upstream_subdag_task_urns: List[DataJobUrn] = []
for upstream_task_id in task.upstream_task_ids:
upstream_task = dag.task_dict[upstream_task_id]
# if upstream task is not a subdag, then skip it
upstream_subdag = getattr(upstream_task, "subdag", None)
if upstream_subdag is None:
continue
# else, link the leaf tasks of the upstream subdag as upstream tasks
for upstream_subdag_task_id in upstream_subdag.task_dict:
upstream_subdag_task = upstream_subdag.task_dict[
upstream_subdag_task_id
]
upstream_subdag_task_urn = DataJobUrn.create_from_ids(
job_id=upstream_subdag_task_id, data_flow_urn=str(flow_urn)
)
# if subdag task is a leaf task, then link it as an upstream task
if len(_task_downstream_task_ids(upstream_subdag_task)) == 0:
upstream_subdag_task_urns.append(upstream_subdag_task_urn)
# resolve URNs for upstream nodes that trigger the subdag containing the current task.
# (if it is in a subdag at all)
upstream_subdag_triggers: List[DataJobUrn] = []
# subdags are always named with 'parent.child' style or Airflow won't run them
# add connection from subdag trigger(s) if subdag task has no upstreams
if (
dag.is_subdag
and dag.parent_dag is not None
and len(task.upstream_task_ids) == 0
):
# filter through the parent dag's tasks and find the subdag trigger(s)
subdags = [
x for x in dag.parent_dag.task_dict.values() if x.subdag is not None
]
matched_subdags = [
x for x in subdags if x.subdag and x.subdag.dag_id == dag.dag_id
]
# id of the task containing the subdag
subdag_task_id = matched_subdags[0].task_id
# iterate through the parent dag's tasks and find the ones that trigger the subdag
for upstream_task_id in dag.parent_dag.task_dict:
upstream_task = dag.parent_dag.task_dict[upstream_task_id]
upstream_task_urn = DataJobUrn.create_from_ids(
data_flow_urn=str(flow_urn), job_id=upstream_task_id
)
# if the task triggers the subdag, link it to this node in the subdag
if subdag_task_id in _task_downstream_task_ids(upstream_task):
upstream_subdag_triggers.append(upstream_task_urn)
# If the operator is an ExternalTaskSensor then we set the remote task as upstream.
# It is possible to tie an external sensor to DAG if external_task_id is omitted but currently we can't tie
# jobflow to anothet jobflow.
external_task_upstreams = []
if task.task_type == "ExternalTaskSensor":
task = cast(ExternalTaskSensor, task)
if hasattr(task, "external_task_id") and task.external_task_id is not None:
external_task_upstreams = [
DataJobUrn.create_from_ids(
job_id=task.external_task_id,
data_flow_urn=str(
DataFlowUrn.create_from_ids(
orchestrator=flow_urn.get_orchestrator_name(),
flow_id=task.external_dag_id,
env=flow_urn.get_env(),
)
),
)
]
# exclude subdag operator tasks since these are not emitted, resulting in empty metadata
upstream_tasks = (
[
DataJobUrn.create_from_ids(job_id=task_id, data_flow_urn=str(flow_urn))
for task_id in task.upstream_task_ids
if getattr(dag.task_dict[task_id], "subdag", None) is None
]
+ upstream_subdag_task_urns
+ upstream_subdag_triggers
+ external_task_upstreams
)
return upstream_tasks
@staticmethod
def generate_dataflow(
cluster: str,
dag: "DAG",
capture_owner: bool = True,
capture_tags: bool = True,
) -> DataFlow:
"""
Generates a Dataflow object from an Airflow DAG
:param cluster: str - name of the cluster
:param dag: DAG -
:param capture_tags:
:param capture_owner:
:return: DataFlow - Data generated dataflow
"""
id = dag.dag_id
orchestrator = "airflow"
description = f"{dag.description}\n\n{dag.doc_md or ''}"
data_flow = DataFlow(
env=cluster, id=id, orchestrator=orchestrator, description=description
)
flow_property_bag: Dict[str, str] = {}
allowed_flow_keys = [
"_access_control",
"_concurrency",
"_default_view",
"catchup",
"fileloc",
"is_paused_upon_creation",
"start_date",
"tags",
"timezone",
]
for key in allowed_flow_keys:
if hasattr(dag, key):
flow_property_bag[key] = repr(getattr(dag, key))
data_flow.properties = flow_property_bag
base_url = conf.get("webserver", "base_url")
data_flow.url = f"{base_url}/tree?dag_id={dag.dag_id}"
if capture_owner and dag.owner:
data_flow.owners.add(dag.owner)
if capture_tags and dag.tags:
data_flow.tags.update(dag.tags)
return data_flow
@staticmethod
def _get_description(task: "Operator") -> Optional[str]:
from airflow.models.baseoperator import BaseOperator
if not isinstance(task, BaseOperator):
# TODO: Get docs for mapped operators.
return None
if hasattr(task, "doc") and task.doc:
return task.doc
elif hasattr(task, "doc_md") and task.doc_md:
return task.doc_md
elif hasattr(task, "doc_json") and task.doc_json:
return task.doc_json
elif hasattr(task, "doc_yaml") and task.doc_yaml:
return task.doc_yaml
elif hasattr(task, "doc_rst") and task.doc_yaml:
return task.doc_yaml
return None
@staticmethod
def generate_datajob(
cluster: str,
task: "Operator",
dag: "DAG",
set_dependencies: bool = True,
capture_owner: bool = True,
capture_tags: bool = True,
) -> DataJob:
"""
:param cluster: str
:param task: TaskIntance
:param dag: DAG
:param set_dependencies: bool - whether to extract dependencies from airflow task
:param capture_owner: bool - whether to extract owner from airflow task
:param capture_tags: bool - whether to set tags automatically from airflow task
:return: DataJob - returns the generated DataJob object
"""
dataflow_urn = DataFlowUrn.create_from_ids(
orchestrator="airflow", env=cluster, flow_id=dag.dag_id
)
datajob = DataJob(id=task.task_id, flow_urn=dataflow_urn)
# TODO add support for MappedOperator
datajob.description = AirflowGenerator._get_description(task)
job_property_bag: Dict[str, str] = {}
allowed_task_keys = [
"_downstream_task_ids",
"_inlets",
"_outlets",
"_task_type",
"_task_module",
"depends_on_past",
"email",
"label",
"execution_timeout",
"sla",
"sql",
"task_id",
"trigger_rule",
"wait_for_downstream",
# In Airflow 2.3, _downstream_task_ids was renamed to downstream_task_ids
"downstream_task_ids",
# In Airflow 2.4, _inlets and _outlets were removed in favor of non-private versions.
"inlets",
"outlets",
]
for key in allowed_task_keys:
if hasattr(task, key):
job_property_bag[key] = repr(getattr(task, key))
datajob.properties = job_property_bag
base_url = conf.get("webserver", "base_url")
datajob.url = f"{base_url}/taskinstance/list/?flt1_dag_id_equals={datajob.flow_urn.get_flow_id()}&_flt_3_task_id={task.task_id}"
if capture_owner and dag.owner:
datajob.owners.add(dag.owner)
if capture_tags and dag.tags:
datajob.tags.update(dag.tags)
if set_dependencies:
datajob.upstream_urns.extend(
AirflowGenerator._get_dependencies(
task=task, dag=dag, flow_urn=datajob.flow_urn
)
)
return datajob
@staticmethod
def create_datajob_instance(
cluster: str,
task: "Operator",
dag: "DAG",
data_job: Optional[DataJob] = None,
) -> DataProcessInstance:
if data_job is None:
data_job = AirflowGenerator.generate_datajob(cluster, task=task, dag=dag)
dpi = DataProcessInstance.from_datajob(
datajob=data_job, id=task.task_id, clone_inlets=True, clone_outlets=True
)
return dpi
@staticmethod
def run_dataflow(
emitter: Union["DatahubRestEmitter", "DatahubKafkaEmitter"],
cluster: str,
dag_run: "DagRun",
start_timestamp_millis: Optional[int] = None,
dataflow: Optional[DataFlow] = None,
) -> None:
if dataflow is None:
assert dag_run.dag
dataflow = AirflowGenerator.generate_dataflow(cluster, dag_run.dag)
if start_timestamp_millis is None:
assert dag_run.execution_date
start_timestamp_millis = int(dag_run.execution_date.timestamp() * 1000)
assert dag_run.run_id
dpi = DataProcessInstance.from_dataflow(dataflow=dataflow, id=dag_run.run_id)
# This property only exists in Airflow2
if hasattr(dag_run, "run_type"):
from airflow.utils.types import DagRunType
if dag_run.run_type == DagRunType.SCHEDULED:
dpi.type = DataProcessTypeClass.BATCH_SCHEDULED
elif dag_run.run_type == DagRunType.MANUAL:
dpi.type = DataProcessTypeClass.BATCH_AD_HOC
else:
if dag_run.run_id.startswith("scheduled__"):
dpi.type = DataProcessTypeClass.BATCH_SCHEDULED
else:
dpi.type = DataProcessTypeClass.BATCH_AD_HOC
property_bag: Dict[str, str] = {}
property_bag["run_id"] = str(dag_run.run_id)
property_bag["execution_date"] = str(dag_run.execution_date)
property_bag["end_date"] = str(dag_run.end_date)
property_bag["start_date"] = str(dag_run.start_date)
property_bag["creating_job_id"] = str(dag_run.creating_job_id)
property_bag["data_interval_start"] = str(dag_run.data_interval_start)
property_bag["data_interval_end"] = str(dag_run.data_interval_end)
property_bag["external_trigger"] = str(dag_run.external_trigger)
dpi.properties.update(property_bag)
dpi.emit_process_start(
emitter=emitter, start_timestamp_millis=start_timestamp_millis
)
@staticmethod
def complete_dataflow(
emitter: Union["DatahubRestEmitter", "DatahubKafkaEmitter"],
cluster: str,
dag_run: "DagRun",
end_timestamp_millis: Optional[int] = None,
dataflow: Optional[DataFlow] = None,
) -> None:
"""
:param emitter: DatahubRestEmitter - the datahub rest emitter to emit the generated mcps
:param cluster: str - name of the cluster
:param dag_run: DagRun
:param end_timestamp_millis: Optional[int] - the completion time in milliseconds if not set the current time will be used.
:param dataflow: Optional[Dataflow]
"""
if dataflow is None:
assert dag_run.dag
dataflow = AirflowGenerator.generate_dataflow(cluster, dag_run.dag)
assert dag_run.run_id
dpi = DataProcessInstance.from_dataflow(dataflow=dataflow, id=dag_run.run_id)
if end_timestamp_millis is None:
if dag_run.end_date is None:
raise Exception(
f"Dag {dag_run.dag_id}_{dag_run.run_id} is still running and unable to get end_date..."
)
end_timestamp_millis = int(dag_run.end_date.timestamp() * 1000)
# We should use DagRunState but it is not available in Airflow 1
if dag_run.state == "success":
result = InstanceRunResult.SUCCESS
elif dag_run.state == "failed":
result = InstanceRunResult.FAILURE
else:
raise Exception(
f"Result should be either success or failure and it was {dag_run.state}"
)
dpi.emit_process_end(
emitter=emitter,
end_timestamp_millis=end_timestamp_millis,
result=result,
result_type="airflow",
)
@staticmethod
def run_datajob(
emitter: Union["DatahubRestEmitter", "DatahubKafkaEmitter"],
cluster: str,
ti: "TaskInstance",
dag: "DAG",
dag_run: "DagRun",
start_timestamp_millis: Optional[int] = None,
datajob: Optional[DataJob] = None,
attempt: Optional[int] = None,
emit_templates: bool = True,
) -> DataProcessInstance:
if datajob is None:
datajob = AirflowGenerator.generate_datajob(cluster, ti.task, dag)
assert dag_run.run_id
dpi = DataProcessInstance.from_datajob(
datajob=datajob,
id=f"{dag.dag_id}_{ti.task_id}_{dag_run.run_id}",
clone_inlets=True,
clone_outlets=True,
)
job_property_bag: Dict[str, str] = {}
job_property_bag["run_id"] = str(dag_run.run_id)
job_property_bag["duration"] = str(ti.duration)
job_property_bag["start_date"] = str(ti.start_date)
job_property_bag["end_date"] = str(ti.end_date)
job_property_bag["execution_date"] = str(ti.execution_date)
job_property_bag["try_number"] = str(ti.try_number - 1)
job_property_bag["hostname"] = str(ti.hostname)
job_property_bag["max_tries"] = str(ti.max_tries)
# Not compatible with Airflow 1
if hasattr(ti, "external_executor_id"):
job_property_bag["external_executor_id"] = str(ti.external_executor_id)
job_property_bag["pid"] = str(ti.pid)
job_property_bag["state"] = str(ti.state)
job_property_bag["operator"] = str(ti.operator)
job_property_bag["priority_weight"] = str(ti.priority_weight)
job_property_bag["unixname"] = str(ti.unixname)
job_property_bag["log_url"] = ti.log_url
dpi.properties.update(job_property_bag)
dpi.url = ti.log_url
# This property only exists in Airflow2
if hasattr(ti, "dag_run") and hasattr(ti.dag_run, "run_type"):
from airflow.utils.types import DagRunType
if ti.dag_run.run_type == DagRunType.SCHEDULED:
dpi.type = DataProcessTypeClass.BATCH_SCHEDULED
elif ti.dag_run.run_type == DagRunType.MANUAL:
dpi.type = DataProcessTypeClass.BATCH_AD_HOC
else:
if dag_run.run_id.startswith("scheduled__"):
dpi.type = DataProcessTypeClass.BATCH_SCHEDULED
else:
dpi.type = DataProcessTypeClass.BATCH_AD_HOC
if start_timestamp_millis is None:
assert ti.start_date
start_timestamp_millis = int(ti.start_date.timestamp() * 1000)
if attempt is None:
attempt = ti.try_number
dpi.emit_process_start(
emitter=emitter,
start_timestamp_millis=start_timestamp_millis,
attempt=attempt,
emit_template=emit_templates,
)
return dpi
@staticmethod
def complete_datajob(
emitter: Union["DatahubRestEmitter", "DatahubKafkaEmitter"],
cluster: str,
ti: "TaskInstance",
dag: "DAG",
dag_run: "DagRun",
end_timestamp_millis: Optional[int] = None,
result: Optional[InstanceRunResult] = None,
datajob: Optional[DataJob] = None,
) -> DataProcessInstance:
"""
:param emitter: DatahubRestEmitter
:param cluster: str
:param ti: TaskInstance
:param dag: DAG
:param dag_run: DagRun
:param end_timestamp_millis: Optional[int]
:param result: Optional[str] One of the result from datahub.metadata.schema_class.RunResultTypeClass
:param datajob: Optional[DataJob]
:return: DataProcessInstance
"""
if datajob is None:
datajob = AirflowGenerator.generate_datajob(cluster, ti.task, dag)
if end_timestamp_millis is None:
assert ti.end_date
end_timestamp_millis = int(ti.end_date.timestamp() * 1000)
if result is None:
# We should use TaskInstanceState but it is not available in Airflow 1
if ti.state == "success":
result = InstanceRunResult.SUCCESS
elif ti.state == "failed":
result = InstanceRunResult.FAILURE
else:
raise Exception(
f"Result should be either success or failure and it was {ti.state}"
)
dpi = DataProcessInstance.from_datajob(
datajob=datajob,
id=f"{dag.dag_id}_{ti.task_id}_{dag_run.run_id}",
clone_inlets=True,
clone_outlets=True,
)
dpi.emit_process_end(
emitter=emitter,
end_timestamp_millis=end_timestamp_millis,
result=result,
result_type="airflow",
)
return dpi
__all__ = ["AirflowGenerator"]

View File

@ -1,48 +1,3 @@
from abc import abstractmethod
from typing import Optional
from datahub_airflow_plugin.entities import Dataset, Urn, _Entity
import attr
import datahub.emitter.mce_builder as builder
from datahub.utilities.urns.urn import guess_entity_type
class _Entity:
@property
@abstractmethod
def urn(self) -> str:
pass
@attr.s(auto_attribs=True, str=True)
class Dataset(_Entity):
platform: str
name: str
env: str = builder.DEFAULT_ENV
platform_instance: Optional[str] = None
@property
def urn(self):
return builder.make_dataset_urn_with_platform_instance(
platform=self.platform,
name=self.name,
platform_instance=self.platform_instance,
env=self.env,
)
@attr.s(str=True)
class Urn(_Entity):
_urn: str = attr.ib()
@_urn.validator
def _validate_urn(self, attribute, value):
if not value.startswith("urn:"):
raise ValueError("invalid urn provided: urns must start with 'urn:'")
if guess_entity_type(value) != "dataset":
# This is because DataJobs only support Dataset lineage.
raise ValueError("Airflow lineage currently only supports datasets")
@property
def urn(self):
return self._urn
__all__ = ["_Entity", "Dataset", "Urn"]

View File

@ -1,216 +1,8 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from datahub.metadata.com.linkedin.pegasus2avro.mxe import (
MetadataChangeEvent,
MetadataChangeProposal,
from datahub_airflow_plugin.hooks.datahub import (
BaseHook,
DatahubGenericHook,
DatahubKafkaHook,
DatahubRestHook,
)
if TYPE_CHECKING:
from airflow.models.connection import Connection
from datahub.emitter.kafka_emitter import DatahubKafkaEmitter
from datahub.emitter.rest_emitter import DatahubRestEmitter
from datahub.ingestion.sink.datahub_kafka import KafkaSinkConfig
class DatahubRestHook(BaseHook):
"""
Creates a DataHub Rest API connection used to send metadata to DataHub.
Takes the endpoint for your DataHub Rest API in the Server Endpoint(host) field.
URI example: ::
AIRFLOW_CONN_DATAHUB_REST_DEFAULT='datahub-rest://rest-endpoint'
:param datahub_rest_conn_id: Reference to the DataHub Rest connection.
:type datahub_rest_conn_id: str
"""
conn_name_attr = "datahub_rest_conn_id"
default_conn_name = "datahub_rest_default"
conn_type = "datahub_rest"
hook_name = "DataHub REST Server"
def __init__(self, datahub_rest_conn_id: str = default_conn_name) -> None:
super().__init__()
self.datahub_rest_conn_id = datahub_rest_conn_id
@staticmethod
def get_connection_form_widgets() -> Dict[str, Any]:
return {}
@staticmethod
def get_ui_field_behaviour() -> Dict:
"""Returns custom field behavior"""
return {
"hidden_fields": ["port", "schema", "login"],
"relabeling": {
"host": "Server Endpoint",
},
}
def _get_config(self) -> Tuple[str, Optional[str], Optional[int]]:
conn: "Connection" = self.get_connection(self.datahub_rest_conn_id)
host = conn.host
if not host:
raise AirflowException("host parameter is required")
if conn.port:
if ":" in host:
raise AirflowException(
"host parameter should not contain a port number if the port is specified separately"
)
host = f"{host}:{conn.port}"
password = conn.password
timeout_sec = conn.extra_dejson.get("timeout_sec")
return (host, password, timeout_sec)
def make_emitter(self) -> "DatahubRestEmitter":
import datahub.emitter.rest_emitter
return datahub.emitter.rest_emitter.DatahubRestEmitter(*self._get_config())
def emit_mces(self, mces: List[MetadataChangeEvent]) -> None:
emitter = self.make_emitter()
for mce in mces:
emitter.emit_mce(mce)
def emit_mcps(self, mcps: List[MetadataChangeProposal]) -> None:
emitter = self.make_emitter()
for mce in mcps:
emitter.emit_mcp(mce)
class DatahubKafkaHook(BaseHook):
"""
Creates a DataHub Kafka connection used to send metadata to DataHub.
Takes your kafka broker in the Kafka Broker(host) field.
URI example: ::
AIRFLOW_CONN_DATAHUB_KAFKA_DEFAULT='datahub-kafka://kafka-broker'
:param datahub_kafka_conn_id: Reference to the DataHub Kafka connection.
:type datahub_kafka_conn_id: str
"""
conn_name_attr = "datahub_kafka_conn_id"
default_conn_name = "datahub_kafka_default"
conn_type = "datahub_kafka"
hook_name = "DataHub Kafka Sink"
def __init__(self, datahub_kafka_conn_id: str = default_conn_name) -> None:
super().__init__()
self.datahub_kafka_conn_id = datahub_kafka_conn_id
@staticmethod
def get_connection_form_widgets() -> Dict[str, Any]:
return {}
@staticmethod
def get_ui_field_behaviour() -> Dict:
"""Returns custom field behavior"""
return {
"hidden_fields": ["port", "schema", "login", "password"],
"relabeling": {
"host": "Kafka Broker",
},
}
def _get_config(self) -> "KafkaSinkConfig":
import datahub.ingestion.sink.datahub_kafka
conn = self.get_connection(self.datahub_kafka_conn_id)
obj = conn.extra_dejson
obj.setdefault("connection", {})
if conn.host is not None:
if "bootstrap" in obj["connection"]:
raise AirflowException(
"Kafka broker specified twice (present in host and extra)"
)
obj["connection"]["bootstrap"] = ":".join(
map(str, filter(None, [conn.host, conn.port]))
)
config = datahub.ingestion.sink.datahub_kafka.KafkaSinkConfig.parse_obj(obj)
return config
def make_emitter(self) -> "DatahubKafkaEmitter":
import datahub.emitter.kafka_emitter
sink_config = self._get_config()
return datahub.emitter.kafka_emitter.DatahubKafkaEmitter(sink_config)
def emit_mces(self, mces: List[MetadataChangeEvent]) -> None:
emitter = self.make_emitter()
errors = []
def callback(exc, msg):
if exc:
errors.append(exc)
for mce in mces:
emitter.emit_mce_async(mce, callback)
emitter.flush()
if errors:
raise AirflowException(f"failed to push some MCEs: {errors}")
def emit_mcps(self, mcps: List[MetadataChangeProposal]) -> None:
emitter = self.make_emitter()
errors = []
def callback(exc, msg):
if exc:
errors.append(exc)
for mcp in mcps:
emitter.emit_mcp_async(mcp, callback)
emitter.flush()
if errors:
raise AirflowException(f"failed to push some MCPs: {errors}")
class DatahubGenericHook(BaseHook):
"""
Emits Metadata Change Events using either the DatahubRestHook or the
DatahubKafkaHook. Set up a DataHub Rest or Kafka connection to use.
:param datahub_conn_id: Reference to the DataHub connection.
:type datahub_conn_id: str
"""
def __init__(self, datahub_conn_id: str) -> None:
super().__init__()
self.datahub_conn_id = datahub_conn_id
def get_underlying_hook(self) -> Union[DatahubRestHook, DatahubKafkaHook]:
conn = self.get_connection(self.datahub_conn_id)
# We need to figure out the underlying hook type. First check the
# conn_type. If that fails, attempt to guess using the conn id name.
if conn.conn_type == DatahubRestHook.conn_type:
return DatahubRestHook(self.datahub_conn_id)
elif conn.conn_type == DatahubKafkaHook.conn_type:
return DatahubKafkaHook(self.datahub_conn_id)
elif "rest" in self.datahub_conn_id:
return DatahubRestHook(self.datahub_conn_id)
elif "kafka" in self.datahub_conn_id:
return DatahubKafkaHook(self.datahub_conn_id)
else:
raise AirflowException(
f"DataHub cannot handle conn_type {conn.conn_type} in {conn}"
)
def make_emitter(self) -> Union["DatahubRestEmitter", "DatahubKafkaEmitter"]:
return self.get_underlying_hook().make_emitter()
def emit_mces(self, mces: List[MetadataChangeEvent]) -> None:
return self.get_underlying_hook().emit_mces(mces)
__all__ = ["DatahubRestHook", "DatahubKafkaHook", "DatahubGenericHook", "BaseHook"]

View File

@ -1,91 +1,6 @@
import json
from typing import TYPE_CHECKING, Dict, List, Optional
from airflow.configuration import conf
from airflow.lineage.backend import LineageBackend
from datahub_provider._lineage_core import (
DatahubBasicLineageConfig,
send_lineage_to_datahub,
from datahub_airflow_plugin.lineage.datahub import (
DatahubLineageBackend,
DatahubLineageConfig,
)
if TYPE_CHECKING:
from airflow.models.baseoperator import BaseOperator
class DatahubLineageConfig(DatahubBasicLineageConfig):
# If set to true, most runtime errors in the lineage backend will be
# suppressed and will not cause the overall task to fail. Note that
# configuration issues will still throw exceptions.
graceful_exceptions: bool = True
def get_lineage_config() -> DatahubLineageConfig:
"""Load the lineage config from airflow.cfg."""
# The kwargs pattern is also used for secret backends.
kwargs_str = conf.get("lineage", "datahub_kwargs", fallback="{}")
kwargs = json.loads(kwargs_str)
# Continue to support top-level datahub_conn_id config.
datahub_conn_id = conf.get("lineage", "datahub_conn_id", fallback=None)
if datahub_conn_id:
kwargs["datahub_conn_id"] = datahub_conn_id
return DatahubLineageConfig.parse_obj(kwargs)
class DatahubLineageBackend(LineageBackend):
"""
Sends lineage data from tasks to DataHub.
Configurable via ``airflow.cfg`` as follows: ::
# For REST-based:
airflow connections add --conn-type 'datahub_rest' 'datahub_rest_default' --conn-host 'http://localhost:8080'
# For Kafka-based (standard Kafka sink config can be passed via extras):
airflow connections add --conn-type 'datahub_kafka' 'datahub_kafka_default' --conn-host 'broker:9092' --conn-extra '{}'
[lineage]
backend = datahub_provider.lineage.datahub.DatahubLineageBackend
datahub_kwargs = {
"datahub_conn_id": "datahub_rest_default",
"capture_ownership_info": true,
"capture_tags_info": true,
"graceful_exceptions": true }
# The above indentation is important!
"""
def __init__(self) -> None:
super().__init__()
# By attempting to get and parse the config, we can detect configuration errors
# ahead of time. The init method is only called in Airflow 2.x.
_ = get_lineage_config()
# With Airflow 2.0, this can be an instance method. However, with Airflow 1.10.x, this
# method is used statically, even though LineageBackend declares it as an instance variable.
@staticmethod
def send_lineage(
operator: "BaseOperator",
inlets: Optional[List] = None, # unused
outlets: Optional[List] = None, # unused
context: Optional[Dict] = None,
) -> None:
config = get_lineage_config()
if not config.enabled:
return
try:
context = context or {} # ensure not None to satisfy mypy
send_lineage_to_datahub(
config, operator, operator.inlets, operator.outlets, context
)
except Exception as e:
if config.graceful_exceptions:
operator.log.error(e)
operator.log.info(
"Suppressing error because graceful_exceptions is set"
)
else:
raise
__all__ = ["DatahubLineageBackend", "DatahubLineageConfig"]

View File

@ -1,63 +1,6 @@
from typing import List, Union
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeEvent
from datahub_provider.hooks.datahub import (
DatahubGenericHook,
DatahubKafkaHook,
DatahubRestHook,
from datahub_airflow_plugin.operators.datahub import (
DatahubBaseOperator,
DatahubEmitterOperator,
)
class DatahubBaseOperator(BaseOperator):
"""
The DatahubBaseOperator is used as a base operator all DataHub operators.
"""
ui_color = "#4398c8"
hook: Union[DatahubRestHook, DatahubKafkaHook]
# mypy is not a fan of this. Newer versions of Airflow support proper typing for the decorator
# using PEP 612. However, there is not yet a good way to inherit the types of the kwargs from
# the superclass.
@apply_defaults # type: ignore[misc]
def __init__( # type: ignore[no-untyped-def]
self,
*,
datahub_conn_id: str,
**kwargs,
):
super().__init__(**kwargs)
self.datahub_conn_id = datahub_conn_id
self.generic_hook = DatahubGenericHook(datahub_conn_id)
class DatahubEmitterOperator(DatahubBaseOperator):
"""
Emits a Metadata Change Event to DataHub using either a DataHub
Rest or Kafka connection.
:param datahub_conn_id: Reference to the DataHub Rest or Kafka Connection.
:type datahub_conn_id: str
"""
# See above for why these mypy type issues are ignored here.
@apply_defaults # type: ignore[misc]
def __init__( # type: ignore[no-untyped-def]
self,
mces: List[MetadataChangeEvent],
datahub_conn_id: str,
**kwargs,
):
super().__init__(
datahub_conn_id=datahub_conn_id,
**kwargs,
)
self.mces = mces
def execute(self, context):
self.generic_hook.get_underlying_hook().emit_mces(self.mces)
__all__ = ["DatahubEmitterOperator", "DatahubBaseOperator"]

View File

@ -1,78 +1,5 @@
import datetime
from typing import Any, List, Optional, Sequence, Union
from airflow.models import BaseOperator
from datahub.api.circuit_breaker import (
AssertionCircuitBreaker,
AssertionCircuitBreakerConfig,
from datahub_airflow_plugin.operators.datahub_assertion_operator import (
DataHubAssertionOperator,
)
from datahub_provider.hooks.datahub import DatahubRestHook
class DataHubAssertionOperator(BaseOperator):
r"""
DataHub Assertion Circuit Breaker Operator.
:param urn: The DataHub dataset unique identifier. (templated)
:param datahub_rest_conn_id: The REST datahub connection id to communicate with DataHub
which is set as Airflow connection.
:param check_last_assertion_time: If set it checks assertions after the last operation was set on the dataset.
By default it is True.
:param time_delta: If verify_after_last_update is False it checks for assertion within the time delta.
"""
template_fields: Sequence[str] = ("urn",)
circuit_breaker: AssertionCircuitBreaker
urn: Union[List[str], str]
def __init__( # type: ignore[no-untyped-def]
self,
*,
urn: Union[List[str], str],
datahub_rest_conn_id: Optional[str] = None,
check_last_assertion_time: bool = True,
time_delta: Optional[datetime.timedelta] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
hook: DatahubRestHook
if datahub_rest_conn_id is not None:
hook = DatahubRestHook(datahub_rest_conn_id=datahub_rest_conn_id)
else:
hook = DatahubRestHook()
host, password, timeout_sec = hook._get_config()
self.urn = urn
config: AssertionCircuitBreakerConfig = AssertionCircuitBreakerConfig(
datahub_host=host,
datahub_token=password,
timeout=timeout_sec,
verify_after_last_update=check_last_assertion_time,
time_delta=time_delta if time_delta else datetime.timedelta(days=1),
)
self.circuit_breaker = AssertionCircuitBreaker(config=config)
def execute(self, context: Any) -> bool:
if "datahub_silence_circuit_breakers" in context["dag_run"].conf:
self.log.info(
"Circuit breaker is silenced because datahub_silence_circuit_breakers config is set"
)
return True
self.log.info(f"Checking if dataset {self.urn} is ready to be consumed")
if isinstance(self.urn, str):
urns = [self.urn]
elif isinstance(self.urn, list):
urns = self.urn
else:
raise Exception(f"urn parameter has invalid type {type(self.urn)}")
for urn in urns:
self.log.info(f"Checking if dataset {self.urn} is ready to be consumed")
ret = self.circuit_breaker.is_circuit_breaker_active(urn=urn)
if ret:
raise Exception(f"Dataset {self.urn} is not in consumable state")
return True
__all__ = ["DataHubAssertionOperator"]

View File

@ -1,78 +1,5 @@
import datetime
from typing import Any, List, Optional, Sequence, Union
from airflow.sensors.base import BaseSensorOperator
from datahub.api.circuit_breaker import (
AssertionCircuitBreaker,
AssertionCircuitBreakerConfig,
from datahub_airflow_plugin.operators.datahub_assertion_sensor import (
DataHubAssertionSensor,
)
from datahub_provider.hooks.datahub import DatahubRestHook
class DataHubAssertionSensor(BaseSensorOperator):
r"""
DataHub Assertion Circuit Breaker Sensor.
:param urn: The DataHub dataset unique identifier. (templated)
:param datahub_rest_conn_id: The REST datahub connection id to communicate with DataHub
which is set as Airflow connection.
:param check_last_assertion_time: If set it checks assertions after the last operation was set on the dataset.
By default it is True.
:param time_delta: If verify_after_last_update is False it checks for assertion within the time delta.
"""
template_fields: Sequence[str] = ("urn",)
circuit_breaker: AssertionCircuitBreaker
urn: Union[List[str], str]
def __init__( # type: ignore[no-untyped-def]
self,
*,
urn: Union[List[str], str],
datahub_rest_conn_id: Optional[str] = None,
check_last_assertion_time: bool = True,
time_delta: datetime.timedelta = datetime.timedelta(days=1),
**kwargs,
) -> None:
super().__init__(**kwargs)
hook: DatahubRestHook
if datahub_rest_conn_id is not None:
hook = DatahubRestHook(datahub_rest_conn_id=datahub_rest_conn_id)
else:
hook = DatahubRestHook()
host, password, timeout_sec = hook._get_config()
self.urn = urn
config: AssertionCircuitBreakerConfig = AssertionCircuitBreakerConfig(
datahub_host=host,
datahub_token=password,
timeout=timeout_sec,
verify_after_last_update=check_last_assertion_time,
time_delta=time_delta,
)
self.circuit_breaker = AssertionCircuitBreaker(config=config)
def poke(self, context: Any) -> bool:
if "datahub_silence_circuit_breakers" in context["dag_run"].conf:
self.log.info(
"Circuit breaker is silenced because datahub_silence_circuit_breakers config is set"
)
return True
self.log.info(f"Checking if dataset {self.urn} is ready to be consumed")
if isinstance(self.urn, str):
urns = [self.urn]
elif isinstance(self.urn, list):
urns = self.urn
else:
raise Exception(f"urn parameter has invalid type {type(self.urn)}")
for urn in urns:
self.log.info(f"Checking if dataset {self.urn} is ready to be consumed")
ret = self.circuit_breaker.is_circuit_breaker_active(urn=urn)
if ret:
self.log.info(f"Dataset {self.urn} is not in consumable state")
return False
return True
__all__ = ["DataHubAssertionSensor"]

View File

@ -1,97 +1,5 @@
import datetime
from typing import Any, List, Optional, Sequence, Union
from airflow.sensors.base import BaseSensorOperator
from datahub.api.circuit_breaker import (
OperationCircuitBreaker,
OperationCircuitBreakerConfig,
from datahub_airflow_plugin.operators.datahub_operation_operator import (
DataHubOperationCircuitBreakerOperator,
)
from datahub_provider.hooks.datahub import DatahubRestHook
class DataHubOperationCircuitBreakerOperator(BaseSensorOperator):
r"""
DataHub Operation Circuit Breaker Operator.
:param urn: The DataHub dataset unique identifier. (templated)
:param datahub_rest_conn_id: The REST datahub connection id to communicate with DataHub
which is set as Airflow connection.
:param partition: The partition to check the operation.
:param source_type: The partition to check the operation. :ref:`https://datahubproject.io/docs/graphql/enums#operationsourcetype`
"""
template_fields: Sequence[str] = (
"urn",
"partition",
"source_type",
"operation_type",
)
circuit_breaker: OperationCircuitBreaker
urn: Union[List[str], str]
partition: Optional[str]
source_type: Optional[str]
operation_type: Optional[str]
def __init__( # type: ignore[no-untyped-def]
self,
*,
urn: Union[List[str], str],
datahub_rest_conn_id: Optional[str] = None,
time_delta: Optional[datetime.timedelta] = datetime.timedelta(days=1),
partition: Optional[str] = None,
source_type: Optional[str] = None,
operation_type: Optional[str] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
hook: DatahubRestHook
if datahub_rest_conn_id is not None:
hook = DatahubRestHook(datahub_rest_conn_id=datahub_rest_conn_id)
else:
hook = DatahubRestHook()
host, password, timeout_sec = hook._get_config()
self.urn = urn
self.partition = partition
self.operation_type = operation_type
self.source_type = source_type
config: OperationCircuitBreakerConfig = OperationCircuitBreakerConfig(
datahub_host=host,
datahub_token=password,
timeout=timeout_sec,
time_delta=time_delta,
)
self.circuit_breaker = OperationCircuitBreaker(config=config)
def execute(self, context: Any) -> bool:
if "datahub_silence_circuit_breakers" in context["dag_run"].conf:
self.log.info(
"Circuit breaker is silenced because datahub_silence_circuit_breakers config is set"
)
return True
self.log.info(f"Checking if dataset {self.urn} is ready to be consumed")
if isinstance(self.urn, str):
urns = [self.urn]
elif isinstance(self.urn, list):
urns = self.urn
else:
raise Exception(f"urn parameter has invalid type {type(self.urn)}")
for urn in urns:
self.log.info(f"Checking if dataset {self.urn} is ready to be consumed")
ret = self.circuit_breaker.is_circuit_breaker_active(
urn=urn,
partition=self.partition,
operation_type=self.operation_type,
source_type=self.source_type,
)
if ret:
raise Exception(f"Dataset {self.urn} is not in consumable state")
return True
__all__ = ["DataHubOperationCircuitBreakerOperator"]

View File

@ -1,100 +1,5 @@
import datetime
from typing import Any, List, Optional, Sequence, Union
from airflow.sensors.base import BaseSensorOperator
from datahub.api.circuit_breaker import (
OperationCircuitBreaker,
OperationCircuitBreakerConfig,
from datahub_airflow_plugin.operators.datahub_operation_sensor import (
DataHubOperationCircuitBreakerSensor,
)
from datahub_provider.hooks.datahub import DatahubRestHook
class DataHubOperationCircuitBreakerSensor(BaseSensorOperator):
r"""
DataHub Operation Circuit Breaker Sensor.
:param urn: The DataHub dataset unique identifier. (templated)
:param datahub_rest_conn_id: The REST datahub connection id to communicate with DataHub
which is set as Airflow connection.
:param partition: The partition to check the operation.
:param source_type: The source type to filter on. If not set it will accept any source type.
See valid values at: https://datahubproject.io/docs/graphql/enums#operationsourcetype
:param operation_type: The operation type to filter on. If not set it will accept any source type.
See valid values at: https://datahubproject.io/docs/graphql/enums/#operationtype
"""
template_fields: Sequence[str] = (
"urn",
"partition",
"source_type",
"operation_type",
)
circuit_breaker: OperationCircuitBreaker
urn: Union[List[str], str]
partition: Optional[str]
source_type: Optional[str]
operation_type: Optional[str]
def __init__( # type: ignore[no-untyped-def]
self,
*,
urn: Union[List[str], str],
datahub_rest_conn_id: Optional[str] = None,
time_delta: Optional[datetime.timedelta] = datetime.timedelta(days=1),
partition: Optional[str] = None,
source_type: Optional[str] = None,
operation_type: Optional[str] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
hook: DatahubRestHook
if datahub_rest_conn_id is not None:
hook = DatahubRestHook(datahub_rest_conn_id=datahub_rest_conn_id)
else:
hook = DatahubRestHook()
host, password, timeout_sec = hook._get_config()
self.urn = urn
self.partition = partition
self.operation_type = operation_type
self.source_type = source_type
config: OperationCircuitBreakerConfig = OperationCircuitBreakerConfig(
datahub_host=host,
datahub_token=password,
timeout=timeout_sec,
time_delta=time_delta,
)
self.circuit_breaker = OperationCircuitBreaker(config=config)
def poke(self, context: Any) -> bool:
if "datahub_silence_circuit_breakers" in context["dag_run"].conf:
self.log.info(
"Circuit breaker is silenced because datahub_silence_circuit_breakers config is set"
)
return True
self.log.info(f"Checking if dataset {self.urn} is ready to be consumed")
if isinstance(self.urn, str):
urns = [self.urn]
elif isinstance(self.urn, list):
urns = self.urn
else:
raise Exception(f"urn parameter has invalid type {type(self.urn)}")
for urn in urns:
self.log.info(f"Checking if dataset {self.urn} is ready to be consumed")
ret = self.circuit_breaker.is_circuit_breaker_active(
urn=urn,
partition=self.partition,
operation_type=self.operation_type,
source_type=self.source_type,
)
if ret:
self.log.info(f"Dataset {self.urn} is not in consumable state")
return False
return True
__all__ = ["DataHubOperationCircuitBreakerSensor"]