mirror of
https://github.com/datahub-project/datahub.git
synced 2025-11-22 15:54:13 +00:00
205 lines
6.4 KiB
Python
205 lines
6.4 KiB
Python
import json
|
|
from contextlib import contextmanager
|
|
from typing import Iterator
|
|
from unittest import mock
|
|
|
|
import airflow.configuration
|
|
import pytest
|
|
from airflow.models import Connection, DagBag
|
|
|
|
import datahub.emitter.mce_builder as builder
|
|
from datahub.ingestion.graph.config import ClientMode
|
|
from datahub_airflow_plugin import get_provider_info
|
|
from datahub_airflow_plugin.entities import Dataset, Urn
|
|
from datahub_airflow_plugin.hooks.datahub import DatahubKafkaHook, DatahubRestHook
|
|
from datahub_airflow_plugin.operators.datahub import DatahubEmitterOperator
|
|
|
|
lineage_mce = builder.make_lineage_mce(
|
|
[
|
|
builder.make_dataset_urn("bigquery", "upstream1"),
|
|
builder.make_dataset_urn("bigquery", "upstream2"),
|
|
],
|
|
builder.make_dataset_urn("bigquery", "downstream1"),
|
|
)
|
|
|
|
datahub_rest_connection_config = Connection(
|
|
conn_id="datahub_rest_test",
|
|
conn_type="datahub_rest",
|
|
host="http://test_host:8080",
|
|
extra=None,
|
|
)
|
|
datahub_rest_connection_config_with_timeout = Connection(
|
|
conn_id="datahub_rest_test",
|
|
conn_type="datahub_rest",
|
|
host="http://test_host:8080",
|
|
extra=json.dumps({"timeout_sec": 5}),
|
|
)
|
|
|
|
datahub_kafka_connection_config = Connection(
|
|
conn_id="datahub_kafka_test",
|
|
conn_type="datahub_kafka",
|
|
host="test_broker:9092",
|
|
extra=json.dumps(
|
|
{
|
|
"connection": {
|
|
"producer_config": {},
|
|
"schema_registry_url": "http://localhost:8081",
|
|
}
|
|
}
|
|
),
|
|
)
|
|
|
|
|
|
def setup_module(module):
|
|
airflow.configuration.conf.load_test_config()
|
|
|
|
|
|
def test_airflow_provider_info():
|
|
assert get_provider_info()
|
|
|
|
|
|
@pytest.mark.filterwarnings("ignore:.*is deprecated.*")
|
|
def test_dags_load_with_no_errors(pytestconfig: pytest.Config) -> None:
|
|
airflow_examples_folder = (
|
|
pytestconfig.rootpath / "src/datahub_airflow_plugin/example_dags"
|
|
)
|
|
|
|
# Note: the .airflowignore file skips the snowflake DAG.
|
|
dag_bag = DagBag(dag_folder=str(airflow_examples_folder), include_examples=False)
|
|
|
|
import_errors = dag_bag.import_errors
|
|
|
|
assert len(import_errors) == 0
|
|
assert dag_bag.size() > 0
|
|
|
|
|
|
@contextmanager
|
|
def patch_airflow_connection(conn: Connection) -> Iterator[Connection]:
|
|
# The return type should really by ContextManager, but mypy doesn't like that.
|
|
# See https://stackoverflow.com/questions/49733699/python-type-hints-and-context-managers#comment106444758_58349659.
|
|
with mock.patch(
|
|
"datahub_provider.hooks.datahub.BaseHook.get_connection", return_value=conn
|
|
):
|
|
yield conn
|
|
|
|
|
|
@mock.patch("datahub.emitter.rest_emitter.DataHubRestEmitter", autospec=True)
|
|
def test_datahub_rest_hook(mock_emitter):
|
|
with patch_airflow_connection(datahub_rest_connection_config) as config:
|
|
assert config.conn_id
|
|
hook = DatahubRestHook(config.conn_id)
|
|
hook.emit_mces([lineage_mce])
|
|
|
|
mock_emitter.assert_called_once_with(
|
|
config.host,
|
|
None,
|
|
client_mode=ClientMode.INGESTION,
|
|
datahub_component="airflow-plugin",
|
|
)
|
|
instance = mock_emitter.return_value
|
|
instance.emit.assert_called_with(lineage_mce)
|
|
|
|
|
|
@mock.patch("datahub.emitter.rest_emitter.DataHubRestEmitter", autospec=True)
|
|
def test_datahub_rest_hook_with_timeout(mock_emitter):
|
|
with patch_airflow_connection(
|
|
datahub_rest_connection_config_with_timeout
|
|
) as config:
|
|
assert config.conn_id
|
|
hook = DatahubRestHook(config.conn_id)
|
|
hook.emit_mces([lineage_mce])
|
|
|
|
mock_emitter.assert_called_once_with(
|
|
config.host,
|
|
None,
|
|
timeout_sec=5,
|
|
client_mode=ClientMode.INGESTION,
|
|
datahub_component="airflow-plugin",
|
|
)
|
|
instance = mock_emitter.return_value
|
|
instance.emit.assert_called_with(lineage_mce)
|
|
|
|
|
|
@mock.patch("datahub.emitter.kafka_emitter.DatahubKafkaEmitter", autospec=True)
|
|
def test_datahub_kafka_hook(mock_emitter):
|
|
with patch_airflow_connection(datahub_kafka_connection_config) as config:
|
|
assert config.conn_id
|
|
hook = DatahubKafkaHook(config.conn_id)
|
|
hook.emit_mces([lineage_mce])
|
|
|
|
mock_emitter.assert_called_once()
|
|
instance = mock_emitter.return_value
|
|
instance.emit.assert_called()
|
|
instance.flush.assert_called_once()
|
|
|
|
|
|
@mock.patch("datahub_provider.hooks.datahub.DatahubRestHook.emit")
|
|
def test_datahub_lineage_operator(mock_emit):
|
|
with patch_airflow_connection(datahub_rest_connection_config) as config:
|
|
assert config.conn_id
|
|
task = DatahubEmitterOperator(
|
|
task_id="emit_lineage",
|
|
datahub_conn_id=config.conn_id,
|
|
mces=[
|
|
builder.make_lineage_mce(
|
|
[
|
|
builder.make_dataset_urn("snowflake", "mydb.schema.tableA"),
|
|
builder.make_dataset_urn("snowflake", "mydb.schema.tableB"),
|
|
],
|
|
builder.make_dataset_urn("snowflake", "mydb.schema.tableC"),
|
|
)
|
|
],
|
|
)
|
|
task.execute(None)
|
|
|
|
mock_emit.assert_called()
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"hook",
|
|
[
|
|
DatahubRestHook,
|
|
DatahubKafkaHook,
|
|
],
|
|
)
|
|
def test_hook_airflow_ui(hook):
|
|
# Simply ensure that these run without issue. These will also show up
|
|
# in the Airflow UI, where it will be even more clear if something
|
|
# is wrong.
|
|
hook.get_connection_form_widgets()
|
|
hook.get_ui_field_behaviour()
|
|
|
|
|
|
def test_entities():
|
|
assert (
|
|
Dataset("snowflake", "mydb.schema.tableConsumed").urn
|
|
== "urn:li:dataset:(urn:li:dataPlatform:snowflake,mydb.schema.tableConsumed,PROD)"
|
|
)
|
|
|
|
assert (
|
|
Urn(
|
|
"urn:li:dataset:(urn:li:dataPlatform:snowflake,mydb.schema.tableConsumed,PROD)"
|
|
).urn
|
|
== "urn:li:dataset:(urn:li:dataPlatform:snowflake,mydb.schema.tableConsumed,PROD)"
|
|
)
|
|
|
|
assert (
|
|
Urn("urn:li:dataJob:(urn:li:dataFlow:(airflow,testDag,PROD),testTask)").urn
|
|
== "urn:li:dataJob:(urn:li:dataFlow:(airflow,testDag,PROD),testTask)"
|
|
)
|
|
|
|
assert (
|
|
Urn(
|
|
"urn:li:dataJob:(urn:li:dataFlow:(airflow,platform.testDag,PROD),testTask)"
|
|
).urn
|
|
== "urn:li:dataJob:(urn:li:dataFlow:(airflow,platform.testDag,PROD),testTask)"
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="invalid"):
|
|
Urn("not a URN")
|
|
|
|
with pytest.raises(
|
|
ValueError, match="only supports datasets and upstream datajobs"
|
|
):
|
|
Urn("urn:li:mlModel:(urn:li:dataPlatform:science,scienceModel,PROD)")
|