mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-12-23 05:28:25 +00:00
MINOR: Fix Databricks DLT Pipeline Lineage to Track Table (#23888)
* MINOR: Fix Databricks DLT Pipeline Lineage to Track Table * fix tests * add support for s3 pipeline lineage as well
This commit is contained in:
parent
14f5d0610d
commit
3c527ca83b
@ -50,6 +50,26 @@ DLT_TABLE_NAME_FUNCTION = re.compile(
|
||||
re.DOTALL | re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Pattern to extract dlt.read_stream("table_name") calls
|
||||
DLT_READ_STREAM_PATTERN = re.compile(
|
||||
r'dlt\.read_stream\s*\(\s*["\']([^"\']+)["\']\s*\)',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Pattern to extract dlt.read("table_name") calls (batch reads)
|
||||
DLT_READ_PATTERN = re.compile(
|
||||
r'dlt\.read\s*\(\s*["\']([^"\']+)["\']\s*\)',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Pattern to extract S3 paths from spark.read operations
|
||||
# Matches: spark.read.json("s3://..."), spark.read.format("parquet").load("s3a://...")
|
||||
# Uses a simpler pattern that captures any spark.read followed by method calls ending with a path
|
||||
S3_PATH_PATTERN = re.compile(
|
||||
r'spark\.read.*?\.(?:load|json|parquet|csv|orc|avro)\s*\(\s*["\']([^"\']+)["\']\s*\)',
|
||||
re.DOTALL | re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KafkaSourceConfig:
|
||||
@ -60,6 +80,17 @@ class KafkaSourceConfig:
|
||||
group_id_prefix: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DLTTableDependency:
|
||||
"""Model for DLT table dependencies"""
|
||||
|
||||
table_name: str
|
||||
depends_on: List[str] = field(default_factory=list)
|
||||
reads_from_kafka: bool = False
|
||||
reads_from_s3: bool = False
|
||||
s3_locations: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
def _extract_variables(source_code: str) -> dict:
|
||||
"""
|
||||
Extract variable assignments from source code
|
||||
@ -316,6 +347,129 @@ def _infer_table_name_from_function(
|
||||
return None
|
||||
|
||||
|
||||
def extract_dlt_table_dependencies(source_code: str) -> List[DLTTableDependency]:
|
||||
"""
|
||||
Extract DLT table dependencies by analyzing @dlt.table decorators and dlt.read_stream calls
|
||||
|
||||
For each DLT table, identifies:
|
||||
- Table name from @dlt.table(name="...")
|
||||
- Dependencies from dlt.read_stream("other_table") or dlt.read("other_table") calls
|
||||
- Whether it reads from Kafka (spark.readStream.format("kafka"))
|
||||
- Whether it reads from S3 (spark.read.json("s3://..."))
|
||||
- S3 locations if applicable
|
||||
|
||||
Example:
|
||||
@dlt.table(name="source_table")
|
||||
def my_source():
|
||||
return spark.read.json("s3://bucket/path/")...
|
||||
|
||||
@dlt.table(name="target_table")
|
||||
def my_target():
|
||||
return dlt.read("source_table")
|
||||
|
||||
Returns:
|
||||
[
|
||||
DLTTableDependency(table_name="source_table", depends_on=[], reads_from_s3=True,
|
||||
s3_locations=["s3://bucket/path/"]),
|
||||
DLTTableDependency(table_name="target_table", depends_on=["source_table"],
|
||||
reads_from_s3=False)
|
||||
]
|
||||
"""
|
||||
dependencies = []
|
||||
|
||||
try:
|
||||
if not source_code:
|
||||
return dependencies
|
||||
|
||||
# Split source code into function definitions
|
||||
# Pattern: @dlt.table(...) or @dlt.view(...) followed by def function_name():
|
||||
# Handle multiline decorators with potentially nested parentheses
|
||||
function_pattern = re.compile(
|
||||
r"(@dlt\.(?:table|view)\s*\(.*?\)\s*def\s+[a-zA-Z_][a-zA-Z0-9_]*\s*\([^)]*\)\s*:.*?)(?=@dlt\.|$)",
|
||||
re.DOTALL | re.IGNORECASE,
|
||||
)
|
||||
|
||||
for match in function_pattern.finditer(source_code):
|
||||
try:
|
||||
function_block = match.group(1)
|
||||
|
||||
# Extract table name from @dlt.table decorator
|
||||
table_name = None
|
||||
name_match = DLT_TABLE_NAME_LITERAL.search(function_block)
|
||||
if name_match and name_match.group(1):
|
||||
table_name = name_match.group(1)
|
||||
else:
|
||||
# Try function name pattern
|
||||
func_name_match = DLT_TABLE_NAME_FUNCTION.search(function_block)
|
||||
if func_name_match and func_name_match.group(1):
|
||||
table_name = _infer_table_name_from_function(
|
||||
func_name_match.group(1), source_code
|
||||
)
|
||||
|
||||
if not table_name:
|
||||
# Try to extract from function definition itself
|
||||
def_match = re.search(
|
||||
r"def\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(", function_block
|
||||
)
|
||||
if def_match:
|
||||
table_name = def_match.group(1)
|
||||
|
||||
if not table_name:
|
||||
logger.debug(
|
||||
f"Could not extract table name from block: {function_block[:100]}..."
|
||||
)
|
||||
continue
|
||||
|
||||
# Check if it reads from Kafka
|
||||
reads_from_kafka = bool(KAFKA_STREAM_PATTERN.search(function_block))
|
||||
|
||||
# Check if it reads from S3
|
||||
s3_locations = []
|
||||
for s3_match in S3_PATH_PATTERN.finditer(function_block):
|
||||
s3_path = s3_match.group(1)
|
||||
if s3_path.startswith(("s3://", "s3a://", "s3n://")):
|
||||
s3_locations.append(s3_path)
|
||||
logger.debug(f"Table {table_name} reads from S3: {s3_path}")
|
||||
|
||||
reads_from_s3 = len(s3_locations) > 0
|
||||
|
||||
# Extract dlt.read_stream dependencies (streaming)
|
||||
depends_on = []
|
||||
for stream_match in DLT_READ_STREAM_PATTERN.finditer(function_block):
|
||||
source_table = stream_match.group(1)
|
||||
depends_on.append(source_table)
|
||||
logger.debug(f"Table {table_name} streams from {source_table}")
|
||||
|
||||
# Extract dlt.read dependencies (batch)
|
||||
for read_match in DLT_READ_PATTERN.finditer(function_block):
|
||||
source_table = read_match.group(1)
|
||||
depends_on.append(source_table)
|
||||
logger.debug(f"Table {table_name} reads from {source_table}")
|
||||
|
||||
dependency = DLTTableDependency(
|
||||
table_name=table_name,
|
||||
depends_on=depends_on,
|
||||
reads_from_kafka=reads_from_kafka,
|
||||
reads_from_s3=reads_from_s3,
|
||||
s3_locations=s3_locations,
|
||||
)
|
||||
dependencies.append(dependency)
|
||||
logger.debug(
|
||||
f"Extracted dependency: {table_name} - depends_on={depends_on}, "
|
||||
f"reads_from_kafka={reads_from_kafka}, reads_from_s3={reads_from_s3}, "
|
||||
f"s3_locations={s3_locations}"
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.debug(f"Error parsing function block: {exc}")
|
||||
continue
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning(f"Error extracting DLT table dependencies: {exc}")
|
||||
|
||||
return dependencies
|
||||
|
||||
|
||||
def get_pipeline_libraries(pipeline_config: dict, client=None) -> List[str]:
|
||||
"""
|
||||
Extract notebook and file paths from pipeline configuration
|
||||
|
||||
@ -58,7 +58,7 @@ from metadata.ingestion.lineage.sql_lineage import get_column_fqn
|
||||
from metadata.ingestion.models.pipeline_status import OMetaPipelineStatus
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
from metadata.ingestion.source.pipeline.databrickspipeline.kafka_parser import (
|
||||
extract_dlt_table_names,
|
||||
extract_dlt_table_dependencies,
|
||||
extract_kafka_sources,
|
||||
)
|
||||
from metadata.ingestion.source.pipeline.databrickspipeline.models import (
|
||||
@ -820,107 +820,293 @@ class DatabrickspipelineSource(PipelineServiceSource):
|
||||
else:
|
||||
logger.info(f"⊗ No Kafka sources found in notebook")
|
||||
|
||||
# Extract DLT table names
|
||||
logger.info(f"⟳ Parsing DLT table names from notebook...")
|
||||
dlt_table_names = extract_dlt_table_names(source_code)
|
||||
if dlt_table_names:
|
||||
# Extract DLT table dependencies
|
||||
logger.info(f"⟳ Parsing DLT table dependencies from notebook...")
|
||||
dlt_dependencies = extract_dlt_table_dependencies(source_code)
|
||||
if dlt_dependencies:
|
||||
logger.info(
|
||||
f"✓ Found {len(dlt_table_names)} DLT table(s): {dlt_table_names}"
|
||||
f"✓ Found {len(dlt_dependencies)} DLT table(s) with dependencies"
|
||||
)
|
||||
for dep in dlt_dependencies:
|
||||
s3_info = (
|
||||
f", reads_from_s3={dep.reads_from_s3}, s3_locations={dep.s3_locations}"
|
||||
if dep.reads_from_s3
|
||||
else ""
|
||||
)
|
||||
logger.info(
|
||||
f" - {dep.table_name}: depends_on={dep.depends_on}, "
|
||||
f"reads_from_kafka={dep.reads_from_kafka}{s3_info}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"⊗ No DLT tables found in notebook")
|
||||
logger.info(f"⊗ No DLT table dependencies found in notebook")
|
||||
|
||||
if not dlt_table_names or not kafka_sources:
|
||||
# Check if we have anything to process
|
||||
has_kafka = kafka_sources and len(kafka_sources) > 0
|
||||
has_s3 = any(dep.reads_from_s3 for dep in dlt_dependencies)
|
||||
has_tables = dlt_dependencies and len(dlt_dependencies) > 0
|
||||
|
||||
if not dlt_dependencies:
|
||||
logger.warning(
|
||||
f"⊗ Skipping lineage for this notebook - need both Kafka sources AND DLT tables"
|
||||
)
|
||||
logger.info(
|
||||
f" Kafka sources: {len(kafka_sources) if kafka_sources else 0}"
|
||||
)
|
||||
logger.info(
|
||||
f" DLT tables: {len(dlt_table_names) if dlt_table_names else 0}"
|
||||
f"⊗ Skipping lineage for this notebook - no DLT tables found"
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"✓ Notebook has both Kafka sources and DLT tables - creating lineage..."
|
||||
)
|
||||
if not has_kafka and not has_s3:
|
||||
logger.info(
|
||||
f"⊗ No external sources (Kafka or S3) found in this notebook - only table-to-table lineage will be created"
|
||||
)
|
||||
|
||||
# Create lineage for each Kafka topic -> DLT table
|
||||
logger.info(f"✓ Notebook has DLT tables - creating lineage...")
|
||||
if has_kafka:
|
||||
logger.info(f" Kafka sources: {len(kafka_sources)}")
|
||||
if has_s3:
|
||||
s3_count = sum(
|
||||
len(dep.s3_locations)
|
||||
for dep in dlt_dependencies
|
||||
if dep.reads_from_s3
|
||||
)
|
||||
logger.info(f" S3 sources: {s3_count} location(s)")
|
||||
|
||||
# Create lineage edges based on dependencies
|
||||
logger.info(f"\n⟳ Creating lineage edges...")
|
||||
lineage_created = 0
|
||||
|
||||
# Step 1: Create Kafka topic -> DLT table lineage
|
||||
# Build a map to identify which tables/views read from external sources (Kafka/S3)
|
||||
external_sources_map = {}
|
||||
for dep in dlt_dependencies:
|
||||
if dep.reads_from_kafka or dep.reads_from_s3:
|
||||
external_sources_map[dep.table_name] = True
|
||||
|
||||
for kafka_config in kafka_sources:
|
||||
for topic_name in kafka_config.topics:
|
||||
try:
|
||||
logger.info(f"\n 🔍 Processing topic: {topic_name}")
|
||||
|
||||
# Use smart discovery to find topic
|
||||
logger.info(
|
||||
f" ⟳ Searching for topic in OpenMetadata..."
|
||||
f"\n 🔍 Processing Kafka topic: {topic_name}"
|
||||
)
|
||||
kafka_topic = self._find_kafka_topic(topic_name)
|
||||
|
||||
kafka_topic = self._find_kafka_topic(topic_name)
|
||||
if not kafka_topic:
|
||||
logger.warning(
|
||||
f" ✗ Kafka topic '{topic_name}' not found in OpenMetadata"
|
||||
)
|
||||
logger.info(
|
||||
f" 💡 Make sure the topic is ingested from a messaging service"
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f" ✓ Topic found: {kafka_topic.fullyQualifiedName.root if hasattr(kafka_topic.fullyQualifiedName, 'root') else kafka_topic.fullyQualifiedName}"
|
||||
)
|
||||
|
||||
# Create lineage to each DLT table in this notebook
|
||||
for table_name in dlt_table_names:
|
||||
logger.info(
|
||||
f" 🔍 Processing target table: {table_name}"
|
||||
)
|
||||
logger.info(
|
||||
f" ⟳ Searching in Databricks/Unity Catalog services..."
|
||||
# Find tables that read directly from Kafka OR from an external source view
|
||||
for dep in dlt_dependencies:
|
||||
is_kafka_consumer = dep.reads_from_kafka or any(
|
||||
src in external_sources_map
|
||||
for src in dep.depends_on
|
||||
)
|
||||
|
||||
# Use cached Databricks service lookup
|
||||
target_table_entity = self._find_dlt_table(
|
||||
table_name=table_name,
|
||||
if is_kafka_consumer:
|
||||
logger.info(
|
||||
f" 🔍 Processing table: {dep.table_name}"
|
||||
)
|
||||
|
||||
target_table = self._find_dlt_table(
|
||||
table_name=dep.table_name,
|
||||
catalog=target_catalog,
|
||||
schema=target_schema,
|
||||
)
|
||||
|
||||
if target_table:
|
||||
table_fqn = (
|
||||
target_table.fullyQualifiedName.root
|
||||
if hasattr(
|
||||
target_table.fullyQualifiedName,
|
||||
"root",
|
||||
)
|
||||
else target_table.fullyQualifiedName
|
||||
)
|
||||
logger.info(
|
||||
f" ✅ Creating lineage: {topic_name} -> {table_fqn}"
|
||||
)
|
||||
|
||||
yield Either(
|
||||
right=AddLineageRequest(
|
||||
edge=EntitiesEdge(
|
||||
fromEntity=EntityReference(
|
||||
id=kafka_topic.id,
|
||||
type="topic",
|
||||
),
|
||||
toEntity=EntityReference(
|
||||
id=target_table.id.root
|
||||
if hasattr(
|
||||
target_table.id, "root"
|
||||
)
|
||||
else target_table.id,
|
||||
type="table",
|
||||
),
|
||||
lineageDetails=LineageDetails(
|
||||
pipeline=EntityReference(
|
||||
id=pipeline_entity.id.root,
|
||||
type="pipeline",
|
||||
),
|
||||
source=LineageSource.PipelineLineage,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
lineage_created += 1
|
||||
else:
|
||||
logger.warning(
|
||||
f" ✗ Table '{dep.table_name}' not found"
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f" ✗ Failed to process topic {topic_name}: {exc}"
|
||||
)
|
||||
logger.debug(traceback.format_exc())
|
||||
continue
|
||||
|
||||
# Step 2: Create table-to-table lineage for downstream dependencies
|
||||
for dep in dlt_dependencies:
|
||||
if dep.depends_on:
|
||||
for source_table_name in dep.depends_on:
|
||||
try:
|
||||
logger.info(
|
||||
f"\n 🔍 Processing table dependency: {source_table_name} -> {dep.table_name}"
|
||||
)
|
||||
|
||||
# Check if source is a view/table that reads from S3
|
||||
source_dep = next(
|
||||
(
|
||||
d
|
||||
for d in dlt_dependencies
|
||||
if d.table_name == source_table_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
# If source reads from S3, create container → table lineage
|
||||
if (
|
||||
source_dep
|
||||
and source_dep.reads_from_s3
|
||||
and source_dep.s3_locations
|
||||
):
|
||||
target_table = self._find_dlt_table(
|
||||
table_name=dep.table_name,
|
||||
catalog=target_catalog,
|
||||
schema=target_schema,
|
||||
)
|
||||
|
||||
if target_table:
|
||||
for s3_location in source_dep.s3_locations:
|
||||
logger.info(
|
||||
f" 🔍 Looking for S3 container: {s3_location}"
|
||||
)
|
||||
# Search for container by S3 path
|
||||
storage_location = s3_location.rstrip(
|
||||
"/"
|
||||
)
|
||||
container_entity = self.metadata.es_search_container_by_path(
|
||||
full_path=storage_location
|
||||
)
|
||||
|
||||
if (
|
||||
container_entity
|
||||
and container_entity[0]
|
||||
):
|
||||
logger.info(
|
||||
f" ✅ Creating lineage: {container_entity[0].fullyQualifiedName.root if hasattr(container_entity[0].fullyQualifiedName, 'root') else container_entity[0].fullyQualifiedName} -> {target_table.fullyQualifiedName.root if hasattr(target_table.fullyQualifiedName, 'root') else target_table.fullyQualifiedName}"
|
||||
)
|
||||
|
||||
yield Either(
|
||||
right=AddLineageRequest(
|
||||
edge=EntitiesEdge(
|
||||
fromEntity=EntityReference(
|
||||
id=container_entity[
|
||||
0
|
||||
].id,
|
||||
type="container",
|
||||
),
|
||||
toEntity=EntityReference(
|
||||
id=target_table.id.root
|
||||
if hasattr(
|
||||
target_table.id,
|
||||
"root",
|
||||
)
|
||||
else target_table.id,
|
||||
type="table",
|
||||
),
|
||||
lineageDetails=LineageDetails(
|
||||
pipeline=EntityReference(
|
||||
id=pipeline_entity.id.root,
|
||||
type="pipeline",
|
||||
),
|
||||
source=LineageSource.PipelineLineage,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
lineage_created += 1
|
||||
else:
|
||||
logger.warning(
|
||||
f" ✗ S3 container not found for path: {storage_location}"
|
||||
)
|
||||
logger.info(
|
||||
f" Make sure the S3 container is ingested in OpenMetadata"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f" ✗ Target table '{dep.table_name}' not found"
|
||||
)
|
||||
continue
|
||||
|
||||
# Otherwise, create table → table lineage
|
||||
source_table = self._find_dlt_table(
|
||||
table_name=source_table_name,
|
||||
catalog=target_catalog,
|
||||
schema=target_schema,
|
||||
)
|
||||
target_table = self._find_dlt_table(
|
||||
table_name=dep.table_name,
|
||||
catalog=target_catalog,
|
||||
schema=target_schema,
|
||||
)
|
||||
|
||||
if target_table_entity:
|
||||
table_fqn = (
|
||||
target_table_entity.fullyQualifiedName.root
|
||||
if source_table and target_table:
|
||||
source_fqn = (
|
||||
source_table.fullyQualifiedName.root
|
||||
if hasattr(
|
||||
target_table_entity.fullyQualifiedName,
|
||||
"root",
|
||||
source_table.fullyQualifiedName, "root"
|
||||
)
|
||||
else target_table_entity.fullyQualifiedName
|
||||
else source_table.fullyQualifiedName
|
||||
)
|
||||
target_fqn = (
|
||||
target_table.fullyQualifiedName.root
|
||||
if hasattr(
|
||||
target_table.fullyQualifiedName, "root"
|
||||
)
|
||||
else target_table.fullyQualifiedName
|
||||
)
|
||||
logger.info(
|
||||
f" ✓ Target table found: {table_fqn}"
|
||||
f" ✅ Creating lineage: {source_fqn} -> {target_fqn}"
|
||||
)
|
||||
logger.info(
|
||||
f" ✅ Creating lineage: {topic_name} -> {table_fqn}"
|
||||
)
|
||||
logger.info(f" Pipeline: {pipeline_id}")
|
||||
|
||||
yield Either(
|
||||
right=AddLineageRequest(
|
||||
edge=EntitiesEdge(
|
||||
fromEntity=EntityReference(
|
||||
id=kafka_topic.id,
|
||||
type="topic",
|
||||
id=source_table.id.root
|
||||
if hasattr(
|
||||
source_table.id, "root"
|
||||
)
|
||||
else source_table.id,
|
||||
type="table",
|
||||
),
|
||||
toEntity=EntityReference(
|
||||
id=target_table_entity.id.root
|
||||
id=target_table.id.root
|
||||
if hasattr(
|
||||
target_table_entity.id,
|
||||
"root",
|
||||
target_table.id, "root"
|
||||
)
|
||||
else target_table_entity.id,
|
||||
else target_table.id,
|
||||
type="table",
|
||||
),
|
||||
lineageDetails=LineageDetails(
|
||||
@ -935,22 +1121,21 @@ class DatabrickspipelineSource(PipelineServiceSource):
|
||||
)
|
||||
lineage_created += 1
|
||||
else:
|
||||
logger.warning(
|
||||
f" ✗ Target table '{table_name}' not found in OpenMetadata"
|
||||
)
|
||||
logger.info(
|
||||
f" 💡 Expected location: {target_catalog}.{target_schema}.{table_name}"
|
||||
)
|
||||
logger.info(
|
||||
f" 💡 Make sure the table is ingested from a Databricks/Unity Catalog database service"
|
||||
)
|
||||
if not source_table:
|
||||
logger.warning(
|
||||
f" ✗ Source table '{source_table_name}' not found"
|
||||
)
|
||||
if not target_table:
|
||||
logger.warning(
|
||||
f" ✗ Target table '{dep.table_name}' not found"
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f" ✗ Failed to process topic {topic_name}: {exc}"
|
||||
)
|
||||
logger.debug(traceback.format_exc())
|
||||
continue
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f" ✗ Failed to process dependency {source_table_name} -> {dep.table_name}: {exc}"
|
||||
)
|
||||
logger.debug(traceback.format_exc())
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"\n✓ Lineage edges created for this notebook: {lineage_created}"
|
||||
|
||||
@ -350,16 +350,16 @@ class TestKafkaLineageIntegration(unittest.TestCase):
|
||||
}
|
||||
self.mock_client.get_pipeline_details.return_value = pipeline_config
|
||||
|
||||
# Mock notebook source code
|
||||
# Mock notebook source code - realistic DLT pattern with Kafka
|
||||
notebook_source = """
|
||||
import dlt
|
||||
|
||||
topic_name = "dev.ern.cashout.moneyRequest_v1"
|
||||
entity_name = "moneyRequest"
|
||||
|
||||
@dlt.table(name=materializer.generate_event_log_table_name())
|
||||
@dlt.table(name="moneyRequest")
|
||||
def event_log():
|
||||
return df
|
||||
return spark.readStream.format("kafka").option("subscribe", topic_name).load()
|
||||
"""
|
||||
self.mock_client.export_notebook_source.return_value = notebook_source
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ Unit tests for Databricks Kafka parser
|
||||
import unittest
|
||||
|
||||
from metadata.ingestion.source.pipeline.databrickspipeline.kafka_parser import (
|
||||
extract_dlt_table_dependencies,
|
||||
extract_dlt_table_names,
|
||||
extract_kafka_sources,
|
||||
get_pipeline_libraries,
|
||||
@ -587,5 +588,295 @@ class TestKafkaFallbackPatterns(unittest.TestCase):
|
||||
self.assertEqual(table_names[0], "moneyRequest")
|
||||
|
||||
|
||||
class TestDLTTableDependencies(unittest.TestCase):
|
||||
"""Test cases for DLT table dependency extraction"""
|
||||
|
||||
def test_bronze_silver_pattern(self):
|
||||
"""Test table dependency pattern with Kafka source and downstream table"""
|
||||
source_code = """
|
||||
import dlt
|
||||
from pyspark.sql.functions import *
|
||||
|
||||
@dlt.table(name="orders_bronze")
|
||||
def bronze():
|
||||
return spark.readStream.format("kafka").option("subscribe", "orders").load()
|
||||
|
||||
@dlt.table(name="orders_silver")
|
||||
def silver():
|
||||
return dlt.read_stream("orders_bronze").select("*")
|
||||
"""
|
||||
deps = extract_dlt_table_dependencies(source_code)
|
||||
self.assertEqual(len(deps), 2)
|
||||
|
||||
bronze = next(d for d in deps if d.table_name == "orders_bronze")
|
||||
self.assertTrue(bronze.reads_from_kafka)
|
||||
self.assertEqual(bronze.depends_on, [])
|
||||
|
||||
silver = next(d for d in deps if d.table_name == "orders_silver")
|
||||
self.assertFalse(silver.reads_from_kafka)
|
||||
self.assertEqual(silver.depends_on, ["orders_bronze"])
|
||||
|
||||
def test_kafka_view_bronze_silver(self):
|
||||
"""Test Kafka view with multi-tier table dependencies"""
|
||||
source_code = """
|
||||
import dlt
|
||||
|
||||
@dlt.view()
|
||||
def kafka_source():
|
||||
return spark.readStream.format("kafka").option("subscribe", "topic").load()
|
||||
|
||||
@dlt.table(name="bronze_table")
|
||||
def bronze():
|
||||
return dlt.read_stream("kafka_source")
|
||||
|
||||
@dlt.table(name="silver_table")
|
||||
def silver():
|
||||
return dlt.read_stream("bronze_table")
|
||||
"""
|
||||
deps = extract_dlt_table_dependencies(source_code)
|
||||
|
||||
bronze = next((d for d in deps if d.table_name == "bronze_table"), None)
|
||||
self.assertIsNotNone(bronze)
|
||||
self.assertEqual(bronze.depends_on, ["kafka_source"])
|
||||
self.assertFalse(bronze.reads_from_kafka)
|
||||
|
||||
silver = next((d for d in deps if d.table_name == "silver_table"), None)
|
||||
self.assertIsNotNone(silver)
|
||||
self.assertEqual(silver.depends_on, ["bronze_table"])
|
||||
self.assertFalse(silver.reads_from_kafka)
|
||||
|
||||
def test_multiple_dependencies(self):
|
||||
"""Test table with multiple source dependencies"""
|
||||
source_code = """
|
||||
@dlt.table(name="source1")
|
||||
def s1():
|
||||
return spark.readStream.format("kafka").load()
|
||||
|
||||
@dlt.table(name="source2")
|
||||
def s2():
|
||||
return spark.readStream.format("kafka").load()
|
||||
|
||||
@dlt.table(name="merged")
|
||||
def merge():
|
||||
df1 = dlt.read_stream("source1")
|
||||
df2 = dlt.read_stream("source2")
|
||||
return df1.union(df2)
|
||||
"""
|
||||
deps = extract_dlt_table_dependencies(source_code)
|
||||
|
||||
merged = next((d for d in deps if d.table_name == "merged"), None)
|
||||
self.assertIsNotNone(merged)
|
||||
self.assertEqual(sorted(merged.depends_on), ["source1", "source2"])
|
||||
self.assertFalse(merged.reads_from_kafka)
|
||||
|
||||
def test_no_dependencies(self):
|
||||
"""Test table with no dependencies (reads from file)"""
|
||||
source_code = """
|
||||
@dlt.table(name="static_data")
|
||||
def static():
|
||||
return spark.read.parquet("/path/to/data")
|
||||
"""
|
||||
deps = extract_dlt_table_dependencies(source_code)
|
||||
self.assertEqual(len(deps), 1)
|
||||
self.assertEqual(deps[0].table_name, "static_data")
|
||||
self.assertEqual(deps[0].depends_on, [])
|
||||
self.assertFalse(deps[0].reads_from_kafka)
|
||||
|
||||
def test_function_name_as_table_name(self):
|
||||
"""Test using function name when no explicit name in decorator"""
|
||||
source_code = """
|
||||
@dlt.table()
|
||||
def my_bronze_table():
|
||||
return spark.readStream.format("kafka").load()
|
||||
"""
|
||||
deps = extract_dlt_table_dependencies(source_code)
|
||||
self.assertEqual(len(deps), 1)
|
||||
self.assertEqual(deps[0].table_name, "my_bronze_table")
|
||||
self.assertTrue(deps[0].reads_from_kafka)
|
||||
|
||||
def test_empty_source_code(self):
|
||||
"""Test empty source code returns empty list"""
|
||||
deps = extract_dlt_table_dependencies("")
|
||||
self.assertEqual(len(deps), 0)
|
||||
|
||||
def test_real_world_pattern(self):
|
||||
"""Test the exact pattern from user's example"""
|
||||
source_code = """
|
||||
import dlt
|
||||
from pyspark.sql.functions import *
|
||||
from pyspark.sql.types import *
|
||||
|
||||
TOPIC = "orders"
|
||||
|
||||
@dlt.view(comment="Kafka source")
|
||||
def kafka_orders_source():
|
||||
return (
|
||||
spark.readStream
|
||||
.format("kafka")
|
||||
.option("subscribe", TOPIC)
|
||||
.load()
|
||||
)
|
||||
|
||||
@dlt.table(name="orders_bronze")
|
||||
def orders_bronze():
|
||||
return dlt.read_stream("kafka_orders_source").select("*")
|
||||
|
||||
@dlt.table(name="orders_silver")
|
||||
def orders_silver():
|
||||
return dlt.read_stream("orders_bronze").select("*")
|
||||
"""
|
||||
deps = extract_dlt_table_dependencies(source_code)
|
||||
|
||||
# Should find bronze and silver (kafka_orders_source is a view, not a table)
|
||||
table_deps = [
|
||||
d for d in deps if d.table_name in ["orders_bronze", "orders_silver"]
|
||||
]
|
||||
self.assertEqual(len(table_deps), 2)
|
||||
|
||||
bronze = next((d for d in table_deps if d.table_name == "orders_bronze"), None)
|
||||
self.assertIsNotNone(bronze)
|
||||
self.assertEqual(bronze.depends_on, ["kafka_orders_source"])
|
||||
self.assertFalse(bronze.reads_from_kafka)
|
||||
|
||||
silver = next((d for d in table_deps if d.table_name == "orders_silver"), None)
|
||||
self.assertIsNotNone(silver)
|
||||
self.assertEqual(silver.depends_on, ["orders_bronze"])
|
||||
self.assertFalse(silver.reads_from_kafka)
|
||||
|
||||
|
||||
class TestS3SourceDetection(unittest.TestCase):
|
||||
"""Test cases for S3 source detection"""
|
||||
|
||||
def test_s3_json_source(self):
|
||||
"""Test detecting S3 source with spark.read.json()"""
|
||||
source_code = """
|
||||
@dlt.view()
|
||||
def s3_source():
|
||||
return spark.read.json("s3://mybucket/data/")
|
||||
"""
|
||||
deps = extract_dlt_table_dependencies(source_code)
|
||||
self.assertEqual(len(deps), 1)
|
||||
self.assertTrue(deps[0].reads_from_s3)
|
||||
self.assertEqual(deps[0].s3_locations, ["s3://mybucket/data/"])
|
||||
|
||||
def test_s3_parquet_source(self):
|
||||
"""Test detecting S3 source with spark.read.parquet()"""
|
||||
source_code = """
|
||||
@dlt.table(name="parquet_table")
|
||||
def load_parquet():
|
||||
return spark.read.parquet("s3a://bucket/path/file.parquet")
|
||||
"""
|
||||
deps = extract_dlt_table_dependencies(source_code)
|
||||
self.assertEqual(len(deps), 1)
|
||||
self.assertTrue(deps[0].reads_from_s3)
|
||||
self.assertIn("s3a://bucket/path/file.parquet", deps[0].s3_locations)
|
||||
|
||||
def test_s3_with_options(self):
|
||||
"""Test S3 source with options"""
|
||||
source_code = """
|
||||
@dlt.view()
|
||||
def external_source():
|
||||
return (
|
||||
spark.read
|
||||
.option("multiline", "true")
|
||||
.json("s3://test-firehose-con-bucket/firehose_data/")
|
||||
)
|
||||
"""
|
||||
deps = extract_dlt_table_dependencies(source_code)
|
||||
self.assertEqual(len(deps), 1)
|
||||
self.assertTrue(deps[0].reads_from_s3)
|
||||
self.assertEqual(
|
||||
deps[0].s3_locations, ["s3://test-firehose-con-bucket/firehose_data/"]
|
||||
)
|
||||
|
||||
def test_s3_format_load(self):
|
||||
"""Test S3 with format().load() pattern"""
|
||||
source_code = """
|
||||
@dlt.table(name="csv_data")
|
||||
def load_csv():
|
||||
return spark.read.format("csv").option("header", "true").load("s3://bucket/data.csv")
|
||||
"""
|
||||
deps = extract_dlt_table_dependencies(source_code)
|
||||
self.assertEqual(len(deps), 1)
|
||||
self.assertTrue(deps[0].reads_from_s3)
|
||||
self.assertIn("s3://bucket/data.csv", deps[0].s3_locations)
|
||||
|
||||
def test_dlt_read_batch(self):
|
||||
"""Test dlt.read() for batch table dependencies"""
|
||||
source_code = """
|
||||
@dlt.table(name="bronze")
|
||||
def bronze_table():
|
||||
return dlt.read("source_view")
|
||||
"""
|
||||
deps = extract_dlt_table_dependencies(source_code)
|
||||
self.assertEqual(len(deps), 1)
|
||||
self.assertEqual(deps[0].depends_on, ["source_view"])
|
||||
|
||||
def test_mixed_dlt_read_and_read_stream(self):
|
||||
"""Test both dlt.read() and dlt.read_stream() in same pipeline"""
|
||||
source_code = """
|
||||
@dlt.table(name="batch_table")
|
||||
def batch():
|
||||
return dlt.read("source1")
|
||||
|
||||
@dlt.table(name="stream_table")
|
||||
def stream():
|
||||
return dlt.read_stream("source2")
|
||||
"""
|
||||
deps = extract_dlt_table_dependencies(source_code)
|
||||
self.assertEqual(len(deps), 2)
|
||||
|
||||
batch = next(d for d in deps if d.table_name == "batch_table")
|
||||
self.assertEqual(batch.depends_on, ["source1"])
|
||||
|
||||
stream = next(d for d in deps if d.table_name == "stream_table")
|
||||
self.assertEqual(stream.depends_on, ["source2"])
|
||||
|
||||
def test_user_s3_example(self):
|
||||
"""Test the user's exact S3 example"""
|
||||
source_code = """
|
||||
import dlt
|
||||
from pyspark.sql.functions import col
|
||||
|
||||
@dlt.view(comment="External source data from S3")
|
||||
def external_source():
|
||||
return (
|
||||
spark.read
|
||||
.option("multiline", "true")
|
||||
.json("s3://test-firehose-con-bucket/firehose_data/")
|
||||
)
|
||||
|
||||
@dlt.table(name="bronze_firehose_data")
|
||||
def bronze_firehose_data():
|
||||
return dlt.read("external_source")
|
||||
|
||||
@dlt.table(name="silver_firehose_data")
|
||||
def silver_firehose_data():
|
||||
return dlt.read("bronze_firehose_data")
|
||||
"""
|
||||
deps = extract_dlt_table_dependencies(source_code)
|
||||
self.assertEqual(len(deps), 3)
|
||||
|
||||
# Verify external_source
|
||||
external = next((d for d in deps if d.table_name == "external_source"), None)
|
||||
self.assertIsNotNone(external)
|
||||
self.assertTrue(external.reads_from_s3)
|
||||
self.assertIn(
|
||||
"s3://test-firehose-con-bucket/firehose_data/", external.s3_locations
|
||||
)
|
||||
|
||||
# Verify bronze
|
||||
bronze = next((d for d in deps if d.table_name == "bronze_firehose_data"), None)
|
||||
self.assertIsNotNone(bronze)
|
||||
self.assertEqual(bronze.depends_on, ["external_source"])
|
||||
self.assertFalse(bronze.reads_from_s3)
|
||||
|
||||
# Verify silver
|
||||
silver = next((d for d in deps if d.table_name == "silver_firehose_data"), None)
|
||||
self.assertIsNotNone(silver)
|
||||
self.assertEqual(silver.depends_on, ["bronze_firehose_data"])
|
||||
self.assertFalse(silver.reads_from_s3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user