mirror of
				https://github.com/open-metadata/OpenMetadata.git
				synced 2025-10-31 02:29:03 +00:00 
			
		
		
		
	 3c527ca83b
			
		
	
	
		3c527ca83b
		
			
		
	
	
	
	
		
			
			* MINOR: Fix Databricks DLT Pipeline Lineage to Track Table * fix tests * add support for s3 pipeline lineage as well
		
			
				
	
	
		
			883 lines
		
	
	
		
			32 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			883 lines
		
	
	
		
			32 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #  Copyright 2025 Collate
 | |
| #  Licensed under the Collate Community License, Version 1.0 (the "License");
 | |
| #  you may not use this file except in compliance with the License.
 | |
| #  You may obtain a copy of the License at
 | |
| #  https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
 | |
| #  Unless required by applicable law or agreed to in writing, software
 | |
| #  distributed under the License is distributed on an "AS IS" BASIS,
 | |
| #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| #  See the License for the specific language governing permissions and
 | |
| #  limitations under the License.
 | |
| 
 | |
| """
 | |
| Unit tests for Databricks Kafka parser
 | |
| """
 | |
| 
 | |
| import unittest
 | |
| 
 | |
| from metadata.ingestion.source.pipeline.databrickspipeline.kafka_parser import (
 | |
|     extract_dlt_table_dependencies,
 | |
|     extract_dlt_table_names,
 | |
|     extract_kafka_sources,
 | |
|     get_pipeline_libraries,
 | |
| )
 | |
| 
 | |
| 
 | |
| class TestKafkaParser(unittest.TestCase):
 | |
|     """Test cases for Kafka configuration parsing"""
 | |
| 
 | |
|     def test_basic_kafka_readstream(self):
 | |
|         """Test basic Kafka readStream pattern"""
 | |
|         source_code = """
 | |
|         df = spark.readStream \\
 | |
|             .format("kafka") \\
 | |
|             .option("kafka.bootstrap.servers", "broker1:9092") \\
 | |
|             .option("subscribe", "events_topic") \\
 | |
|             .load()
 | |
|         """
 | |
|         configs = extract_kafka_sources(source_code)
 | |
|         self.assertEqual(len(configs), 1)
 | |
|         self.assertEqual(configs[0].bootstrap_servers, "broker1:9092")
 | |
|         self.assertEqual(configs[0].topics, ["events_topic"])
 | |
| 
 | |
|     def test_multiple_topics(self):
 | |
|         """Test comma-separated topics"""
 | |
|         source_code = """
 | |
|         spark.readStream.format("kafka") \\
 | |
|             .option("subscribe", "topic1,topic2,topic3") \\
 | |
|             .load()
 | |
|         """
 | |
|         configs = extract_kafka_sources(source_code)
 | |
|         self.assertEqual(len(configs), 1)
 | |
|         self.assertEqual(configs[0].topics, ["topic1", "topic2", "topic3"])
 | |
| 
 | |
|     def test_topics_option(self):
 | |
|         """Test 'topics' option instead of 'subscribe'"""
 | |
|         source_code = """
 | |
|         df = spark.readStream.format("kafka") \\
 | |
|             .option("topics", "single_topic") \\
 | |
|             .load()
 | |
|         """
 | |
|         configs = extract_kafka_sources(source_code)
 | |
|         self.assertEqual(len(configs), 1)
 | |
|         self.assertEqual(configs[0].topics, ["single_topic"])
 | |
| 
 | |
|     def test_group_id_prefix(self):
 | |
|         """Test groupIdPrefix extraction"""
 | |
|         source_code = """
 | |
|         spark.readStream.format("kafka") \\
 | |
|             .option("kafka.bootstrap.servers", "localhost:9092") \\
 | |
|             .option("subscribe", "test_topic") \\
 | |
|             .option("groupIdPrefix", "dlt-pipeline-123") \\
 | |
|             .load()
 | |
|         """
 | |
|         configs = extract_kafka_sources(source_code)
 | |
|         self.assertEqual(len(configs), 1)
 | |
|         self.assertEqual(configs[0].group_id_prefix, "dlt-pipeline-123")
 | |
| 
 | |
|     def test_multiple_kafka_sources(self):
 | |
|         """Test multiple Kafka sources in same file"""
 | |
|         source_code = """
 | |
|         # First stream
 | |
|         df1 = spark.readStream.format("kafka") \\
 | |
|             .option("subscribe", "topic_a") \\
 | |
|             .load()
 | |
| 
 | |
|         # Second stream
 | |
|         df2 = spark.readStream.format("kafka") \\
 | |
|             .option("subscribe", "topic_b") \\
 | |
|             .load()
 | |
|         """
 | |
|         configs = extract_kafka_sources(source_code)
 | |
|         self.assertEqual(len(configs), 2)
 | |
|         topics = [c.topics[0] for c in configs]
 | |
|         self.assertIn("topic_a", topics)
 | |
|         self.assertIn("topic_b", topics)
 | |
| 
 | |
|     def test_single_quotes(self):
 | |
|         """Test single quotes in options"""
 | |
|         source_code = """
 | |
|         df = spark.readStream.format('kafka') \\
 | |
|             .option('kafka.bootstrap.servers', 'broker:9092') \\
 | |
|             .option('subscribe', 'my_topic') \\
 | |
|             .load()
 | |
|         """
 | |
|         configs = extract_kafka_sources(source_code)
 | |
|         self.assertEqual(len(configs), 1)
 | |
|         self.assertEqual(configs[0].bootstrap_servers, "broker:9092")
 | |
|         self.assertEqual(configs[0].topics, ["my_topic"])
 | |
| 
 | |
|     def test_mixed_quotes(self):
 | |
|         """Test mixed single and double quotes"""
 | |
|         source_code = """
 | |
|         df = spark.readStream.format("kafka") \\
 | |
|             .option('subscribe', "topic_mixed") \\
 | |
|             .load()
 | |
|         """
 | |
|         configs = extract_kafka_sources(source_code)
 | |
|         self.assertEqual(len(configs), 1)
 | |
|         self.assertEqual(configs[0].topics, ["topic_mixed"])
 | |
| 
 | |
|     def test_compact_format(self):
 | |
|         """Test compact single-line format"""
 | |
|         source_code = """
 | |
|         df = spark.readStream.format("kafka").option("subscribe", "compact_topic").load()
 | |
|         """
 | |
|         configs = extract_kafka_sources(source_code)
 | |
|         self.assertEqual(len(configs), 1)
 | |
|         self.assertEqual(configs[0].topics, ["compact_topic"])
 | |
| 
 | |
|     def test_no_kafka_sources(self):
 | |
|         """Test code with no Kafka sources"""
 | |
|         source_code = """
 | |
|         df = spark.read.parquet("/data/path")
 | |
|         df.write.format("delta").save("/output")
 | |
|         """
 | |
|         configs = extract_kafka_sources(source_code)
 | |
|         self.assertEqual(len(configs), 0)
 | |
| 
 | |
|     def test_partial_kafka_config(self):
 | |
|         """Test Kafka source with only topics (no brokers)"""
 | |
|         source_code = """
 | |
|         df = spark.readStream.format("kafka") \\
 | |
|             .option("subscribe", "topic_only") \\
 | |
|             .load()
 | |
|         """
 | |
|         configs = extract_kafka_sources(source_code)
 | |
|         self.assertEqual(len(configs), 1)
 | |
|         self.assertIsNone(configs[0].bootstrap_servers)
 | |
|         self.assertEqual(configs[0].topics, ["topic_only"])
 | |
| 
 | |
|     def test_malformed_kafka_incomplete(self):
 | |
|         """Test incomplete Kafka configuration doesn't crash"""
 | |
|         source_code = """
 | |
|         df = spark.readStream.format("kafka")
 | |
|         # No .load() - malformed
 | |
|         """
 | |
|         configs = extract_kafka_sources(source_code)
 | |
|         # Should return empty list, not crash
 | |
|         self.assertEqual(len(configs), 0)
 | |
| 
 | |
|     def test_special_characters_in_topic(self):
 | |
|         """Test topics with special characters"""
 | |
|         source_code = """
 | |
|         df = spark.readStream.format("kafka") \\
 | |
|             .option("subscribe", "topic-with-dashes_and_underscores.dots") \\
 | |
|             .load()
 | |
|         """
 | |
|         configs = extract_kafka_sources(source_code)
 | |
|         self.assertEqual(len(configs), 1)
 | |
|         self.assertEqual(configs[0].topics, ["topic-with-dashes_and_underscores.dots"])
 | |
| 
 | |
|     def test_whitespace_variations(self):
 | |
|         """Test various whitespace patterns"""
 | |
|         source_code = """
 | |
|         df=spark.readStream.format(  "kafka"  ).option(  "subscribe"  ,  "topic"  ).load(  )
 | |
|         """
 | |
|         configs = extract_kafka_sources(source_code)
 | |
|         self.assertEqual(len(configs), 1)
 | |
|         self.assertEqual(configs[0].topics, ["topic"])
 | |
| 
 | |
|     def test_case_insensitive_format(self):
 | |
|         """Test case insensitive Kafka format"""
 | |
|         source_code = """
 | |
|         df = spark.readStream.format("KAFKA") \\
 | |
|             .option("subscribe", "topic_upper") \\
 | |
|             .load()
 | |
|         """
 | |
|         configs = extract_kafka_sources(source_code)
 | |
|         self.assertEqual(len(configs), 1)
 | |
|         self.assertEqual(configs[0].topics, ["topic_upper"])
 | |
| 
 | |
|     def test_dlt_decorator_pattern(self):
 | |
|         """Test DLT table decorator pattern"""
 | |
|         source_code = """
 | |
|         import dlt
 | |
| 
 | |
|         @dlt.table
 | |
|         def bronze_events():
 | |
|             return spark.readStream \\
 | |
|                 .format("kafka") \\
 | |
|                 .option("kafka.bootstrap.servers", "kafka:9092") \\
 | |
|                 .option("subscribe", "raw_events") \\
 | |
|                 .load()
 | |
|         """
 | |
|         configs = extract_kafka_sources(source_code)
 | |
|         self.assertEqual(len(configs), 1)
 | |
|         self.assertEqual(configs[0].topics, ["raw_events"])
 | |
| 
 | |
|     def test_multiline_with_comments(self):
 | |
|         """Test code with inline comments"""
 | |
|         source_code = """
 | |
|         df = (spark.readStream
 | |
|             .format("kafka")  # Using Kafka source
 | |
|             .option("kafka.bootstrap.servers", "broker:9092")  # Broker config
 | |
|             .option("subscribe", "commented_topic")  # Topic name
 | |
|             .load())  # Load the stream
 | |
|         """
 | |
|         configs = extract_kafka_sources(source_code)
 | |
|         self.assertEqual(len(configs), 1)
 | |
|         self.assertEqual(configs[0].topics, ["commented_topic"])
 | |
| 
 | |
|     def test_empty_source_code(self):
 | |
|         """Test empty source code"""
 | |
|         configs = extract_kafka_sources("")
 | |
|         self.assertEqual(len(configs), 0)
 | |
| 
 | |
|     def test_null_source_code(self):
 | |
|         """Test None source code doesn't crash"""
 | |
|         configs = extract_kafka_sources(None)
 | |
|         self.assertEqual(len(configs), 0)
 | |
| 
 | |
|     def test_topics_with_whitespace(self):
 | |
|         """Test topics with surrounding whitespace are trimmed"""
 | |
|         source_code = """
 | |
|         df = spark.readStream.format("kafka") \\
 | |
|             .option("subscribe", " topic1 , topic2 , topic3 ") \\
 | |
|             .load()
 | |
|         """
 | |
|         configs = extract_kafka_sources(source_code)
 | |
|         self.assertEqual(len(configs), 1)
 | |
|         self.assertEqual(configs[0].topics, ["topic1", "topic2", "topic3"])
 | |
| 
 | |
|     def test_variable_topic_reference(self):
 | |
|         """Test Kafka config with variable reference for topic"""
 | |
|         source_code = """
 | |
|         TOPIC = "events_topic"
 | |
|         df = spark.readStream.format("kafka") \\
 | |
|             .option("subscribe", TOPIC) \\
 | |
|             .load()
 | |
|         """
 | |
|         configs = extract_kafka_sources(source_code)
 | |
|         self.assertEqual(len(configs), 1)
 | |
|         self.assertEqual(configs[0].topics, ["events_topic"])
 | |
| 
 | |
|     def test_real_world_dlt_pattern(self):
 | |
|         """Test real-world DLT pattern with variables"""
 | |
|         source_code = """
 | |
|         import dlt
 | |
|         from pyspark.sql.functions import *
 | |
| 
 | |
|         TOPIC = "tracker-events"
 | |
|         KAFKA_BROKER = spark.conf.get("KAFKA_SERVER")
 | |
| 
 | |
|         raw_kafka_events = (spark.readStream
 | |
|             .format("kafka")
 | |
|             .option("subscribe", TOPIC)
 | |
|             .option("kafka.bootstrap.servers", KAFKA_BROKER)
 | |
|             .option("startingOffsets", "earliest")
 | |
|             .load()
 | |
|         )
 | |
| 
 | |
|         @dlt.table(table_properties={"pipelines.reset.allowed":"false"})
 | |
|         def kafka_bronze():
 | |
|             return raw_kafka_events
 | |
|         """
 | |
|         configs = extract_kafka_sources(source_code)
 | |
|         self.assertEqual(len(configs), 1)
 | |
|         self.assertEqual(configs[0].topics, ["tracker-events"])
 | |
| 
 | |
|     def test_multiple_variable_topics(self):
 | |
|         """Test multiple topics defined as variables"""
 | |
|         source_code = """
 | |
|         TOPIC_A = "orders"
 | |
|         TOPIC_B = "payments"
 | |
| 
 | |
|         df = spark.readStream.format("kafka") \\
 | |
|             .option("subscribe", TOPIC_A) \\
 | |
|             .load()
 | |
| 
 | |
|         df2 = spark.readStream.format("kafka") \\
 | |
|             .option("topics", TOPIC_B) \\
 | |
|             .load()
 | |
|         """
 | |
|         configs = extract_kafka_sources(source_code)
 | |
|         self.assertEqual(len(configs), 2)
 | |
|         topics = [c.topics[0] for c in configs]
 | |
|         self.assertIn("orders", topics)
 | |
|         self.assertIn("payments", topics)
 | |
| 
 | |
|     def test_variable_not_defined(self):
 | |
|         """Test variable reference without definition"""
 | |
|         source_code = """
 | |
|         df = spark.readStream.format("kafka") \\
 | |
|             .option("subscribe", UNDEFINED_TOPIC) \\
 | |
|             .load()
 | |
|         """
 | |
|         configs = extract_kafka_sources(source_code)
 | |
|         # Should still find Kafka source but with empty topics
 | |
|         self.assertEqual(len(configs), 0)
 | |
| 
 | |
| 
 | |
| class TestPipelineLibraries(unittest.TestCase):
 | |
|     """Test cases for pipeline library extraction"""
 | |
| 
 | |
|     def test_notebook_library(self):
 | |
|         """Test notebook library extraction"""
 | |
|         pipeline_config = {
 | |
|             "libraries": [{"notebook": {"path": "/Workspace/dlt/bronze_pipeline"}}]
 | |
|         }
 | |
|         libraries = get_pipeline_libraries(pipeline_config)
 | |
|         self.assertEqual(len(libraries), 1)
 | |
|         self.assertEqual(libraries[0], "/Workspace/dlt/bronze_pipeline")
 | |
| 
 | |
|     def test_file_library(self):
 | |
|         """Test file library extraction"""
 | |
|         pipeline_config = {
 | |
|             "libraries": [{"file": {"path": "/Workspace/scripts/etl.py"}}]
 | |
|         }
 | |
|         libraries = get_pipeline_libraries(pipeline_config)
 | |
|         self.assertEqual(len(libraries), 1)
 | |
|         self.assertEqual(libraries[0], "/Workspace/scripts/etl.py")
 | |
| 
 | |
|     def test_mixed_libraries(self):
 | |
|         """Test mixed notebook and file libraries"""
 | |
|         pipeline_config = {
 | |
|             "libraries": [
 | |
|                 {"notebook": {"path": "/nb1"}},
 | |
|                 {"file": {"path": "/file1.py"}},
 | |
|                 {"notebook": {"path": "/nb2"}},
 | |
|             ]
 | |
|         }
 | |
|         libraries = get_pipeline_libraries(pipeline_config)
 | |
|         self.assertEqual(len(libraries), 3)
 | |
|         self.assertIn("/nb1", libraries)
 | |
|         self.assertIn("/file1.py", libraries)
 | |
|         self.assertIn("/nb2", libraries)
 | |
| 
 | |
|     def test_empty_libraries(self):
 | |
|         """Test empty libraries list"""
 | |
|         pipeline_config = {"libraries": []}
 | |
|         libraries = get_pipeline_libraries(pipeline_config)
 | |
|         self.assertEqual(len(libraries), 0)
 | |
| 
 | |
|     def test_missing_libraries_key(self):
 | |
|         """Test missing libraries key"""
 | |
|         pipeline_config = {}
 | |
|         libraries = get_pipeline_libraries(pipeline_config)
 | |
|         self.assertEqual(len(libraries), 0)
 | |
| 
 | |
|     def test_library_with_no_path(self):
 | |
|         """Test library entry with no path"""
 | |
|         pipeline_config = {"libraries": [{"notebook": {}}, {"file": {}}]}
 | |
|         libraries = get_pipeline_libraries(pipeline_config)
 | |
|         # Should skip entries without paths
 | |
|         self.assertEqual(len(libraries), 0)
 | |
| 
 | |
|     def test_unsupported_library_type(self):
 | |
|         """Test unsupported library types are skipped"""
 | |
|         pipeline_config = {
 | |
|             "libraries": [
 | |
|                 {"jar": {"path": "/lib.jar"}},  # Not supported
 | |
|                 {"notebook": {"path": "/nb"}},  # Supported
 | |
|                 {"whl": {"path": "/wheel.whl"}},  # Not supported
 | |
|             ]
 | |
|         }
 | |
|         libraries = get_pipeline_libraries(pipeline_config)
 | |
|         self.assertEqual(len(libraries), 1)
 | |
|         self.assertEqual(libraries[0], "/nb")
 | |
| 
 | |
| 
 | |
| class TestDLTTableExtraction(unittest.TestCase):
 | |
|     """Test cases for DLT table name extraction"""
 | |
| 
 | |
|     def test_literal_table_name(self):
 | |
|         """Test DLT table with literal string name"""
 | |
|         source_code = """
 | |
|         import dlt
 | |
| 
 | |
|         @dlt.table(name="user_events_bronze")
 | |
|         def bronze_events():
 | |
|             return spark.readStream.format("kafka").load()
 | |
|         """
 | |
|         table_names = extract_dlt_table_names(source_code)
 | |
|         self.assertEqual(len(table_names), 1)
 | |
|         self.assertEqual(table_names[0], "user_events_bronze")
 | |
| 
 | |
|     def test_table_name_with_comment(self):
 | |
|         """Test DLT table decorator with comment parameter"""
 | |
|         source_code = """
 | |
|         @dlt.table(
 | |
|             name="my_table",
 | |
|             comment="This is a test table"
 | |
|         )
 | |
|         def my_function():
 | |
|             return df
 | |
|         """
 | |
|         table_names = extract_dlt_table_names(source_code)
 | |
|         self.assertEqual(len(table_names), 1)
 | |
|         self.assertEqual(table_names[0], "my_table")
 | |
| 
 | |
|     def test_multiple_dlt_tables(self):
 | |
|         """Test multiple DLT table decorators"""
 | |
|         source_code = """
 | |
|         @dlt.table(name="bronze_table")
 | |
|         def bronze():
 | |
|             return spark.readStream.format("kafka").load()
 | |
| 
 | |
|         @dlt.table(name="silver_table")
 | |
|         def silver():
 | |
|             return dlt.read("bronze_table")
 | |
|         """
 | |
|         table_names = extract_dlt_table_names(source_code)
 | |
|         self.assertEqual(len(table_names), 2)
 | |
|         self.assertIn("bronze_table", table_names)
 | |
|         self.assertIn("silver_table", table_names)
 | |
| 
 | |
|     def test_function_name_pattern(self):
 | |
|         """Test DLT table with function call for name"""
 | |
|         source_code = """
 | |
|         entity_name = "moneyRequest"
 | |
| 
 | |
|         @dlt.table(name=materializer.generate_event_log_table_name())
 | |
|         def event_log():
 | |
|             return df
 | |
|         """
 | |
|         table_names = extract_dlt_table_names(source_code)
 | |
|         # Should infer from entity_name variable
 | |
|         self.assertEqual(len(table_names), 1)
 | |
|         self.assertEqual(table_names[0], "moneyRequest")
 | |
| 
 | |
|     def test_no_name_parameter(self):
 | |
|         """Test DLT table decorator without name parameter"""
 | |
|         source_code = """
 | |
|         @dlt.table(comment="No name specified")
 | |
|         def my_function():
 | |
|             return df
 | |
|         """
 | |
|         table_names = extract_dlt_table_names(source_code)
 | |
|         # Should return empty list when no name found
 | |
|         self.assertEqual(len(table_names), 0)
 | |
| 
 | |
|     def test_mixed_case_decorator(self):
 | |
|         """Test case insensitive DLT decorator"""
 | |
|         source_code = """
 | |
|         @DLT.TABLE(name="CasedTable")
 | |
|         def func():
 | |
|             return df
 | |
|         """
 | |
|         table_names = extract_dlt_table_names(source_code)
 | |
|         self.assertEqual(len(table_names), 1)
 | |
|         self.assertEqual(table_names[0], "CasedTable")
 | |
| 
 | |
|     def test_empty_source_code(self):
 | |
|         """Test empty source code"""
 | |
|         table_names = extract_dlt_table_names("")
 | |
|         self.assertEqual(len(table_names), 0)
 | |
| 
 | |
|     def test_null_source_code(self):
 | |
|         """Test None source code doesn't crash"""
 | |
|         table_names = extract_dlt_table_names(None)
 | |
|         self.assertEqual(len(table_names), 0)
 | |
| 
 | |
| 
 | |
| class TestKafkaFallbackPatterns(unittest.TestCase):
 | |
|     """Test cases for Kafka fallback extraction patterns"""
 | |
| 
 | |
|     def test_topic_variable_fallback(self):
 | |
|         """Test fallback to topic_name variable when no explicit Kafka pattern"""
 | |
|         source_code = """
 | |
|         topic_name = "dev.ern.cashout.moneyRequest_v1"
 | |
|         entity_name = "moneyRequest"
 | |
| 
 | |
|         # Kafka reading is abstracted in helper class
 | |
|         materializer = KafkaMaterializer(topic_name, entity_name)
 | |
|         """
 | |
|         configs = extract_kafka_sources(source_code)
 | |
|         self.assertEqual(len(configs), 1)
 | |
|         self.assertEqual(configs[0].topics, ["dev.ern.cashout.moneyRequest_v1"])
 | |
|         self.assertIsNone(configs[0].bootstrap_servers)
 | |
| 
 | |
|     def test_subject_variable_fallback(self):
 | |
|         """Test fallback to subject_name variable"""
 | |
|         source_code = """
 | |
|         subject_name = "user-events"
 | |
|         # Using helper class
 | |
|         """
 | |
|         configs = extract_kafka_sources(source_code)
 | |
|         self.assertEqual(len(configs), 1)
 | |
|         self.assertEqual(configs[0].topics, ["user-events"])
 | |
| 
 | |
|     def test_stream_variable_fallback(self):
 | |
|         """Test fallback to stream variable"""
 | |
|         source_code = """
 | |
|         stream_topic = "payment-stream"
 | |
|         """
 | |
|         configs = extract_kafka_sources(source_code)
 | |
|         self.assertEqual(len(configs), 1)
 | |
|         self.assertEqual(configs[0].topics, ["payment-stream"])
 | |
| 
 | |
|     def test_topic_with_dots(self):
 | |
|         """Test topic names with dots (namespace pattern)"""
 | |
|         source_code = """
 | |
|         df = spark.readStream.format("kafka") \\
 | |
|             .option("subscribe", "pre-prod.earnin.customer-experience.messages") \\
 | |
|             .load()
 | |
|         """
 | |
|         configs = extract_kafka_sources(source_code)
 | |
|         self.assertEqual(len(configs), 1)
 | |
|         self.assertEqual(
 | |
|             configs[0].topics, ["pre-prod.earnin.customer-experience.messages"]
 | |
|         )
 | |
| 
 | |
|     def test_multiple_topic_variables(self):
 | |
|         """Test multiple topic variables"""
 | |
|         source_code = """
 | |
|         topic_name = "events_v1"
 | |
|         stream_topic = "metrics_v1"
 | |
|         """
 | |
|         configs = extract_kafka_sources(source_code)
 | |
|         self.assertEqual(len(configs), 1)
 | |
|         # Should find both topics
 | |
|         self.assertEqual(len(configs[0].topics), 2)
 | |
|         self.assertIn("events_v1", configs[0].topics)
 | |
|         self.assertIn("metrics_v1", configs[0].topics)
 | |
| 
 | |
|     def test_no_fallback_when_explicit_kafka(self):
 | |
|         """Test that fallback is not used when explicit Kafka pattern exists"""
 | |
|         source_code = """
 | |
|         topic_name = "fallback_topic"
 | |
| 
 | |
|         df = spark.readStream.format("kafka") \\
 | |
|             .option("subscribe", "explicit_topic") \\
 | |
|             .load()
 | |
|         """
 | |
|         configs = extract_kafka_sources(source_code)
 | |
|         # Should find the explicit Kafka source, not the fallback
 | |
|         self.assertEqual(len(configs), 1)
 | |
|         self.assertEqual(configs[0].topics, ["explicit_topic"])
 | |
| 
 | |
|     def test_lowercase_variables(self):
 | |
|         """Test lowercase variable names are captured"""
 | |
|         source_code = """
 | |
|         topic_name = "lowercase_topic"
 | |
|         TOPIC_NAME = "uppercase_topic"
 | |
|         """
 | |
|         configs = extract_kafka_sources(source_code)
 | |
|         # Should find both
 | |
|         self.assertEqual(len(configs), 1)
 | |
|         self.assertEqual(len(configs[0].topics), 2)
 | |
| 
 | |
|     def test_real_world_abstracted_pattern(self):
 | |
|         """Test real-world pattern with abstracted Kafka reading"""
 | |
|         source_code = """
 | |
|         import dlt
 | |
|         from shared.materializers import KafkaMaterializer
 | |
| 
 | |
|         topic_name = "dev.ern.cashout.moneyRequest_v1"
 | |
|         entity_name = "moneyRequest"
 | |
| 
 | |
|         materializer = KafkaMaterializer(
 | |
|             topic_name=topic_name,
 | |
|             entity_name=entity_name,
 | |
|             spark=spark
 | |
|         )
 | |
| 
 | |
|         @dlt.table(name=materializer.generate_event_log_table_name())
 | |
|         def event_log():
 | |
|             return materializer.read_stream()
 | |
|         """
 | |
|         # Test Kafka extraction
 | |
|         kafka_configs = extract_kafka_sources(source_code)
 | |
|         self.assertEqual(len(kafka_configs), 1)
 | |
|         self.assertEqual(kafka_configs[0].topics, ["dev.ern.cashout.moneyRequest_v1"])
 | |
| 
 | |
|         # Test DLT table extraction
 | |
|         table_names = extract_dlt_table_names(source_code)
 | |
|         self.assertEqual(len(table_names), 1)
 | |
|         self.assertEqual(table_names[0], "moneyRequest")
 | |
| 
 | |
| 
 | |
| class TestDLTTableDependencies(unittest.TestCase):
 | |
|     """Test cases for DLT table dependency extraction"""
 | |
| 
 | |
|     def test_bronze_silver_pattern(self):
 | |
|         """Test table dependency pattern with Kafka source and downstream table"""
 | |
|         source_code = """
 | |
|         import dlt
 | |
|         from pyspark.sql.functions import *
 | |
| 
 | |
|         @dlt.table(name="orders_bronze")
 | |
|         def bronze():
 | |
|             return spark.readStream.format("kafka").option("subscribe", "orders").load()
 | |
| 
 | |
|         @dlt.table(name="orders_silver")
 | |
|         def silver():
 | |
|             return dlt.read_stream("orders_bronze").select("*")
 | |
|         """
 | |
|         deps = extract_dlt_table_dependencies(source_code)
 | |
|         self.assertEqual(len(deps), 2)
 | |
| 
 | |
|         bronze = next(d for d in deps if d.table_name == "orders_bronze")
 | |
|         self.assertTrue(bronze.reads_from_kafka)
 | |
|         self.assertEqual(bronze.depends_on, [])
 | |
| 
 | |
|         silver = next(d for d in deps if d.table_name == "orders_silver")
 | |
|         self.assertFalse(silver.reads_from_kafka)
 | |
|         self.assertEqual(silver.depends_on, ["orders_bronze"])
 | |
| 
 | |
|     def test_kafka_view_bronze_silver(self):
 | |
|         """Test Kafka view with multi-tier table dependencies"""
 | |
|         source_code = """
 | |
|         import dlt
 | |
| 
 | |
|         @dlt.view()
 | |
|         def kafka_source():
 | |
|             return spark.readStream.format("kafka").option("subscribe", "topic").load()
 | |
| 
 | |
|         @dlt.table(name="bronze_table")
 | |
|         def bronze():
 | |
|             return dlt.read_stream("kafka_source")
 | |
| 
 | |
|         @dlt.table(name="silver_table")
 | |
|         def silver():
 | |
|             return dlt.read_stream("bronze_table")
 | |
|         """
 | |
|         deps = extract_dlt_table_dependencies(source_code)
 | |
| 
 | |
|         bronze = next((d for d in deps if d.table_name == "bronze_table"), None)
 | |
|         self.assertIsNotNone(bronze)
 | |
|         self.assertEqual(bronze.depends_on, ["kafka_source"])
 | |
|         self.assertFalse(bronze.reads_from_kafka)
 | |
| 
 | |
|         silver = next((d for d in deps if d.table_name == "silver_table"), None)
 | |
|         self.assertIsNotNone(silver)
 | |
|         self.assertEqual(silver.depends_on, ["bronze_table"])
 | |
|         self.assertFalse(silver.reads_from_kafka)
 | |
| 
 | |
|     def test_multiple_dependencies(self):
 | |
|         """Test table with multiple source dependencies"""
 | |
|         source_code = """
 | |
|         @dlt.table(name="source1")
 | |
|         def s1():
 | |
|             return spark.readStream.format("kafka").load()
 | |
| 
 | |
|         @dlt.table(name="source2")
 | |
|         def s2():
 | |
|             return spark.readStream.format("kafka").load()
 | |
| 
 | |
|         @dlt.table(name="merged")
 | |
|         def merge():
 | |
|             df1 = dlt.read_stream("source1")
 | |
|             df2 = dlt.read_stream("source2")
 | |
|             return df1.union(df2)
 | |
|         """
 | |
|         deps = extract_dlt_table_dependencies(source_code)
 | |
| 
 | |
|         merged = next((d for d in deps if d.table_name == "merged"), None)
 | |
|         self.assertIsNotNone(merged)
 | |
|         self.assertEqual(sorted(merged.depends_on), ["source1", "source2"])
 | |
|         self.assertFalse(merged.reads_from_kafka)
 | |
| 
 | |
|     def test_no_dependencies(self):
 | |
|         """Test table with no dependencies (reads from file)"""
 | |
|         source_code = """
 | |
|         @dlt.table(name="static_data")
 | |
|         def static():
 | |
|             return spark.read.parquet("/path/to/data")
 | |
|         """
 | |
|         deps = extract_dlt_table_dependencies(source_code)
 | |
|         self.assertEqual(len(deps), 1)
 | |
|         self.assertEqual(deps[0].table_name, "static_data")
 | |
|         self.assertEqual(deps[0].depends_on, [])
 | |
|         self.assertFalse(deps[0].reads_from_kafka)
 | |
| 
 | |
|     def test_function_name_as_table_name(self):
 | |
|         """Test using function name when no explicit name in decorator"""
 | |
|         source_code = """
 | |
|         @dlt.table()
 | |
|         def my_bronze_table():
 | |
|             return spark.readStream.format("kafka").load()
 | |
|         """
 | |
|         deps = extract_dlt_table_dependencies(source_code)
 | |
|         self.assertEqual(len(deps), 1)
 | |
|         self.assertEqual(deps[0].table_name, "my_bronze_table")
 | |
|         self.assertTrue(deps[0].reads_from_kafka)
 | |
| 
 | |
|     def test_empty_source_code(self):
 | |
|         """Test empty source code returns empty list"""
 | |
|         deps = extract_dlt_table_dependencies("")
 | |
|         self.assertEqual(len(deps), 0)
 | |
| 
 | |
|     def test_real_world_pattern(self):
 | |
|         """Test the exact pattern from user's example"""
 | |
|         source_code = """
 | |
| import dlt
 | |
| from pyspark.sql.functions import *
 | |
| from pyspark.sql.types import *
 | |
| 
 | |
| TOPIC = "orders"
 | |
| 
 | |
| @dlt.view(comment="Kafka source")
 | |
| def kafka_orders_source():
 | |
|     return (
 | |
|         spark.readStream
 | |
|         .format("kafka")
 | |
|         .option("subscribe", TOPIC)
 | |
|         .load()
 | |
|     )
 | |
| 
 | |
| @dlt.table(name="orders_bronze")
 | |
| def orders_bronze():
 | |
|     return dlt.read_stream("kafka_orders_source").select("*")
 | |
| 
 | |
| @dlt.table(name="orders_silver")
 | |
| def orders_silver():
 | |
|     return dlt.read_stream("orders_bronze").select("*")
 | |
|         """
 | |
|         deps = extract_dlt_table_dependencies(source_code)
 | |
| 
 | |
|         # Should find bronze and silver (kafka_orders_source is a view, not a table)
 | |
|         table_deps = [
 | |
|             d for d in deps if d.table_name in ["orders_bronze", "orders_silver"]
 | |
|         ]
 | |
|         self.assertEqual(len(table_deps), 2)
 | |
| 
 | |
|         bronze = next((d for d in table_deps if d.table_name == "orders_bronze"), None)
 | |
|         self.assertIsNotNone(bronze)
 | |
|         self.assertEqual(bronze.depends_on, ["kafka_orders_source"])
 | |
|         self.assertFalse(bronze.reads_from_kafka)
 | |
| 
 | |
|         silver = next((d for d in table_deps if d.table_name == "orders_silver"), None)
 | |
|         self.assertIsNotNone(silver)
 | |
|         self.assertEqual(silver.depends_on, ["orders_bronze"])
 | |
|         self.assertFalse(silver.reads_from_kafka)
 | |
| 
 | |
| 
 | |
| class TestS3SourceDetection(unittest.TestCase):
 | |
|     """Test cases for S3 source detection"""
 | |
| 
 | |
|     def test_s3_json_source(self):
 | |
|         """Test detecting S3 source with spark.read.json()"""
 | |
|         source_code = """
 | |
|         @dlt.view()
 | |
|         def s3_source():
 | |
|             return spark.read.json("s3://mybucket/data/")
 | |
|         """
 | |
|         deps = extract_dlt_table_dependencies(source_code)
 | |
|         self.assertEqual(len(deps), 1)
 | |
|         self.assertTrue(deps[0].reads_from_s3)
 | |
|         self.assertEqual(deps[0].s3_locations, ["s3://mybucket/data/"])
 | |
| 
 | |
|     def test_s3_parquet_source(self):
 | |
|         """Test detecting S3 source with spark.read.parquet()"""
 | |
|         source_code = """
 | |
|         @dlt.table(name="parquet_table")
 | |
|         def load_parquet():
 | |
|             return spark.read.parquet("s3a://bucket/path/file.parquet")
 | |
|         """
 | |
|         deps = extract_dlt_table_dependencies(source_code)
 | |
|         self.assertEqual(len(deps), 1)
 | |
|         self.assertTrue(deps[0].reads_from_s3)
 | |
|         self.assertIn("s3a://bucket/path/file.parquet", deps[0].s3_locations)
 | |
| 
 | |
|     def test_s3_with_options(self):
 | |
|         """Test S3 source with options"""
 | |
|         source_code = """
 | |
|         @dlt.view()
 | |
|         def external_source():
 | |
|             return (
 | |
|                 spark.read
 | |
|                     .option("multiline", "true")
 | |
|                     .json("s3://test-firehose-con-bucket/firehose_data/")
 | |
|             )
 | |
|         """
 | |
|         deps = extract_dlt_table_dependencies(source_code)
 | |
|         self.assertEqual(len(deps), 1)
 | |
|         self.assertTrue(deps[0].reads_from_s3)
 | |
|         self.assertEqual(
 | |
|             deps[0].s3_locations, ["s3://test-firehose-con-bucket/firehose_data/"]
 | |
|         )
 | |
| 
 | |
|     def test_s3_format_load(self):
 | |
|         """Test S3 with format().load() pattern"""
 | |
|         source_code = """
 | |
|         @dlt.table(name="csv_data")
 | |
|         def load_csv():
 | |
|             return spark.read.format("csv").option("header", "true").load("s3://bucket/data.csv")
 | |
|         """
 | |
|         deps = extract_dlt_table_dependencies(source_code)
 | |
|         self.assertEqual(len(deps), 1)
 | |
|         self.assertTrue(deps[0].reads_from_s3)
 | |
|         self.assertIn("s3://bucket/data.csv", deps[0].s3_locations)
 | |
| 
 | |
|     def test_dlt_read_batch(self):
 | |
|         """Test dlt.read() for batch table dependencies"""
 | |
|         source_code = """
 | |
|         @dlt.table(name="bronze")
 | |
|         def bronze_table():
 | |
|             return dlt.read("source_view")
 | |
|         """
 | |
|         deps = extract_dlt_table_dependencies(source_code)
 | |
|         self.assertEqual(len(deps), 1)
 | |
|         self.assertEqual(deps[0].depends_on, ["source_view"])
 | |
| 
 | |
|     def test_mixed_dlt_read_and_read_stream(self):
 | |
|         """Test both dlt.read() and dlt.read_stream() in same pipeline"""
 | |
|         source_code = """
 | |
|         @dlt.table(name="batch_table")
 | |
|         def batch():
 | |
|             return dlt.read("source1")
 | |
| 
 | |
|         @dlt.table(name="stream_table")
 | |
|         def stream():
 | |
|             return dlt.read_stream("source2")
 | |
|         """
 | |
|         deps = extract_dlt_table_dependencies(source_code)
 | |
|         self.assertEqual(len(deps), 2)
 | |
| 
 | |
|         batch = next(d for d in deps if d.table_name == "batch_table")
 | |
|         self.assertEqual(batch.depends_on, ["source1"])
 | |
| 
 | |
|         stream = next(d for d in deps if d.table_name == "stream_table")
 | |
|         self.assertEqual(stream.depends_on, ["source2"])
 | |
| 
 | |
|     def test_user_s3_example(self):
 | |
|         """Test the user's exact S3 example"""
 | |
|         source_code = """
 | |
|         import dlt
 | |
|         from pyspark.sql.functions import col
 | |
| 
 | |
|         @dlt.view(comment="External source data from S3")
 | |
|         def external_source():
 | |
|             return (
 | |
|                 spark.read
 | |
|                     .option("multiline", "true")
 | |
|                     .json("s3://test-firehose-con-bucket/firehose_data/")
 | |
|             )
 | |
| 
 | |
|         @dlt.table(name="bronze_firehose_data")
 | |
|         def bronze_firehose_data():
 | |
|             return dlt.read("external_source")
 | |
| 
 | |
|         @dlt.table(name="silver_firehose_data")
 | |
|         def silver_firehose_data():
 | |
|             return dlt.read("bronze_firehose_data")
 | |
|         """
 | |
|         deps = extract_dlt_table_dependencies(source_code)
 | |
|         self.assertEqual(len(deps), 3)
 | |
| 
 | |
|         # Verify external_source
 | |
|         external = next((d for d in deps if d.table_name == "external_source"), None)
 | |
|         self.assertIsNotNone(external)
 | |
|         self.assertTrue(external.reads_from_s3)
 | |
|         self.assertIn(
 | |
|             "s3://test-firehose-con-bucket/firehose_data/", external.s3_locations
 | |
|         )
 | |
| 
 | |
|         # Verify bronze
 | |
|         bronze = next((d for d in deps if d.table_name == "bronze_firehose_data"), None)
 | |
|         self.assertIsNotNone(bronze)
 | |
|         self.assertEqual(bronze.depends_on, ["external_source"])
 | |
|         self.assertFalse(bronze.reads_from_s3)
 | |
| 
 | |
|         # Verify silver
 | |
|         silver = next((d for d in deps if d.table_name == "silver_firehose_data"), None)
 | |
|         self.assertIsNotNone(silver)
 | |
|         self.assertEqual(silver.depends_on, ["bronze_firehose_data"])
 | |
|         self.assertFalse(silver.reads_from_s3)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     unittest.main()
 |