MINOR: Fix Databricks DLT Pipeline Lineage to Track Table (#23888)

* MINOR: Fix Databricks DLT Pipeline Lineage to Track Table

* fix tests

* add support for s3 pipeline lineage as well
This commit is contained in:
Mayur Singal 2025-10-15 14:24:01 +05:30 committed by GitHub
parent 14f5d0610d
commit 3c527ca83b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 702 additions and 72 deletions

View File

@ -50,6 +50,26 @@ DLT_TABLE_NAME_FUNCTION = re.compile(
re.DOTALL | re.IGNORECASE,
)
# Pattern to extract dlt.read_stream("table_name") calls
DLT_READ_STREAM_PATTERN = re.compile(
r'dlt\.read_stream\s*\(\s*["\']([^"\']+)["\']\s*\)',
re.IGNORECASE,
)
# Pattern to extract dlt.read("table_name") calls (batch reads)
DLT_READ_PATTERN = re.compile(
r'dlt\.read\s*\(\s*["\']([^"\']+)["\']\s*\)',
re.IGNORECASE,
)
# Pattern to extract S3 paths from spark.read operations
# Matches: spark.read.json("s3://..."), spark.read.format("parquet").load("s3a://...")
# Uses a simpler pattern that captures any spark.read followed by method calls ending with a path
S3_PATH_PATTERN = re.compile(
r'spark\.read.*?\.(?:load|json|parquet|csv|orc|avro)\s*\(\s*["\']([^"\']+)["\']\s*\)',
re.DOTALL | re.IGNORECASE,
)
@dataclass
class KafkaSourceConfig:
@ -60,6 +80,17 @@ class KafkaSourceConfig:
group_id_prefix: Optional[str] = None
@dataclass
class DLTTableDependency:
"""Model for DLT table dependencies"""
table_name: str
depends_on: List[str] = field(default_factory=list)
reads_from_kafka: bool = False
reads_from_s3: bool = False
s3_locations: List[str] = field(default_factory=list)
def _extract_variables(source_code: str) -> dict:
"""
Extract variable assignments from source code
@ -316,6 +347,129 @@ def _infer_table_name_from_function(
return None
def extract_dlt_table_dependencies(source_code: str) -> List[DLTTableDependency]:
"""
Extract DLT table dependencies by analyzing @dlt.table decorators and dlt.read_stream calls
For each DLT table, identifies:
- Table name from @dlt.table(name="...")
- Dependencies from dlt.read_stream("other_table") or dlt.read("other_table") calls
- Whether it reads from Kafka (spark.readStream.format("kafka"))
- Whether it reads from S3 (spark.read.json("s3://..."))
- S3 locations if applicable
Example:
@dlt.table(name="source_table")
def my_source():
return spark.read.json("s3://bucket/path/")...
@dlt.table(name="target_table")
def my_target():
return dlt.read("source_table")
Returns:
[
DLTTableDependency(table_name="source_table", depends_on=[], reads_from_s3=True,
s3_locations=["s3://bucket/path/"]),
DLTTableDependency(table_name="target_table", depends_on=["source_table"],
reads_from_s3=False)
]
"""
dependencies = []
try:
if not source_code:
return dependencies
# Split source code into function definitions
# Pattern: @dlt.table(...) or @dlt.view(...) followed by def function_name():
# Handle multiline decorators with potentially nested parentheses
function_pattern = re.compile(
r"(@dlt\.(?:table|view)\s*\(.*?\)\s*def\s+[a-zA-Z_][a-zA-Z0-9_]*\s*\([^)]*\)\s*:.*?)(?=@dlt\.|$)",
re.DOTALL | re.IGNORECASE,
)
for match in function_pattern.finditer(source_code):
try:
function_block = match.group(1)
# Extract table name from @dlt.table decorator
table_name = None
name_match = DLT_TABLE_NAME_LITERAL.search(function_block)
if name_match and name_match.group(1):
table_name = name_match.group(1)
else:
# Try function name pattern
func_name_match = DLT_TABLE_NAME_FUNCTION.search(function_block)
if func_name_match and func_name_match.group(1):
table_name = _infer_table_name_from_function(
func_name_match.group(1), source_code
)
if not table_name:
# Try to extract from function definition itself
def_match = re.search(
r"def\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(", function_block
)
if def_match:
table_name = def_match.group(1)
if not table_name:
logger.debug(
f"Could not extract table name from block: {function_block[:100]}..."
)
continue
# Check if it reads from Kafka
reads_from_kafka = bool(KAFKA_STREAM_PATTERN.search(function_block))
# Check if it reads from S3
s3_locations = []
for s3_match in S3_PATH_PATTERN.finditer(function_block):
s3_path = s3_match.group(1)
if s3_path.startswith(("s3://", "s3a://", "s3n://")):
s3_locations.append(s3_path)
logger.debug(f"Table {table_name} reads from S3: {s3_path}")
reads_from_s3 = len(s3_locations) > 0
# Extract dlt.read_stream dependencies (streaming)
depends_on = []
for stream_match in DLT_READ_STREAM_PATTERN.finditer(function_block):
source_table = stream_match.group(1)
depends_on.append(source_table)
logger.debug(f"Table {table_name} streams from {source_table}")
# Extract dlt.read dependencies (batch)
for read_match in DLT_READ_PATTERN.finditer(function_block):
source_table = read_match.group(1)
depends_on.append(source_table)
logger.debug(f"Table {table_name} reads from {source_table}")
dependency = DLTTableDependency(
table_name=table_name,
depends_on=depends_on,
reads_from_kafka=reads_from_kafka,
reads_from_s3=reads_from_s3,
s3_locations=s3_locations,
)
dependencies.append(dependency)
logger.debug(
f"Extracted dependency: {table_name} - depends_on={depends_on}, "
f"reads_from_kafka={reads_from_kafka}, reads_from_s3={reads_from_s3}, "
f"s3_locations={s3_locations}"
)
except Exception as exc:
logger.debug(f"Error parsing function block: {exc}")
continue
except Exception as exc:
logger.warning(f"Error extracting DLT table dependencies: {exc}")
return dependencies
def get_pipeline_libraries(pipeline_config: dict, client=None) -> List[str]:
"""
Extract notebook and file paths from pipeline configuration

View File

@ -58,7 +58,7 @@ from metadata.ingestion.lineage.sql_lineage import get_column_fqn
from metadata.ingestion.models.pipeline_status import OMetaPipelineStatus
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.pipeline.databrickspipeline.kafka_parser import (
extract_dlt_table_names,
extract_dlt_table_dependencies,
extract_kafka_sources,
)
from metadata.ingestion.source.pipeline.databrickspipeline.models import (
@ -820,107 +820,293 @@ class DatabrickspipelineSource(PipelineServiceSource):
else:
logger.info(f"⊗ No Kafka sources found in notebook")
# Extract DLT table names
logger.info(f"⟳ Parsing DLT table names from notebook...")
dlt_table_names = extract_dlt_table_names(source_code)
if dlt_table_names:
# Extract DLT table dependencies
logger.info(f"⟳ Parsing DLT table dependencies from notebook...")
dlt_dependencies = extract_dlt_table_dependencies(source_code)
if dlt_dependencies:
logger.info(
f"✓ Found {len(dlt_table_names)} DLT table(s): {dlt_table_names}"
f"✓ Found {len(dlt_dependencies)} DLT table(s) with dependencies"
)
for dep in dlt_dependencies:
s3_info = (
f", reads_from_s3={dep.reads_from_s3}, s3_locations={dep.s3_locations}"
if dep.reads_from_s3
else ""
)
logger.info(
f" - {dep.table_name}: depends_on={dep.depends_on}, "
f"reads_from_kafka={dep.reads_from_kafka}{s3_info}"
)
else:
logger.info(f"⊗ No DLT tables found in notebook")
logger.info(f"⊗ No DLT table dependencies found in notebook")
if not dlt_table_names or not kafka_sources:
# Check if we have anything to process
has_kafka = kafka_sources and len(kafka_sources) > 0
has_s3 = any(dep.reads_from_s3 for dep in dlt_dependencies)
has_tables = dlt_dependencies and len(dlt_dependencies) > 0
if not dlt_dependencies:
logger.warning(
f"⊗ Skipping lineage for this notebook - need both Kafka sources AND DLT tables"
)
logger.info(
f" Kafka sources: {len(kafka_sources) if kafka_sources else 0}"
)
logger.info(
f" DLT tables: {len(dlt_table_names) if dlt_table_names else 0}"
f"⊗ Skipping lineage for this notebook - no DLT tables found"
)
continue
logger.info(
f"✓ Notebook has both Kafka sources and DLT tables - creating lineage..."
)
if not has_kafka and not has_s3:
logger.info(
f"⊗ No external sources (Kafka or S3) found in this notebook - only table-to-table lineage will be created"
)
# Create lineage for each Kafka topic -> DLT table
logger.info(f"✓ Notebook has DLT tables - creating lineage...")
if has_kafka:
logger.info(f" Kafka sources: {len(kafka_sources)}")
if has_s3:
s3_count = sum(
len(dep.s3_locations)
for dep in dlt_dependencies
if dep.reads_from_s3
)
logger.info(f" S3 sources: {s3_count} location(s)")
# Create lineage edges based on dependencies
logger.info(f"\n⟳ Creating lineage edges...")
lineage_created = 0
# Step 1: Create Kafka topic -> DLT table lineage
# Build a map to identify which tables/views read from external sources (Kafka/S3)
external_sources_map = {}
for dep in dlt_dependencies:
if dep.reads_from_kafka or dep.reads_from_s3:
external_sources_map[dep.table_name] = True
for kafka_config in kafka_sources:
for topic_name in kafka_config.topics:
try:
logger.info(f"\n 🔍 Processing topic: {topic_name}")
# Use smart discovery to find topic
logger.info(
f" ⟳ Searching for topic in OpenMetadata..."
f"\n 🔍 Processing Kafka topic: {topic_name}"
)
kafka_topic = self._find_kafka_topic(topic_name)
kafka_topic = self._find_kafka_topic(topic_name)
if not kafka_topic:
logger.warning(
f" ✗ Kafka topic '{topic_name}' not found in OpenMetadata"
)
logger.info(
f" 💡 Make sure the topic is ingested from a messaging service"
)
continue
logger.info(
f" ✓ Topic found: {kafka_topic.fullyQualifiedName.root if hasattr(kafka_topic.fullyQualifiedName, 'root') else kafka_topic.fullyQualifiedName}"
)
# Create lineage to each DLT table in this notebook
for table_name in dlt_table_names:
logger.info(
f" 🔍 Processing target table: {table_name}"
)
logger.info(
f" ⟳ Searching in Databricks/Unity Catalog services..."
# Find tables that read directly from Kafka OR from an external source view
for dep in dlt_dependencies:
is_kafka_consumer = dep.reads_from_kafka or any(
src in external_sources_map
for src in dep.depends_on
)
# Use cached Databricks service lookup
target_table_entity = self._find_dlt_table(
table_name=table_name,
if is_kafka_consumer:
logger.info(
f" 🔍 Processing table: {dep.table_name}"
)
target_table = self._find_dlt_table(
table_name=dep.table_name,
catalog=target_catalog,
schema=target_schema,
)
if target_table:
table_fqn = (
target_table.fullyQualifiedName.root
if hasattr(
target_table.fullyQualifiedName,
"root",
)
else target_table.fullyQualifiedName
)
logger.info(
f" ✅ Creating lineage: {topic_name} -> {table_fqn}"
)
yield Either(
right=AddLineageRequest(
edge=EntitiesEdge(
fromEntity=EntityReference(
id=kafka_topic.id,
type="topic",
),
toEntity=EntityReference(
id=target_table.id.root
if hasattr(
target_table.id, "root"
)
else target_table.id,
type="table",
),
lineageDetails=LineageDetails(
pipeline=EntityReference(
id=pipeline_entity.id.root,
type="pipeline",
),
source=LineageSource.PipelineLineage,
),
)
)
)
lineage_created += 1
else:
logger.warning(
f" ✗ Table '{dep.table_name}' not found"
)
except Exception as exc:
logger.error(
f" ✗ Failed to process topic {topic_name}: {exc}"
)
logger.debug(traceback.format_exc())
continue
# Step 2: Create table-to-table lineage for downstream dependencies
for dep in dlt_dependencies:
if dep.depends_on:
for source_table_name in dep.depends_on:
try:
logger.info(
f"\n 🔍 Processing table dependency: {source_table_name} -> {dep.table_name}"
)
# Check if source is a view/table that reads from S3
source_dep = next(
(
d
for d in dlt_dependencies
if d.table_name == source_table_name
),
None,
)
# If source reads from S3, create container → table lineage
if (
source_dep
and source_dep.reads_from_s3
and source_dep.s3_locations
):
target_table = self._find_dlt_table(
table_name=dep.table_name,
catalog=target_catalog,
schema=target_schema,
)
if target_table:
for s3_location in source_dep.s3_locations:
logger.info(
f" 🔍 Looking for S3 container: {s3_location}"
)
# Search for container by S3 path
storage_location = s3_location.rstrip(
"/"
)
container_entity = self.metadata.es_search_container_by_path(
full_path=storage_location
)
if (
container_entity
and container_entity[0]
):
logger.info(
f" ✅ Creating lineage: {container_entity[0].fullyQualifiedName.root if hasattr(container_entity[0].fullyQualifiedName, 'root') else container_entity[0].fullyQualifiedName} -> {target_table.fullyQualifiedName.root if hasattr(target_table.fullyQualifiedName, 'root') else target_table.fullyQualifiedName}"
)
yield Either(
right=AddLineageRequest(
edge=EntitiesEdge(
fromEntity=EntityReference(
id=container_entity[
0
].id,
type="container",
),
toEntity=EntityReference(
id=target_table.id.root
if hasattr(
target_table.id,
"root",
)
else target_table.id,
type="table",
),
lineageDetails=LineageDetails(
pipeline=EntityReference(
id=pipeline_entity.id.root,
type="pipeline",
),
source=LineageSource.PipelineLineage,
),
)
)
)
lineage_created += 1
else:
logger.warning(
f" ✗ S3 container not found for path: {storage_location}"
)
logger.info(
f" Make sure the S3 container is ingested in OpenMetadata"
)
else:
logger.warning(
f" ✗ Target table '{dep.table_name}' not found"
)
continue
# Otherwise, create table → table lineage
source_table = self._find_dlt_table(
table_name=source_table_name,
catalog=target_catalog,
schema=target_schema,
)
target_table = self._find_dlt_table(
table_name=dep.table_name,
catalog=target_catalog,
schema=target_schema,
)
if target_table_entity:
table_fqn = (
target_table_entity.fullyQualifiedName.root
if source_table and target_table:
source_fqn = (
source_table.fullyQualifiedName.root
if hasattr(
target_table_entity.fullyQualifiedName,
"root",
source_table.fullyQualifiedName, "root"
)
else target_table_entity.fullyQualifiedName
else source_table.fullyQualifiedName
)
target_fqn = (
target_table.fullyQualifiedName.root
if hasattr(
target_table.fullyQualifiedName, "root"
)
else target_table.fullyQualifiedName
)
logger.info(
f" ✓ Target table found: {table_fqn}"
f" ✅ Creating lineage: {source_fqn} -> {target_fqn}"
)
logger.info(
f" ✅ Creating lineage: {topic_name} -> {table_fqn}"
)
logger.info(f" Pipeline: {pipeline_id}")
yield Either(
right=AddLineageRequest(
edge=EntitiesEdge(
fromEntity=EntityReference(
id=kafka_topic.id,
type="topic",
id=source_table.id.root
if hasattr(
source_table.id, "root"
)
else source_table.id,
type="table",
),
toEntity=EntityReference(
id=target_table_entity.id.root
id=target_table.id.root
if hasattr(
target_table_entity.id,
"root",
target_table.id, "root"
)
else target_table_entity.id,
else target_table.id,
type="table",
),
lineageDetails=LineageDetails(
@ -935,22 +1121,21 @@ class DatabrickspipelineSource(PipelineServiceSource):
)
lineage_created += 1
else:
logger.warning(
f" ✗ Target table '{table_name}' not found in OpenMetadata"
)
logger.info(
f" 💡 Expected location: {target_catalog}.{target_schema}.{table_name}"
)
logger.info(
f" 💡 Make sure the table is ingested from a Databricks/Unity Catalog database service"
)
if not source_table:
logger.warning(
f" ✗ Source table '{source_table_name}' not found"
)
if not target_table:
logger.warning(
f" ✗ Target table '{dep.table_name}' not found"
)
except Exception as exc:
logger.error(
f" ✗ Failed to process topic {topic_name}: {exc}"
)
logger.debug(traceback.format_exc())
continue
except Exception as exc:
logger.error(
f" ✗ Failed to process dependency {source_table_name} -> {dep.table_name}: {exc}"
)
logger.debug(traceback.format_exc())
continue
logger.info(
f"\n✓ Lineage edges created for this notebook: {lineage_created}"

View File

@ -350,16 +350,16 @@ class TestKafkaLineageIntegration(unittest.TestCase):
}
self.mock_client.get_pipeline_details.return_value = pipeline_config
# Mock notebook source code
# Mock notebook source code - realistic DLT pattern with Kafka
notebook_source = """
import dlt
topic_name = "dev.ern.cashout.moneyRequest_v1"
entity_name = "moneyRequest"
@dlt.table(name=materializer.generate_event_log_table_name())
@dlt.table(name="moneyRequest")
def event_log():
return df
return spark.readStream.format("kafka").option("subscribe", topic_name).load()
"""
self.mock_client.export_notebook_source.return_value = notebook_source

View File

@ -16,6 +16,7 @@ Unit tests for Databricks Kafka parser
import unittest
from metadata.ingestion.source.pipeline.databrickspipeline.kafka_parser import (
extract_dlt_table_dependencies,
extract_dlt_table_names,
extract_kafka_sources,
get_pipeline_libraries,
@ -587,5 +588,295 @@ class TestKafkaFallbackPatterns(unittest.TestCase):
self.assertEqual(table_names[0], "moneyRequest")
class TestDLTTableDependencies(unittest.TestCase):
"""Test cases for DLT table dependency extraction"""
def test_bronze_silver_pattern(self):
"""Test table dependency pattern with Kafka source and downstream table"""
source_code = """
import dlt
from pyspark.sql.functions import *
@dlt.table(name="orders_bronze")
def bronze():
return spark.readStream.format("kafka").option("subscribe", "orders").load()
@dlt.table(name="orders_silver")
def silver():
return dlt.read_stream("orders_bronze").select("*")
"""
deps = extract_dlt_table_dependencies(source_code)
self.assertEqual(len(deps), 2)
bronze = next(d for d in deps if d.table_name == "orders_bronze")
self.assertTrue(bronze.reads_from_kafka)
self.assertEqual(bronze.depends_on, [])
silver = next(d for d in deps if d.table_name == "orders_silver")
self.assertFalse(silver.reads_from_kafka)
self.assertEqual(silver.depends_on, ["orders_bronze"])
def test_kafka_view_bronze_silver(self):
"""Test Kafka view with multi-tier table dependencies"""
source_code = """
import dlt
@dlt.view()
def kafka_source():
return spark.readStream.format("kafka").option("subscribe", "topic").load()
@dlt.table(name="bronze_table")
def bronze():
return dlt.read_stream("kafka_source")
@dlt.table(name="silver_table")
def silver():
return dlt.read_stream("bronze_table")
"""
deps = extract_dlt_table_dependencies(source_code)
bronze = next((d for d in deps if d.table_name == "bronze_table"), None)
self.assertIsNotNone(bronze)
self.assertEqual(bronze.depends_on, ["kafka_source"])
self.assertFalse(bronze.reads_from_kafka)
silver = next((d for d in deps if d.table_name == "silver_table"), None)
self.assertIsNotNone(silver)
self.assertEqual(silver.depends_on, ["bronze_table"])
self.assertFalse(silver.reads_from_kafka)
def test_multiple_dependencies(self):
"""Test table with multiple source dependencies"""
source_code = """
@dlt.table(name="source1")
def s1():
return spark.readStream.format("kafka").load()
@dlt.table(name="source2")
def s2():
return spark.readStream.format("kafka").load()
@dlt.table(name="merged")
def merge():
df1 = dlt.read_stream("source1")
df2 = dlt.read_stream("source2")
return df1.union(df2)
"""
deps = extract_dlt_table_dependencies(source_code)
merged = next((d for d in deps if d.table_name == "merged"), None)
self.assertIsNotNone(merged)
self.assertEqual(sorted(merged.depends_on), ["source1", "source2"])
self.assertFalse(merged.reads_from_kafka)
def test_no_dependencies(self):
"""Test table with no dependencies (reads from file)"""
source_code = """
@dlt.table(name="static_data")
def static():
return spark.read.parquet("/path/to/data")
"""
deps = extract_dlt_table_dependencies(source_code)
self.assertEqual(len(deps), 1)
self.assertEqual(deps[0].table_name, "static_data")
self.assertEqual(deps[0].depends_on, [])
self.assertFalse(deps[0].reads_from_kafka)
def test_function_name_as_table_name(self):
"""Test using function name when no explicit name in decorator"""
source_code = """
@dlt.table()
def my_bronze_table():
return spark.readStream.format("kafka").load()
"""
deps = extract_dlt_table_dependencies(source_code)
self.assertEqual(len(deps), 1)
self.assertEqual(deps[0].table_name, "my_bronze_table")
self.assertTrue(deps[0].reads_from_kafka)
def test_empty_source_code(self):
"""Test empty source code returns empty list"""
deps = extract_dlt_table_dependencies("")
self.assertEqual(len(deps), 0)
def test_real_world_pattern(self):
"""Test the exact pattern from user's example"""
source_code = """
import dlt
from pyspark.sql.functions import *
from pyspark.sql.types import *
TOPIC = "orders"
@dlt.view(comment="Kafka source")
def kafka_orders_source():
return (
spark.readStream
.format("kafka")
.option("subscribe", TOPIC)
.load()
)
@dlt.table(name="orders_bronze")
def orders_bronze():
return dlt.read_stream("kafka_orders_source").select("*")
@dlt.table(name="orders_silver")
def orders_silver():
return dlt.read_stream("orders_bronze").select("*")
"""
deps = extract_dlt_table_dependencies(source_code)
# Should find bronze and silver (kafka_orders_source is a view, not a table)
table_deps = [
d for d in deps if d.table_name in ["orders_bronze", "orders_silver"]
]
self.assertEqual(len(table_deps), 2)
bronze = next((d for d in table_deps if d.table_name == "orders_bronze"), None)
self.assertIsNotNone(bronze)
self.assertEqual(bronze.depends_on, ["kafka_orders_source"])
self.assertFalse(bronze.reads_from_kafka)
silver = next((d for d in table_deps if d.table_name == "orders_silver"), None)
self.assertIsNotNone(silver)
self.assertEqual(silver.depends_on, ["orders_bronze"])
self.assertFalse(silver.reads_from_kafka)
class TestS3SourceDetection(unittest.TestCase):
"""Test cases for S3 source detection"""
def test_s3_json_source(self):
"""Test detecting S3 source with spark.read.json()"""
source_code = """
@dlt.view()
def s3_source():
return spark.read.json("s3://mybucket/data/")
"""
deps = extract_dlt_table_dependencies(source_code)
self.assertEqual(len(deps), 1)
self.assertTrue(deps[0].reads_from_s3)
self.assertEqual(deps[0].s3_locations, ["s3://mybucket/data/"])
def test_s3_parquet_source(self):
"""Test detecting S3 source with spark.read.parquet()"""
source_code = """
@dlt.table(name="parquet_table")
def load_parquet():
return spark.read.parquet("s3a://bucket/path/file.parquet")
"""
deps = extract_dlt_table_dependencies(source_code)
self.assertEqual(len(deps), 1)
self.assertTrue(deps[0].reads_from_s3)
self.assertIn("s3a://bucket/path/file.parquet", deps[0].s3_locations)
def test_s3_with_options(self):
"""Test S3 source with options"""
source_code = """
@dlt.view()
def external_source():
return (
spark.read
.option("multiline", "true")
.json("s3://test-firehose-con-bucket/firehose_data/")
)
"""
deps = extract_dlt_table_dependencies(source_code)
self.assertEqual(len(deps), 1)
self.assertTrue(deps[0].reads_from_s3)
self.assertEqual(
deps[0].s3_locations, ["s3://test-firehose-con-bucket/firehose_data/"]
)
def test_s3_format_load(self):
"""Test S3 with format().load() pattern"""
source_code = """
@dlt.table(name="csv_data")
def load_csv():
return spark.read.format("csv").option("header", "true").load("s3://bucket/data.csv")
"""
deps = extract_dlt_table_dependencies(source_code)
self.assertEqual(len(deps), 1)
self.assertTrue(deps[0].reads_from_s3)
self.assertIn("s3://bucket/data.csv", deps[0].s3_locations)
def test_dlt_read_batch(self):
"""Test dlt.read() for batch table dependencies"""
source_code = """
@dlt.table(name="bronze")
def bronze_table():
return dlt.read("source_view")
"""
deps = extract_dlt_table_dependencies(source_code)
self.assertEqual(len(deps), 1)
self.assertEqual(deps[0].depends_on, ["source_view"])
def test_mixed_dlt_read_and_read_stream(self):
"""Test both dlt.read() and dlt.read_stream() in same pipeline"""
source_code = """
@dlt.table(name="batch_table")
def batch():
return dlt.read("source1")
@dlt.table(name="stream_table")
def stream():
return dlt.read_stream("source2")
"""
deps = extract_dlt_table_dependencies(source_code)
self.assertEqual(len(deps), 2)
batch = next(d for d in deps if d.table_name == "batch_table")
self.assertEqual(batch.depends_on, ["source1"])
stream = next(d for d in deps if d.table_name == "stream_table")
self.assertEqual(stream.depends_on, ["source2"])
def test_user_s3_example(self):
"""Test the user's exact S3 example"""
source_code = """
import dlt
from pyspark.sql.functions import col
@dlt.view(comment="External source data from S3")
def external_source():
return (
spark.read
.option("multiline", "true")
.json("s3://test-firehose-con-bucket/firehose_data/")
)
@dlt.table(name="bronze_firehose_data")
def bronze_firehose_data():
return dlt.read("external_source")
@dlt.table(name="silver_firehose_data")
def silver_firehose_data():
return dlt.read("bronze_firehose_data")
"""
deps = extract_dlt_table_dependencies(source_code)
self.assertEqual(len(deps), 3)
# Verify external_source
external = next((d for d in deps if d.table_name == "external_source"), None)
self.assertIsNotNone(external)
self.assertTrue(external.reads_from_s3)
self.assertIn(
"s3://test-firehose-con-bucket/firehose_data/", external.s3_locations
)
# Verify bronze
bronze = next((d for d in deps if d.table_name == "bronze_firehose_data"), None)
self.assertIsNotNone(bronze)
self.assertEqual(bronze.depends_on, ["external_source"])
self.assertFalse(bronze.reads_from_s3)
# Verify silver
silver = next((d for d in deps if d.table_name == "silver_firehose_data"), None)
self.assertIsNotNone(silver)
self.assertEqual(silver.depends_on, ["bronze_firehose_data"])
self.assertFalse(silver.reads_from_s3)
if __name__ == "__main__":
unittest.main()