mirror of
https://github.com/datahub-project/datahub.git
synced 2025-11-02 11:49:23 +00:00
feat(ingest): loosen sqlalchemy dep & support airflow 2.3+ (#6204)
Co-authored-by: Tamas Nemeth <treff7es@gmail.com>
This commit is contained in:
parent
6c42064332
commit
3e907ab0d1
9
.github/workflows/metadata-ingestion.yml
vendored
9
.github/workflows/metadata-ingestion.yml
vendored
@ -38,6 +38,11 @@ jobs:
|
||||
"testIntegrationBatch1",
|
||||
"testSlowIntegration",
|
||||
]
|
||||
include:
|
||||
- python-version: "3.7"
|
||||
extraPythonRequirement: "sqlalchemy==1.3.24 apache-airflow~=2.2.0"
|
||||
- python-version: "3.10"
|
||||
extraPythonRequirement: "sqlalchemy~=1.4.0 apache-airflow>=2.4.0"
|
||||
fail-fast: false
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
@ -50,8 +55,8 @@ jobs:
|
||||
hadoop-version: "3.2"
|
||||
- name: Install dependencies
|
||||
run: ./metadata-ingestion/scripts/install_deps.sh
|
||||
- name: Run metadata-ingestion tests
|
||||
run: ./gradlew :metadata-ingestion:build :metadata-ingestion:${{ matrix.command }}
|
||||
- name: Run metadata-ingestion tests (extras ${{ matrix.extraPythonRequirement }})
|
||||
run: ./gradlew -Pextra_pip_requirements='${{ matrix.extraPythonRequirement }}' :metadata-ingestion:${{ matrix.command }}
|
||||
- name: pip freeze show list installed
|
||||
if: always()
|
||||
run: source metadata-ingestion/venv/bin/activate && pip freeze
|
||||
|
||||
@ -7,8 +7,10 @@ ext {
|
||||
venv_name = 'venv'
|
||||
}
|
||||
|
||||
def pip_install_command = "${venv_name}/bin/pip install -e ../../metadata-ingestion"
|
||||
|
||||
task checkPythonVersion(type: Exec) {
|
||||
commandLine python_executable, '-c', 'import sys; assert sys.version_info >= (3, 6)'
|
||||
commandLine python_executable, '-c', 'import sys; assert sys.version_info >= (3, 7)'
|
||||
}
|
||||
|
||||
task environmentSetup(type: Exec, dependsOn: checkPythonVersion) {
|
||||
@ -20,7 +22,7 @@ task environmentSetup(type: Exec, dependsOn: checkPythonVersion) {
|
||||
task installPackage(type: Exec, dependsOn: environmentSetup) {
|
||||
inputs.file file('setup.py')
|
||||
outputs.dir("${venv_name}")
|
||||
commandLine "${venv_name}/bin/pip", 'install', '-e', '.'
|
||||
commandLine 'bash', '-x', '-c', "${pip_install_command} -e ."
|
||||
}
|
||||
|
||||
task install(dependsOn: [installPackage])
|
||||
@ -30,7 +32,7 @@ task installDev(type: Exec, dependsOn: [install]) {
|
||||
outputs.dir("${venv_name}")
|
||||
outputs.file("${venv_name}/.build_install_dev_sentinel")
|
||||
commandLine 'bash', '-x', '-c',
|
||||
"${venv_name}/bin/pip install -e .[dev] && touch ${venv_name}/.build_install_dev_sentinel"
|
||||
"${pip_install_command} -e .[dev] && touch ${venv_name}/.build_install_dev_sentinel"
|
||||
}
|
||||
|
||||
task lint(type: Exec, dependsOn: installDev) {
|
||||
@ -65,7 +67,7 @@ task installDevTest(type: Exec, dependsOn: [installDev]) {
|
||||
outputs.dir("${venv_name}")
|
||||
outputs.file("${venv_name}/.build_install_dev_test_sentinel")
|
||||
commandLine 'bash', '-x', '-c',
|
||||
"${venv_name}/bin/pip install -e .[dev,integration-tests] && touch ${venv_name}/.build_install_dev_test_sentinel"
|
||||
"${pip_install_command} -e .[dev,integration-tests] && touch ${venv_name}/.build_install_dev_test_sentinel"
|
||||
}
|
||||
|
||||
def testFile = hasProperty('testFile') ? testFile : 'unknown'
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import os
|
||||
import pathlib
|
||||
from typing import Dict, Set
|
||||
|
||||
import setuptools
|
||||
|
||||
|
||||
package_metadata: dict = {}
|
||||
with open("./src/datahub_airflow_plugin/__init__.py") as fp:
|
||||
exec(fp.read(), package_metadata)
|
||||
@ -23,9 +23,7 @@ base_requirements = {
|
||||
"typing-inspect",
|
||||
"pydantic>=1.5.1",
|
||||
"apache-airflow >= 2.0.2",
|
||||
"acryl-datahub[airflow] >= 0.8.36",
|
||||
# Pinned dependencies to make dependency resolution faster.
|
||||
"sqlalchemy==1.3.24",
|
||||
f"acryl-datahub[airflow] == {package_metadata['__version__']}",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -1,362 +1,4 @@
|
||||
import contextlib
|
||||
import traceback
|
||||
from typing import Any, Iterable
|
||||
|
||||
import attr
|
||||
from airflow.configuration import conf
|
||||
from airflow.lineage import PIPELINE_OUTLETS
|
||||
from airflow.models.baseoperator import BaseOperator
|
||||
from airflow.plugins_manager import AirflowPlugin
|
||||
from airflow.utils.module_loading import import_string
|
||||
from cattr import structure
|
||||
from datahub.api.entities.dataprocess.dataprocess_instance import InstanceRunResult
|
||||
from datahub_provider.client.airflow_generator import AirflowGenerator
|
||||
from datahub_provider.hooks.datahub import DatahubGenericHook
|
||||
from datahub_provider.lineage.datahub import DatahubLineageConfig
|
||||
|
||||
|
||||
def get_lineage_config() -> DatahubLineageConfig:
|
||||
"""Load the lineage config from airflow.cfg."""
|
||||
|
||||
enabled = conf.get("datahub", "enabled", fallback=True)
|
||||
datahub_conn_id = conf.get("datahub", "conn_id", fallback="datahub_rest_default")
|
||||
cluster = conf.get("datahub", "cluster", fallback="prod")
|
||||
graceful_exceptions = conf.get("datahub", "graceful_exceptions", fallback=True)
|
||||
capture_tags_info = conf.get("datahub", "capture_tags_info", fallback=True)
|
||||
capture_ownership_info = conf.get(
|
||||
"datahub", "capture_ownership_info", fallback=True
|
||||
)
|
||||
capture_executions = conf.get("datahub", "capture_executions", fallback=True)
|
||||
return DatahubLineageConfig(
|
||||
enabled=enabled,
|
||||
datahub_conn_id=datahub_conn_id,
|
||||
cluster=cluster,
|
||||
graceful_exceptions=graceful_exceptions,
|
||||
capture_ownership_info=capture_ownership_info,
|
||||
capture_tags_info=capture_tags_info,
|
||||
capture_executions=capture_executions,
|
||||
)
|
||||
|
||||
|
||||
def get_inlets_from_task(task: BaseOperator, context: Any) -> Iterable[Any]:
|
||||
inlets = []
|
||||
if isinstance(task._inlets, (str, BaseOperator)) or attr.has(task._inlets): # type: ignore
|
||||
inlets = [
|
||||
task._inlets,
|
||||
]
|
||||
|
||||
if task._inlets and isinstance(task._inlets, list):
|
||||
inlets = []
|
||||
task_ids = (
|
||||
{o for o in task._inlets if isinstance(o, str)}
|
||||
.union(op.task_id for op in task._inlets if isinstance(op, BaseOperator))
|
||||
.intersection(task.get_flat_relative_ids(upstream=True))
|
||||
)
|
||||
|
||||
from airflow.lineage import AUTO
|
||||
|
||||
# pick up unique direct upstream task_ids if AUTO is specified
|
||||
if AUTO.upper() in task._inlets or AUTO.lower() in task._inlets:
|
||||
print("Picking up unique direct upstream task_ids as AUTO is specified")
|
||||
task_ids = task_ids.union(
|
||||
task_ids.symmetric_difference(task.upstream_task_ids)
|
||||
)
|
||||
|
||||
inlets = task.xcom_pull(
|
||||
context, task_ids=list(task_ids), dag_id=task.dag_id, key=PIPELINE_OUTLETS
|
||||
)
|
||||
|
||||
# re-instantiate the obtained inlets
|
||||
inlets = [
|
||||
structure(item["data"], import_string(item["type_name"]))
|
||||
# _get_instance(structure(item, Metadata))
|
||||
for sublist in inlets
|
||||
if sublist
|
||||
for item in sublist
|
||||
]
|
||||
|
||||
for inlet in task._inlets:
|
||||
if type(inlet) != str:
|
||||
inlets.append(inlet)
|
||||
|
||||
return inlets
|
||||
|
||||
|
||||
def datahub_on_failure_callback(context, *args, **kwargs):
|
||||
ti = context["ti"]
|
||||
task: "BaseOperator" = ti.task
|
||||
dag = context["dag"]
|
||||
|
||||
# This code is from the original airflow lineage code ->
|
||||
# https://github.com/apache/airflow/blob/main/airflow/lineage/__init__.py
|
||||
inlets = get_inlets_from_task(task, context)
|
||||
|
||||
emitter = (
|
||||
DatahubGenericHook(context["_datahub_config"].datahub_conn_id)
|
||||
.get_underlying_hook()
|
||||
.make_emitter()
|
||||
)
|
||||
|
||||
dataflow = AirflowGenerator.generate_dataflow(
|
||||
cluster=context["_datahub_config"].cluster,
|
||||
dag=dag,
|
||||
capture_tags=context["_datahub_config"].capture_tags_info,
|
||||
capture_owner=context["_datahub_config"].capture_ownership_info,
|
||||
)
|
||||
dataflow.emit(emitter)
|
||||
|
||||
task.log.info(f"Emitted Datahub DataFlow: {dataflow}")
|
||||
|
||||
datajob = AirflowGenerator.generate_datajob(
|
||||
cluster=context["_datahub_config"].cluster,
|
||||
task=context["ti"].task,
|
||||
dag=dag,
|
||||
capture_tags=context["_datahub_config"].capture_tags_info,
|
||||
capture_owner=context["_datahub_config"].capture_ownership_info,
|
||||
)
|
||||
|
||||
for inlet in inlets:
|
||||
datajob.inlets.append(inlet.urn)
|
||||
|
||||
for outlet in task._outlets:
|
||||
datajob.outlets.append(outlet.urn)
|
||||
|
||||
task.log.info(f"Emitted Datahub DataJob: {datajob}")
|
||||
datajob.emit(emitter)
|
||||
|
||||
if context["_datahub_config"].capture_executions:
|
||||
dpi = AirflowGenerator.run_datajob(
|
||||
emitter=emitter,
|
||||
cluster=context["_datahub_config"].cluster,
|
||||
ti=context["ti"],
|
||||
dag=dag,
|
||||
dag_run=context["dag_run"],
|
||||
datajob=datajob,
|
||||
start_timestamp_millis=int(ti.start_date.timestamp() * 1000),
|
||||
)
|
||||
|
||||
task.log.info(f"Emitted Start Datahub Dataprocess Instance: {dpi}")
|
||||
|
||||
dpi = AirflowGenerator.complete_datajob(
|
||||
emitter=emitter,
|
||||
cluster=context["_datahub_config"].cluster,
|
||||
ti=context["ti"],
|
||||
dag_run=context["dag_run"],
|
||||
result=InstanceRunResult.FAILURE,
|
||||
dag=dag,
|
||||
datajob=datajob,
|
||||
end_timestamp_millis=int(ti.end_date.timestamp() * 1000),
|
||||
)
|
||||
task.log.info(f"Emitted Completed Datahub Dataprocess Instance: {dpi}")
|
||||
|
||||
|
||||
def datahub_on_success_callback(context, *args, **kwargs):
|
||||
ti = context["ti"]
|
||||
task: "BaseOperator" = ti.task
|
||||
dag = context["dag"]
|
||||
|
||||
# This code is from the original airflow lineage code ->
|
||||
# https://github.com/apache/airflow/blob/main/airflow/lineage/__init__.py
|
||||
inlets = get_inlets_from_task(task, context)
|
||||
|
||||
emitter = (
|
||||
DatahubGenericHook(context["_datahub_config"].datahub_conn_id)
|
||||
.get_underlying_hook()
|
||||
.make_emitter()
|
||||
)
|
||||
|
||||
dataflow = AirflowGenerator.generate_dataflow(
|
||||
cluster=context["_datahub_config"].cluster,
|
||||
dag=dag,
|
||||
capture_tags=context["_datahub_config"].capture_tags_info,
|
||||
capture_owner=context["_datahub_config"].capture_ownership_info,
|
||||
)
|
||||
dataflow.emit(emitter)
|
||||
|
||||
task.log.info(f"Emitted Datahub DataFlow: {dataflow}")
|
||||
|
||||
datajob = AirflowGenerator.generate_datajob(
|
||||
cluster=context["_datahub_config"].cluster,
|
||||
task=task,
|
||||
dag=dag,
|
||||
capture_tags=context["_datahub_config"].capture_tags_info,
|
||||
capture_owner=context["_datahub_config"].capture_ownership_info,
|
||||
)
|
||||
|
||||
for inlet in inlets:
|
||||
datajob.inlets.append(inlet.urn)
|
||||
|
||||
# We have to use _outlets because outlets is empty
|
||||
for outlet in task._outlets:
|
||||
datajob.outlets.append(outlet.urn)
|
||||
|
||||
task.log.info(f"Emitted Datahub dataJob: {datajob}")
|
||||
datajob.emit(emitter)
|
||||
|
||||
if context["_datahub_config"].capture_executions:
|
||||
dpi = AirflowGenerator.run_datajob(
|
||||
emitter=emitter,
|
||||
cluster=context["_datahub_config"].cluster,
|
||||
ti=context["ti"],
|
||||
dag=dag,
|
||||
dag_run=context["dag_run"],
|
||||
datajob=datajob,
|
||||
start_timestamp_millis=int(ti.start_date.timestamp() * 1000),
|
||||
)
|
||||
|
||||
task.log.info(f"Emitted Start Datahub Dataprocess Instance: {dpi}")
|
||||
|
||||
dpi = AirflowGenerator.complete_datajob(
|
||||
emitter=emitter,
|
||||
cluster=context["_datahub_config"].cluster,
|
||||
ti=context["ti"],
|
||||
dag_run=context["dag_run"],
|
||||
result=InstanceRunResult.SUCCESS,
|
||||
dag=dag,
|
||||
datajob=datajob,
|
||||
end_timestamp_millis=int(ti.end_date.timestamp() * 1000),
|
||||
)
|
||||
task.log.info(f"Emitted Completed Data Process Instance: {dpi}")
|
||||
|
||||
|
||||
def datahub_pre_execution(context):
|
||||
ti = context["ti"]
|
||||
task: "BaseOperator" = ti.task
|
||||
dag = context["dag"]
|
||||
|
||||
task.log.info("Running Datahub pre_execute method")
|
||||
|
||||
emitter = (
|
||||
DatahubGenericHook(context["_datahub_config"].datahub_conn_id)
|
||||
.get_underlying_hook()
|
||||
.make_emitter()
|
||||
)
|
||||
|
||||
# This code is from the original airflow lineage code ->
|
||||
# https://github.com/apache/airflow/blob/main/airflow/lineage/__init__.py
|
||||
inlets = get_inlets_from_task(task, context)
|
||||
|
||||
datajob = AirflowGenerator.generate_datajob(
|
||||
cluster=context["_datahub_config"].cluster,
|
||||
task=context["ti"].task,
|
||||
dag=dag,
|
||||
capture_tags=context["_datahub_config"].capture_tags_info,
|
||||
capture_owner=context["_datahub_config"].capture_ownership_info,
|
||||
)
|
||||
|
||||
for inlet in inlets:
|
||||
datajob.inlets.append(inlet.urn)
|
||||
|
||||
for outlet in task._outlets:
|
||||
datajob.outlets.append(outlet.urn)
|
||||
|
||||
datajob.emit(emitter)
|
||||
task.log.info(f"Emitting Datahub DataJob: {datajob}")
|
||||
|
||||
if context["_datahub_config"].capture_executions:
|
||||
dpi = AirflowGenerator.run_datajob(
|
||||
emitter=emitter,
|
||||
cluster=context["_datahub_config"].cluster,
|
||||
ti=context["ti"],
|
||||
dag=dag,
|
||||
dag_run=context["dag_run"],
|
||||
datajob=datajob,
|
||||
start_timestamp_millis=int(ti.start_date.timestamp() * 1000),
|
||||
)
|
||||
|
||||
task.log.info(f"Emitting Datahub Dataprocess Instance: {dpi}")
|
||||
|
||||
|
||||
def _wrap_pre_execution(pre_execution):
|
||||
def custom_pre_execution(context):
|
||||
config = get_lineage_config()
|
||||
context["_datahub_config"] = config
|
||||
datahub_pre_execution(context)
|
||||
|
||||
# Call original policy
|
||||
if pre_execution:
|
||||
pre_execution(context)
|
||||
|
||||
return custom_pre_execution
|
||||
|
||||
|
||||
def _wrap_on_failure_callback(on_failure_callback):
|
||||
def custom_on_failure_callback(context):
|
||||
config = get_lineage_config()
|
||||
context["_datahub_config"] = config
|
||||
try:
|
||||
datahub_on_failure_callback(context)
|
||||
except Exception as e:
|
||||
if not config.graceful_exceptions:
|
||||
raise e
|
||||
else:
|
||||
print(f"Exception: {traceback.format_exc()}")
|
||||
|
||||
# Call original policy
|
||||
if on_failure_callback:
|
||||
on_failure_callback(context)
|
||||
|
||||
return custom_on_failure_callback
|
||||
|
||||
|
||||
def _wrap_on_success_callback(on_success_callback):
|
||||
def custom_on_success_callback(context):
|
||||
config = get_lineage_config()
|
||||
context["_datahub_config"] = config
|
||||
try:
|
||||
datahub_on_success_callback(context)
|
||||
except Exception as e:
|
||||
if not config.graceful_exceptions:
|
||||
raise e
|
||||
else:
|
||||
print(f"Exception: {traceback.format_exc()}")
|
||||
|
||||
if on_success_callback:
|
||||
on_success_callback(context)
|
||||
|
||||
return custom_on_success_callback
|
||||
|
||||
|
||||
def task_policy(task: BaseOperator) -> None:
|
||||
print(f"Setting task policy for Dag: {task.dag_id} Task: {task.task_id}")
|
||||
# task.add_inlets(["auto"])
|
||||
# task.pre_execute = _wrap_pre_execution(task.pre_execute)
|
||||
task.on_failure_callback = _wrap_on_failure_callback(task.on_failure_callback)
|
||||
task.on_success_callback = _wrap_on_success_callback(task.on_success_callback)
|
||||
# task.pre_execute = _wrap_pre_execution(task.pre_execute)
|
||||
|
||||
|
||||
def _wrap_task_policy(policy):
|
||||
if policy and hasattr(policy, "_task_policy_patched_by"):
|
||||
return policy
|
||||
|
||||
def custom_task_policy(task):
|
||||
policy(task)
|
||||
task_policy(task)
|
||||
|
||||
setattr(custom_task_policy, "_task_policy_patched_by", "datahub_plugin")
|
||||
return custom_task_policy
|
||||
|
||||
|
||||
def _patch_policy(settings):
|
||||
print("Patching datahub policy")
|
||||
if hasattr(settings, "task_policy"):
|
||||
datahub_task_policy = _wrap_task_policy(settings.task_policy)
|
||||
settings.task_policy = datahub_task_policy
|
||||
|
||||
|
||||
def _patch_datahub_policy():
|
||||
with contextlib.suppress(ImportError):
|
||||
import airflow_local_settings
|
||||
|
||||
_patch_policy(airflow_local_settings)
|
||||
from airflow.models.dagbag import settings
|
||||
|
||||
_patch_policy(settings)
|
||||
|
||||
|
||||
_patch_datahub_policy()
|
||||
|
||||
|
||||
class DatahubPlugin(AirflowPlugin):
|
||||
name = "datahub_plugin"
|
||||
# This package serves as a shim, but the actual implementation lives in datahub_provider
|
||||
# from the acryl-datahub package. We leave this shim here to avoid breaking existing
|
||||
# Airflow installs.
|
||||
from datahub_provider._plugin import DatahubPlugin # noqa: F401
|
||||
|
||||
@ -7,8 +7,12 @@ ext {
|
||||
venv_name = 'venv'
|
||||
}
|
||||
|
||||
if (!project.hasProperty("extra_pip_requirements")) {
|
||||
ext.extra_pip_requirements = ""
|
||||
}
|
||||
|
||||
task checkPythonVersion(type: Exec) {
|
||||
commandLine python_executable, '-c', 'import sys; assert sys.version_info >= (3, 6)'
|
||||
commandLine python_executable, '-c', 'import sys; assert sys.version_info >= (3, 7)'
|
||||
}
|
||||
|
||||
task environmentSetup(type: Exec, dependsOn: checkPythonVersion) {
|
||||
@ -24,7 +28,7 @@ task runPreFlightScript(type: Exec, dependsOn: environmentSetup) {
|
||||
task installPackage(type: Exec, dependsOn: runPreFlightScript) {
|
||||
inputs.file file('setup.py')
|
||||
outputs.dir("${venv_name}")
|
||||
commandLine "${venv_name}/bin/pip", 'install', '-e', '.'
|
||||
commandLine 'bash', '-x', '-c', "${venv_name}/bin/pip install -e . ${extra_pip_requirements}"
|
||||
}
|
||||
|
||||
task codegen(type: Exec, dependsOn: [environmentSetup, installPackage, ':metadata-events:mxe-schemas:build']) {
|
||||
@ -40,7 +44,7 @@ task installDev(type: Exec, dependsOn: [install]) {
|
||||
outputs.dir("${venv_name}")
|
||||
outputs.file("${venv_name}/.build_install_dev_sentinel")
|
||||
commandLine 'bash', '-x', '-c',
|
||||
"${venv_name}/bin/pip install -e .[dev] && touch ${venv_name}/.build_install_dev_sentinel"
|
||||
"${venv_name}/bin/pip install -e .[dev] ${extra_pip_requirements} && touch ${venv_name}/.build_install_dev_sentinel"
|
||||
}
|
||||
|
||||
|
||||
@ -67,15 +71,21 @@ task lint(type: Exec, dependsOn: installDev) {
|
||||
*/
|
||||
commandLine 'bash', '-c',
|
||||
"find ${venv_name}/lib -path *airflow/_vendor/connexion/spec.py -exec sed -i.bak -e '169,169s/ # type: List\\[str\\]//g' {} \\; && " +
|
||||
"source ${venv_name}/bin/activate && set -x && black --check --diff src/ tests/ examples/ && isort --check --diff src/ tests/ examples/ && flake8 --count --statistics src/ tests/ examples/ && mypy src/ tests/ examples/"
|
||||
"source ${venv_name}/bin/activate && set -x && " +
|
||||
"./scripts/install-sqlalchemy-stubs.sh && " +
|
||||
"black --check --diff src/ tests/ examples/ && " +
|
||||
"isort --check --diff src/ tests/ examples/ && " +
|
||||
"flake8 --count --statistics src/ tests/ examples/ && " +
|
||||
"mypy --show-traceback --show-error-codes src/ tests/ examples/"
|
||||
}
|
||||
task lintFix(type: Exec, dependsOn: installDev) {
|
||||
commandLine 'bash', '-c',
|
||||
"source ${venv_name}/bin/activate && set -x && " +
|
||||
"./scripts/install-sqlalchemy-stubs.sh && " +
|
||||
"black src/ tests/ examples/ && " +
|
||||
"isort src/ tests/ examples/ && " +
|
||||
"flake8 src/ tests/ examples/ && " +
|
||||
"mypy src/ tests/ examples/"
|
||||
"mypy --show-traceback --show-error-codes src/ tests/ examples/"
|
||||
}
|
||||
|
||||
task testQuick(type: Exec, dependsOn: installDev) {
|
||||
@ -92,7 +102,7 @@ task installDevTest(type: Exec, dependsOn: [install]) {
|
||||
outputs.dir("${venv_name}")
|
||||
outputs.file("${venv_name}/.build_install_dev_test_sentinel")
|
||||
commandLine 'bash', '-c',
|
||||
"${venv_name}/bin/pip install -e .[dev,integration-tests] && touch ${venv_name}/.build_install_dev_test_sentinel"
|
||||
"${venv_name}/bin/pip install -e .[dev,integration-tests] ${extra_pip_requirements} && touch ${venv_name}/.build_install_dev_test_sentinel"
|
||||
}
|
||||
|
||||
def testFile = hasProperty('testFile') ? testFile : 'unknown'
|
||||
|
||||
@ -1 +1 @@
|
||||
As a SQL-based service, the Athena integration is also supported by our SQL profiler. See here for more details on configuration.
|
||||
As a SQL-based service, the Oracle integration is also supported by our SQL profiler. See here for more details on configuration.
|
||||
|
||||
@ -14,6 +14,7 @@ target-version = ['py36', 'py37', 'py38']
|
||||
[tool.isort]
|
||||
combine_as_imports = true
|
||||
indent = ' '
|
||||
known_future_library = ['__future__', 'datahub.utilities._markupsafe_compat', 'datahub_provider._airflow_compat']
|
||||
profile = 'black'
|
||||
sections = 'FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER'
|
||||
skip_glob = 'src/datahub/metadata'
|
||||
|
||||
28
metadata-ingestion/scripts/install-sqlalchemy-stubs.sh
Executable file
28
metadata-ingestion/scripts/install-sqlalchemy-stubs.sh
Executable file
@ -0,0 +1,28 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# ASSUMPTION: This assumes that we're running from inside the venv.
|
||||
|
||||
SQLALCHEMY_VERSION=$(python -c 'import sqlalchemy; print(sqlalchemy.__version__)')
|
||||
|
||||
if [[ $SQLALCHEMY_VERSION == 1.3.* ]]; then
|
||||
ENSURE_NOT_INSTALLED=sqlalchemy2-stubs
|
||||
ENSURE_INSTALLED=sqlalchemy-stubs
|
||||
elif [[ $SQLALCHEMY_VERSION == 1.4.* ]]; then
|
||||
ENSURE_NOT_INSTALLED=sqlalchemy-stubs
|
||||
ENSURE_INSTALLED=sqlalchemy2-stubs
|
||||
else
|
||||
echo "Unsupported SQLAlchemy version: $SQLALCHEMY_VERSION"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
FORCE_REINSTALL=""
|
||||
if pip show $ENSURE_NOT_INSTALLED >/dev/null 2>&1 ; then
|
||||
pip uninstall --yes $ENSURE_NOT_INSTALLED
|
||||
FORCE_REINSTALL="--force-reinstall"
|
||||
fi
|
||||
|
||||
if [ -n "$FORCE_REINSTALL" ] || ! pip show $ENSURE_INSTALLED >/dev/null 2>&1 ; then
|
||||
pip install $FORCE_REINSTALL $ENSURE_INSTALLED
|
||||
fi
|
||||
@ -22,7 +22,7 @@ ban-relative-imports = true
|
||||
|
||||
[mypy]
|
||||
plugins =
|
||||
sqlmypy,
|
||||
./tests/test_helpers/sqlalchemy_mypy_plugin.py,
|
||||
pydantic.mypy
|
||||
exclude = ^(venv|build|dist)/
|
||||
ignore_missing_imports = yes
|
||||
@ -55,6 +55,7 @@ disallow_untyped_defs = yes
|
||||
asyncio_mode = auto
|
||||
addopts = --cov=src --cov-report term-missing --cov-config setup.cfg --strict-markers
|
||||
markers =
|
||||
airflow: marks tests related to airflow (deselect with '-m not airflow')
|
||||
slow_unit: marks tests to only run slow unit tests (deselect with '-m not slow_unit')
|
||||
integration: marks tests to only run in integration (deselect with '-m "not integration"')
|
||||
integration_batch_1: mark tests to only run in batch 1 of integration tests. This is done mainly for parallelisation (deselect with '-m not integration_batch_1')
|
||||
|
||||
@ -55,6 +55,10 @@ framework_common = {
|
||||
"click-spinner",
|
||||
}
|
||||
|
||||
rest_common = {
|
||||
"requests",
|
||||
}
|
||||
|
||||
kafka_common = {
|
||||
# The confluent_kafka package provides a number of pre-built wheels for
|
||||
# various platforms and architectures. However, it does not provide wheels
|
||||
@ -103,7 +107,7 @@ kafka_protobuf = {
|
||||
|
||||
sql_common = {
|
||||
# Required for all SQL sources.
|
||||
"sqlalchemy==1.3.24",
|
||||
"sqlalchemy>=1.3.24, <2",
|
||||
# Required for SQL profiling.
|
||||
"great-expectations>=0.15.12",
|
||||
# GE added handling for higher version of jinja2
|
||||
@ -147,6 +151,12 @@ bigquery_common = {
|
||||
"more-itertools>=8.12.0",
|
||||
}
|
||||
|
||||
clickhouse_common = {
|
||||
# Clickhouse 0.1.8 requires SQLAlchemy 1.3.x, while the newer versions
|
||||
# allow SQLAlchemy 1.4.x.
|
||||
"clickhouse-sqlalchemy>=0.1.8",
|
||||
}
|
||||
|
||||
redshift_common = {
|
||||
"sqlalchemy-redshift",
|
||||
"psycopg2-binary",
|
||||
@ -214,10 +224,12 @@ databricks_cli = {
|
||||
plugins: Dict[str, Set[str]] = {
|
||||
# Sink plugins.
|
||||
"datahub-kafka": kafka_common,
|
||||
"datahub-rest": {"requests"},
|
||||
"datahub-rest": rest_common,
|
||||
# Integrations.
|
||||
"airflow": {
|
||||
"apache-airflow >= 2.0.2",
|
||||
*rest_common,
|
||||
*kafka_common,
|
||||
},
|
||||
"circuit-breaker": {
|
||||
"gql>=3.3.0",
|
||||
@ -239,12 +251,8 @@ plugins: Dict[str, Set[str]] = {
|
||||
"sqllineage==1.3.6",
|
||||
"sql_metadata",
|
||||
}, # deprecated, but keeping the extra for backwards compatibility
|
||||
"clickhouse": sql_common | {"clickhouse-sqlalchemy==0.1.8"},
|
||||
"clickhouse-usage": sql_common
|
||||
| usage_common
|
||||
| {
|
||||
"clickhouse-sqlalchemy==0.1.8",
|
||||
},
|
||||
"clickhouse": sql_common | clickhouse_common,
|
||||
"clickhouse-usage": sql_common | usage_common | clickhouse_common,
|
||||
"datahub-lineage-file": set(),
|
||||
"datahub-business-glossary": set(),
|
||||
"delta-lake": {*data_lake_profiling, *delta_lake},
|
||||
@ -339,7 +347,6 @@ all_exclude_plugins: Set[str] = {
|
||||
|
||||
mypy_stubs = {
|
||||
"types-dataclasses",
|
||||
"sqlalchemy-stubs",
|
||||
"types-pkg_resources",
|
||||
"types-six",
|
||||
"types-python-dateutil",
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from datahub.utilities._markupsafe_compat import MARKUPSAFE_PATCHED
|
||||
|
||||
import collections
|
||||
import concurrent.futures
|
||||
import contextlib
|
||||
@ -51,6 +53,7 @@ from datahub.utilities.sqlalchemy_query_combiner import (
|
||||
get_query_columns,
|
||||
)
|
||||
|
||||
assert MARKUPSAFE_PATCHED
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
@ -276,7 +276,8 @@ class ConfluentJDBCSourceConnector:
|
||||
url_instance = make_url(url)
|
||||
source_platform = get_platform_from_sqlalchemy_uri(str(url_instance))
|
||||
database_name = url_instance.database
|
||||
db_connection_url = f"{url_instance.drivername}://{url_instance.host}:{url_instance.port}/{url_instance.database}"
|
||||
assert database_name
|
||||
db_connection_url = f"{url_instance.drivername}://{url_instance.host}:{url_instance.port}/{database_name}"
|
||||
|
||||
topic_prefix = self.connector_manifest.config.get("topic.prefix", None)
|
||||
|
||||
|
||||
@ -109,7 +109,7 @@ class AthenaSource(SQLAlchemySource):
|
||||
self, inspector: Inspector, schema: str, table: str
|
||||
) -> Tuple[Optional[str], Dict[str, str], Optional[str]]:
|
||||
if not self.cursor:
|
||||
self.cursor = inspector.dialect._raw_connection(inspector.engine).cursor()
|
||||
self.cursor = inspector.engine.raw_connection().cursor()
|
||||
|
||||
assert self.cursor
|
||||
# Unfortunately properties can be only get through private methods as those are not exposed
|
||||
|
||||
@ -792,8 +792,11 @@ class BigQuerySource(SQLAlchemySource):
|
||||
# Bigquery only supports one partition column
|
||||
# https://stackoverflow.com/questions/62886213/adding-multiple-partitioned-columns-to-bigquery-table-from-sql-query
|
||||
row = result.fetchone()
|
||||
if row and hasattr(row, "_asdict"):
|
||||
# Compat with sqlalchemy 1.4 Row type.
|
||||
row = row._asdict()
|
||||
if row:
|
||||
return BigQueryPartitionColumn(**row)
|
||||
return BigQueryPartitionColumn(**row.items())
|
||||
return None
|
||||
|
||||
def get_shard_from_table(self, table: str) -> Tuple[str, Optional[str]]:
|
||||
|
||||
@ -11,7 +11,6 @@ from pydantic.fields import Field
|
||||
from sqlalchemy import create_engine, inspect
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.result import ResultProxy, RowProxy
|
||||
|
||||
from datahub.configuration.common import AllowDenyPattern
|
||||
from datahub.ingestion.api.common import PipelineContext
|
||||
@ -135,7 +134,7 @@ class SQLServerSource(SQLAlchemySource):
|
||||
def _populate_table_descriptions(self, conn: Connection, db_name: str) -> None:
|
||||
# see https://stackoverflow.com/questions/5953330/how-do-i-map-the-id-in-sys-extended-properties-to-an-object-name
|
||||
# also see https://www.mssqltips.com/sqlservertip/5384/working-with-sql-server-extended-properties/
|
||||
table_metadata: ResultProxy = conn.execute(
|
||||
table_metadata = conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
SCHEMA_NAME(T.SCHEMA_ID) AS schema_name,
|
||||
@ -149,13 +148,13 @@ class SQLServerSource(SQLAlchemySource):
|
||||
AND EP.CLASS = 1
|
||||
"""
|
||||
)
|
||||
for row in table_metadata: # type: RowProxy
|
||||
for row in table_metadata:
|
||||
self.table_descriptions[
|
||||
f"{db_name}.{row['schema_name']}.{row['table_name']}"
|
||||
] = row["table_description"]
|
||||
|
||||
def _populate_column_descriptions(self, conn: Connection, db_name: str) -> None:
|
||||
column_metadata: RowProxy = conn.execute(
|
||||
column_metadata = conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
SCHEMA_NAME(T.SCHEMA_ID) AS schema_name,
|
||||
@ -172,7 +171,7 @@ class SQLServerSource(SQLAlchemySource):
|
||||
AND EP.CLASS = 1
|
||||
"""
|
||||
)
|
||||
for row in column_metadata: # type: RowProxy
|
||||
for row in column_metadata:
|
||||
self.column_descriptions[
|
||||
f"{db_name}.{row['schema_name']}.{row['table_name']}.{row['column_name']}"
|
||||
] = row["column_description"]
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Any, Iterable, List, Optional, Tuple, cast
|
||||
from typing import Any, Iterable, List, NoReturn, Optional, Tuple, cast
|
||||
from unittest.mock import patch
|
||||
|
||||
# This import verifies that the dependencies are available.
|
||||
@ -35,6 +35,10 @@ extra_oracle_types = {
|
||||
assert OracleDialect.ischema_names
|
||||
|
||||
|
||||
def _raise_err(exc: Exception) -> NoReturn:
|
||||
raise exc
|
||||
|
||||
|
||||
def output_type_handler(cursor, name, defaultType, size, precision, scale):
|
||||
"""Add CLOB and BLOB support to Oracle connection."""
|
||||
|
||||
@ -94,7 +98,9 @@ class OracleInspectorObjectWrapper:
|
||||
s = "SELECT username FROM dba_users ORDER BY username"
|
||||
cursor = self._inspector_instance.bind.execute(s)
|
||||
return [
|
||||
self._inspector_instance.dialect.normalize_name(row[0]) for row in cursor
|
||||
self._inspector_instance.dialect.normalize_name(row[0])
|
||||
or _raise_err(ValueError(f"Invalid schema name: {row[0]}"))
|
||||
for row in cursor
|
||||
]
|
||||
|
||||
def get_table_names(self, schema: str = None, order_by: str = None) -> List[str]:
|
||||
@ -121,7 +127,9 @@ class OracleInspectorObjectWrapper:
|
||||
cursor = self._inspector_instance.bind.execute(sql.text(sql_str), owner=schema)
|
||||
|
||||
return [
|
||||
self._inspector_instance.dialect.normalize_name(row[0]) for row in cursor
|
||||
self._inspector_instance.dialect.normalize_name(row[0])
|
||||
or _raise_err(ValueError(f"Invalid table name: {row[0]}"))
|
||||
for row in cursor
|
||||
]
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
|
||||
@ -27,6 +27,7 @@ from sqlalchemy import create_engine, inspect
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.exc import ProgrammingError
|
||||
from sqlalchemy.sql import sqltypes as types
|
||||
from sqlalchemy.types import TypeDecorator, TypeEngine
|
||||
|
||||
from datahub.configuration.common import AllowDenyPattern
|
||||
from datahub.emitter.mce_builder import (
|
||||
@ -329,7 +330,7 @@ class SqlWorkUnit(MetadataWorkUnit):
|
||||
pass
|
||||
|
||||
|
||||
_field_type_mapping: Dict[Type[types.TypeEngine], Type] = {
|
||||
_field_type_mapping: Dict[Type[TypeEngine], Type] = {
|
||||
types.Integer: NumberTypeClass,
|
||||
types.Numeric: NumberTypeClass,
|
||||
types.Boolean: BooleanTypeClass,
|
||||
@ -367,30 +368,28 @@ _field_type_mapping: Dict[Type[types.TypeEngine], Type] = {
|
||||
# assigns the NullType by default. We want to carry this warning through.
|
||||
types.NullType: NullTypeClass,
|
||||
}
|
||||
_known_unknown_field_types: Set[Type[types.TypeEngine]] = {
|
||||
_known_unknown_field_types: Set[Type[TypeEngine]] = {
|
||||
types.Interval,
|
||||
types.CLOB,
|
||||
}
|
||||
|
||||
|
||||
def register_custom_type(
|
||||
tp: Type[types.TypeEngine], output: Optional[Type] = None
|
||||
) -> None:
|
||||
def register_custom_type(tp: Type[TypeEngine], output: Optional[Type] = None) -> None:
|
||||
if output:
|
||||
_field_type_mapping[tp] = output
|
||||
else:
|
||||
_known_unknown_field_types.add(tp)
|
||||
|
||||
|
||||
class _CustomSQLAlchemyDummyType(types.TypeDecorator):
|
||||
class _CustomSQLAlchemyDummyType(TypeDecorator):
|
||||
impl = types.LargeBinary
|
||||
|
||||
|
||||
def make_sqlalchemy_type(name: str) -> Type[types.TypeEngine]:
|
||||
def make_sqlalchemy_type(name: str) -> Type[TypeEngine]:
|
||||
# This usage of type() dynamically constructs a class.
|
||||
# See https://stackoverflow.com/a/15247202/5004662 and
|
||||
# https://docs.python.org/3/library/functions.html#type.
|
||||
sqlalchemy_type: Type[types.TypeEngine] = type(
|
||||
sqlalchemy_type: Type[TypeEngine] = type(
|
||||
name,
|
||||
(_CustomSQLAlchemyDummyType,),
|
||||
{
|
||||
|
||||
@ -4,15 +4,12 @@ from textwrap import dedent
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
# This import verifies that the dependencies are available.
|
||||
import trino.sqlalchemy # noqa: F401
|
||||
from pydantic.fields import Field
|
||||
from sqlalchemy import exc, sql
|
||||
from sqlalchemy.engine import reflection
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.sql import sqltypes
|
||||
from sqlalchemy.sql.type_api import TypeEngine
|
||||
from sqlalchemy.types import TypeEngine
|
||||
from trino.exceptions import TrinoQueryError
|
||||
from trino.sqlalchemy import datatype, error
|
||||
from trino.sqlalchemy.dialect import TrinoDialect
|
||||
|
||||
@ -3,13 +3,12 @@ import dataclasses
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, Iterable, List, Optional, Set
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set
|
||||
|
||||
from pydantic.fields import Field
|
||||
from pydantic.main import BaseModel
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.engine.result import ResultProxy, RowProxy
|
||||
|
||||
import datahub.emitter.mce_builder as builder
|
||||
from datahub.configuration.source_common import EnvBasedSourceConfigBase
|
||||
@ -39,6 +38,13 @@ from datahub.metadata.schema_classes import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
from sqlalchemy.engine import Row # type: ignore
|
||||
except ImportError:
|
||||
# See https://github.com/python/mypy/issues/1153.
|
||||
from sqlalchemy.engine.result import RowProxy as Row # type: ignore
|
||||
|
||||
REDSHIFT_DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
|
||||
@ -266,7 +272,7 @@ class RedshiftUsageSource(Source):
|
||||
logger.debug(f"sql_alchemy_url = {url}")
|
||||
return create_engine(url, **self.config.options)
|
||||
|
||||
def _should_process_row(self, row: RowProxy) -> bool:
|
||||
def _should_process_row(self, row: "Row") -> bool:
|
||||
# Check for mandatory proerties being present first.
|
||||
missing_props: List[str] = [
|
||||
prop
|
||||
@ -294,10 +300,13 @@ class RedshiftUsageSource(Source):
|
||||
def _gen_access_events_from_history_query(
|
||||
self, query: str, engine: Engine
|
||||
) -> Iterable[RedshiftAccessEvent]:
|
||||
results: ResultProxy = engine.execute(query)
|
||||
for row in results: # type: RowProxy
|
||||
results = engine.execute(query)
|
||||
for row in results:
|
||||
if not self._should_process_row(row):
|
||||
continue
|
||||
if hasattr(row, "_asdict"):
|
||||
# Compatibility with sqlalchemy 1.4.x.
|
||||
row = row._asdict()
|
||||
access_event = RedshiftAccessEvent(**dict(row.items()))
|
||||
# Replace database name with the alias name if one is provided in the config.
|
||||
if self.config.database_alias:
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from datahub.utilities._markupsafe_compat import MARKUPSAFE_PATCHED
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@ -56,6 +58,7 @@ from datahub.metadata.com.linkedin.pegasus2avro.events.metadata import ChangeTyp
|
||||
from datahub.metadata.schema_classes import PartitionSpecClass, PartitionTypeClass
|
||||
from datahub.utilities.sql_parser import DefaultSQLParser
|
||||
|
||||
assert MARKUPSAFE_PATCHED
|
||||
logger = logging.getLogger(__name__)
|
||||
if os.getenv("DATAHUB_DEBUG", False):
|
||||
handler = logging.StreamHandler(stream=sys.stdout)
|
||||
|
||||
@ -0,0 +1,11 @@
|
||||
try:
|
||||
import markupsafe
|
||||
|
||||
# This monkeypatch hack is required for markupsafe>=2.1.0 and older versions of Jinja2.
|
||||
# Changelog: https://markupsafe.palletsprojects.com/en/2.1.x/changes/#version-2-1-0
|
||||
# Example discussion: https://github.com/aws/aws-sam-cli/issues/3661.
|
||||
markupsafe.soft_unicode = markupsafe.soft_str # type: ignore[attr-defined]
|
||||
|
||||
MARKUPSAFE_PATCHED = True
|
||||
except ImportError:
|
||||
MARKUPSAFE_PATCHED = False
|
||||
@ -7,7 +7,7 @@ import random
|
||||
import string
|
||||
import threading
|
||||
import unittest.mock
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, cast
|
||||
|
||||
import greenlet
|
||||
import sqlalchemy
|
||||
@ -39,7 +39,8 @@ class _RowProxyFake(collections.OrderedDict):
|
||||
|
||||
|
||||
class _ResultProxyFake:
|
||||
# This imitates the interface provided by sqlalchemy.engine.result.ResultProxy.
|
||||
# This imitates the interface provided by sqlalchemy.engine.result.ResultProxy (sqlalchemy 1.3.x)
|
||||
# or sqlalchemy.engine.Result (1.4.x).
|
||||
# Adapted from https://github.com/rajivsarvepalli/mock-alchemy/blob/2eba95588e7693aab973a6d60441d2bc3c4ea35d/src/mock_alchemy/mocking.py#L213
|
||||
|
||||
def __init__(self, result: List[_RowProxyFake]) -> None:
|
||||
@ -363,7 +364,11 @@ class SQLAlchemyQueryCombiner:
|
||||
*query_future.multiparams,
|
||||
**query_future.params,
|
||||
)
|
||||
query_future.res = res
|
||||
|
||||
# The actual execute method returns a CursorResult on SQLAlchemy 1.4.x
|
||||
# and a ResultProxy on SQLAlchemy 1.3.x. Both interfaces are shimmed
|
||||
# by _ResultProxyFake.
|
||||
query_future.res = cast(_ResultProxyFake, res)
|
||||
except Exception as e:
|
||||
query_future.exc = e
|
||||
finally:
|
||||
|
||||
25
metadata-ingestion/src/datahub_provider/_airflow_compat.py
Normal file
25
metadata-ingestion/src/datahub_provider/_airflow_compat.py
Normal file
@ -0,0 +1,25 @@
|
||||
from datahub.utilities._markupsafe_compat import MARKUPSAFE_PATCHED
|
||||
|
||||
from airflow.models.baseoperator import BaseOperator
|
||||
|
||||
try:
|
||||
from airflow.models.mappedoperator import MappedOperator
|
||||
from airflow.models.operator import Operator
|
||||
except ModuleNotFoundError:
|
||||
Operator = BaseOperator # type: ignore
|
||||
MappedOperator = None # type: ignore
|
||||
|
||||
try:
|
||||
from airflow.sensors.external_task import ExternalTaskSensor
|
||||
except ImportError:
|
||||
from airflow.sensors.external_task_sensor import ExternalTaskSensor # type: ignore
|
||||
|
||||
assert MARKUPSAFE_PATCHED
|
||||
|
||||
__all__ = [
|
||||
"MARKUPSAFE_PATCHED",
|
||||
"Operator",
|
||||
"BaseOperator",
|
||||
"MappedOperator",
|
||||
"ExternalTaskSensor",
|
||||
]
|
||||
@ -1,3 +1,5 @@
|
||||
from datahub_provider._airflow_compat import Operator
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Dict, List
|
||||
|
||||
@ -10,7 +12,6 @@ from datahub_provider.entities import _Entity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from airflow import DAG
|
||||
from airflow.models.baseoperator import BaseOperator
|
||||
from airflow.models.dagrun import DagRun
|
||||
from airflow.models.taskinstance import TaskInstance
|
||||
|
||||
@ -47,7 +48,7 @@ class DatahubBasicLineageConfig(ConfigModel):
|
||||
|
||||
def send_lineage_to_datahub(
|
||||
config: DatahubBasicLineageConfig,
|
||||
operator: "BaseOperator",
|
||||
operator: "Operator",
|
||||
inlets: List[_Entity],
|
||||
outlets: List[_Entity],
|
||||
context: Dict,
|
||||
@ -56,7 +57,7 @@ def send_lineage_to_datahub(
|
||||
return
|
||||
|
||||
dag: "DAG" = context["dag"]
|
||||
task: "BaseOperator" = context["task"]
|
||||
task: "Operator" = context["task"]
|
||||
ti: "TaskInstance" = context["task_instance"]
|
||||
|
||||
hook = config.make_emitter_hook()
|
||||
|
||||
321
metadata-ingestion/src/datahub_provider/_plugin.py
Normal file
321
metadata-ingestion/src/datahub_provider/_plugin.py
Normal file
@ -0,0 +1,321 @@
|
||||
from datahub_provider._airflow_compat import Operator
|
||||
|
||||
import contextlib
|
||||
import traceback
|
||||
from typing import Any, Iterable, List
|
||||
|
||||
from airflow.configuration import conf
|
||||
from airflow.lineage import PIPELINE_OUTLETS
|
||||
from airflow.models.baseoperator import BaseOperator
|
||||
from airflow.plugins_manager import AirflowPlugin
|
||||
from airflow.utils.module_loading import import_string
|
||||
from cattr import structure
|
||||
|
||||
from datahub.api.entities.dataprocess.dataprocess_instance import InstanceRunResult
|
||||
from datahub_provider.client.airflow_generator import AirflowGenerator
|
||||
from datahub_provider.hooks.datahub import DatahubGenericHook
|
||||
from datahub_provider.lineage.datahub import DatahubLineageConfig
|
||||
|
||||
|
||||
def get_lineage_config() -> DatahubLineageConfig:
|
||||
"""Load the lineage config from airflow.cfg."""
|
||||
|
||||
enabled = conf.get("datahub", "enabled", fallback=True)
|
||||
datahub_conn_id = conf.get("datahub", "conn_id", fallback="datahub_rest_default")
|
||||
cluster = conf.get("datahub", "cluster", fallback="prod")
|
||||
graceful_exceptions = conf.get("datahub", "graceful_exceptions", fallback=True)
|
||||
capture_tags_info = conf.get("datahub", "capture_tags_info", fallback=True)
|
||||
capture_ownership_info = conf.get(
|
||||
"datahub", "capture_ownership_info", fallback=True
|
||||
)
|
||||
capture_executions = conf.get("datahub", "capture_executions", fallback=True)
|
||||
return DatahubLineageConfig(
|
||||
enabled=enabled,
|
||||
datahub_conn_id=datahub_conn_id,
|
||||
cluster=cluster,
|
||||
graceful_exceptions=graceful_exceptions,
|
||||
capture_ownership_info=capture_ownership_info,
|
||||
capture_tags_info=capture_tags_info,
|
||||
capture_executions=capture_executions,
|
||||
)
|
||||
|
||||
|
||||
def _task_inlets(operator: "Operator") -> List:
|
||||
# From Airflow 2.4 _inlets is dropped and inlets used consistently. Earlier it was not the case, so we have to stick there to _inlets
|
||||
if hasattr(operator, "_inlets"):
|
||||
return operator._inlets # type: ignore[attr-defined, union-attr]
|
||||
return operator.inlets
|
||||
|
||||
|
||||
def _task_outlets(operator: "Operator") -> List:
|
||||
# From Airflow 2.4 _outlets is dropped and inlets used consistently. Earlier it was not the case, so we have to stick there to _outlets
|
||||
# We have to use _outlets because outlets is empty in Airflow < 2.4.0
|
||||
if hasattr(operator, "_outlets"):
|
||||
return operator._outlets # type: ignore[attr-defined, union-attr]
|
||||
return operator.outlets
|
||||
|
||||
|
||||
def get_inlets_from_task(task: BaseOperator, context: Any) -> Iterable[Any]:
|
||||
# TODO: Fix for https://github.com/apache/airflow/commit/1b1f3fabc5909a447a6277cafef3a0d4ef1f01ae
|
||||
# in Airflow 2.4.
|
||||
# TODO: ignore/handle airflow's dataset type in our lineage
|
||||
|
||||
inlets: List[Any] = []
|
||||
task_inlets = _task_inlets(task)
|
||||
# From Airflow 2.3 this should be AbstractOperator but due to compatibility reason lets use BaseOperator
|
||||
if isinstance(task_inlets, (str, BaseOperator)):
|
||||
inlets = [
|
||||
task_inlets,
|
||||
]
|
||||
|
||||
if task_inlets and isinstance(task_inlets, list):
|
||||
inlets = []
|
||||
task_ids = (
|
||||
{o for o in task_inlets if isinstance(o, str)}
|
||||
.union(op.task_id for op in task_inlets if isinstance(op, BaseOperator))
|
||||
.intersection(task.get_flat_relative_ids(upstream=True))
|
||||
)
|
||||
|
||||
from airflow.lineage import AUTO
|
||||
|
||||
# pick up unique direct upstream task_ids if AUTO is specified
|
||||
if AUTO.upper() in task_inlets or AUTO.lower() in task_inlets:
|
||||
print("Picking up unique direct upstream task_ids as AUTO is specified")
|
||||
task_ids = task_ids.union(
|
||||
task_ids.symmetric_difference(task.upstream_task_ids)
|
||||
)
|
||||
|
||||
inlets = task.xcom_pull(
|
||||
context, task_ids=list(task_ids), dag_id=task.dag_id, key=PIPELINE_OUTLETS
|
||||
)
|
||||
|
||||
# re-instantiate the obtained inlets
|
||||
inlets = [
|
||||
structure(item["data"], import_string(item["type_name"]))
|
||||
# _get_instance(structure(item, Metadata))
|
||||
for sublist in inlets
|
||||
if sublist
|
||||
for item in sublist
|
||||
]
|
||||
|
||||
for inlet in task_inlets:
|
||||
if type(inlet) != str:
|
||||
inlets.append(inlet)
|
||||
|
||||
return inlets
|
||||
|
||||
|
||||
def datahub_task_status_callback(context, status):
|
||||
ti = context["ti"]
|
||||
task: "BaseOperator" = ti.task
|
||||
dag = context["dag"]
|
||||
|
||||
# This code is from the original airflow lineage code ->
|
||||
# https://github.com/apache/airflow/blob/main/airflow/lineage/__init__.py
|
||||
inlets = get_inlets_from_task(task, context)
|
||||
|
||||
emitter = (
|
||||
DatahubGenericHook(context["_datahub_config"].datahub_conn_id)
|
||||
.get_underlying_hook()
|
||||
.make_emitter()
|
||||
)
|
||||
|
||||
dataflow = AirflowGenerator.generate_dataflow(
|
||||
cluster=context["_datahub_config"].cluster,
|
||||
dag=dag,
|
||||
capture_tags=context["_datahub_config"].capture_tags_info,
|
||||
capture_owner=context["_datahub_config"].capture_ownership_info,
|
||||
)
|
||||
dataflow.emit(emitter)
|
||||
|
||||
task.log.info(f"Emitted Datahub DataFlow: {dataflow}")
|
||||
|
||||
datajob = AirflowGenerator.generate_datajob(
|
||||
cluster=context["_datahub_config"].cluster,
|
||||
task=task,
|
||||
dag=dag,
|
||||
capture_tags=context["_datahub_config"].capture_tags_info,
|
||||
capture_owner=context["_datahub_config"].capture_ownership_info,
|
||||
)
|
||||
|
||||
for inlet in inlets:
|
||||
datajob.inlets.append(inlet.urn)
|
||||
|
||||
task_outlets = _task_outlets(task)
|
||||
for outlet in task_outlets:
|
||||
datajob.outlets.append(outlet.urn)
|
||||
|
||||
task.log.info(f"Emitted Datahub dataJob: {datajob}")
|
||||
datajob.emit(emitter)
|
||||
|
||||
if context["_datahub_config"].capture_executions:
|
||||
dpi = AirflowGenerator.run_datajob(
|
||||
emitter=emitter,
|
||||
cluster=context["_datahub_config"].cluster,
|
||||
ti=context["ti"],
|
||||
dag=dag,
|
||||
dag_run=context["dag_run"],
|
||||
datajob=datajob,
|
||||
start_timestamp_millis=int(ti.start_date.timestamp() * 1000),
|
||||
)
|
||||
|
||||
task.log.info(f"Emitted Start Datahub Dataprocess Instance: {dpi}")
|
||||
|
||||
dpi = AirflowGenerator.complete_datajob(
|
||||
emitter=emitter,
|
||||
cluster=context["_datahub_config"].cluster,
|
||||
ti=context["ti"],
|
||||
dag_run=context["dag_run"],
|
||||
result=status,
|
||||
dag=dag,
|
||||
datajob=datajob,
|
||||
end_timestamp_millis=int(ti.end_date.timestamp() * 1000),
|
||||
)
|
||||
task.log.info(f"Emitted Completed Data Process Instance: {dpi}")
|
||||
|
||||
|
||||
def datahub_pre_execution(context):
|
||||
ti = context["ti"]
|
||||
task: "BaseOperator" = ti.task
|
||||
dag = context["dag"]
|
||||
|
||||
task.log.info("Running Datahub pre_execute method")
|
||||
|
||||
emitter = (
|
||||
DatahubGenericHook(context["_datahub_config"].datahub_conn_id)
|
||||
.get_underlying_hook()
|
||||
.make_emitter()
|
||||
)
|
||||
|
||||
# This code is from the original airflow lineage code ->
|
||||
# https://github.com/apache/airflow/blob/main/airflow/lineage/__init__.py
|
||||
inlets = get_inlets_from_task(task, context)
|
||||
|
||||
datajob = AirflowGenerator.generate_datajob(
|
||||
cluster=context["_datahub_config"].cluster,
|
||||
task=context["ti"].task,
|
||||
dag=dag,
|
||||
capture_tags=context["_datahub_config"].capture_tags_info,
|
||||
capture_owner=context["_datahub_config"].capture_ownership_info,
|
||||
)
|
||||
|
||||
for inlet in inlets:
|
||||
datajob.inlets.append(inlet.urn)
|
||||
|
||||
task_outlets = _task_outlets(task)
|
||||
|
||||
for outlet in task_outlets:
|
||||
datajob.outlets.append(outlet.urn)
|
||||
|
||||
datajob.emit(emitter)
|
||||
task.log.info(f"Emitting Datahub DataJob: {datajob}")
|
||||
|
||||
if context["_datahub_config"].capture_executions:
|
||||
dpi = AirflowGenerator.run_datajob(
|
||||
emitter=emitter,
|
||||
cluster=context["_datahub_config"].cluster,
|
||||
ti=context["ti"],
|
||||
dag=dag,
|
||||
dag_run=context["dag_run"],
|
||||
datajob=datajob,
|
||||
start_timestamp_millis=int(ti.start_date.timestamp() * 1000),
|
||||
)
|
||||
|
||||
task.log.info(f"Emitting Datahub Dataprocess Instance: {dpi}")
|
||||
|
||||
|
||||
def _wrap_pre_execution(pre_execution):
|
||||
def custom_pre_execution(context):
|
||||
config = get_lineage_config()
|
||||
context["_datahub_config"] = config
|
||||
datahub_pre_execution(context)
|
||||
|
||||
# Call original policy
|
||||
if pre_execution:
|
||||
pre_execution(context)
|
||||
|
||||
return custom_pre_execution
|
||||
|
||||
|
||||
def _wrap_on_failure_callback(on_failure_callback):
|
||||
def custom_on_failure_callback(context):
|
||||
config = get_lineage_config()
|
||||
context["_datahub_config"] = config
|
||||
try:
|
||||
datahub_task_status_callback(context, status=InstanceRunResult.FAILURE)
|
||||
except Exception as e:
|
||||
if not config.graceful_exceptions:
|
||||
raise e
|
||||
else:
|
||||
print(f"Exception: {traceback.format_exc()}")
|
||||
|
||||
# Call original policy
|
||||
if on_failure_callback:
|
||||
on_failure_callback(context)
|
||||
|
||||
return custom_on_failure_callback
|
||||
|
||||
|
||||
def _wrap_on_success_callback(on_success_callback):
|
||||
def custom_on_success_callback(context):
|
||||
config = get_lineage_config()
|
||||
context["_datahub_config"] = config
|
||||
try:
|
||||
datahub_task_status_callback(context, status=InstanceRunResult.SUCCESS)
|
||||
except Exception as e:
|
||||
if not config.graceful_exceptions:
|
||||
raise e
|
||||
else:
|
||||
print(f"Exception: {traceback.format_exc()}")
|
||||
|
||||
if on_success_callback:
|
||||
on_success_callback(context)
|
||||
|
||||
return custom_on_success_callback
|
||||
|
||||
|
||||
def task_policy(task: BaseOperator) -> None:
|
||||
print(f"Setting task policy for Dag: {task.dag_id} Task: {task.task_id}")
|
||||
# task.add_inlets(["auto"])
|
||||
# task.pre_execute = _wrap_pre_execution(task.pre_execute)
|
||||
task.on_failure_callback = _wrap_on_failure_callback(task.on_failure_callback)
|
||||
task.on_success_callback = _wrap_on_success_callback(task.on_success_callback)
|
||||
# task.pre_execute = _wrap_pre_execution(task.pre_execute)
|
||||
|
||||
|
||||
def _wrap_task_policy(policy):
|
||||
if policy and hasattr(policy, "_task_policy_patched_by"):
|
||||
return policy
|
||||
|
||||
def custom_task_policy(task):
|
||||
policy(task)
|
||||
task_policy(task)
|
||||
|
||||
setattr(custom_task_policy, "_task_policy_patched_by", "datahub_plugin")
|
||||
return custom_task_policy
|
||||
|
||||
|
||||
def _patch_policy(settings):
|
||||
if hasattr(settings, "task_policy"):
|
||||
datahub_task_policy = _wrap_task_policy(settings.task_policy)
|
||||
settings.task_policy = datahub_task_policy
|
||||
|
||||
|
||||
def _patch_datahub_policy():
|
||||
print("Patching datahub policy")
|
||||
|
||||
with contextlib.suppress(ImportError):
|
||||
import airflow_local_settings
|
||||
|
||||
_patch_policy(airflow_local_settings)
|
||||
|
||||
from airflow.models.dagbag import settings
|
||||
|
||||
_patch_policy(settings)
|
||||
|
||||
|
||||
_patch_datahub_policy()
|
||||
|
||||
|
||||
class DatahubPlugin(AirflowPlugin):
|
||||
name = "datahub_plugin"
|
||||
@ -1,4 +1,6 @@
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast
|
||||
from datahub_provider._airflow_compat import BaseOperator, ExternalTaskSensor, Operator
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast
|
||||
|
||||
from airflow.configuration import conf
|
||||
|
||||
@ -13,16 +15,22 @@ from datahub.utilities.urns.data_job_urn import DataJobUrn
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from airflow import DAG
|
||||
from airflow.models import BaseOperator, DagRun, TaskInstance
|
||||
from airflow.models import DagRun, TaskInstance
|
||||
|
||||
from datahub.emitter.kafka_emitter import DatahubKafkaEmitter
|
||||
from datahub.emitter.rest_emitter import DatahubRestEmitter
|
||||
|
||||
|
||||
def _task_downstream_task_ids(operator: "Operator") -> Set[str]:
|
||||
if hasattr(operator, "downstream_task_ids"):
|
||||
return operator.downstream_task_ids
|
||||
return operator._downstream_task_id # type: ignore[attr-defined,union-attr]
|
||||
|
||||
|
||||
class AirflowGenerator:
|
||||
@staticmethod
|
||||
def _get_dependencies(
|
||||
task: "BaseOperator", dag: "DAG", flow_urn: DataFlowUrn
|
||||
task: "Operator", dag: "DAG", flow_urn: DataFlowUrn
|
||||
) -> List[DataJobUrn]:
|
||||
|
||||
# resolve URNs for upstream nodes in subdags upstream of the current task.
|
||||
@ -47,7 +55,7 @@ class AirflowGenerator:
|
||||
)
|
||||
|
||||
# if subdag task is a leaf task, then link it as an upstream task
|
||||
if len(upstream_subdag_task._downstream_task_ids) == 0:
|
||||
if len(_task_downstream_task_ids(upstream_subdag_task)) == 0:
|
||||
upstream_subdag_task_urns.append(upstream_subdag_task_urn)
|
||||
|
||||
# resolve URNs for upstream nodes that trigger the subdag containing the current task.
|
||||
@ -59,7 +67,7 @@ class AirflowGenerator:
|
||||
if (
|
||||
dag.is_subdag
|
||||
and dag.parent_dag is not None
|
||||
and len(task._upstream_task_ids) == 0
|
||||
and len(task.upstream_task_ids) == 0
|
||||
):
|
||||
|
||||
# filter through the parent dag's tasks and find the subdag trigger(s)
|
||||
@ -83,7 +91,7 @@ class AirflowGenerator:
|
||||
)
|
||||
|
||||
# if the task triggers the subdag, link it to this node in the subdag
|
||||
if subdag_task_id in upstream_task._downstream_task_ids:
|
||||
if subdag_task_id in _task_downstream_task_ids(upstream_task):
|
||||
upstream_subdag_triggers.append(upstream_task_urn)
|
||||
|
||||
# If the operator is an ExternalTaskSensor then we set the remote task as upstream.
|
||||
@ -91,8 +99,6 @@ class AirflowGenerator:
|
||||
# jobflow to anothet jobflow.
|
||||
external_task_upstreams = []
|
||||
if task.task_type == "ExternalTaskSensor":
|
||||
from airflow.sensors.external_task_sensor import ExternalTaskSensor
|
||||
|
||||
task = cast(ExternalTaskSensor, task)
|
||||
if hasattr(task, "external_task_id") and task.external_task_id is not None:
|
||||
external_task_upstreams = [
|
||||
@ -173,7 +179,11 @@ class AirflowGenerator:
|
||||
return data_flow
|
||||
|
||||
@staticmethod
|
||||
def _get_description(task: "BaseOperator") -> Optional[str]:
|
||||
def _get_description(task: "Operator") -> Optional[str]:
|
||||
if not isinstance(task, BaseOperator):
|
||||
# TODO: Get docs for mapped operators.
|
||||
return None
|
||||
|
||||
if hasattr(task, "doc") and task.doc:
|
||||
return task.doc
|
||||
elif hasattr(task, "doc_md") and task.doc_md:
|
||||
@ -189,9 +199,9 @@ class AirflowGenerator:
|
||||
@staticmethod
|
||||
def generate_datajob(
|
||||
cluster: str,
|
||||
task: "BaseOperator",
|
||||
task: "Operator",
|
||||
dag: "DAG",
|
||||
set_dependendecies: bool = True,
|
||||
set_dependencies: bool = True,
|
||||
capture_owner: bool = True,
|
||||
capture_tags: bool = True,
|
||||
) -> DataJob:
|
||||
@ -200,7 +210,7 @@ class AirflowGenerator:
|
||||
:param cluster: str
|
||||
:param task: TaskIntance
|
||||
:param dag: DAG
|
||||
:param set_dependendecies: bool - whether to extract dependencies from airflow task
|
||||
:param set_dependencies: bool - whether to extract dependencies from airflow task
|
||||
:param capture_owner: bool - whether to extract owner from airflow task
|
||||
:param capture_tags: bool - whether to set tags automatically from airflow task
|
||||
:return: DataJob - returns the generated DataJob object
|
||||
@ -209,6 +219,8 @@ class AirflowGenerator:
|
||||
orchestrator="airflow", env=cluster, flow_id=dag.dag_id
|
||||
)
|
||||
datajob = DataJob(id=task.task_id, flow_urn=dataflow_urn)
|
||||
|
||||
# TODO add support for MappedOperator
|
||||
datajob.description = AirflowGenerator._get_description(task)
|
||||
|
||||
job_property_bag: Dict[str, str] = {}
|
||||
@ -228,6 +240,11 @@ class AirflowGenerator:
|
||||
"task_id",
|
||||
"trigger_rule",
|
||||
"wait_for_downstream",
|
||||
# In Airflow 2.3, _downstream_task_ids was renamed to downstream_task_ids
|
||||
"downstream_task_ids",
|
||||
# In Airflow 2.4, _inlets and _outlets were removed in favor of non-private versions.
|
||||
"inlets",
|
||||
"outlets",
|
||||
]
|
||||
|
||||
for key in allowed_task_keys:
|
||||
@ -244,7 +261,7 @@ class AirflowGenerator:
|
||||
if capture_tags and dag.tags:
|
||||
datajob.tags.update(dag.tags)
|
||||
|
||||
if set_dependendecies:
|
||||
if set_dependencies:
|
||||
datajob.upstream_urns.extend(
|
||||
AirflowGenerator._get_dependencies(
|
||||
task=task, dag=dag, flow_urn=datajob.flow_urn
|
||||
@ -256,7 +273,7 @@ class AirflowGenerator:
|
||||
@staticmethod
|
||||
def create_datajob_instance(
|
||||
cluster: str,
|
||||
task: "BaseOperator",
|
||||
task: "Operator",
|
||||
dag: "DAG",
|
||||
data_job: Optional[DataJob] = None,
|
||||
) -> DataProcessInstance:
|
||||
@ -282,8 +299,10 @@ class AirflowGenerator:
|
||||
dataflow = AirflowGenerator.generate_dataflow(cluster, dag_run.dag)
|
||||
|
||||
if start_timestamp_millis is None:
|
||||
assert dag_run.execution_date
|
||||
start_timestamp_millis = int(dag_run.execution_date.timestamp() * 1000)
|
||||
|
||||
assert dag_run.run_id
|
||||
dpi = DataProcessInstance.from_dataflow(dataflow=dataflow, id=dag_run.run_id)
|
||||
|
||||
# This property only exists in Airflow2
|
||||
@ -335,6 +354,7 @@ class AirflowGenerator:
|
||||
assert dag_run.dag
|
||||
dataflow = AirflowGenerator.generate_dataflow(cluster, dag_run.dag)
|
||||
|
||||
assert dag_run.run_id
|
||||
dpi = DataProcessInstance.from_dataflow(dataflow=dataflow, id=dag_run.run_id)
|
||||
if end_timestamp_millis is None:
|
||||
if dag_run.end_date is None:
|
||||
@ -375,6 +395,7 @@ class AirflowGenerator:
|
||||
if datajob is None:
|
||||
datajob = AirflowGenerator.generate_datajob(cluster, ti.task, dag)
|
||||
|
||||
assert dag_run.run_id
|
||||
dpi = DataProcessInstance.from_datajob(
|
||||
datajob=datajob,
|
||||
id=f"{dag.dag_id}_{ti.task_id}_{dag_run.run_id}",
|
||||
|
||||
@ -9,6 +9,8 @@ from datahub.metadata.com.linkedin.pegasus2avro.mxe import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from airflow.models.connection import Connection
|
||||
|
||||
from datahub.emitter.kafka_emitter import DatahubKafkaEmitter
|
||||
from datahub.emitter.rest_emitter import DatahubRestEmitter
|
||||
from datahub.ingestion.sink.datahub_kafka import KafkaSinkConfig
|
||||
@ -51,12 +53,14 @@ class DatahubRestHook(BaseHook):
|
||||
}
|
||||
|
||||
def _get_config(self) -> Tuple[str, Optional[str], Optional[int]]:
|
||||
conn = self.get_connection(self.datahub_rest_conn_id)
|
||||
conn: "Connection" = self.get_connection(self.datahub_rest_conn_id)
|
||||
|
||||
host = conn.host
|
||||
if host is None:
|
||||
raise AirflowException("host parameter is required")
|
||||
password = conn.password
|
||||
timeout_sec = conn.extra_dejson.get("timeout_sec")
|
||||
return (host, conn.password, timeout_sec)
|
||||
return (host, password, timeout_sec)
|
||||
|
||||
def make_emitter(self) -> "DatahubRestEmitter":
|
||||
import datahub.emitter.rest_emitter
|
||||
|
||||
@ -0,0 +1,33 @@
|
||||
# On SQLAlchemy 1.4.x, the mypy plugin is built-in.
|
||||
# However, with SQLAlchemy 1.3.x, it requires the sqlalchemy-stubs package and hence has a separate import.
|
||||
# This file serves as a thin shim layer that directs mypy to the appropriate plugin implementation.
|
||||
try:
|
||||
from mypy.semanal import SemanticAnalyzer
|
||||
from sqlalchemy.ext.mypy.plugin import plugin
|
||||
|
||||
# On SQLAlchemy >=1.4, <=1.4.29, the mypy plugin is incompatible with newer versions of mypy.
|
||||
# See https://github.com/sqlalchemy/sqlalchemy/commit/aded8b11d9eccbd1f2b645a94338e34a3d234bc9
|
||||
# and https://github.com/sqlalchemy/sqlalchemy/issues/7496.
|
||||
# To fix this, we need to patch the mypy plugin interface.
|
||||
#
|
||||
# We cannot set a min version of SQLAlchemy because of the bigquery SQLAlchemy package.
|
||||
# See https://github.com/googleapis/python-bigquery-sqlalchemy/issues/385.
|
||||
_named_type_original = SemanticAnalyzer.named_type
|
||||
_named_type_translations = {
|
||||
"__builtins__.object": "builtins.object",
|
||||
"__builtins__.str": "builtins.str",
|
||||
"__builtins__.list": "builtins.list",
|
||||
"__sa_Mapped": "sqlalchemy.orm.attributes.Mapped",
|
||||
}
|
||||
|
||||
def _named_type_shim(self, fullname, *args, **kwargs):
|
||||
if fullname in _named_type_translations:
|
||||
fullname = _named_type_translations[fullname]
|
||||
|
||||
return _named_type_original(self, fullname, *args, **kwargs)
|
||||
|
||||
SemanticAnalyzer.named_type = _named_type_shim # type: ignore
|
||||
except ModuleNotFoundError:
|
||||
from sqlmypy import plugin # type: ignore[no-redef]
|
||||
|
||||
__all__ = ["plugin"]
|
||||
@ -1,3 +1,5 @@
|
||||
from datahub_provider._airflow_compat import MARKUPSAFE_PATCHED
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
@ -13,19 +15,19 @@ import packaging.version
|
||||
import pytest
|
||||
from airflow.lineage import apply_lineage, prepare_lineage
|
||||
from airflow.models import DAG, Connection, DagBag, DagRun, TaskInstance
|
||||
from airflow.operators.dummy import DummyOperator
|
||||
from airflow.utils.dates import days_ago
|
||||
|
||||
try:
|
||||
from airflow.operators.dummy import DummyOperator
|
||||
except ModuleNotFoundError:
|
||||
from airflow.operators.dummy_operator import DummyOperator
|
||||
|
||||
import datahub.emitter.mce_builder as builder
|
||||
from datahub_provider import get_provider_info
|
||||
from datahub_provider.entities import Dataset
|
||||
from datahub_provider.hooks.datahub import DatahubKafkaHook, DatahubRestHook
|
||||
from datahub_provider.operators.datahub import DatahubEmitterOperator
|
||||
|
||||
assert MARKUPSAFE_PATCHED
|
||||
|
||||
pytestmark = pytest.mark.airflow
|
||||
|
||||
# Approach suggested by https://stackoverflow.com/a/11887885/5004662.
|
||||
AIRFLOW_VERSION = packaging.version.parse(airflow.version.version)
|
||||
|
||||
@ -73,6 +75,10 @@ def test_airflow_provider_info():
|
||||
assert get_provider_info()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
AIRFLOW_VERSION < packaging.version.parse("2.0.0"),
|
||||
reason="the examples use list-style lineage, which is only supported on Airflow 2.x",
|
||||
)
|
||||
def test_dags_load_with_no_errors(pytestconfig):
|
||||
airflow_examples_folder = (
|
||||
pytestconfig.rootpath / "src/datahub_provider/example_dags"
|
||||
@ -99,6 +105,7 @@ def patch_airflow_connection(conn: Connection) -> Iterator[Connection]:
|
||||
@mock.patch("datahub.emitter.rest_emitter.DatahubRestEmitter", autospec=True)
|
||||
def test_datahub_rest_hook(mock_emitter):
|
||||
with patch_airflow_connection(datahub_rest_connection_config) as config:
|
||||
assert config.conn_id
|
||||
hook = DatahubRestHook(config.conn_id)
|
||||
hook.emit_mces([lineage_mce])
|
||||
|
||||
@ -112,6 +119,7 @@ def test_datahub_rest_hook_with_timeout(mock_emitter):
|
||||
with patch_airflow_connection(
|
||||
datahub_rest_connection_config_with_timeout
|
||||
) as config:
|
||||
assert config.conn_id
|
||||
hook = DatahubRestHook(config.conn_id)
|
||||
hook.emit_mces([lineage_mce])
|
||||
|
||||
@ -123,6 +131,7 @@ def test_datahub_rest_hook_with_timeout(mock_emitter):
|
||||
@mock.patch("datahub.emitter.kafka_emitter.DatahubKafkaEmitter", autospec=True)
|
||||
def test_datahub_kafka_hook(mock_emitter):
|
||||
with patch_airflow_connection(datahub_kafka_connection_config) as config:
|
||||
assert config.conn_id
|
||||
hook = DatahubKafkaHook(config.conn_id)
|
||||
hook.emit_mces([lineage_mce])
|
||||
|
||||
@ -135,6 +144,7 @@ def test_datahub_kafka_hook(mock_emitter):
|
||||
@mock.patch("datahub_provider.hooks.datahub.DatahubRestHook.emit_mces")
|
||||
def test_datahub_lineage_operator(mock_emit):
|
||||
with patch_airflow_connection(datahub_rest_connection_config) as config:
|
||||
assert config.conn_id
|
||||
task = DatahubEmitterOperator(
|
||||
task_id="emit_lineage",
|
||||
datahub_conn_id=config.conn_id,
|
||||
@ -331,6 +341,7 @@ def test_lineage_backend(mock_emit, inlets, outlets):
|
||||
)
|
||||
@mock.patch("datahub_provider.hooks.datahub.DatahubRestHook.make_emitter")
|
||||
def test_lineage_backend_capture_executions(mock_emit, inlets, outlets):
|
||||
# TODO: Merge this code into the test above to reduce duplication.
|
||||
DEFAULT_DATE = datetime.datetime(2020, 5, 17)
|
||||
mock_emitter = Mock()
|
||||
mock_emit.return_value = mock_emitter
|
||||
@ -375,10 +386,6 @@ def test_lineage_backend_capture_executions(mock_emit, inlets, outlets):
|
||||
ti = TaskInstance(task=op2, execution_date=DEFAULT_DATE)
|
||||
# Ignoring type here because DagRun state is just a sring at Airflow 1
|
||||
dag_run = DagRun(state="success", run_id=f"scheduled_{DEFAULT_DATE}") # type: ignore
|
||||
ti.dag_run = dag_run
|
||||
ti.start_date = datetime.datetime.utcnow()
|
||||
ti.execution_date = DEFAULT_DATE
|
||||
|
||||
else:
|
||||
from airflow.utils.state import DagRunState
|
||||
|
||||
@ -386,9 +393,10 @@ def test_lineage_backend_capture_executions(mock_emit, inlets, outlets):
|
||||
dag_run = DagRun(
|
||||
state=DagRunState.SUCCESS, run_id=f"scheduled_{DEFAULT_DATE}"
|
||||
)
|
||||
ti.dag_run = dag_run
|
||||
ti.start_date = datetime.datetime.utcnow()
|
||||
ti.execution_date = DEFAULT_DATE
|
||||
|
||||
ti.dag_run = dag_run # type: ignore
|
||||
ti.start_date = datetime.datetime.utcnow()
|
||||
ti.execution_date = DEFAULT_DATE
|
||||
|
||||
ctx1 = {
|
||||
"dag": dag,
|
||||
|
||||
@ -65,9 +65,8 @@ def test_athena_get_table_properties():
|
||||
|
||||
mock_cursor = mock.MagicMock()
|
||||
mock_inspector = mock.MagicMock()
|
||||
mock_inspector.engine.return_value = mock.MagicMock()
|
||||
mock_inspector.dialect._raw_connection.return_value = mock_cursor
|
||||
mock_inspector.dialect._raw_connection().cursor()._get_table_metadata.return_value = AthenaTableMetadata(
|
||||
mock_inspector.engine.raw_connection().cursor.return_value = mock_cursor
|
||||
mock_cursor._get_table_metadata.return_value = AthenaTableMetadata(
|
||||
response=table_metadata
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user