From 3c527ca83b6f09d977fab3ffda64e40b764aa22e Mon Sep 17 00:00:00 2001 From: Mayur Singal <39544459+ulixius9@users.noreply.github.com> Date: Wed, 15 Oct 2025 14:24:01 +0530 Subject: [PATCH] MINOR: Fix Databricks DLT Pipeline Lineage to Track Table (#23888) * MINOR: Fix Databricks DLT Pipeline Lineage to Track Table * fix tests * add support for s3 pipeline lineage as well --- .../databrickspipeline/kafka_parser.py | 154 +++++++++ .../pipeline/databrickspipeline/metadata.py | 323 ++++++++++++++---- .../pipeline/test_databricks_kafka_lineage.py | 6 +- .../pipeline/test_databricks_kafka_parser.py | 291 ++++++++++++++++ 4 files changed, 702 insertions(+), 72 deletions(-) diff --git a/ingestion/src/metadata/ingestion/source/pipeline/databrickspipeline/kafka_parser.py b/ingestion/src/metadata/ingestion/source/pipeline/databrickspipeline/kafka_parser.py index a8eb92fc114..fc0c59b44dc 100644 --- a/ingestion/src/metadata/ingestion/source/pipeline/databrickspipeline/kafka_parser.py +++ b/ingestion/src/metadata/ingestion/source/pipeline/databrickspipeline/kafka_parser.py @@ -50,6 +50,26 @@ DLT_TABLE_NAME_FUNCTION = re.compile( re.DOTALL | re.IGNORECASE, ) +# Pattern to extract dlt.read_stream("table_name") calls +DLT_READ_STREAM_PATTERN = re.compile( + r'dlt\.read_stream\s*\(\s*["\']([^"\']+)["\']\s*\)', + re.IGNORECASE, +) + +# Pattern to extract dlt.read("table_name") calls (batch reads) +DLT_READ_PATTERN = re.compile( + r'dlt\.read\s*\(\s*["\']([^"\']+)["\']\s*\)', + re.IGNORECASE, +) + +# Pattern to extract S3 paths from spark.read operations +# Matches: spark.read.json("s3://..."), spark.read.format("parquet").load("s3a://...") +# Uses a simpler pattern that captures any spark.read followed by method calls ending with a path +S3_PATH_PATTERN = re.compile( + r'spark\.read.*?\.(?:load|json|parquet|csv|orc|avro)\s*\(\s*["\']([^"\']+)["\']\s*\)', + re.DOTALL | re.IGNORECASE, +) + @dataclass class KafkaSourceConfig: @@ -60,6 +80,17 @@ class KafkaSourceConfig: group_id_prefix: Optional[str] = None +@dataclass +class DLTTableDependency: + """Model for DLT table dependencies""" + + table_name: str + depends_on: List[str] = field(default_factory=list) + reads_from_kafka: bool = False + reads_from_s3: bool = False + s3_locations: List[str] = field(default_factory=list) + + def _extract_variables(source_code: str) -> dict: """ Extract variable assignments from source code @@ -316,6 +347,129 @@ def _infer_table_name_from_function( return None +def extract_dlt_table_dependencies(source_code: str) -> List[DLTTableDependency]: + """ + Extract DLT table dependencies by analyzing @dlt.table decorators and dlt.read_stream calls + + For each DLT table, identifies: + - Table name from @dlt.table(name="...") + - Dependencies from dlt.read_stream("other_table") or dlt.read("other_table") calls + - Whether it reads from Kafka (spark.readStream.format("kafka")) + - Whether it reads from S3 (spark.read.json("s3://...")) + - S3 locations if applicable + + Example: + @dlt.table(name="source_table") + def my_source(): + return spark.read.json("s3://bucket/path/")... + + @dlt.table(name="target_table") + def my_target(): + return dlt.read("source_table") + + Returns: + [ + DLTTableDependency(table_name="source_table", depends_on=[], reads_from_s3=True, + s3_locations=["s3://bucket/path/"]), + DLTTableDependency(table_name="target_table", depends_on=["source_table"], + reads_from_s3=False) + ] + """ + dependencies = [] + + try: + if not source_code: + return dependencies + + # Split source code into function definitions + # Pattern: @dlt.table(...) or @dlt.view(...) followed by def function_name(): + # Handle multiline decorators with potentially nested parentheses + function_pattern = re.compile( + r"(@dlt\.(?:table|view)\s*\(.*?\)\s*def\s+[a-zA-Z_][a-zA-Z0-9_]*\s*\([^)]*\)\s*:.*?)(?=@dlt\.|$)", + re.DOTALL | re.IGNORECASE, + ) + + for match in function_pattern.finditer(source_code): + try: + function_block = match.group(1) + + # Extract table name from @dlt.table decorator + table_name = None + name_match = DLT_TABLE_NAME_LITERAL.search(function_block) + if name_match and name_match.group(1): + table_name = name_match.group(1) + else: + # Try function name pattern + func_name_match = DLT_TABLE_NAME_FUNCTION.search(function_block) + if func_name_match and func_name_match.group(1): + table_name = _infer_table_name_from_function( + func_name_match.group(1), source_code + ) + + if not table_name: + # Try to extract from function definition itself + def_match = re.search( + r"def\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(", function_block + ) + if def_match: + table_name = def_match.group(1) + + if not table_name: + logger.debug( + f"Could not extract table name from block: {function_block[:100]}..." + ) + continue + + # Check if it reads from Kafka + reads_from_kafka = bool(KAFKA_STREAM_PATTERN.search(function_block)) + + # Check if it reads from S3 + s3_locations = [] + for s3_match in S3_PATH_PATTERN.finditer(function_block): + s3_path = s3_match.group(1) + if s3_path.startswith(("s3://", "s3a://", "s3n://")): + s3_locations.append(s3_path) + logger.debug(f"Table {table_name} reads from S3: {s3_path}") + + reads_from_s3 = len(s3_locations) > 0 + + # Extract dlt.read_stream dependencies (streaming) + depends_on = [] + for stream_match in DLT_READ_STREAM_PATTERN.finditer(function_block): + source_table = stream_match.group(1) + depends_on.append(source_table) + logger.debug(f"Table {table_name} streams from {source_table}") + + # Extract dlt.read dependencies (batch) + for read_match in DLT_READ_PATTERN.finditer(function_block): + source_table = read_match.group(1) + depends_on.append(source_table) + logger.debug(f"Table {table_name} reads from {source_table}") + + dependency = DLTTableDependency( + table_name=table_name, + depends_on=depends_on, + reads_from_kafka=reads_from_kafka, + reads_from_s3=reads_from_s3, + s3_locations=s3_locations, + ) + dependencies.append(dependency) + logger.debug( + f"Extracted dependency: {table_name} - depends_on={depends_on}, " + f"reads_from_kafka={reads_from_kafka}, reads_from_s3={reads_from_s3}, " + f"s3_locations={s3_locations}" + ) + + except Exception as exc: + logger.debug(f"Error parsing function block: {exc}") + continue + + except Exception as exc: + logger.warning(f"Error extracting DLT table dependencies: {exc}") + + return dependencies + + def get_pipeline_libraries(pipeline_config: dict, client=None) -> List[str]: """ Extract notebook and file paths from pipeline configuration diff --git a/ingestion/src/metadata/ingestion/source/pipeline/databrickspipeline/metadata.py b/ingestion/src/metadata/ingestion/source/pipeline/databrickspipeline/metadata.py index 0aaedf7d35d..5ff0f379223 100644 --- a/ingestion/src/metadata/ingestion/source/pipeline/databrickspipeline/metadata.py +++ b/ingestion/src/metadata/ingestion/source/pipeline/databrickspipeline/metadata.py @@ -58,7 +58,7 @@ from metadata.ingestion.lineage.sql_lineage import get_column_fqn from metadata.ingestion.models.pipeline_status import OMetaPipelineStatus from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.source.pipeline.databrickspipeline.kafka_parser import ( - extract_dlt_table_names, + extract_dlt_table_dependencies, extract_kafka_sources, ) from metadata.ingestion.source.pipeline.databrickspipeline.models import ( @@ -820,107 +820,293 @@ class DatabrickspipelineSource(PipelineServiceSource): else: logger.info(f"⊗ No Kafka sources found in notebook") - # Extract DLT table names - logger.info(f"⟳ Parsing DLT table names from notebook...") - dlt_table_names = extract_dlt_table_names(source_code) - if dlt_table_names: + # Extract DLT table dependencies + logger.info(f"⟳ Parsing DLT table dependencies from notebook...") + dlt_dependencies = extract_dlt_table_dependencies(source_code) + if dlt_dependencies: logger.info( - f"✓ Found {len(dlt_table_names)} DLT table(s): {dlt_table_names}" + f"✓ Found {len(dlt_dependencies)} DLT table(s) with dependencies" ) + for dep in dlt_dependencies: + s3_info = ( + f", reads_from_s3={dep.reads_from_s3}, s3_locations={dep.s3_locations}" + if dep.reads_from_s3 + else "" + ) + logger.info( + f" - {dep.table_name}: depends_on={dep.depends_on}, " + f"reads_from_kafka={dep.reads_from_kafka}{s3_info}" + ) else: - logger.info(f"⊗ No DLT tables found in notebook") + logger.info(f"⊗ No DLT table dependencies found in notebook") - if not dlt_table_names or not kafka_sources: + # Check if we have anything to process + has_kafka = kafka_sources and len(kafka_sources) > 0 + has_s3 = any(dep.reads_from_s3 for dep in dlt_dependencies) + has_tables = dlt_dependencies and len(dlt_dependencies) > 0 + + if not dlt_dependencies: logger.warning( - f"⊗ Skipping lineage for this notebook - need both Kafka sources AND DLT tables" - ) - logger.info( - f" Kafka sources: {len(kafka_sources) if kafka_sources else 0}" - ) - logger.info( - f" DLT tables: {len(dlt_table_names) if dlt_table_names else 0}" + f"⊗ Skipping lineage for this notebook - no DLT tables found" ) continue - logger.info( - f"✓ Notebook has both Kafka sources and DLT tables - creating lineage..." - ) + if not has_kafka and not has_s3: + logger.info( + f"⊗ No external sources (Kafka or S3) found in this notebook - only table-to-table lineage will be created" + ) - # Create lineage for each Kafka topic -> DLT table + logger.info(f"✓ Notebook has DLT tables - creating lineage...") + if has_kafka: + logger.info(f" Kafka sources: {len(kafka_sources)}") + if has_s3: + s3_count = sum( + len(dep.s3_locations) + for dep in dlt_dependencies + if dep.reads_from_s3 + ) + logger.info(f" S3 sources: {s3_count} location(s)") + + # Create lineage edges based on dependencies logger.info(f"\n⟳ Creating lineage edges...") lineage_created = 0 + # Step 1: Create Kafka topic -> DLT table lineage + # Build a map to identify which tables/views read from external sources (Kafka/S3) + external_sources_map = {} + for dep in dlt_dependencies: + if dep.reads_from_kafka or dep.reads_from_s3: + external_sources_map[dep.table_name] = True + for kafka_config in kafka_sources: for topic_name in kafka_config.topics: try: - logger.info(f"\n 🔍 Processing topic: {topic_name}") - - # Use smart discovery to find topic logger.info( - f" ⟳ Searching for topic in OpenMetadata..." + f"\n 🔍 Processing Kafka topic: {topic_name}" ) - kafka_topic = self._find_kafka_topic(topic_name) + kafka_topic = self._find_kafka_topic(topic_name) if not kafka_topic: logger.warning( f" ✗ Kafka topic '{topic_name}' not found in OpenMetadata" ) - logger.info( - f" 💡 Make sure the topic is ingested from a messaging service" - ) continue logger.info( f" ✓ Topic found: {kafka_topic.fullyQualifiedName.root if hasattr(kafka_topic.fullyQualifiedName, 'root') else kafka_topic.fullyQualifiedName}" ) - # Create lineage to each DLT table in this notebook - for table_name in dlt_table_names: - logger.info( - f" 🔍 Processing target table: {table_name}" - ) - logger.info( - f" ⟳ Searching in Databricks/Unity Catalog services..." + # Find tables that read directly from Kafka OR from an external source view + for dep in dlt_dependencies: + is_kafka_consumer = dep.reads_from_kafka or any( + src in external_sources_map + for src in dep.depends_on ) - # Use cached Databricks service lookup - target_table_entity = self._find_dlt_table( - table_name=table_name, + if is_kafka_consumer: + logger.info( + f" 🔍 Processing table: {dep.table_name}" + ) + + target_table = self._find_dlt_table( + table_name=dep.table_name, + catalog=target_catalog, + schema=target_schema, + ) + + if target_table: + table_fqn = ( + target_table.fullyQualifiedName.root + if hasattr( + target_table.fullyQualifiedName, + "root", + ) + else target_table.fullyQualifiedName + ) + logger.info( + f" ✅ Creating lineage: {topic_name} -> {table_fqn}" + ) + + yield Either( + right=AddLineageRequest( + edge=EntitiesEdge( + fromEntity=EntityReference( + id=kafka_topic.id, + type="topic", + ), + toEntity=EntityReference( + id=target_table.id.root + if hasattr( + target_table.id, "root" + ) + else target_table.id, + type="table", + ), + lineageDetails=LineageDetails( + pipeline=EntityReference( + id=pipeline_entity.id.root, + type="pipeline", + ), + source=LineageSource.PipelineLineage, + ), + ) + ) + ) + lineage_created += 1 + else: + logger.warning( + f" ✗ Table '{dep.table_name}' not found" + ) + + except Exception as exc: + logger.error( + f" ✗ Failed to process topic {topic_name}: {exc}" + ) + logger.debug(traceback.format_exc()) + continue + + # Step 2: Create table-to-table lineage for downstream dependencies + for dep in dlt_dependencies: + if dep.depends_on: + for source_table_name in dep.depends_on: + try: + logger.info( + f"\n 🔍 Processing table dependency: {source_table_name} -> {dep.table_name}" + ) + + # Check if source is a view/table that reads from S3 + source_dep = next( + ( + d + for d in dlt_dependencies + if d.table_name == source_table_name + ), + None, + ) + + # If source reads from S3, create container → table lineage + if ( + source_dep + and source_dep.reads_from_s3 + and source_dep.s3_locations + ): + target_table = self._find_dlt_table( + table_name=dep.table_name, + catalog=target_catalog, + schema=target_schema, + ) + + if target_table: + for s3_location in source_dep.s3_locations: + logger.info( + f" 🔍 Looking for S3 container: {s3_location}" + ) + # Search for container by S3 path + storage_location = s3_location.rstrip( + "/" + ) + container_entity = self.metadata.es_search_container_by_path( + full_path=storage_location + ) + + if ( + container_entity + and container_entity[0] + ): + logger.info( + f" ✅ Creating lineage: {container_entity[0].fullyQualifiedName.root if hasattr(container_entity[0].fullyQualifiedName, 'root') else container_entity[0].fullyQualifiedName} -> {target_table.fullyQualifiedName.root if hasattr(target_table.fullyQualifiedName, 'root') else target_table.fullyQualifiedName}" + ) + + yield Either( + right=AddLineageRequest( + edge=EntitiesEdge( + fromEntity=EntityReference( + id=container_entity[ + 0 + ].id, + type="container", + ), + toEntity=EntityReference( + id=target_table.id.root + if hasattr( + target_table.id, + "root", + ) + else target_table.id, + type="table", + ), + lineageDetails=LineageDetails( + pipeline=EntityReference( + id=pipeline_entity.id.root, + type="pipeline", + ), + source=LineageSource.PipelineLineage, + ), + ) + ) + ) + lineage_created += 1 + else: + logger.warning( + f" ✗ S3 container not found for path: {storage_location}" + ) + logger.info( + f" Make sure the S3 container is ingested in OpenMetadata" + ) + else: + logger.warning( + f" ✗ Target table '{dep.table_name}' not found" + ) + continue + + # Otherwise, create table → table lineage + source_table = self._find_dlt_table( + table_name=source_table_name, + catalog=target_catalog, + schema=target_schema, + ) + target_table = self._find_dlt_table( + table_name=dep.table_name, catalog=target_catalog, schema=target_schema, ) - if target_table_entity: - table_fqn = ( - target_table_entity.fullyQualifiedName.root + if source_table and target_table: + source_fqn = ( + source_table.fullyQualifiedName.root if hasattr( - target_table_entity.fullyQualifiedName, - "root", + source_table.fullyQualifiedName, "root" ) - else target_table_entity.fullyQualifiedName + else source_table.fullyQualifiedName + ) + target_fqn = ( + target_table.fullyQualifiedName.root + if hasattr( + target_table.fullyQualifiedName, "root" + ) + else target_table.fullyQualifiedName ) logger.info( - f" ✓ Target table found: {table_fqn}" + f" ✅ Creating lineage: {source_fqn} -> {target_fqn}" ) - logger.info( - f" ✅ Creating lineage: {topic_name} -> {table_fqn}" - ) - logger.info(f" Pipeline: {pipeline_id}") yield Either( right=AddLineageRequest( edge=EntitiesEdge( fromEntity=EntityReference( - id=kafka_topic.id, - type="topic", + id=source_table.id.root + if hasattr( + source_table.id, "root" + ) + else source_table.id, + type="table", ), toEntity=EntityReference( - id=target_table_entity.id.root + id=target_table.id.root if hasattr( - target_table_entity.id, - "root", + target_table.id, "root" ) - else target_table_entity.id, + else target_table.id, type="table", ), lineageDetails=LineageDetails( @@ -935,22 +1121,21 @@ class DatabrickspipelineSource(PipelineServiceSource): ) lineage_created += 1 else: - logger.warning( - f" ✗ Target table '{table_name}' not found in OpenMetadata" - ) - logger.info( - f" 💡 Expected location: {target_catalog}.{target_schema}.{table_name}" - ) - logger.info( - f" 💡 Make sure the table is ingested from a Databricks/Unity Catalog database service" - ) + if not source_table: + logger.warning( + f" ✗ Source table '{source_table_name}' not found" + ) + if not target_table: + logger.warning( + f" ✗ Target table '{dep.table_name}' not found" + ) - except Exception as exc: - logger.error( - f" ✗ Failed to process topic {topic_name}: {exc}" - ) - logger.debug(traceback.format_exc()) - continue + except Exception as exc: + logger.error( + f" ✗ Failed to process dependency {source_table_name} -> {dep.table_name}: {exc}" + ) + logger.debug(traceback.format_exc()) + continue logger.info( f"\n✓ Lineage edges created for this notebook: {lineage_created}" diff --git a/ingestion/tests/unit/topology/pipeline/test_databricks_kafka_lineage.py b/ingestion/tests/unit/topology/pipeline/test_databricks_kafka_lineage.py index 09d32cd89bb..7b9f4f0ba4c 100644 --- a/ingestion/tests/unit/topology/pipeline/test_databricks_kafka_lineage.py +++ b/ingestion/tests/unit/topology/pipeline/test_databricks_kafka_lineage.py @@ -350,16 +350,16 @@ class TestKafkaLineageIntegration(unittest.TestCase): } self.mock_client.get_pipeline_details.return_value = pipeline_config - # Mock notebook source code + # Mock notebook source code - realistic DLT pattern with Kafka notebook_source = """ import dlt topic_name = "dev.ern.cashout.moneyRequest_v1" entity_name = "moneyRequest" - @dlt.table(name=materializer.generate_event_log_table_name()) + @dlt.table(name="moneyRequest") def event_log(): - return df + return spark.readStream.format("kafka").option("subscribe", topic_name).load() """ self.mock_client.export_notebook_source.return_value = notebook_source diff --git a/ingestion/tests/unit/topology/pipeline/test_databricks_kafka_parser.py b/ingestion/tests/unit/topology/pipeline/test_databricks_kafka_parser.py index 290b29b5428..4f3e6a30790 100644 --- a/ingestion/tests/unit/topology/pipeline/test_databricks_kafka_parser.py +++ b/ingestion/tests/unit/topology/pipeline/test_databricks_kafka_parser.py @@ -16,6 +16,7 @@ Unit tests for Databricks Kafka parser import unittest from metadata.ingestion.source.pipeline.databrickspipeline.kafka_parser import ( + extract_dlt_table_dependencies, extract_dlt_table_names, extract_kafka_sources, get_pipeline_libraries, @@ -587,5 +588,295 @@ class TestKafkaFallbackPatterns(unittest.TestCase): self.assertEqual(table_names[0], "moneyRequest") +class TestDLTTableDependencies(unittest.TestCase): + """Test cases for DLT table dependency extraction""" + + def test_bronze_silver_pattern(self): + """Test table dependency pattern with Kafka source and downstream table""" + source_code = """ + import dlt + from pyspark.sql.functions import * + + @dlt.table(name="orders_bronze") + def bronze(): + return spark.readStream.format("kafka").option("subscribe", "orders").load() + + @dlt.table(name="orders_silver") + def silver(): + return dlt.read_stream("orders_bronze").select("*") + """ + deps = extract_dlt_table_dependencies(source_code) + self.assertEqual(len(deps), 2) + + bronze = next(d for d in deps if d.table_name == "orders_bronze") + self.assertTrue(bronze.reads_from_kafka) + self.assertEqual(bronze.depends_on, []) + + silver = next(d for d in deps if d.table_name == "orders_silver") + self.assertFalse(silver.reads_from_kafka) + self.assertEqual(silver.depends_on, ["orders_bronze"]) + + def test_kafka_view_bronze_silver(self): + """Test Kafka view with multi-tier table dependencies""" + source_code = """ + import dlt + + @dlt.view() + def kafka_source(): + return spark.readStream.format("kafka").option("subscribe", "topic").load() + + @dlt.table(name="bronze_table") + def bronze(): + return dlt.read_stream("kafka_source") + + @dlt.table(name="silver_table") + def silver(): + return dlt.read_stream("bronze_table") + """ + deps = extract_dlt_table_dependencies(source_code) + + bronze = next((d for d in deps if d.table_name == "bronze_table"), None) + self.assertIsNotNone(bronze) + self.assertEqual(bronze.depends_on, ["kafka_source"]) + self.assertFalse(bronze.reads_from_kafka) + + silver = next((d for d in deps if d.table_name == "silver_table"), None) + self.assertIsNotNone(silver) + self.assertEqual(silver.depends_on, ["bronze_table"]) + self.assertFalse(silver.reads_from_kafka) + + def test_multiple_dependencies(self): + """Test table with multiple source dependencies""" + source_code = """ + @dlt.table(name="source1") + def s1(): + return spark.readStream.format("kafka").load() + + @dlt.table(name="source2") + def s2(): + return spark.readStream.format("kafka").load() + + @dlt.table(name="merged") + def merge(): + df1 = dlt.read_stream("source1") + df2 = dlt.read_stream("source2") + return df1.union(df2) + """ + deps = extract_dlt_table_dependencies(source_code) + + merged = next((d for d in deps if d.table_name == "merged"), None) + self.assertIsNotNone(merged) + self.assertEqual(sorted(merged.depends_on), ["source1", "source2"]) + self.assertFalse(merged.reads_from_kafka) + + def test_no_dependencies(self): + """Test table with no dependencies (reads from file)""" + source_code = """ + @dlt.table(name="static_data") + def static(): + return spark.read.parquet("/path/to/data") + """ + deps = extract_dlt_table_dependencies(source_code) + self.assertEqual(len(deps), 1) + self.assertEqual(deps[0].table_name, "static_data") + self.assertEqual(deps[0].depends_on, []) + self.assertFalse(deps[0].reads_from_kafka) + + def test_function_name_as_table_name(self): + """Test using function name when no explicit name in decorator""" + source_code = """ + @dlt.table() + def my_bronze_table(): + return spark.readStream.format("kafka").load() + """ + deps = extract_dlt_table_dependencies(source_code) + self.assertEqual(len(deps), 1) + self.assertEqual(deps[0].table_name, "my_bronze_table") + self.assertTrue(deps[0].reads_from_kafka) + + def test_empty_source_code(self): + """Test empty source code returns empty list""" + deps = extract_dlt_table_dependencies("") + self.assertEqual(len(deps), 0) + + def test_real_world_pattern(self): + """Test the exact pattern from user's example""" + source_code = """ +import dlt +from pyspark.sql.functions import * +from pyspark.sql.types import * + +TOPIC = "orders" + +@dlt.view(comment="Kafka source") +def kafka_orders_source(): + return ( + spark.readStream + .format("kafka") + .option("subscribe", TOPIC) + .load() + ) + +@dlt.table(name="orders_bronze") +def orders_bronze(): + return dlt.read_stream("kafka_orders_source").select("*") + +@dlt.table(name="orders_silver") +def orders_silver(): + return dlt.read_stream("orders_bronze").select("*") + """ + deps = extract_dlt_table_dependencies(source_code) + + # Should find bronze and silver (kafka_orders_source is a view, not a table) + table_deps = [ + d for d in deps if d.table_name in ["orders_bronze", "orders_silver"] + ] + self.assertEqual(len(table_deps), 2) + + bronze = next((d for d in table_deps if d.table_name == "orders_bronze"), None) + self.assertIsNotNone(bronze) + self.assertEqual(bronze.depends_on, ["kafka_orders_source"]) + self.assertFalse(bronze.reads_from_kafka) + + silver = next((d for d in table_deps if d.table_name == "orders_silver"), None) + self.assertIsNotNone(silver) + self.assertEqual(silver.depends_on, ["orders_bronze"]) + self.assertFalse(silver.reads_from_kafka) + + +class TestS3SourceDetection(unittest.TestCase): + """Test cases for S3 source detection""" + + def test_s3_json_source(self): + """Test detecting S3 source with spark.read.json()""" + source_code = """ + @dlt.view() + def s3_source(): + return spark.read.json("s3://mybucket/data/") + """ + deps = extract_dlt_table_dependencies(source_code) + self.assertEqual(len(deps), 1) + self.assertTrue(deps[0].reads_from_s3) + self.assertEqual(deps[0].s3_locations, ["s3://mybucket/data/"]) + + def test_s3_parquet_source(self): + """Test detecting S3 source with spark.read.parquet()""" + source_code = """ + @dlt.table(name="parquet_table") + def load_parquet(): + return spark.read.parquet("s3a://bucket/path/file.parquet") + """ + deps = extract_dlt_table_dependencies(source_code) + self.assertEqual(len(deps), 1) + self.assertTrue(deps[0].reads_from_s3) + self.assertIn("s3a://bucket/path/file.parquet", deps[0].s3_locations) + + def test_s3_with_options(self): + """Test S3 source with options""" + source_code = """ + @dlt.view() + def external_source(): + return ( + spark.read + .option("multiline", "true") + .json("s3://test-firehose-con-bucket/firehose_data/") + ) + """ + deps = extract_dlt_table_dependencies(source_code) + self.assertEqual(len(deps), 1) + self.assertTrue(deps[0].reads_from_s3) + self.assertEqual( + deps[0].s3_locations, ["s3://test-firehose-con-bucket/firehose_data/"] + ) + + def test_s3_format_load(self): + """Test S3 with format().load() pattern""" + source_code = """ + @dlt.table(name="csv_data") + def load_csv(): + return spark.read.format("csv").option("header", "true").load("s3://bucket/data.csv") + """ + deps = extract_dlt_table_dependencies(source_code) + self.assertEqual(len(deps), 1) + self.assertTrue(deps[0].reads_from_s3) + self.assertIn("s3://bucket/data.csv", deps[0].s3_locations) + + def test_dlt_read_batch(self): + """Test dlt.read() for batch table dependencies""" + source_code = """ + @dlt.table(name="bronze") + def bronze_table(): + return dlt.read("source_view") + """ + deps = extract_dlt_table_dependencies(source_code) + self.assertEqual(len(deps), 1) + self.assertEqual(deps[0].depends_on, ["source_view"]) + + def test_mixed_dlt_read_and_read_stream(self): + """Test both dlt.read() and dlt.read_stream() in same pipeline""" + source_code = """ + @dlt.table(name="batch_table") + def batch(): + return dlt.read("source1") + + @dlt.table(name="stream_table") + def stream(): + return dlt.read_stream("source2") + """ + deps = extract_dlt_table_dependencies(source_code) + self.assertEqual(len(deps), 2) + + batch = next(d for d in deps if d.table_name == "batch_table") + self.assertEqual(batch.depends_on, ["source1"]) + + stream = next(d for d in deps if d.table_name == "stream_table") + self.assertEqual(stream.depends_on, ["source2"]) + + def test_user_s3_example(self): + """Test the user's exact S3 example""" + source_code = """ + import dlt + from pyspark.sql.functions import col + + @dlt.view(comment="External source data from S3") + def external_source(): + return ( + spark.read + .option("multiline", "true") + .json("s3://test-firehose-con-bucket/firehose_data/") + ) + + @dlt.table(name="bronze_firehose_data") + def bronze_firehose_data(): + return dlt.read("external_source") + + @dlt.table(name="silver_firehose_data") + def silver_firehose_data(): + return dlt.read("bronze_firehose_data") + """ + deps = extract_dlt_table_dependencies(source_code) + self.assertEqual(len(deps), 3) + + # Verify external_source + external = next((d for d in deps if d.table_name == "external_source"), None) + self.assertIsNotNone(external) + self.assertTrue(external.reads_from_s3) + self.assertIn( + "s3://test-firehose-con-bucket/firehose_data/", external.s3_locations + ) + + # Verify bronze + bronze = next((d for d in deps if d.table_name == "bronze_firehose_data"), None) + self.assertIsNotNone(bronze) + self.assertEqual(bronze.depends_on, ["external_source"]) + self.assertFalse(bronze.reads_from_s3) + + # Verify silver + silver = next((d for d in deps if d.table_name == "silver_firehose_data"), None) + self.assertIsNotNone(silver) + self.assertEqual(silver.depends_on, ["bronze_firehose_data"]) + self.assertFalse(silver.reads_from_s3) + + if __name__ == "__main__": unittest.main()