Allow/deny patterns for kafka source

This commit is contained in:
Harshal Sheth 2021-02-11 22:48:20 -08:00 committed by Shirshanka Das
parent df3e3da45b
commit d483d23fd7
3 changed files with 43 additions and 23 deletions

View File

@ -3,6 +3,11 @@ source:
type: "kafka"
kafka:
connection.bootstrap: "localhost:9092"
topic_patterns:
allow:
- ".*"
deny:
- "^_.+" # deny all tables that start with an underscore
sink:
type: "datahub-kafka"

View File

@ -2,11 +2,10 @@ import logging
from gometa.configuration import ConfigModel
from gometa.configuration.kafka import KafkaConsumerConnectionConfig
from gometa.ingestion.api.source import Source, SourceReport
from typing import Iterable, List, Dict, Any
from typing import Iterable, List, Dict
from dataclasses import dataclass, field
import confluent_kafka
from confluent_kafka.schema_registry.schema_registry_client import SchemaRegistryClient
import re
from gometa.ingestion.source.metadata_common import MetadataWorkUnit
import time
@ -20,6 +19,7 @@ from gometa.metadata.com.linkedin.pegasus2avro.schema import (
KafkaSchema,
SchemaField,
)
from gometa.configuration.common import AllowDenyPattern
from gometa.metadata.com.linkedin.pegasus2avro.common import AuditStamp, Status
logger = logging.getLogger(__name__)
@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
class KafkaSourceConfig(ConfigModel):
connection: KafkaConsumerConnectionConfig = KafkaConsumerConnectionConfig()
topic: str = ".*" # default is wildcard subscription
topic_patterns: AllowDenyPattern = AllowDenyPattern(allow=[".*"], deny=["^_.*"])
@dataclass
@ -57,18 +57,16 @@ class KafkaSourceReport(SourceReport):
@dataclass
class KafkaSource(Source):
source_config: KafkaSourceConfig
topic_pattern: Any # actually re.Pattern
consumer: confluent_kafka.Consumer
report: KafkaSourceReport
def __init__(self, config: KafkaSourceConfig, ctx: PipelineContext):
super().__init__(ctx)
self.source_config = config
self.topic_pattern = re.compile(self.source_config.topic)
self.consumer = confluent_kafka.Consumer(
{
'group.id': 'test',
'bootstrap.servers': self.source_config.connection.bootstrap,
"group.id": "test",
"bootstrap.servers": self.source_config.connection.bootstrap,
**self.source_config.connection.consumer_config,
}
)
@ -87,10 +85,9 @@ class KafkaSource(Source):
for t in topics:
self.report.report_topic_scanned(t)
# TODO: topics config should support allow and deny patterns
if re.fullmatch(self.topic_pattern, t) and not t.startswith("_"):
if self.source_config.topic_patterns.allowed(t):
mce = self._extract_record(t)
wu = MetadataWorkUnit(id=f'kafka-{t}', mce=mce)
wu = MetadataWorkUnit(id=f"kafka-{t}", mce=mce)
self.report.report_workunit(wu)
yield wu
else:
@ -123,7 +120,7 @@ class KafkaSource(Source):
# Parse the schema
fields: List[SchemaField] = []
if has_schema and schema.schema_type == 'AVRO':
if has_schema and schema.schema_type == "AVRO":
fields = schema_util.avro_schema_to_mce_fields(schema.schema_str)
elif has_schema:
self.report.report_warning(

View File

@ -10,8 +10,11 @@ from unittest.mock import patch, MagicMock
class KafkaSourceTest(unittest.TestCase):
@patch("gometa.ingestion.source.kafka.confluent_kafka.Consumer")
def test_kafka_source_configuration(self, mock_kafka):
ctx = PipelineContext(run_id='test')
kafka_source = KafkaSource.create({'connection': {'bootstrap': 'foobar:9092'}}, ctx)
ctx = PipelineContext(run_id="test")
kafka_source = KafkaSource.create(
{"connection": {"bootstrap": "foobar:9092"}}, ctx
)
kafka_source.close()
assert mock_kafka.call_count == 1
@patch("gometa.ingestion.source.kafka.confluent_kafka.Consumer")
@ -21,13 +24,15 @@ class KafkaSourceTest(unittest.TestCase):
mock_cluster_metadata.topics = ["foobar", "bazbaz"]
mock_kafka_instance.list_topics.return_value = mock_cluster_metadata
ctx = PipelineContext(run_id='test')
kafka_source = KafkaSource.create({'connection': {'bootstrap': 'localhost:9092'}}, ctx)
ctx = PipelineContext(run_id="test")
kafka_source = KafkaSource.create(
{"connection": {"bootstrap": "localhost:9092"}}, ctx
)
workunits = []
for w in kafka_source.get_workunits():
workunits.append(w)
first_mce = workunits[0].get_metadata()['mce']
first_mce = workunits[0].get_metadata()["mce"]
assert isinstance(first_mce, MetadataChangeEvent)
mock_kafka.assert_called_once()
mock_kafka_instance.list_topics.assert_called_once()
@ -40,9 +45,14 @@ class KafkaSourceTest(unittest.TestCase):
mock_cluster_metadata.topics = ["test", "foobar", "bazbaz"]
mock_kafka_instance.list_topics.return_value = mock_cluster_metadata
ctx = PipelineContext(run_id='test1')
kafka_source = KafkaSource.create({'topic': 'test', 'connection': {'bootstrap': 'localhost:9092'}}, ctx)
assert kafka_source.source_config.topic == "test"
ctx = PipelineContext(run_id="test1")
kafka_source = KafkaSource.create(
{
"topic_patterns": {"allow": ["test"]},
"connection": {"bootstrap": "localhost:9092"},
},
ctx,
)
workunits = [w for w in kafka_source.get_workunits()]
mock_kafka.assert_called_once()
@ -50,15 +60,23 @@ class KafkaSourceTest(unittest.TestCase):
assert len(workunits) == 1
mock_cluster_metadata.topics = ["test", "test2", "bazbaz"]
ctx = PipelineContext(run_id='test2')
kafka_source = KafkaSource.create({'topic': 'test.*', 'connection': {'bootstrap': 'localhost:9092'}}, ctx)
ctx = PipelineContext(run_id="test2")
kafka_source = KafkaSource.create(
{
"topic_patterns": {"allow": ["test.*"]},
"connection": {"bootstrap": "localhost:9092"},
},
ctx,
)
workunits = [w for w in kafka_source.get_workunits()]
assert len(workunits) == 2
@patch("gometa.ingestion.source.kafka.confluent_kafka.Consumer")
def test_close(self, mock_kafka):
mock_kafka_instance = mock_kafka.return_value
ctx = PipelineContext(run_id='test')
kafka_source = KafkaSource.create({'topic': 'test', 'connection': {'bootstrap': 'localhost:9092'}}, ctx)
ctx = PipelineContext(run_id="test")
kafka_source = KafkaSource.create(
{"topic": "test", "connection": {"bootstrap": "localhost:9092"}}, ctx
)
kafka_source.close()
assert mock_kafka_instance.close.call_count == 1