fix(ingestion/kafka): OAuth callback execution (#11900)

This commit is contained in:
sid-acryl 2024-11-22 13:08:23 +05:30 committed by GitHub
parent dac80fb7e1
commit 86b8175627
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 183 additions and 26 deletions

View File

@ -102,7 +102,29 @@ source:
connection:
bootstrap: "broker:9092"
schema_registry_url: http://localhost:8081
```
### OAuth Callback
The OAuth callback function can be set up using `config.connection.consumer_config.oauth_cb`.
You need to specify a Python function reference in the format <python-module>:<function-name>.
For example, in the configuration `oauth:create_token`, `create_token` is a function defined in `oauth.py`, and `oauth.py` must be accessible in the PYTHONPATH.
```YAML
source:
type: "kafka"
config:
# Set the custom schema registry implementation class
schema_registry_class: "datahub.ingestion.source.confluent_schema_registry.ConfluentSchemaRegistry"
# Coordinates
connection:
bootstrap: "broker:9092"
schema_registry_url: http://localhost:8081
consumer_config:
security.protocol: "SASL_PLAINTEXT"
sasl.mechanism: "OAUTHBEARER"
oauth_cb: "oauth:create_token"
# sink configs
```

View File

@ -741,8 +741,8 @@ entry_points = {
"hive = datahub.ingestion.source.sql.hive:HiveSource",
"hive-metastore = datahub.ingestion.source.sql.hive_metastore:HiveMetastoreSource",
"json-schema = datahub.ingestion.source.schema.json_schema:JsonSchemaSource",
"kafka = datahub.ingestion.source.kafka:KafkaSource",
"kafka-connect = datahub.ingestion.source.kafka_connect:KafkaConnectSource",
"kafka = datahub.ingestion.source.kafka.kafka:KafkaSource",
"kafka-connect = datahub.ingestion.source.kafka.kafka_connect:KafkaConnectSource",
"ldap = datahub.ingestion.source.ldap:LDAPSource",
"looker = datahub.ingestion.source.looker.looker_source:LookerDashboardSource",
"lookml = datahub.ingestion.source.looker.lookml_source:LookMLSource",

View File

@ -1,6 +1,7 @@
from pydantic import Field, validator
from datahub.configuration.common import ConfigModel
from datahub.configuration.common import ConfigModel, ConfigurationError
from datahub.configuration.kafka_consumer_config import CallableConsumerConfig
from datahub.configuration.validate_host_port import validate_host_port
@ -36,6 +37,16 @@ class KafkaConsumerConnectionConfig(_KafkaConnectionConfig):
description="Extra consumer config serialized as JSON. These options will be passed into Kafka's DeserializingConsumer. See https://docs.confluent.io/platform/current/clients/confluent-kafka-python/html/index.html#deserializingconsumer and https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md .",
)
@validator("consumer_config")
@classmethod
def resolve_callback(cls, value: dict) -> dict:
if CallableConsumerConfig.is_callable_config(value):
try:
value = CallableConsumerConfig(value).callable_config()
except Exception as e:
raise ConfigurationError(e)
return value
class KafkaProducerConnectionConfig(_KafkaConnectionConfig):
"""Configuration class for holding connectivity information for Kafka producers"""

View File

@ -0,0 +1,35 @@
import logging
from typing import Any, Dict, Optional
from datahub.ingestion.api.registry import import_path
logger = logging.getLogger(__name__)
class CallableConsumerConfig:
CALLBACK_ATTRIBUTE: str = "oauth_cb"
def __init__(self, config: Dict[str, Any]):
self._config = config
self._resolve_oauth_callback()
def callable_config(self) -> Dict[str, Any]:
return self._config
@staticmethod
def is_callable_config(config: Dict[str, Any]) -> bool:
return CallableConsumerConfig.CALLBACK_ATTRIBUTE in config
def get_call_back_attribute(self) -> Optional[str]:
return self._config.get(CallableConsumerConfig.CALLBACK_ATTRIBUTE)
def _resolve_oauth_callback(self) -> None:
if not self.get_call_back_attribute():
return
call_back = self.get_call_back_attribute()
assert call_back # to silent lint
# Set the callback
self._config[CallableConsumerConfig.CALLBACK_ATTRIBUTE] = import_path(call_back)

View File

@ -16,8 +16,10 @@ from confluent_kafka.schema_registry.schema_registry_client import (
from datahub.ingestion.extractor import protobuf_util, schema_util
from datahub.ingestion.extractor.json_schema_util import JsonSchemaTranslator
from datahub.ingestion.extractor.protobuf_util import ProtobufSchema
from datahub.ingestion.source.kafka import KafkaSourceConfig, KafkaSourceReport
from datahub.ingestion.source.kafka_schema_registry_base import KafkaSchemaRegistryBase
from datahub.ingestion.source.kafka.kafka import KafkaSourceConfig, KafkaSourceReport
from datahub.ingestion.source.kafka.kafka_schema_registry_base import (
KafkaSchemaRegistryBase,
)
from datahub.metadata.com.linkedin.pegasus2avro.schema import (
KafkaSchema,
SchemaField,

View File

@ -18,6 +18,7 @@ from confluent_kafka.schema_registry.schema_registry_client import SchemaRegistr
from datahub.configuration.common import AllowDenyPattern
from datahub.configuration.kafka import KafkaConsumerConnectionConfig
from datahub.configuration.kafka_consumer_config import CallableConsumerConfig
from datahub.configuration.source_common import (
DatasetSourceConfigMixin,
LowerCaseDatasetUrnConfigMixin,
@ -49,7 +50,9 @@ from datahub.ingestion.api.source import (
)
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.common.subtypes import DatasetSubTypes
from datahub.ingestion.source.kafka_schema_registry_base import KafkaSchemaRegistryBase
from datahub.ingestion.source.kafka.kafka_schema_registry_base import (
KafkaSchemaRegistryBase,
)
from datahub.ingestion.source.state.stale_entity_removal_handler import (
StaleEntityRemovalHandler,
StaleEntityRemovalSourceReport,
@ -143,7 +146,7 @@ class KafkaSourceConfig(
def get_kafka_consumer(
connection: KafkaConsumerConnectionConfig,
) -> confluent_kafka.Consumer:
return confluent_kafka.Consumer(
consumer = confluent_kafka.Consumer(
{
"group.id": "test",
"bootstrap.servers": connection.bootstrap,
@ -151,6 +154,13 @@ def get_kafka_consumer(
}
)
if CallableConsumerConfig.is_callable_config(connection.consumer_config):
# As per documentation, we need to explicitly call the poll method to make sure OAuth callback gets executed
# https://docs.confluent.io/platform/current/clients/confluent-kafka-python/html/index.html#kafka-client-configuration
consumer.poll(timeout=30)
return consumer
@dataclass
class KafkaSourceReport(StaleEntityRemovalSourceReport):

View File

@ -0,0 +1,20 @@
run_id: kafka-test
source:
type: kafka
config:
connection:
bootstrap: "localhost:29092"
schema_registry_url: "http://localhost:28081"
consumer_config:
security.protocol: "SASL_PLAINTEXT"
sasl.mechanism: "OAUTHBEARER"
oauth_cb: "oauth:create_token"
domain:
"urn:li:domain:sales":
allow:
- "key_value_topic"
sink:
type: file
config:
filename: "./kafka_mces.json"

View File

@ -0,0 +1,14 @@
import logging
from typing import Any, Tuple
logger = logging.getLogger(__name__)
MESSAGE: str = "OAuth token `create_token` callback"
def create_token(*args: Any, **kwargs: Any) -> Tuple[str, int]:
logger.warning(MESSAGE)
return (
"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJjbGllbnRfaWQiOiJrYWZrYV9jbGllbnQiLCJleHAiOjE2OTg3NjYwMDB9.dummy_sig_abcdef123456",
3600,
)

View File

@ -1,10 +1,14 @@
import logging
import subprocess
import pytest
import yaml
from freezegun import freeze_time
from datahub.ingestion.api.source import SourceCapability
from datahub.ingestion.source.kafka import KafkaSource
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.kafka.kafka import KafkaSource
from tests.integration.kafka import oauth # type: ignore
from tests.test_helpers import mce_helpers, test_connection_helpers
from tests.test_helpers.click_helpers import run_datahub_cmd
from tests.test_helpers.docker_helpers import wait_for_port
@ -99,3 +103,36 @@ def test_kafka_test_connection(mock_kafka_service, config_dict, is_success):
SourceCapability.SCHEMA_METADATA: "Failed to establish a new connection"
},
)
@freeze_time(FROZEN_TIME)
@pytest.mark.integration
def test_kafka_oauth_callback(
mock_kafka_service, test_resources_dir, pytestconfig, tmp_path, mock_time
):
# Run the metadata ingestion pipeline.
config_file = (test_resources_dir / "kafka_to_file_oauth.yml").resolve()
log_file = tmp_path / "kafka_oauth_message.log"
file_handler = logging.FileHandler(
str(log_file)
) # Add a file handler to later validate a test-case
logging.getLogger().addHandler(file_handler)
recipe: dict = {}
with open(config_file) as fp:
recipe = yaml.safe_load(fp)
pipeline = Pipeline.create(recipe)
pipeline.run()
is_found: bool = False
with open(log_file, "r") as file:
for line_number, line in enumerate(file, 1):
if oauth.MESSAGE in line:
is_found = True
break
assert is_found

View File

@ -33,7 +33,9 @@ pytestmark = pytest.mark.random_order(disabled=True)
class TestPipeline:
@patch("confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.KafkaSource.get_workunits", autospec=True)
@patch(
"datahub.ingestion.source.kafka.kafka.KafkaSource.get_workunits", autospec=True
)
@patch("datahub.ingestion.sink.console.ConsoleSink.close", autospec=True)
@freeze_time(FROZEN_TIME)
def test_configure(self, mock_sink, mock_source, mock_consumer):
@ -198,7 +200,9 @@ class TestPipeline:
assert pipeline.ctx.graph.config.token == pipeline.config.sink.config["token"]
@freeze_time(FROZEN_TIME)
@patch("datahub.ingestion.source.kafka.KafkaSource.get_workunits", autospec=True)
@patch(
"datahub.ingestion.source.kafka.kafka.KafkaSource.get_workunits", autospec=True
)
def test_configure_with_file_sink_does_not_init_graph(self, mock_source, tmp_path):
pipeline = Pipeline.create(
{

View File

@ -8,7 +8,7 @@ from confluent_kafka.schema_registry.schema_registry_client import (
)
from datahub.ingestion.source.confluent_schema_registry import ConfluentSchemaRegistry
from datahub.ingestion.source.kafka import KafkaSourceConfig, KafkaSourceReport
from datahub.ingestion.source.kafka.kafka import KafkaSourceConfig, KafkaSourceReport
class ConfluentSchemaRegistryTest(unittest.TestCase):

View File

@ -23,7 +23,7 @@ from datahub.emitter.mce_builder import (
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.kafka import KafkaSource, KafkaSourceConfig
from datahub.ingestion.source.kafka.kafka import KafkaSource, KafkaSourceConfig
from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeEvent
from datahub.metadata.schema_classes import (
BrowsePathsClass,
@ -38,11 +38,13 @@ from datahub.metadata.schema_classes import (
@pytest.fixture
def mock_admin_client():
with patch("datahub.ingestion.source.kafka.AdminClient", autospec=True) as mock:
with patch(
"datahub.ingestion.source.kafka.kafka.AdminClient", autospec=True
) as mock:
yield mock
@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_source_configuration(mock_kafka):
ctx = PipelineContext(run_id="test")
kafka_source = KafkaSource(
@ -53,7 +55,7 @@ def test_kafka_source_configuration(mock_kafka):
assert mock_kafka.call_count == 1
@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_source_workunits_wildcard_topic(mock_kafka, mock_admin_client):
mock_kafka_instance = mock_kafka.return_value
mock_cluster_metadata = MagicMock()
@ -74,7 +76,7 @@ def test_kafka_source_workunits_wildcard_topic(mock_kafka, mock_admin_client):
assert len(workunits) == 4
@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_source_workunits_topic_pattern(mock_kafka, mock_admin_client):
mock_kafka_instance = mock_kafka.return_value
mock_cluster_metadata = MagicMock()
@ -108,7 +110,7 @@ def test_kafka_source_workunits_topic_pattern(mock_kafka, mock_admin_client):
assert len(workunits) == 4
@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_source_workunits_with_platform_instance(mock_kafka, mock_admin_client):
PLATFORM_INSTANCE = "kafka_cluster"
PLATFORM = "kafka"
@ -160,7 +162,7 @@ def test_kafka_source_workunits_with_platform_instance(mock_kafka, mock_admin_cl
assert f"/prod/{PLATFORM}/{PLATFORM_INSTANCE}" in browse_path_aspects[0].paths
@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_source_workunits_no_platform_instance(mock_kafka, mock_admin_client):
PLATFORM = "kafka"
TOPIC_NAME = "test"
@ -204,7 +206,7 @@ def test_kafka_source_workunits_no_platform_instance(mock_kafka, mock_admin_clie
assert f"/prod/{PLATFORM}" in browse_path_aspects[0].paths
@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_close(mock_kafka, mock_admin_client):
mock_kafka_instance = mock_kafka.return_value
ctx = PipelineContext(run_id="test")
@ -223,7 +225,7 @@ def test_close(mock_kafka, mock_admin_client):
"datahub.ingestion.source.confluent_schema_registry.SchemaRegistryClient",
autospec=True,
)
@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_source_workunits_schema_registry_subject_name_strategies(
mock_kafka_consumer, mock_schema_registry_client, mock_admin_client
):
@ -415,7 +417,7 @@ def test_kafka_source_workunits_schema_registry_subject_name_strategies(
"datahub.ingestion.source.confluent_schema_registry.SchemaRegistryClient",
autospec=True,
)
@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_ignore_warnings_on_schema_type(
mock_kafka_consumer,
mock_schema_registry_client,
@ -483,8 +485,8 @@ def test_kafka_ignore_warnings_on_schema_type(
assert kafka_source.report.warnings
@patch("datahub.ingestion.source.kafka.AdminClient", autospec=True)
@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.AdminClient", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_source_succeeds_with_admin_client_init_error(
mock_kafka, mock_kafka_admin_client
):
@ -513,8 +515,8 @@ def test_kafka_source_succeeds_with_admin_client_init_error(
assert len(workunits) == 2
@patch("datahub.ingestion.source.kafka.AdminClient", autospec=True)
@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.AdminClient", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_source_succeeds_with_describe_configs_error(
mock_kafka, mock_kafka_admin_client
):
@ -550,7 +552,7 @@ def test_kafka_source_succeeds_with_describe_configs_error(
"datahub.ingestion.source.confluent_schema_registry.SchemaRegistryClient",
autospec=True,
)
@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_source_topic_meta_mappings(
mock_kafka_consumer, mock_schema_registry_client, mock_admin_client
):