feat(ingest): loosen sqlalchemy dep & support airflow 2.3+ (#6204)

Co-authored-by: Tamas Nemeth <treff7es@gmail.com>
This commit is contained in:
Harshal Sheth 2022-11-11 15:04:36 -05:00 committed by GitHub
parent 6c42064332
commit 3e907ab0d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 598 additions and 454 deletions

View File

@ -38,6 +38,11 @@ jobs:
"testIntegrationBatch1",
"testSlowIntegration",
]
include:
- python-version: "3.7"
extraPythonRequirement: "sqlalchemy==1.3.24 apache-airflow~=2.2.0"
- python-version: "3.10"
extraPythonRequirement: "sqlalchemy~=1.4.0 apache-airflow>=2.4.0"
fail-fast: false
steps:
- uses: actions/checkout@v3
@ -50,8 +55,8 @@ jobs:
hadoop-version: "3.2"
- name: Install dependencies
run: ./metadata-ingestion/scripts/install_deps.sh
- name: Run metadata-ingestion tests
run: ./gradlew :metadata-ingestion:build :metadata-ingestion:${{ matrix.command }}
- name: Run metadata-ingestion tests (extras ${{ matrix.extraPythonRequirement }})
run: ./gradlew -Pextra_pip_requirements='${{ matrix.extraPythonRequirement }}' :metadata-ingestion:${{ matrix.command }}
- name: pip freeze show list installed
if: always()
run: source metadata-ingestion/venv/bin/activate && pip freeze

View File

@ -7,8 +7,10 @@ ext {
venv_name = 'venv'
}
def pip_install_command = "${venv_name}/bin/pip install -e ../../metadata-ingestion"
task checkPythonVersion(type: Exec) {
commandLine python_executable, '-c', 'import sys; assert sys.version_info >= (3, 6)'
commandLine python_executable, '-c', 'import sys; assert sys.version_info >= (3, 7)'
}
task environmentSetup(type: Exec, dependsOn: checkPythonVersion) {
@ -20,7 +22,7 @@ task environmentSetup(type: Exec, dependsOn: checkPythonVersion) {
task installPackage(type: Exec, dependsOn: environmentSetup) {
inputs.file file('setup.py')
outputs.dir("${venv_name}")
commandLine "${venv_name}/bin/pip", 'install', '-e', '.'
commandLine 'bash', '-x', '-c', "${pip_install_command} -e ."
}
task install(dependsOn: [installPackage])
@ -30,7 +32,7 @@ task installDev(type: Exec, dependsOn: [install]) {
outputs.dir("${venv_name}")
outputs.file("${venv_name}/.build_install_dev_sentinel")
commandLine 'bash', '-x', '-c',
"${venv_name}/bin/pip install -e .[dev] && touch ${venv_name}/.build_install_dev_sentinel"
"${pip_install_command} -e .[dev] && touch ${venv_name}/.build_install_dev_sentinel"
}
task lint(type: Exec, dependsOn: installDev) {
@ -65,7 +67,7 @@ task installDevTest(type: Exec, dependsOn: [installDev]) {
outputs.dir("${venv_name}")
outputs.file("${venv_name}/.build_install_dev_test_sentinel")
commandLine 'bash', '-x', '-c',
"${venv_name}/bin/pip install -e .[dev,integration-tests] && touch ${venv_name}/.build_install_dev_test_sentinel"
"${pip_install_command} -e .[dev,integration-tests] && touch ${venv_name}/.build_install_dev_test_sentinel"
}
def testFile = hasProperty('testFile') ? testFile : 'unknown'

View File

@ -1,9 +1,9 @@
import os
import pathlib
from typing import Dict, Set
import setuptools
package_metadata: dict = {}
with open("./src/datahub_airflow_plugin/__init__.py") as fp:
exec(fp.read(), package_metadata)
@ -23,9 +23,7 @@ base_requirements = {
"typing-inspect",
"pydantic>=1.5.1",
"apache-airflow >= 2.0.2",
"acryl-datahub[airflow] >= 0.8.36",
# Pinned dependencies to make dependency resolution faster.
"sqlalchemy==1.3.24",
f"acryl-datahub[airflow] == {package_metadata['__version__']}",
}

View File

@ -1,362 +1,4 @@
import contextlib
import traceback
from typing import Any, Iterable
import attr
from airflow.configuration import conf
from airflow.lineage import PIPELINE_OUTLETS
from airflow.models.baseoperator import BaseOperator
from airflow.plugins_manager import AirflowPlugin
from airflow.utils.module_loading import import_string
from cattr import structure
from datahub.api.entities.dataprocess.dataprocess_instance import InstanceRunResult
from datahub_provider.client.airflow_generator import AirflowGenerator
from datahub_provider.hooks.datahub import DatahubGenericHook
from datahub_provider.lineage.datahub import DatahubLineageConfig
def get_lineage_config() -> DatahubLineageConfig:
"""Load the lineage config from airflow.cfg."""
enabled = conf.get("datahub", "enabled", fallback=True)
datahub_conn_id = conf.get("datahub", "conn_id", fallback="datahub_rest_default")
cluster = conf.get("datahub", "cluster", fallback="prod")
graceful_exceptions = conf.get("datahub", "graceful_exceptions", fallback=True)
capture_tags_info = conf.get("datahub", "capture_tags_info", fallback=True)
capture_ownership_info = conf.get(
"datahub", "capture_ownership_info", fallback=True
)
capture_executions = conf.get("datahub", "capture_executions", fallback=True)
return DatahubLineageConfig(
enabled=enabled,
datahub_conn_id=datahub_conn_id,
cluster=cluster,
graceful_exceptions=graceful_exceptions,
capture_ownership_info=capture_ownership_info,
capture_tags_info=capture_tags_info,
capture_executions=capture_executions,
)
def get_inlets_from_task(task: BaseOperator, context: Any) -> Iterable[Any]:
inlets = []
if isinstance(task._inlets, (str, BaseOperator)) or attr.has(task._inlets): # type: ignore
inlets = [
task._inlets,
]
if task._inlets and isinstance(task._inlets, list):
inlets = []
task_ids = (
{o for o in task._inlets if isinstance(o, str)}
.union(op.task_id for op in task._inlets if isinstance(op, BaseOperator))
.intersection(task.get_flat_relative_ids(upstream=True))
)
from airflow.lineage import AUTO
# pick up unique direct upstream task_ids if AUTO is specified
if AUTO.upper() in task._inlets or AUTO.lower() in task._inlets:
print("Picking up unique direct upstream task_ids as AUTO is specified")
task_ids = task_ids.union(
task_ids.symmetric_difference(task.upstream_task_ids)
)
inlets = task.xcom_pull(
context, task_ids=list(task_ids), dag_id=task.dag_id, key=PIPELINE_OUTLETS
)
# re-instantiate the obtained inlets
inlets = [
structure(item["data"], import_string(item["type_name"]))
# _get_instance(structure(item, Metadata))
for sublist in inlets
if sublist
for item in sublist
]
for inlet in task._inlets:
if type(inlet) != str:
inlets.append(inlet)
return inlets
def datahub_on_failure_callback(context, *args, **kwargs):
ti = context["ti"]
task: "BaseOperator" = ti.task
dag = context["dag"]
# This code is from the original airflow lineage code ->
# https://github.com/apache/airflow/blob/main/airflow/lineage/__init__.py
inlets = get_inlets_from_task(task, context)
emitter = (
DatahubGenericHook(context["_datahub_config"].datahub_conn_id)
.get_underlying_hook()
.make_emitter()
)
dataflow = AirflowGenerator.generate_dataflow(
cluster=context["_datahub_config"].cluster,
dag=dag,
capture_tags=context["_datahub_config"].capture_tags_info,
capture_owner=context["_datahub_config"].capture_ownership_info,
)
dataflow.emit(emitter)
task.log.info(f"Emitted Datahub DataFlow: {dataflow}")
datajob = AirflowGenerator.generate_datajob(
cluster=context["_datahub_config"].cluster,
task=context["ti"].task,
dag=dag,
capture_tags=context["_datahub_config"].capture_tags_info,
capture_owner=context["_datahub_config"].capture_ownership_info,
)
for inlet in inlets:
datajob.inlets.append(inlet.urn)
for outlet in task._outlets:
datajob.outlets.append(outlet.urn)
task.log.info(f"Emitted Datahub DataJob: {datajob}")
datajob.emit(emitter)
if context["_datahub_config"].capture_executions:
dpi = AirflowGenerator.run_datajob(
emitter=emitter,
cluster=context["_datahub_config"].cluster,
ti=context["ti"],
dag=dag,
dag_run=context["dag_run"],
datajob=datajob,
start_timestamp_millis=int(ti.start_date.timestamp() * 1000),
)
task.log.info(f"Emitted Start Datahub Dataprocess Instance: {dpi}")
dpi = AirflowGenerator.complete_datajob(
emitter=emitter,
cluster=context["_datahub_config"].cluster,
ti=context["ti"],
dag_run=context["dag_run"],
result=InstanceRunResult.FAILURE,
dag=dag,
datajob=datajob,
end_timestamp_millis=int(ti.end_date.timestamp() * 1000),
)
task.log.info(f"Emitted Completed Datahub Dataprocess Instance: {dpi}")
def datahub_on_success_callback(context, *args, **kwargs):
ti = context["ti"]
task: "BaseOperator" = ti.task
dag = context["dag"]
# This code is from the original airflow lineage code ->
# https://github.com/apache/airflow/blob/main/airflow/lineage/__init__.py
inlets = get_inlets_from_task(task, context)
emitter = (
DatahubGenericHook(context["_datahub_config"].datahub_conn_id)
.get_underlying_hook()
.make_emitter()
)
dataflow = AirflowGenerator.generate_dataflow(
cluster=context["_datahub_config"].cluster,
dag=dag,
capture_tags=context["_datahub_config"].capture_tags_info,
capture_owner=context["_datahub_config"].capture_ownership_info,
)
dataflow.emit(emitter)
task.log.info(f"Emitted Datahub DataFlow: {dataflow}")
datajob = AirflowGenerator.generate_datajob(
cluster=context["_datahub_config"].cluster,
task=task,
dag=dag,
capture_tags=context["_datahub_config"].capture_tags_info,
capture_owner=context["_datahub_config"].capture_ownership_info,
)
for inlet in inlets:
datajob.inlets.append(inlet.urn)
# We have to use _outlets because outlets is empty
for outlet in task._outlets:
datajob.outlets.append(outlet.urn)
task.log.info(f"Emitted Datahub dataJob: {datajob}")
datajob.emit(emitter)
if context["_datahub_config"].capture_executions:
dpi = AirflowGenerator.run_datajob(
emitter=emitter,
cluster=context["_datahub_config"].cluster,
ti=context["ti"],
dag=dag,
dag_run=context["dag_run"],
datajob=datajob,
start_timestamp_millis=int(ti.start_date.timestamp() * 1000),
)
task.log.info(f"Emitted Start Datahub Dataprocess Instance: {dpi}")
dpi = AirflowGenerator.complete_datajob(
emitter=emitter,
cluster=context["_datahub_config"].cluster,
ti=context["ti"],
dag_run=context["dag_run"],
result=InstanceRunResult.SUCCESS,
dag=dag,
datajob=datajob,
end_timestamp_millis=int(ti.end_date.timestamp() * 1000),
)
task.log.info(f"Emitted Completed Data Process Instance: {dpi}")
def datahub_pre_execution(context):
ti = context["ti"]
task: "BaseOperator" = ti.task
dag = context["dag"]
task.log.info("Running Datahub pre_execute method")
emitter = (
DatahubGenericHook(context["_datahub_config"].datahub_conn_id)
.get_underlying_hook()
.make_emitter()
)
# This code is from the original airflow lineage code ->
# https://github.com/apache/airflow/blob/main/airflow/lineage/__init__.py
inlets = get_inlets_from_task(task, context)
datajob = AirflowGenerator.generate_datajob(
cluster=context["_datahub_config"].cluster,
task=context["ti"].task,
dag=dag,
capture_tags=context["_datahub_config"].capture_tags_info,
capture_owner=context["_datahub_config"].capture_ownership_info,
)
for inlet in inlets:
datajob.inlets.append(inlet.urn)
for outlet in task._outlets:
datajob.outlets.append(outlet.urn)
datajob.emit(emitter)
task.log.info(f"Emitting Datahub DataJob: {datajob}")
if context["_datahub_config"].capture_executions:
dpi = AirflowGenerator.run_datajob(
emitter=emitter,
cluster=context["_datahub_config"].cluster,
ti=context["ti"],
dag=dag,
dag_run=context["dag_run"],
datajob=datajob,
start_timestamp_millis=int(ti.start_date.timestamp() * 1000),
)
task.log.info(f"Emitting Datahub Dataprocess Instance: {dpi}")
def _wrap_pre_execution(pre_execution):
def custom_pre_execution(context):
config = get_lineage_config()
context["_datahub_config"] = config
datahub_pre_execution(context)
# Call original policy
if pre_execution:
pre_execution(context)
return custom_pre_execution
def _wrap_on_failure_callback(on_failure_callback):
def custom_on_failure_callback(context):
config = get_lineage_config()
context["_datahub_config"] = config
try:
datahub_on_failure_callback(context)
except Exception as e:
if not config.graceful_exceptions:
raise e
else:
print(f"Exception: {traceback.format_exc()}")
# Call original policy
if on_failure_callback:
on_failure_callback(context)
return custom_on_failure_callback
def _wrap_on_success_callback(on_success_callback):
def custom_on_success_callback(context):
config = get_lineage_config()
context["_datahub_config"] = config
try:
datahub_on_success_callback(context)
except Exception as e:
if not config.graceful_exceptions:
raise e
else:
print(f"Exception: {traceback.format_exc()}")
if on_success_callback:
on_success_callback(context)
return custom_on_success_callback
def task_policy(task: BaseOperator) -> None:
print(f"Setting task policy for Dag: {task.dag_id} Task: {task.task_id}")
# task.add_inlets(["auto"])
# task.pre_execute = _wrap_pre_execution(task.pre_execute)
task.on_failure_callback = _wrap_on_failure_callback(task.on_failure_callback)
task.on_success_callback = _wrap_on_success_callback(task.on_success_callback)
# task.pre_execute = _wrap_pre_execution(task.pre_execute)
def _wrap_task_policy(policy):
if policy and hasattr(policy, "_task_policy_patched_by"):
return policy
def custom_task_policy(task):
policy(task)
task_policy(task)
setattr(custom_task_policy, "_task_policy_patched_by", "datahub_plugin")
return custom_task_policy
def _patch_policy(settings):
print("Patching datahub policy")
if hasattr(settings, "task_policy"):
datahub_task_policy = _wrap_task_policy(settings.task_policy)
settings.task_policy = datahub_task_policy
def _patch_datahub_policy():
with contextlib.suppress(ImportError):
import airflow_local_settings
_patch_policy(airflow_local_settings)
from airflow.models.dagbag import settings
_patch_policy(settings)
_patch_datahub_policy()
class DatahubPlugin(AirflowPlugin):
name = "datahub_plugin"
# This package serves as a shim, but the actual implementation lives in datahub_provider
# from the acryl-datahub package. We leave this shim here to avoid breaking existing
# Airflow installs.
from datahub_provider._plugin import DatahubPlugin # noqa: F401

View File

@ -7,8 +7,12 @@ ext {
venv_name = 'venv'
}
if (!project.hasProperty("extra_pip_requirements")) {
ext.extra_pip_requirements = ""
}
task checkPythonVersion(type: Exec) {
commandLine python_executable, '-c', 'import sys; assert sys.version_info >= (3, 6)'
commandLine python_executable, '-c', 'import sys; assert sys.version_info >= (3, 7)'
}
task environmentSetup(type: Exec, dependsOn: checkPythonVersion) {
@ -24,7 +28,7 @@ task runPreFlightScript(type: Exec, dependsOn: environmentSetup) {
task installPackage(type: Exec, dependsOn: runPreFlightScript) {
inputs.file file('setup.py')
outputs.dir("${venv_name}")
commandLine "${venv_name}/bin/pip", 'install', '-e', '.'
commandLine 'bash', '-x', '-c', "${venv_name}/bin/pip install -e . ${extra_pip_requirements}"
}
task codegen(type: Exec, dependsOn: [environmentSetup, installPackage, ':metadata-events:mxe-schemas:build']) {
@ -40,7 +44,7 @@ task installDev(type: Exec, dependsOn: [install]) {
outputs.dir("${venv_name}")
outputs.file("${venv_name}/.build_install_dev_sentinel")
commandLine 'bash', '-x', '-c',
"${venv_name}/bin/pip install -e .[dev] && touch ${venv_name}/.build_install_dev_sentinel"
"${venv_name}/bin/pip install -e .[dev] ${extra_pip_requirements} && touch ${venv_name}/.build_install_dev_sentinel"
}
@ -67,15 +71,21 @@ task lint(type: Exec, dependsOn: installDev) {
*/
commandLine 'bash', '-c',
"find ${venv_name}/lib -path *airflow/_vendor/connexion/spec.py -exec sed -i.bak -e '169,169s/ # type: List\\[str\\]//g' {} \\; && " +
"source ${venv_name}/bin/activate && set -x && black --check --diff src/ tests/ examples/ && isort --check --diff src/ tests/ examples/ && flake8 --count --statistics src/ tests/ examples/ && mypy src/ tests/ examples/"
"source ${venv_name}/bin/activate && set -x && " +
"./scripts/install-sqlalchemy-stubs.sh && " +
"black --check --diff src/ tests/ examples/ && " +
"isort --check --diff src/ tests/ examples/ && " +
"flake8 --count --statistics src/ tests/ examples/ && " +
"mypy --show-traceback --show-error-codes src/ tests/ examples/"
}
task lintFix(type: Exec, dependsOn: installDev) {
commandLine 'bash', '-c',
"source ${venv_name}/bin/activate && set -x && " +
"./scripts/install-sqlalchemy-stubs.sh && " +
"black src/ tests/ examples/ && " +
"isort src/ tests/ examples/ && " +
"flake8 src/ tests/ examples/ && " +
"mypy src/ tests/ examples/"
"mypy --show-traceback --show-error-codes src/ tests/ examples/"
}
task testQuick(type: Exec, dependsOn: installDev) {
@ -92,7 +102,7 @@ task installDevTest(type: Exec, dependsOn: [install]) {
outputs.dir("${venv_name}")
outputs.file("${venv_name}/.build_install_dev_test_sentinel")
commandLine 'bash', '-c',
"${venv_name}/bin/pip install -e .[dev,integration-tests] && touch ${venv_name}/.build_install_dev_test_sentinel"
"${venv_name}/bin/pip install -e .[dev,integration-tests] ${extra_pip_requirements} && touch ${venv_name}/.build_install_dev_test_sentinel"
}
def testFile = hasProperty('testFile') ? testFile : 'unknown'

View File

@ -1 +1 @@
As a SQL-based service, the Athena integration is also supported by our SQL profiler. See here for more details on configuration.
As a SQL-based service, the Oracle integration is also supported by our SQL profiler. See here for more details on configuration.

View File

@ -14,6 +14,7 @@ target-version = ['py36', 'py37', 'py38']
[tool.isort]
combine_as_imports = true
indent = ' '
known_future_library = ['__future__', 'datahub.utilities._markupsafe_compat', 'datahub_provider._airflow_compat']
profile = 'black'
sections = 'FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER'
skip_glob = 'src/datahub/metadata'

View File

@ -0,0 +1,28 @@
#!/bin/bash
set -euo pipefail
# ASSUMPTION: This assumes that we're running from inside the venv.
SQLALCHEMY_VERSION=$(python -c 'import sqlalchemy; print(sqlalchemy.__version__)')
if [[ $SQLALCHEMY_VERSION == 1.3.* ]]; then
ENSURE_NOT_INSTALLED=sqlalchemy2-stubs
ENSURE_INSTALLED=sqlalchemy-stubs
elif [[ $SQLALCHEMY_VERSION == 1.4.* ]]; then
ENSURE_NOT_INSTALLED=sqlalchemy-stubs
ENSURE_INSTALLED=sqlalchemy2-stubs
else
echo "Unsupported SQLAlchemy version: $SQLALCHEMY_VERSION"
exit 1
fi
FORCE_REINSTALL=""
if pip show $ENSURE_NOT_INSTALLED >/dev/null 2>&1 ; then
pip uninstall --yes $ENSURE_NOT_INSTALLED
FORCE_REINSTALL="--force-reinstall"
fi
if [ -n "$FORCE_REINSTALL" ] || ! pip show $ENSURE_INSTALLED >/dev/null 2>&1 ; then
pip install $FORCE_REINSTALL $ENSURE_INSTALLED
fi

View File

@ -22,7 +22,7 @@ ban-relative-imports = true
[mypy]
plugins =
sqlmypy,
./tests/test_helpers/sqlalchemy_mypy_plugin.py,
pydantic.mypy
exclude = ^(venv|build|dist)/
ignore_missing_imports = yes
@ -55,6 +55,7 @@ disallow_untyped_defs = yes
asyncio_mode = auto
addopts = --cov=src --cov-report term-missing --cov-config setup.cfg --strict-markers
markers =
airflow: marks tests related to airflow (deselect with '-m not airflow')
slow_unit: marks tests to only run slow unit tests (deselect with '-m not slow_unit')
integration: marks tests to only run in integration (deselect with '-m "not integration"')
integration_batch_1: mark tests to only run in batch 1 of integration tests. This is done mainly for parallelisation (deselect with '-m not integration_batch_1')

View File

@ -55,6 +55,10 @@ framework_common = {
"click-spinner",
}
rest_common = {
"requests",
}
kafka_common = {
# The confluent_kafka package provides a number of pre-built wheels for
# various platforms and architectures. However, it does not provide wheels
@ -103,7 +107,7 @@ kafka_protobuf = {
sql_common = {
# Required for all SQL sources.
"sqlalchemy==1.3.24",
"sqlalchemy>=1.3.24, <2",
# Required for SQL profiling.
"great-expectations>=0.15.12",
# GE added handling for higher version of jinja2
@ -147,6 +151,12 @@ bigquery_common = {
"more-itertools>=8.12.0",
}
clickhouse_common = {
# Clickhouse 0.1.8 requires SQLAlchemy 1.3.x, while the newer versions
# allow SQLAlchemy 1.4.x.
"clickhouse-sqlalchemy>=0.1.8",
}
redshift_common = {
"sqlalchemy-redshift",
"psycopg2-binary",
@ -214,10 +224,12 @@ databricks_cli = {
plugins: Dict[str, Set[str]] = {
# Sink plugins.
"datahub-kafka": kafka_common,
"datahub-rest": {"requests"},
"datahub-rest": rest_common,
# Integrations.
"airflow": {
"apache-airflow >= 2.0.2",
*rest_common,
*kafka_common,
},
"circuit-breaker": {
"gql>=3.3.0",
@ -239,12 +251,8 @@ plugins: Dict[str, Set[str]] = {
"sqllineage==1.3.6",
"sql_metadata",
}, # deprecated, but keeping the extra for backwards compatibility
"clickhouse": sql_common | {"clickhouse-sqlalchemy==0.1.8"},
"clickhouse-usage": sql_common
| usage_common
| {
"clickhouse-sqlalchemy==0.1.8",
},
"clickhouse": sql_common | clickhouse_common,
"clickhouse-usage": sql_common | usage_common | clickhouse_common,
"datahub-lineage-file": set(),
"datahub-business-glossary": set(),
"delta-lake": {*data_lake_profiling, *delta_lake},
@ -339,7 +347,6 @@ all_exclude_plugins: Set[str] = {
mypy_stubs = {
"types-dataclasses",
"sqlalchemy-stubs",
"types-pkg_resources",
"types-six",
"types-python-dateutil",

View File

@ -1,3 +1,5 @@
from datahub.utilities._markupsafe_compat import MARKUPSAFE_PATCHED
import collections
import concurrent.futures
import contextlib
@ -51,6 +53,7 @@ from datahub.utilities.sqlalchemy_query_combiner import (
get_query_columns,
)
assert MARKUPSAFE_PATCHED
logger: logging.Logger = logging.getLogger(__name__)
P = ParamSpec("P")

View File

@ -276,7 +276,8 @@ class ConfluentJDBCSourceConnector:
url_instance = make_url(url)
source_platform = get_platform_from_sqlalchemy_uri(str(url_instance))
database_name = url_instance.database
db_connection_url = f"{url_instance.drivername}://{url_instance.host}:{url_instance.port}/{url_instance.database}"
assert database_name
db_connection_url = f"{url_instance.drivername}://{url_instance.host}:{url_instance.port}/{database_name}"
topic_prefix = self.connector_manifest.config.get("topic.prefix", None)

View File

@ -109,7 +109,7 @@ class AthenaSource(SQLAlchemySource):
self, inspector: Inspector, schema: str, table: str
) -> Tuple[Optional[str], Dict[str, str], Optional[str]]:
if not self.cursor:
self.cursor = inspector.dialect._raw_connection(inspector.engine).cursor()
self.cursor = inspector.engine.raw_connection().cursor()
assert self.cursor
# Unfortunately properties can be only get through private methods as those are not exposed

View File

@ -792,8 +792,11 @@ class BigQuerySource(SQLAlchemySource):
# Bigquery only supports one partition column
# https://stackoverflow.com/questions/62886213/adding-multiple-partitioned-columns-to-bigquery-table-from-sql-query
row = result.fetchone()
if row and hasattr(row, "_asdict"):
# Compat with sqlalchemy 1.4 Row type.
row = row._asdict()
if row:
return BigQueryPartitionColumn(**row)
return BigQueryPartitionColumn(**row.items())
return None
def get_shard_from_table(self, table: str) -> Tuple[str, Optional[str]]:

View File

@ -11,7 +11,6 @@ from pydantic.fields import Field
from sqlalchemy import create_engine, inspect
from sqlalchemy.engine.base import Connection
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.result import ResultProxy, RowProxy
from datahub.configuration.common import AllowDenyPattern
from datahub.ingestion.api.common import PipelineContext
@ -135,7 +134,7 @@ class SQLServerSource(SQLAlchemySource):
def _populate_table_descriptions(self, conn: Connection, db_name: str) -> None:
# see https://stackoverflow.com/questions/5953330/how-do-i-map-the-id-in-sys-extended-properties-to-an-object-name
# also see https://www.mssqltips.com/sqlservertip/5384/working-with-sql-server-extended-properties/
table_metadata: ResultProxy = conn.execute(
table_metadata = conn.execute(
"""
SELECT
SCHEMA_NAME(T.SCHEMA_ID) AS schema_name,
@ -149,13 +148,13 @@ class SQLServerSource(SQLAlchemySource):
AND EP.CLASS = 1
"""
)
for row in table_metadata: # type: RowProxy
for row in table_metadata:
self.table_descriptions[
f"{db_name}.{row['schema_name']}.{row['table_name']}"
] = row["table_description"]
def _populate_column_descriptions(self, conn: Connection, db_name: str) -> None:
column_metadata: RowProxy = conn.execute(
column_metadata = conn.execute(
"""
SELECT
SCHEMA_NAME(T.SCHEMA_ID) AS schema_name,
@ -172,7 +171,7 @@ class SQLServerSource(SQLAlchemySource):
AND EP.CLASS = 1
"""
)
for row in column_metadata: # type: RowProxy
for row in column_metadata:
self.column_descriptions[
f"{db_name}.{row['schema_name']}.{row['table_name']}.{row['column_name']}"
] = row["column_description"]

View File

@ -1,5 +1,5 @@
import logging
from typing import Any, Iterable, List, Optional, Tuple, cast
from typing import Any, Iterable, List, NoReturn, Optional, Tuple, cast
from unittest.mock import patch
# This import verifies that the dependencies are available.
@ -35,6 +35,10 @@ extra_oracle_types = {
assert OracleDialect.ischema_names
def _raise_err(exc: Exception) -> NoReturn:
raise exc
def output_type_handler(cursor, name, defaultType, size, precision, scale):
"""Add CLOB and BLOB support to Oracle connection."""
@ -94,7 +98,9 @@ class OracleInspectorObjectWrapper:
s = "SELECT username FROM dba_users ORDER BY username"
cursor = self._inspector_instance.bind.execute(s)
return [
self._inspector_instance.dialect.normalize_name(row[0]) for row in cursor
self._inspector_instance.dialect.normalize_name(row[0])
or _raise_err(ValueError(f"Invalid schema name: {row[0]}"))
for row in cursor
]
def get_table_names(self, schema: str = None, order_by: str = None) -> List[str]:
@ -121,7 +127,9 @@ class OracleInspectorObjectWrapper:
cursor = self._inspector_instance.bind.execute(sql.text(sql_str), owner=schema)
return [
self._inspector_instance.dialect.normalize_name(row[0]) for row in cursor
self._inspector_instance.dialect.normalize_name(row[0])
or _raise_err(ValueError(f"Invalid table name: {row[0]}"))
for row in cursor
]
def __getattr__(self, item: str) -> Any:

View File

@ -27,6 +27,7 @@ from sqlalchemy import create_engine, inspect
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.sql import sqltypes as types
from sqlalchemy.types import TypeDecorator, TypeEngine
from datahub.configuration.common import AllowDenyPattern
from datahub.emitter.mce_builder import (
@ -329,7 +330,7 @@ class SqlWorkUnit(MetadataWorkUnit):
pass
_field_type_mapping: Dict[Type[types.TypeEngine], Type] = {
_field_type_mapping: Dict[Type[TypeEngine], Type] = {
types.Integer: NumberTypeClass,
types.Numeric: NumberTypeClass,
types.Boolean: BooleanTypeClass,
@ -367,30 +368,28 @@ _field_type_mapping: Dict[Type[types.TypeEngine], Type] = {
# assigns the NullType by default. We want to carry this warning through.
types.NullType: NullTypeClass,
}
_known_unknown_field_types: Set[Type[types.TypeEngine]] = {
_known_unknown_field_types: Set[Type[TypeEngine]] = {
types.Interval,
types.CLOB,
}
def register_custom_type(
tp: Type[types.TypeEngine], output: Optional[Type] = None
) -> None:
def register_custom_type(tp: Type[TypeEngine], output: Optional[Type] = None) -> None:
if output:
_field_type_mapping[tp] = output
else:
_known_unknown_field_types.add(tp)
class _CustomSQLAlchemyDummyType(types.TypeDecorator):
class _CustomSQLAlchemyDummyType(TypeDecorator):
impl = types.LargeBinary
def make_sqlalchemy_type(name: str) -> Type[types.TypeEngine]:
def make_sqlalchemy_type(name: str) -> Type[TypeEngine]:
# This usage of type() dynamically constructs a class.
# See https://stackoverflow.com/a/15247202/5004662 and
# https://docs.python.org/3/library/functions.html#type.
sqlalchemy_type: Type[types.TypeEngine] = type(
sqlalchemy_type: Type[TypeEngine] = type(
name,
(_CustomSQLAlchemyDummyType,),
{

View File

@ -4,15 +4,12 @@ from textwrap import dedent
from typing import Any, Dict, List, Optional
import sqlalchemy
# This import verifies that the dependencies are available.
import trino.sqlalchemy # noqa: F401
from pydantic.fields import Field
from sqlalchemy import exc, sql
from sqlalchemy.engine import reflection
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql import sqltypes
from sqlalchemy.sql.type_api import TypeEngine
from sqlalchemy.types import TypeEngine
from trino.exceptions import TrinoQueryError
from trino.sqlalchemy import datatype, error
from trino.sqlalchemy.dialect import TrinoDialect

View File

@ -3,13 +3,12 @@ import dataclasses
import logging
import time
from datetime import datetime
from typing import Dict, Iterable, List, Optional, Set
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set
from pydantic.fields import Field
from pydantic.main import BaseModel
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.engine.result import ResultProxy, RowProxy
import datahub.emitter.mce_builder as builder
from datahub.configuration.source_common import EnvBasedSourceConfigBase
@ -39,6 +38,13 @@ from datahub.metadata.schema_classes import (
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
try:
from sqlalchemy.engine import Row # type: ignore
except ImportError:
# See https://github.com/python/mypy/issues/1153.
from sqlalchemy.engine.result import RowProxy as Row # type: ignore
REDSHIFT_DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S"
@ -266,7 +272,7 @@ class RedshiftUsageSource(Source):
logger.debug(f"sql_alchemy_url = {url}")
return create_engine(url, **self.config.options)
def _should_process_row(self, row: RowProxy) -> bool:
def _should_process_row(self, row: "Row") -> bool:
# Check for mandatory proerties being present first.
missing_props: List[str] = [
prop
@ -294,10 +300,13 @@ class RedshiftUsageSource(Source):
def _gen_access_events_from_history_query(
self, query: str, engine: Engine
) -> Iterable[RedshiftAccessEvent]:
results: ResultProxy = engine.execute(query)
for row in results: # type: RowProxy
results = engine.execute(query)
for row in results:
if not self._should_process_row(row):
continue
if hasattr(row, "_asdict"):
# Compatibility with sqlalchemy 1.4.x.
row = row._asdict()
access_event = RedshiftAccessEvent(**dict(row.items()))
# Replace database name with the alias name if one is provided in the config.
if self.config.database_alias:

View File

@ -1,3 +1,5 @@
from datahub.utilities._markupsafe_compat import MARKUPSAFE_PATCHED
import json
import logging
import os
@ -56,6 +58,7 @@ from datahub.metadata.com.linkedin.pegasus2avro.events.metadata import ChangeTyp
from datahub.metadata.schema_classes import PartitionSpecClass, PartitionTypeClass
from datahub.utilities.sql_parser import DefaultSQLParser
assert MARKUPSAFE_PATCHED
logger = logging.getLogger(__name__)
if os.getenv("DATAHUB_DEBUG", False):
handler = logging.StreamHandler(stream=sys.stdout)

View File

@ -0,0 +1,11 @@
try:
import markupsafe
# This monkeypatch hack is required for markupsafe>=2.1.0 and older versions of Jinja2.
# Changelog: https://markupsafe.palletsprojects.com/en/2.1.x/changes/#version-2-1-0
# Example discussion: https://github.com/aws/aws-sam-cli/issues/3661.
markupsafe.soft_unicode = markupsafe.soft_str # type: ignore[attr-defined]
MARKUPSAFE_PATCHED = True
except ImportError:
MARKUPSAFE_PATCHED = False

View File

@ -7,7 +7,7 @@ import random
import string
import threading
import unittest.mock
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, cast
import greenlet
import sqlalchemy
@ -39,7 +39,8 @@ class _RowProxyFake(collections.OrderedDict):
class _ResultProxyFake:
# This imitates the interface provided by sqlalchemy.engine.result.ResultProxy.
# This imitates the interface provided by sqlalchemy.engine.result.ResultProxy (sqlalchemy 1.3.x)
# or sqlalchemy.engine.Result (1.4.x).
# Adapted from https://github.com/rajivsarvepalli/mock-alchemy/blob/2eba95588e7693aab973a6d60441d2bc3c4ea35d/src/mock_alchemy/mocking.py#L213
def __init__(self, result: List[_RowProxyFake]) -> None:
@ -363,7 +364,11 @@ class SQLAlchemyQueryCombiner:
*query_future.multiparams,
**query_future.params,
)
query_future.res = res
# The actual execute method returns a CursorResult on SQLAlchemy 1.4.x
# and a ResultProxy on SQLAlchemy 1.3.x. Both interfaces are shimmed
# by _ResultProxyFake.
query_future.res = cast(_ResultProxyFake, res)
except Exception as e:
query_future.exc = e
finally:

View File

@ -0,0 +1,25 @@
from datahub.utilities._markupsafe_compat import MARKUPSAFE_PATCHED
from airflow.models.baseoperator import BaseOperator
try:
from airflow.models.mappedoperator import MappedOperator
from airflow.models.operator import Operator
except ModuleNotFoundError:
Operator = BaseOperator # type: ignore
MappedOperator = None # type: ignore
try:
from airflow.sensors.external_task import ExternalTaskSensor
except ImportError:
from airflow.sensors.external_task_sensor import ExternalTaskSensor # type: ignore
assert MARKUPSAFE_PATCHED
__all__ = [
"MARKUPSAFE_PATCHED",
"Operator",
"BaseOperator",
"MappedOperator",
"ExternalTaskSensor",
]

View File

@ -1,3 +1,5 @@
from datahub_provider._airflow_compat import Operator
from datetime import datetime
from typing import TYPE_CHECKING, Dict, List
@ -10,7 +12,6 @@ from datahub_provider.entities import _Entity
if TYPE_CHECKING:
from airflow import DAG
from airflow.models.baseoperator import BaseOperator
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
@ -47,7 +48,7 @@ class DatahubBasicLineageConfig(ConfigModel):
def send_lineage_to_datahub(
config: DatahubBasicLineageConfig,
operator: "BaseOperator",
operator: "Operator",
inlets: List[_Entity],
outlets: List[_Entity],
context: Dict,
@ -56,7 +57,7 @@ def send_lineage_to_datahub(
return
dag: "DAG" = context["dag"]
task: "BaseOperator" = context["task"]
task: "Operator" = context["task"]
ti: "TaskInstance" = context["task_instance"]
hook = config.make_emitter_hook()

View File

@ -0,0 +1,321 @@
from datahub_provider._airflow_compat import Operator
import contextlib
import traceback
from typing import Any, Iterable, List
from airflow.configuration import conf
from airflow.lineage import PIPELINE_OUTLETS
from airflow.models.baseoperator import BaseOperator
from airflow.plugins_manager import AirflowPlugin
from airflow.utils.module_loading import import_string
from cattr import structure
from datahub.api.entities.dataprocess.dataprocess_instance import InstanceRunResult
from datahub_provider.client.airflow_generator import AirflowGenerator
from datahub_provider.hooks.datahub import DatahubGenericHook
from datahub_provider.lineage.datahub import DatahubLineageConfig
def get_lineage_config() -> DatahubLineageConfig:
"""Load the lineage config from airflow.cfg."""
enabled = conf.get("datahub", "enabled", fallback=True)
datahub_conn_id = conf.get("datahub", "conn_id", fallback="datahub_rest_default")
cluster = conf.get("datahub", "cluster", fallback="prod")
graceful_exceptions = conf.get("datahub", "graceful_exceptions", fallback=True)
capture_tags_info = conf.get("datahub", "capture_tags_info", fallback=True)
capture_ownership_info = conf.get(
"datahub", "capture_ownership_info", fallback=True
)
capture_executions = conf.get("datahub", "capture_executions", fallback=True)
return DatahubLineageConfig(
enabled=enabled,
datahub_conn_id=datahub_conn_id,
cluster=cluster,
graceful_exceptions=graceful_exceptions,
capture_ownership_info=capture_ownership_info,
capture_tags_info=capture_tags_info,
capture_executions=capture_executions,
)
def _task_inlets(operator: "Operator") -> List:
# From Airflow 2.4 _inlets is dropped and inlets used consistently. Earlier it was not the case, so we have to stick there to _inlets
if hasattr(operator, "_inlets"):
return operator._inlets # type: ignore[attr-defined, union-attr]
return operator.inlets
def _task_outlets(operator: "Operator") -> List:
# From Airflow 2.4 _outlets is dropped and inlets used consistently. Earlier it was not the case, so we have to stick there to _outlets
# We have to use _outlets because outlets is empty in Airflow < 2.4.0
if hasattr(operator, "_outlets"):
return operator._outlets # type: ignore[attr-defined, union-attr]
return operator.outlets
def get_inlets_from_task(task: BaseOperator, context: Any) -> Iterable[Any]:
# TODO: Fix for https://github.com/apache/airflow/commit/1b1f3fabc5909a447a6277cafef3a0d4ef1f01ae
# in Airflow 2.4.
# TODO: ignore/handle airflow's dataset type in our lineage
inlets: List[Any] = []
task_inlets = _task_inlets(task)
# From Airflow 2.3 this should be AbstractOperator but due to compatibility reason lets use BaseOperator
if isinstance(task_inlets, (str, BaseOperator)):
inlets = [
task_inlets,
]
if task_inlets and isinstance(task_inlets, list):
inlets = []
task_ids = (
{o for o in task_inlets if isinstance(o, str)}
.union(op.task_id for op in task_inlets if isinstance(op, BaseOperator))
.intersection(task.get_flat_relative_ids(upstream=True))
)
from airflow.lineage import AUTO
# pick up unique direct upstream task_ids if AUTO is specified
if AUTO.upper() in task_inlets or AUTO.lower() in task_inlets:
print("Picking up unique direct upstream task_ids as AUTO is specified")
task_ids = task_ids.union(
task_ids.symmetric_difference(task.upstream_task_ids)
)
inlets = task.xcom_pull(
context, task_ids=list(task_ids), dag_id=task.dag_id, key=PIPELINE_OUTLETS
)
# re-instantiate the obtained inlets
inlets = [
structure(item["data"], import_string(item["type_name"]))
# _get_instance(structure(item, Metadata))
for sublist in inlets
if sublist
for item in sublist
]
for inlet in task_inlets:
if type(inlet) != str:
inlets.append(inlet)
return inlets
def datahub_task_status_callback(context, status):
ti = context["ti"]
task: "BaseOperator" = ti.task
dag = context["dag"]
# This code is from the original airflow lineage code ->
# https://github.com/apache/airflow/blob/main/airflow/lineage/__init__.py
inlets = get_inlets_from_task(task, context)
emitter = (
DatahubGenericHook(context["_datahub_config"].datahub_conn_id)
.get_underlying_hook()
.make_emitter()
)
dataflow = AirflowGenerator.generate_dataflow(
cluster=context["_datahub_config"].cluster,
dag=dag,
capture_tags=context["_datahub_config"].capture_tags_info,
capture_owner=context["_datahub_config"].capture_ownership_info,
)
dataflow.emit(emitter)
task.log.info(f"Emitted Datahub DataFlow: {dataflow}")
datajob = AirflowGenerator.generate_datajob(
cluster=context["_datahub_config"].cluster,
task=task,
dag=dag,
capture_tags=context["_datahub_config"].capture_tags_info,
capture_owner=context["_datahub_config"].capture_ownership_info,
)
for inlet in inlets:
datajob.inlets.append(inlet.urn)
task_outlets = _task_outlets(task)
for outlet in task_outlets:
datajob.outlets.append(outlet.urn)
task.log.info(f"Emitted Datahub dataJob: {datajob}")
datajob.emit(emitter)
if context["_datahub_config"].capture_executions:
dpi = AirflowGenerator.run_datajob(
emitter=emitter,
cluster=context["_datahub_config"].cluster,
ti=context["ti"],
dag=dag,
dag_run=context["dag_run"],
datajob=datajob,
start_timestamp_millis=int(ti.start_date.timestamp() * 1000),
)
task.log.info(f"Emitted Start Datahub Dataprocess Instance: {dpi}")
dpi = AirflowGenerator.complete_datajob(
emitter=emitter,
cluster=context["_datahub_config"].cluster,
ti=context["ti"],
dag_run=context["dag_run"],
result=status,
dag=dag,
datajob=datajob,
end_timestamp_millis=int(ti.end_date.timestamp() * 1000),
)
task.log.info(f"Emitted Completed Data Process Instance: {dpi}")
def datahub_pre_execution(context):
ti = context["ti"]
task: "BaseOperator" = ti.task
dag = context["dag"]
task.log.info("Running Datahub pre_execute method")
emitter = (
DatahubGenericHook(context["_datahub_config"].datahub_conn_id)
.get_underlying_hook()
.make_emitter()
)
# This code is from the original airflow lineage code ->
# https://github.com/apache/airflow/blob/main/airflow/lineage/__init__.py
inlets = get_inlets_from_task(task, context)
datajob = AirflowGenerator.generate_datajob(
cluster=context["_datahub_config"].cluster,
task=context["ti"].task,
dag=dag,
capture_tags=context["_datahub_config"].capture_tags_info,
capture_owner=context["_datahub_config"].capture_ownership_info,
)
for inlet in inlets:
datajob.inlets.append(inlet.urn)
task_outlets = _task_outlets(task)
for outlet in task_outlets:
datajob.outlets.append(outlet.urn)
datajob.emit(emitter)
task.log.info(f"Emitting Datahub DataJob: {datajob}")
if context["_datahub_config"].capture_executions:
dpi = AirflowGenerator.run_datajob(
emitter=emitter,
cluster=context["_datahub_config"].cluster,
ti=context["ti"],
dag=dag,
dag_run=context["dag_run"],
datajob=datajob,
start_timestamp_millis=int(ti.start_date.timestamp() * 1000),
)
task.log.info(f"Emitting Datahub Dataprocess Instance: {dpi}")
def _wrap_pre_execution(pre_execution):
def custom_pre_execution(context):
config = get_lineage_config()
context["_datahub_config"] = config
datahub_pre_execution(context)
# Call original policy
if pre_execution:
pre_execution(context)
return custom_pre_execution
def _wrap_on_failure_callback(on_failure_callback):
def custom_on_failure_callback(context):
config = get_lineage_config()
context["_datahub_config"] = config
try:
datahub_task_status_callback(context, status=InstanceRunResult.FAILURE)
except Exception as e:
if not config.graceful_exceptions:
raise e
else:
print(f"Exception: {traceback.format_exc()}")
# Call original policy
if on_failure_callback:
on_failure_callback(context)
return custom_on_failure_callback
def _wrap_on_success_callback(on_success_callback):
def custom_on_success_callback(context):
config = get_lineage_config()
context["_datahub_config"] = config
try:
datahub_task_status_callback(context, status=InstanceRunResult.SUCCESS)
except Exception as e:
if not config.graceful_exceptions:
raise e
else:
print(f"Exception: {traceback.format_exc()}")
if on_success_callback:
on_success_callback(context)
return custom_on_success_callback
def task_policy(task: BaseOperator) -> None:
print(f"Setting task policy for Dag: {task.dag_id} Task: {task.task_id}")
# task.add_inlets(["auto"])
# task.pre_execute = _wrap_pre_execution(task.pre_execute)
task.on_failure_callback = _wrap_on_failure_callback(task.on_failure_callback)
task.on_success_callback = _wrap_on_success_callback(task.on_success_callback)
# task.pre_execute = _wrap_pre_execution(task.pre_execute)
def _wrap_task_policy(policy):
if policy and hasattr(policy, "_task_policy_patched_by"):
return policy
def custom_task_policy(task):
policy(task)
task_policy(task)
setattr(custom_task_policy, "_task_policy_patched_by", "datahub_plugin")
return custom_task_policy
def _patch_policy(settings):
if hasattr(settings, "task_policy"):
datahub_task_policy = _wrap_task_policy(settings.task_policy)
settings.task_policy = datahub_task_policy
def _patch_datahub_policy():
print("Patching datahub policy")
with contextlib.suppress(ImportError):
import airflow_local_settings
_patch_policy(airflow_local_settings)
from airflow.models.dagbag import settings
_patch_policy(settings)
_patch_datahub_policy()
class DatahubPlugin(AirflowPlugin):
name = "datahub_plugin"

View File

@ -1,4 +1,6 @@
from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast
from datahub_provider._airflow_compat import BaseOperator, ExternalTaskSensor, Operator
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast
from airflow.configuration import conf
@ -13,16 +15,22 @@ from datahub.utilities.urns.data_job_urn import DataJobUrn
if TYPE_CHECKING:
from airflow import DAG
from airflow.models import BaseOperator, DagRun, TaskInstance
from airflow.models import DagRun, TaskInstance
from datahub.emitter.kafka_emitter import DatahubKafkaEmitter
from datahub.emitter.rest_emitter import DatahubRestEmitter
def _task_downstream_task_ids(operator: "Operator") -> Set[str]:
if hasattr(operator, "downstream_task_ids"):
return operator.downstream_task_ids
return operator._downstream_task_id # type: ignore[attr-defined,union-attr]
class AirflowGenerator:
@staticmethod
def _get_dependencies(
task: "BaseOperator", dag: "DAG", flow_urn: DataFlowUrn
task: "Operator", dag: "DAG", flow_urn: DataFlowUrn
) -> List[DataJobUrn]:
# resolve URNs for upstream nodes in subdags upstream of the current task.
@ -47,7 +55,7 @@ class AirflowGenerator:
)
# if subdag task is a leaf task, then link it as an upstream task
if len(upstream_subdag_task._downstream_task_ids) == 0:
if len(_task_downstream_task_ids(upstream_subdag_task)) == 0:
upstream_subdag_task_urns.append(upstream_subdag_task_urn)
# resolve URNs for upstream nodes that trigger the subdag containing the current task.
@ -59,7 +67,7 @@ class AirflowGenerator:
if (
dag.is_subdag
and dag.parent_dag is not None
and len(task._upstream_task_ids) == 0
and len(task.upstream_task_ids) == 0
):
# filter through the parent dag's tasks and find the subdag trigger(s)
@ -83,7 +91,7 @@ class AirflowGenerator:
)
# if the task triggers the subdag, link it to this node in the subdag
if subdag_task_id in upstream_task._downstream_task_ids:
if subdag_task_id in _task_downstream_task_ids(upstream_task):
upstream_subdag_triggers.append(upstream_task_urn)
# If the operator is an ExternalTaskSensor then we set the remote task as upstream.
@ -91,8 +99,6 @@ class AirflowGenerator:
# jobflow to anothet jobflow.
external_task_upstreams = []
if task.task_type == "ExternalTaskSensor":
from airflow.sensors.external_task_sensor import ExternalTaskSensor
task = cast(ExternalTaskSensor, task)
if hasattr(task, "external_task_id") and task.external_task_id is not None:
external_task_upstreams = [
@ -173,7 +179,11 @@ class AirflowGenerator:
return data_flow
@staticmethod
def _get_description(task: "BaseOperator") -> Optional[str]:
def _get_description(task: "Operator") -> Optional[str]:
if not isinstance(task, BaseOperator):
# TODO: Get docs for mapped operators.
return None
if hasattr(task, "doc") and task.doc:
return task.doc
elif hasattr(task, "doc_md") and task.doc_md:
@ -189,9 +199,9 @@ class AirflowGenerator:
@staticmethod
def generate_datajob(
cluster: str,
task: "BaseOperator",
task: "Operator",
dag: "DAG",
set_dependendecies: bool = True,
set_dependencies: bool = True,
capture_owner: bool = True,
capture_tags: bool = True,
) -> DataJob:
@ -200,7 +210,7 @@ class AirflowGenerator:
:param cluster: str
:param task: TaskIntance
:param dag: DAG
:param set_dependendecies: bool - whether to extract dependencies from airflow task
:param set_dependencies: bool - whether to extract dependencies from airflow task
:param capture_owner: bool - whether to extract owner from airflow task
:param capture_tags: bool - whether to set tags automatically from airflow task
:return: DataJob - returns the generated DataJob object
@ -209,6 +219,8 @@ class AirflowGenerator:
orchestrator="airflow", env=cluster, flow_id=dag.dag_id
)
datajob = DataJob(id=task.task_id, flow_urn=dataflow_urn)
# TODO add support for MappedOperator
datajob.description = AirflowGenerator._get_description(task)
job_property_bag: Dict[str, str] = {}
@ -228,6 +240,11 @@ class AirflowGenerator:
"task_id",
"trigger_rule",
"wait_for_downstream",
# In Airflow 2.3, _downstream_task_ids was renamed to downstream_task_ids
"downstream_task_ids",
# In Airflow 2.4, _inlets and _outlets were removed in favor of non-private versions.
"inlets",
"outlets",
]
for key in allowed_task_keys:
@ -244,7 +261,7 @@ class AirflowGenerator:
if capture_tags and dag.tags:
datajob.tags.update(dag.tags)
if set_dependendecies:
if set_dependencies:
datajob.upstream_urns.extend(
AirflowGenerator._get_dependencies(
task=task, dag=dag, flow_urn=datajob.flow_urn
@ -256,7 +273,7 @@ class AirflowGenerator:
@staticmethod
def create_datajob_instance(
cluster: str,
task: "BaseOperator",
task: "Operator",
dag: "DAG",
data_job: Optional[DataJob] = None,
) -> DataProcessInstance:
@ -282,8 +299,10 @@ class AirflowGenerator:
dataflow = AirflowGenerator.generate_dataflow(cluster, dag_run.dag)
if start_timestamp_millis is None:
assert dag_run.execution_date
start_timestamp_millis = int(dag_run.execution_date.timestamp() * 1000)
assert dag_run.run_id
dpi = DataProcessInstance.from_dataflow(dataflow=dataflow, id=dag_run.run_id)
# This property only exists in Airflow2
@ -335,6 +354,7 @@ class AirflowGenerator:
assert dag_run.dag
dataflow = AirflowGenerator.generate_dataflow(cluster, dag_run.dag)
assert dag_run.run_id
dpi = DataProcessInstance.from_dataflow(dataflow=dataflow, id=dag_run.run_id)
if end_timestamp_millis is None:
if dag_run.end_date is None:
@ -375,6 +395,7 @@ class AirflowGenerator:
if datajob is None:
datajob = AirflowGenerator.generate_datajob(cluster, ti.task, dag)
assert dag_run.run_id
dpi = DataProcessInstance.from_datajob(
datajob=datajob,
id=f"{dag.dag_id}_{ti.task_id}_{dag_run.run_id}",

View File

@ -9,6 +9,8 @@ from datahub.metadata.com.linkedin.pegasus2avro.mxe import (
)
if TYPE_CHECKING:
from airflow.models.connection import Connection
from datahub.emitter.kafka_emitter import DatahubKafkaEmitter
from datahub.emitter.rest_emitter import DatahubRestEmitter
from datahub.ingestion.sink.datahub_kafka import KafkaSinkConfig
@ -51,12 +53,14 @@ class DatahubRestHook(BaseHook):
}
def _get_config(self) -> Tuple[str, Optional[str], Optional[int]]:
conn = self.get_connection(self.datahub_rest_conn_id)
conn: "Connection" = self.get_connection(self.datahub_rest_conn_id)
host = conn.host
if host is None:
raise AirflowException("host parameter is required")
password = conn.password
timeout_sec = conn.extra_dejson.get("timeout_sec")
return (host, conn.password, timeout_sec)
return (host, password, timeout_sec)
def make_emitter(self) -> "DatahubRestEmitter":
import datahub.emitter.rest_emitter

View File

@ -0,0 +1,33 @@
# On SQLAlchemy 1.4.x, the mypy plugin is built-in.
# However, with SQLAlchemy 1.3.x, it requires the sqlalchemy-stubs package and hence has a separate import.
# This file serves as a thin shim layer that directs mypy to the appropriate plugin implementation.
try:
from mypy.semanal import SemanticAnalyzer
from sqlalchemy.ext.mypy.plugin import plugin
# On SQLAlchemy >=1.4, <=1.4.29, the mypy plugin is incompatible with newer versions of mypy.
# See https://github.com/sqlalchemy/sqlalchemy/commit/aded8b11d9eccbd1f2b645a94338e34a3d234bc9
# and https://github.com/sqlalchemy/sqlalchemy/issues/7496.
# To fix this, we need to patch the mypy plugin interface.
#
# We cannot set a min version of SQLAlchemy because of the bigquery SQLAlchemy package.
# See https://github.com/googleapis/python-bigquery-sqlalchemy/issues/385.
_named_type_original = SemanticAnalyzer.named_type
_named_type_translations = {
"__builtins__.object": "builtins.object",
"__builtins__.str": "builtins.str",
"__builtins__.list": "builtins.list",
"__sa_Mapped": "sqlalchemy.orm.attributes.Mapped",
}
def _named_type_shim(self, fullname, *args, **kwargs):
if fullname in _named_type_translations:
fullname = _named_type_translations[fullname]
return _named_type_original(self, fullname, *args, **kwargs)
SemanticAnalyzer.named_type = _named_type_shim # type: ignore
except ModuleNotFoundError:
from sqlmypy import plugin # type: ignore[no-redef]
__all__ = ["plugin"]

View File

@ -1,3 +1,5 @@
from datahub_provider._airflow_compat import MARKUPSAFE_PATCHED
import datetime
import json
import os
@ -13,19 +15,19 @@ import packaging.version
import pytest
from airflow.lineage import apply_lineage, prepare_lineage
from airflow.models import DAG, Connection, DagBag, DagRun, TaskInstance
from airflow.operators.dummy import DummyOperator
from airflow.utils.dates import days_ago
try:
from airflow.operators.dummy import DummyOperator
except ModuleNotFoundError:
from airflow.operators.dummy_operator import DummyOperator
import datahub.emitter.mce_builder as builder
from datahub_provider import get_provider_info
from datahub_provider.entities import Dataset
from datahub_provider.hooks.datahub import DatahubKafkaHook, DatahubRestHook
from datahub_provider.operators.datahub import DatahubEmitterOperator
assert MARKUPSAFE_PATCHED
pytestmark = pytest.mark.airflow
# Approach suggested by https://stackoverflow.com/a/11887885/5004662.
AIRFLOW_VERSION = packaging.version.parse(airflow.version.version)
@ -73,6 +75,10 @@ def test_airflow_provider_info():
assert get_provider_info()
@pytest.mark.skipif(
AIRFLOW_VERSION < packaging.version.parse("2.0.0"),
reason="the examples use list-style lineage, which is only supported on Airflow 2.x",
)
def test_dags_load_with_no_errors(pytestconfig):
airflow_examples_folder = (
pytestconfig.rootpath / "src/datahub_provider/example_dags"
@ -99,6 +105,7 @@ def patch_airflow_connection(conn: Connection) -> Iterator[Connection]:
@mock.patch("datahub.emitter.rest_emitter.DatahubRestEmitter", autospec=True)
def test_datahub_rest_hook(mock_emitter):
with patch_airflow_connection(datahub_rest_connection_config) as config:
assert config.conn_id
hook = DatahubRestHook(config.conn_id)
hook.emit_mces([lineage_mce])
@ -112,6 +119,7 @@ def test_datahub_rest_hook_with_timeout(mock_emitter):
with patch_airflow_connection(
datahub_rest_connection_config_with_timeout
) as config:
assert config.conn_id
hook = DatahubRestHook(config.conn_id)
hook.emit_mces([lineage_mce])
@ -123,6 +131,7 @@ def test_datahub_rest_hook_with_timeout(mock_emitter):
@mock.patch("datahub.emitter.kafka_emitter.DatahubKafkaEmitter", autospec=True)
def test_datahub_kafka_hook(mock_emitter):
with patch_airflow_connection(datahub_kafka_connection_config) as config:
assert config.conn_id
hook = DatahubKafkaHook(config.conn_id)
hook.emit_mces([lineage_mce])
@ -135,6 +144,7 @@ def test_datahub_kafka_hook(mock_emitter):
@mock.patch("datahub_provider.hooks.datahub.DatahubRestHook.emit_mces")
def test_datahub_lineage_operator(mock_emit):
with patch_airflow_connection(datahub_rest_connection_config) as config:
assert config.conn_id
task = DatahubEmitterOperator(
task_id="emit_lineage",
datahub_conn_id=config.conn_id,
@ -331,6 +341,7 @@ def test_lineage_backend(mock_emit, inlets, outlets):
)
@mock.patch("datahub_provider.hooks.datahub.DatahubRestHook.make_emitter")
def test_lineage_backend_capture_executions(mock_emit, inlets, outlets):
# TODO: Merge this code into the test above to reduce duplication.
DEFAULT_DATE = datetime.datetime(2020, 5, 17)
mock_emitter = Mock()
mock_emit.return_value = mock_emitter
@ -375,10 +386,6 @@ def test_lineage_backend_capture_executions(mock_emit, inlets, outlets):
ti = TaskInstance(task=op2, execution_date=DEFAULT_DATE)
# Ignoring type here because DagRun state is just a sring at Airflow 1
dag_run = DagRun(state="success", run_id=f"scheduled_{DEFAULT_DATE}") # type: ignore
ti.dag_run = dag_run
ti.start_date = datetime.datetime.utcnow()
ti.execution_date = DEFAULT_DATE
else:
from airflow.utils.state import DagRunState
@ -386,9 +393,10 @@ def test_lineage_backend_capture_executions(mock_emit, inlets, outlets):
dag_run = DagRun(
state=DagRunState.SUCCESS, run_id=f"scheduled_{DEFAULT_DATE}"
)
ti.dag_run = dag_run
ti.start_date = datetime.datetime.utcnow()
ti.execution_date = DEFAULT_DATE
ti.dag_run = dag_run # type: ignore
ti.start_date = datetime.datetime.utcnow()
ti.execution_date = DEFAULT_DATE
ctx1 = {
"dag": dag,

View File

@ -65,9 +65,8 @@ def test_athena_get_table_properties():
mock_cursor = mock.MagicMock()
mock_inspector = mock.MagicMock()
mock_inspector.engine.return_value = mock.MagicMock()
mock_inspector.dialect._raw_connection.return_value = mock_cursor
mock_inspector.dialect._raw_connection().cursor()._get_table_metadata.return_value = AthenaTableMetadata(
mock_inspector.engine.raw_connection().cursor.return_value = mock_cursor
mock_cursor._get_table_metadata.return_value = AthenaTableMetadata(
response=table_metadata
)