OpenMetadata/ingestion/tests/unit/topology/pipeline/test_databricks_kafka_parser.py
2025-10-11 22:25:43 +02:00

592 lines
21 KiB
Python

# Copyright 2025 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Unit tests for Databricks Kafka parser
"""
import unittest
from metadata.ingestion.source.pipeline.databrickspipeline.kafka_parser import (
extract_dlt_table_names,
extract_kafka_sources,
get_pipeline_libraries,
)
class TestKafkaParser(unittest.TestCase):
"""Test cases for Kafka configuration parsing"""
def test_basic_kafka_readstream(self):
"""Test basic Kafka readStream pattern"""
source_code = """
df = spark.readStream \\
.format("kafka") \\
.option("kafka.bootstrap.servers", "broker1:9092") \\
.option("subscribe", "events_topic") \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].bootstrap_servers, "broker1:9092")
self.assertEqual(configs[0].topics, ["events_topic"])
def test_multiple_topics(self):
"""Test comma-separated topics"""
source_code = """
spark.readStream.format("kafka") \\
.option("subscribe", "topic1,topic2,topic3") \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].topics, ["topic1", "topic2", "topic3"])
def test_topics_option(self):
"""Test 'topics' option instead of 'subscribe'"""
source_code = """
df = spark.readStream.format("kafka") \\
.option("topics", "single_topic") \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].topics, ["single_topic"])
def test_group_id_prefix(self):
"""Test groupIdPrefix extraction"""
source_code = """
spark.readStream.format("kafka") \\
.option("kafka.bootstrap.servers", "localhost:9092") \\
.option("subscribe", "test_topic") \\
.option("groupIdPrefix", "dlt-pipeline-123") \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].group_id_prefix, "dlt-pipeline-123")
def test_multiple_kafka_sources(self):
"""Test multiple Kafka sources in same file"""
source_code = """
# First stream
df1 = spark.readStream.format("kafka") \\
.option("subscribe", "topic_a") \\
.load()
# Second stream
df2 = spark.readStream.format("kafka") \\
.option("subscribe", "topic_b") \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 2)
topics = [c.topics[0] for c in configs]
self.assertIn("topic_a", topics)
self.assertIn("topic_b", topics)
def test_single_quotes(self):
"""Test single quotes in options"""
source_code = """
df = spark.readStream.format('kafka') \\
.option('kafka.bootstrap.servers', 'broker:9092') \\
.option('subscribe', 'my_topic') \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].bootstrap_servers, "broker:9092")
self.assertEqual(configs[0].topics, ["my_topic"])
def test_mixed_quotes(self):
"""Test mixed single and double quotes"""
source_code = """
df = spark.readStream.format("kafka") \\
.option('subscribe', "topic_mixed") \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].topics, ["topic_mixed"])
def test_compact_format(self):
"""Test compact single-line format"""
source_code = """
df = spark.readStream.format("kafka").option("subscribe", "compact_topic").load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].topics, ["compact_topic"])
def test_no_kafka_sources(self):
"""Test code with no Kafka sources"""
source_code = """
df = spark.read.parquet("/data/path")
df.write.format("delta").save("/output")
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 0)
def test_partial_kafka_config(self):
"""Test Kafka source with only topics (no brokers)"""
source_code = """
df = spark.readStream.format("kafka") \\
.option("subscribe", "topic_only") \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertIsNone(configs[0].bootstrap_servers)
self.assertEqual(configs[0].topics, ["topic_only"])
def test_malformed_kafka_incomplete(self):
"""Test incomplete Kafka configuration doesn't crash"""
source_code = """
df = spark.readStream.format("kafka")
# No .load() - malformed
"""
configs = extract_kafka_sources(source_code)
# Should return empty list, not crash
self.assertEqual(len(configs), 0)
def test_special_characters_in_topic(self):
"""Test topics with special characters"""
source_code = """
df = spark.readStream.format("kafka") \\
.option("subscribe", "topic-with-dashes_and_underscores.dots") \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].topics, ["topic-with-dashes_and_underscores.dots"])
def test_whitespace_variations(self):
"""Test various whitespace patterns"""
source_code = """
df=spark.readStream.format( "kafka" ).option( "subscribe" , "topic" ).load( )
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].topics, ["topic"])
def test_case_insensitive_format(self):
"""Test case insensitive Kafka format"""
source_code = """
df = spark.readStream.format("KAFKA") \\
.option("subscribe", "topic_upper") \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].topics, ["topic_upper"])
def test_dlt_decorator_pattern(self):
"""Test DLT table decorator pattern"""
source_code = """
import dlt
@dlt.table
def bronze_events():
return spark.readStream \\
.format("kafka") \\
.option("kafka.bootstrap.servers", "kafka:9092") \\
.option("subscribe", "raw_events") \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].topics, ["raw_events"])
def test_multiline_with_comments(self):
"""Test code with inline comments"""
source_code = """
df = (spark.readStream
.format("kafka") # Using Kafka source
.option("kafka.bootstrap.servers", "broker:9092") # Broker config
.option("subscribe", "commented_topic") # Topic name
.load()) # Load the stream
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].topics, ["commented_topic"])
def test_empty_source_code(self):
"""Test empty source code"""
configs = extract_kafka_sources("")
self.assertEqual(len(configs), 0)
def test_null_source_code(self):
"""Test None source code doesn't crash"""
configs = extract_kafka_sources(None)
self.assertEqual(len(configs), 0)
def test_topics_with_whitespace(self):
"""Test topics with surrounding whitespace are trimmed"""
source_code = """
df = spark.readStream.format("kafka") \\
.option("subscribe", " topic1 , topic2 , topic3 ") \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].topics, ["topic1", "topic2", "topic3"])
def test_variable_topic_reference(self):
"""Test Kafka config with variable reference for topic"""
source_code = """
TOPIC = "events_topic"
df = spark.readStream.format("kafka") \\
.option("subscribe", TOPIC) \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].topics, ["events_topic"])
def test_real_world_dlt_pattern(self):
"""Test real-world DLT pattern with variables"""
source_code = """
import dlt
from pyspark.sql.functions import *
TOPIC = "tracker-events"
KAFKA_BROKER = spark.conf.get("KAFKA_SERVER")
raw_kafka_events = (spark.readStream
.format("kafka")
.option("subscribe", TOPIC)
.option("kafka.bootstrap.servers", KAFKA_BROKER)
.option("startingOffsets", "earliest")
.load()
)
@dlt.table(table_properties={"pipelines.reset.allowed":"false"})
def kafka_bronze():
return raw_kafka_events
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].topics, ["tracker-events"])
def test_multiple_variable_topics(self):
"""Test multiple topics defined as variables"""
source_code = """
TOPIC_A = "orders"
TOPIC_B = "payments"
df = spark.readStream.format("kafka") \\
.option("subscribe", TOPIC_A) \\
.load()
df2 = spark.readStream.format("kafka") \\
.option("topics", TOPIC_B) \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 2)
topics = [c.topics[0] for c in configs]
self.assertIn("orders", topics)
self.assertIn("payments", topics)
def test_variable_not_defined(self):
"""Test variable reference without definition"""
source_code = """
df = spark.readStream.format("kafka") \\
.option("subscribe", UNDEFINED_TOPIC) \\
.load()
"""
configs = extract_kafka_sources(source_code)
# Should still find Kafka source but with empty topics
self.assertEqual(len(configs), 0)
class TestPipelineLibraries(unittest.TestCase):
"""Test cases for pipeline library extraction"""
def test_notebook_library(self):
"""Test notebook library extraction"""
pipeline_config = {
"libraries": [{"notebook": {"path": "/Workspace/dlt/bronze_pipeline"}}]
}
libraries = get_pipeline_libraries(pipeline_config)
self.assertEqual(len(libraries), 1)
self.assertEqual(libraries[0], "/Workspace/dlt/bronze_pipeline")
def test_file_library(self):
"""Test file library extraction"""
pipeline_config = {
"libraries": [{"file": {"path": "/Workspace/scripts/etl.py"}}]
}
libraries = get_pipeline_libraries(pipeline_config)
self.assertEqual(len(libraries), 1)
self.assertEqual(libraries[0], "/Workspace/scripts/etl.py")
def test_mixed_libraries(self):
"""Test mixed notebook and file libraries"""
pipeline_config = {
"libraries": [
{"notebook": {"path": "/nb1"}},
{"file": {"path": "/file1.py"}},
{"notebook": {"path": "/nb2"}},
]
}
libraries = get_pipeline_libraries(pipeline_config)
self.assertEqual(len(libraries), 3)
self.assertIn("/nb1", libraries)
self.assertIn("/file1.py", libraries)
self.assertIn("/nb2", libraries)
def test_empty_libraries(self):
"""Test empty libraries list"""
pipeline_config = {"libraries": []}
libraries = get_pipeline_libraries(pipeline_config)
self.assertEqual(len(libraries), 0)
def test_missing_libraries_key(self):
"""Test missing libraries key"""
pipeline_config = {}
libraries = get_pipeline_libraries(pipeline_config)
self.assertEqual(len(libraries), 0)
def test_library_with_no_path(self):
"""Test library entry with no path"""
pipeline_config = {"libraries": [{"notebook": {}}, {"file": {}}]}
libraries = get_pipeline_libraries(pipeline_config)
# Should skip entries without paths
self.assertEqual(len(libraries), 0)
def test_unsupported_library_type(self):
"""Test unsupported library types are skipped"""
pipeline_config = {
"libraries": [
{"jar": {"path": "/lib.jar"}}, # Not supported
{"notebook": {"path": "/nb"}}, # Supported
{"whl": {"path": "/wheel.whl"}}, # Not supported
]
}
libraries = get_pipeline_libraries(pipeline_config)
self.assertEqual(len(libraries), 1)
self.assertEqual(libraries[0], "/nb")
class TestDLTTableExtraction(unittest.TestCase):
"""Test cases for DLT table name extraction"""
def test_literal_table_name(self):
"""Test DLT table with literal string name"""
source_code = """
import dlt
@dlt.table(name="user_events_bronze")
def bronze_events():
return spark.readStream.format("kafka").load()
"""
table_names = extract_dlt_table_names(source_code)
self.assertEqual(len(table_names), 1)
self.assertEqual(table_names[0], "user_events_bronze")
def test_table_name_with_comment(self):
"""Test DLT table decorator with comment parameter"""
source_code = """
@dlt.table(
name="my_table",
comment="This is a test table"
)
def my_function():
return df
"""
table_names = extract_dlt_table_names(source_code)
self.assertEqual(len(table_names), 1)
self.assertEqual(table_names[0], "my_table")
def test_multiple_dlt_tables(self):
"""Test multiple DLT table decorators"""
source_code = """
@dlt.table(name="bronze_table")
def bronze():
return spark.readStream.format("kafka").load()
@dlt.table(name="silver_table")
def silver():
return dlt.read("bronze_table")
"""
table_names = extract_dlt_table_names(source_code)
self.assertEqual(len(table_names), 2)
self.assertIn("bronze_table", table_names)
self.assertIn("silver_table", table_names)
def test_function_name_pattern(self):
"""Test DLT table with function call for name"""
source_code = """
entity_name = "moneyRequest"
@dlt.table(name=materializer.generate_event_log_table_name())
def event_log():
return df
"""
table_names = extract_dlt_table_names(source_code)
# Should infer from entity_name variable
self.assertEqual(len(table_names), 1)
self.assertEqual(table_names[0], "moneyRequest")
def test_no_name_parameter(self):
"""Test DLT table decorator without name parameter"""
source_code = """
@dlt.table(comment="No name specified")
def my_function():
return df
"""
table_names = extract_dlt_table_names(source_code)
# Should return empty list when no name found
self.assertEqual(len(table_names), 0)
def test_mixed_case_decorator(self):
"""Test case insensitive DLT decorator"""
source_code = """
@DLT.TABLE(name="CasedTable")
def func():
return df
"""
table_names = extract_dlt_table_names(source_code)
self.assertEqual(len(table_names), 1)
self.assertEqual(table_names[0], "CasedTable")
def test_empty_source_code(self):
"""Test empty source code"""
table_names = extract_dlt_table_names("")
self.assertEqual(len(table_names), 0)
def test_null_source_code(self):
"""Test None source code doesn't crash"""
table_names = extract_dlt_table_names(None)
self.assertEqual(len(table_names), 0)
class TestKafkaFallbackPatterns(unittest.TestCase):
"""Test cases for Kafka fallback extraction patterns"""
def test_topic_variable_fallback(self):
"""Test fallback to topic_name variable when no explicit Kafka pattern"""
source_code = """
topic_name = "dev.ern.cashout.moneyRequest_v1"
entity_name = "moneyRequest"
# Kafka reading is abstracted in helper class
materializer = KafkaMaterializer(topic_name, entity_name)
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].topics, ["dev.ern.cashout.moneyRequest_v1"])
self.assertIsNone(configs[0].bootstrap_servers)
def test_subject_variable_fallback(self):
"""Test fallback to subject_name variable"""
source_code = """
subject_name = "user-events"
# Using helper class
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].topics, ["user-events"])
def test_stream_variable_fallback(self):
"""Test fallback to stream variable"""
source_code = """
stream_topic = "payment-stream"
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].topics, ["payment-stream"])
def test_topic_with_dots(self):
"""Test topic names with dots (namespace pattern)"""
source_code = """
df = spark.readStream.format("kafka") \\
.option("subscribe", "pre-prod.earnin.customer-experience.messages") \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(
configs[0].topics, ["pre-prod.earnin.customer-experience.messages"]
)
def test_multiple_topic_variables(self):
"""Test multiple topic variables"""
source_code = """
topic_name = "events_v1"
stream_topic = "metrics_v1"
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
# Should find both topics
self.assertEqual(len(configs[0].topics), 2)
self.assertIn("events_v1", configs[0].topics)
self.assertIn("metrics_v1", configs[0].topics)
def test_no_fallback_when_explicit_kafka(self):
"""Test that fallback is not used when explicit Kafka pattern exists"""
source_code = """
topic_name = "fallback_topic"
df = spark.readStream.format("kafka") \\
.option("subscribe", "explicit_topic") \\
.load()
"""
configs = extract_kafka_sources(source_code)
# Should find the explicit Kafka source, not the fallback
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].topics, ["explicit_topic"])
def test_lowercase_variables(self):
"""Test lowercase variable names are captured"""
source_code = """
topic_name = "lowercase_topic"
TOPIC_NAME = "uppercase_topic"
"""
configs = extract_kafka_sources(source_code)
# Should find both
self.assertEqual(len(configs), 1)
self.assertEqual(len(configs[0].topics), 2)
def test_real_world_abstracted_pattern(self):
"""Test real-world pattern with abstracted Kafka reading"""
source_code = """
import dlt
from shared.materializers import KafkaMaterializer
topic_name = "dev.ern.cashout.moneyRequest_v1"
entity_name = "moneyRequest"
materializer = KafkaMaterializer(
topic_name=topic_name,
entity_name=entity_name,
spark=spark
)
@dlt.table(name=materializer.generate_event_log_table_name())
def event_log():
return materializer.read_stream()
"""
# Test Kafka extraction
kafka_configs = extract_kafka_sources(source_code)
self.assertEqual(len(kafka_configs), 1)
self.assertEqual(kafka_configs[0].topics, ["dev.ern.cashout.moneyRequest_v1"])
# Test DLT table extraction
table_names = extract_dlt_table_names(source_code)
self.assertEqual(len(table_names), 1)
self.assertEqual(table_names[0], "moneyRequest")
if __name__ == "__main__":
unittest.main()