mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-11-02 11:39:12 +00:00
Feat: Add kafka lineage support in databricks pipelines (#23813)
* Add dlt pipeline support * Fix code style * Add variable parsing * Fix kafka lineage --------- Co-authored-by: Sriharsha Chintalapani <harsha@getcollate.io>
This commit is contained in:
parent
509295ed39
commit
05f064787f
@ -11,6 +11,7 @@
|
||||
"""
|
||||
Client to interact with databricks apis
|
||||
"""
|
||||
import base64
|
||||
import json
|
||||
import traceback
|
||||
from datetime import timedelta
|
||||
@ -372,3 +373,122 @@ class DatabricksClient:
|
||||
continue
|
||||
self._job_column_lineage_executed = True
|
||||
logger.debug("Table and column lineage caching completed.")
|
||||
|
||||
def get_pipeline_details(self, pipeline_id: str) -> Optional[dict]:
|
||||
"""
|
||||
Get DLT pipeline configuration including libraries and notebooks
|
||||
"""
|
||||
try:
|
||||
url = f"{self.base_url}/pipelines/{pipeline_id}"
|
||||
response = self.client.get(
|
||||
url,
|
||||
headers=self.headers,
|
||||
timeout=self.api_timeout,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
logger.warning(
|
||||
f"Failed to get pipeline details for {pipeline_id}: {response.status_code}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(f"Error getting pipeline details for {pipeline_id}: {exc}")
|
||||
return None
|
||||
|
||||
def list_pipelines(self) -> Iterable[dict]:
|
||||
"""
|
||||
List all DLT (Delta Live Tables) pipelines in the workspace
|
||||
Uses the Pipelines API (/api/2.0/pipelines)
|
||||
"""
|
||||
try:
|
||||
url = f"{self.base_url}/pipelines"
|
||||
params = {"max_results": PAGE_SIZE}
|
||||
|
||||
response = self.client.get(
|
||||
url,
|
||||
params=params,
|
||||
headers=self.headers,
|
||||
timeout=self.api_timeout,
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
pipelines = data.get("statuses", [])
|
||||
logger.info(f"Found {len(pipelines)} DLT pipelines")
|
||||
yield from pipelines
|
||||
|
||||
# Handle pagination if there's a next_page_token
|
||||
while data.get("next_page_token"):
|
||||
params["page_token"] = data["next_page_token"]
|
||||
response = self.client.get(
|
||||
url,
|
||||
params=params,
|
||||
headers=self.headers,
|
||||
timeout=self.api_timeout,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
yield from data.get("statuses", [])
|
||||
else:
|
||||
break
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to list pipelines: {response.status_code} - {response.text}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(f"Error listing DLT pipelines: {exc}")
|
||||
|
||||
def list_workspace_objects(self, path: str) -> List[dict]:
|
||||
"""
|
||||
List objects in a Databricks workspace directory
|
||||
"""
|
||||
try:
|
||||
url = f"{self.base_url}/workspace/list"
|
||||
params = {"path": path}
|
||||
|
||||
response = self.client.get(
|
||||
url,
|
||||
params=params,
|
||||
headers=self.headers,
|
||||
timeout=self.api_timeout,
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json().get("objects", [])
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to list workspace directory {path}: {response.text}"
|
||||
)
|
||||
return []
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(f"Error listing workspace directory {path}: {exc}")
|
||||
return []
|
||||
|
||||
def export_notebook_source(self, notebook_path: str) -> Optional[str]:
|
||||
"""
|
||||
Export notebook source code from Databricks workspace
|
||||
"""
|
||||
try:
|
||||
url = f"{self.base_url}/workspace/export"
|
||||
params = {"path": notebook_path, "format": "SOURCE"}
|
||||
|
||||
response = self.client.get(
|
||||
url,
|
||||
params=params,
|
||||
headers=self.headers,
|
||||
timeout=self.api_timeout,
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
content = response.json().get("content")
|
||||
if content:
|
||||
return base64.b64decode(content).decode("utf-8")
|
||||
logger.warning(
|
||||
f"Failed to export notebook {notebook_path}: {response.status_code}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(f"Error exporting notebook {notebook_path}: {exc}")
|
||||
return None
|
||||
|
||||
@ -0,0 +1,243 @@
|
||||
# Copyright 2025 Collate
|
||||
# Licensed under the Collate Community License, Version 1.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Kafka configuration parser for Databricks DLT pipelines
|
||||
"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
from metadata.utils.logger import ingestion_logger
|
||||
|
||||
logger = ingestion_logger()
|
||||
|
||||
# Compile regex patterns at module level for performance
|
||||
KAFKA_STREAM_PATTERN = re.compile(
|
||||
r'\.format\s*\(\s*["\']kafka["\']\s*\)(.*?)\.load\s*\(\s*\)',
|
||||
re.DOTALL | re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Pattern to extract variable assignments like: TOPIC = "tracker-events"
|
||||
VARIABLE_ASSIGNMENT_PATTERN = re.compile(
|
||||
r'^\s*([A-Z_][A-Z0-9_]*)\s*=\s*["\']([^"\']+)["\']\s*$',
|
||||
re.MULTILINE,
|
||||
)
|
||||
|
||||
# Pattern to extract DLT table decorators: @dlt.table(name="table_name", ...)
|
||||
DLT_TABLE_PATTERN = re.compile(
|
||||
r'@dlt\.table\s*\(\s*(?:.*?name\s*=\s*["\']([^"\']+)["\'])?',
|
||||
re.DOTALL | re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KafkaSourceConfig:
|
||||
"""Model for Kafka source configuration extracted from DLT code"""
|
||||
|
||||
bootstrap_servers: Optional[str] = None
|
||||
topics: List[str] = field(default_factory=list)
|
||||
group_id_prefix: Optional[str] = None
|
||||
|
||||
|
||||
def _extract_variables(source_code: str) -> dict:
|
||||
"""
|
||||
Extract variable assignments from source code
|
||||
|
||||
Examples:
|
||||
TOPIC = "events"
|
||||
KAFKA_BROKER = "localhost:9092"
|
||||
|
||||
Returns dict like: {"TOPIC": "events", "KAFKA_BROKER": "localhost:9092"}
|
||||
"""
|
||||
variables = {}
|
||||
try:
|
||||
for match in VARIABLE_ASSIGNMENT_PATTERN.finditer(source_code):
|
||||
var_name = match.group(1)
|
||||
var_value = match.group(2)
|
||||
variables[var_name] = var_value
|
||||
logger.debug(f"Found variable: {var_name} = {var_value}")
|
||||
except Exception as exc:
|
||||
logger.debug(f"Error extracting variables: {exc}")
|
||||
return variables
|
||||
|
||||
|
||||
def extract_kafka_sources(source_code: str) -> List[KafkaSourceConfig]:
|
||||
"""
|
||||
Extract Kafka topic configurations from DLT source code
|
||||
|
||||
Parses patterns like:
|
||||
- spark.readStream.format("kafka").option("subscribe", "topic1,topic2")
|
||||
- .option("kafka.bootstrap.servers", "broker:9092")
|
||||
- .option("groupIdPrefix", "dlt-pipeline")
|
||||
|
||||
Also supports variable references:
|
||||
- TOPIC = "events"
|
||||
- .option("subscribe", TOPIC)
|
||||
|
||||
Returns empty list if parsing fails or no sources found
|
||||
"""
|
||||
kafka_configs = []
|
||||
|
||||
try:
|
||||
if not source_code:
|
||||
logger.debug("Empty or None source code provided")
|
||||
return kafka_configs
|
||||
|
||||
# Extract variable assignments for resolution
|
||||
variables = _extract_variables(source_code)
|
||||
|
||||
for match in KAFKA_STREAM_PATTERN.finditer(source_code):
|
||||
try:
|
||||
config_block = match.group(1)
|
||||
|
||||
bootstrap_servers = _extract_option(
|
||||
config_block, r"kafka\.bootstrap\.servers", variables
|
||||
)
|
||||
subscribe_topics = _extract_option(
|
||||
config_block, r"subscribe", variables
|
||||
)
|
||||
topics = _extract_option(config_block, r"topics", variables)
|
||||
group_id_prefix = _extract_option(
|
||||
config_block, r"groupIdPrefix", variables
|
||||
)
|
||||
|
||||
topic_list = []
|
||||
if subscribe_topics:
|
||||
topic_list = [
|
||||
t.strip() for t in subscribe_topics.split(",") if t.strip()
|
||||
]
|
||||
elif topics:
|
||||
topic_list = [t.strip() for t in topics.split(",") if t.strip()]
|
||||
|
||||
if bootstrap_servers or topic_list:
|
||||
kafka_config = KafkaSourceConfig(
|
||||
bootstrap_servers=bootstrap_servers,
|
||||
topics=topic_list,
|
||||
group_id_prefix=group_id_prefix,
|
||||
)
|
||||
kafka_configs.append(kafka_config)
|
||||
logger.debug(
|
||||
f"Extracted Kafka config: brokers={bootstrap_servers}, "
|
||||
f"topics={topic_list}, group_prefix={group_id_prefix}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to parse individual Kafka config block: {exc}")
|
||||
continue
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning(f"Error parsing Kafka sources from code: {exc}")
|
||||
|
||||
return kafka_configs
|
||||
|
||||
|
||||
def _extract_option(
|
||||
config_block: str, option_name: str, variables: dict = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Extract a single option value from Kafka configuration block
|
||||
Supports both string literals and variable references
|
||||
Safely handles any parsing errors
|
||||
"""
|
||||
if variables is None:
|
||||
variables = {}
|
||||
|
||||
try:
|
||||
# Try matching quoted string literal: .option("subscribe", "topic")
|
||||
pattern_literal = (
|
||||
rf'\.option\s*\(\s*["\']({option_name})["\']\s*,\s*["\']([^"\']+)["\']\s*\)'
|
||||
)
|
||||
match = re.search(pattern_literal, config_block, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(2)
|
||||
|
||||
# Try matching variable reference: .option("subscribe", TOPIC)
|
||||
pattern_variable = (
|
||||
rf'\.option\s*\(\s*["\']({option_name})["\']\s*,\s*([A-Z_][A-Z0-9_]*)\s*\)'
|
||||
)
|
||||
match = re.search(pattern_variable, config_block, re.IGNORECASE)
|
||||
if match:
|
||||
var_name = match.group(2)
|
||||
# Resolve variable
|
||||
if var_name in variables:
|
||||
logger.debug(
|
||||
f"Resolved variable {var_name} = {variables[var_name]} for option {option_name}"
|
||||
)
|
||||
return variables[var_name]
|
||||
else:
|
||||
logger.debug(
|
||||
f"Variable {var_name} referenced but not found in source code"
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.debug(f"Failed to extract option {option_name}: {exc}")
|
||||
return None
|
||||
|
||||
|
||||
def extract_dlt_table_names(source_code: str) -> List[str]:
|
||||
"""
|
||||
Extract DLT table names from @dlt.table decorators
|
||||
|
||||
Parses patterns like:
|
||||
- @dlt.table(name="user_events_bronze_pl", ...)
|
||||
- @dlt.table(comment="...", name="my_table")
|
||||
|
||||
Returns list of table names found in decorators
|
||||
"""
|
||||
table_names = []
|
||||
|
||||
try:
|
||||
if not source_code:
|
||||
logger.debug("Empty or None source code provided")
|
||||
return table_names
|
||||
|
||||
for match in DLT_TABLE_PATTERN.finditer(source_code):
|
||||
table_name = match.group(1)
|
||||
if table_name:
|
||||
table_names.append(table_name)
|
||||
logger.debug(f"Found DLT table: {table_name}")
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning(f"Error parsing DLT table names from code: {exc}")
|
||||
|
||||
return table_names
|
||||
|
||||
|
||||
def get_pipeline_libraries(pipeline_config: dict, client=None) -> List[str]:
|
||||
"""
|
||||
Extract notebook and file paths from pipeline configuration
|
||||
Safely handles missing or malformed configuration
|
||||
"""
|
||||
libraries = []
|
||||
|
||||
try:
|
||||
if not pipeline_config:
|
||||
return libraries
|
||||
|
||||
for lib in pipeline_config.get("libraries", []):
|
||||
try:
|
||||
if "notebook" in lib:
|
||||
notebook_path = lib["notebook"].get("path")
|
||||
if notebook_path:
|
||||
libraries.append(notebook_path)
|
||||
elif "file" in lib:
|
||||
file_path = lib["file"].get("path")
|
||||
if file_path:
|
||||
libraries.append(file_path)
|
||||
except Exception as exc:
|
||||
logger.debug(f"Failed to process library entry {lib}: {exc}")
|
||||
continue
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning(f"Error extracting pipeline libraries: {exc}")
|
||||
|
||||
return libraries
|
||||
@ -28,6 +28,7 @@ from metadata.generated.schema.entity.data.pipeline import (
|
||||
TaskStatus,
|
||||
)
|
||||
from metadata.generated.schema.entity.data.table import Table
|
||||
from metadata.generated.schema.entity.data.topic import Topic
|
||||
from metadata.generated.schema.entity.services.connections.pipeline.databricksPipelineConnection import (
|
||||
DatabricksPipelineConnection,
|
||||
)
|
||||
@ -56,6 +57,10 @@ from metadata.ingestion.api.steps import InvalidSourceException
|
||||
from metadata.ingestion.lineage.sql_lineage import get_column_fqn
|
||||
from metadata.ingestion.models.pipeline_status import OMetaPipelineStatus
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
from metadata.ingestion.source.pipeline.databrickspipeline.kafka_parser import (
|
||||
extract_dlt_table_names,
|
||||
extract_kafka_sources,
|
||||
)
|
||||
from metadata.ingestion.source.pipeline.databrickspipeline.models import (
|
||||
DataBrickPipelineDetails,
|
||||
DBRun,
|
||||
@ -105,14 +110,30 @@ class DatabrickspipelineSource(PipelineServiceSource):
|
||||
yield DataBrickPipelineDetails(**workflow)
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.error(f"Failed to get pipeline list due to : {exc}")
|
||||
logger.error(f"Failed to get jobs list due to : {exc}")
|
||||
|
||||
# Fetch DLT pipelines directly (new)
|
||||
try:
|
||||
for pipeline in self.client.list_pipelines() or []:
|
||||
try:
|
||||
yield DataBrickPipelineDetails(**pipeline)
|
||||
except Exception as exc:
|
||||
logger.debug(f"Error creating DLT pipeline details: {exc}")
|
||||
logger.debug(traceback.format_exc())
|
||||
continue
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(f"Failed to get DLT pipelines list due to : {exc}")
|
||||
|
||||
return None
|
||||
|
||||
def get_pipeline_name(
|
||||
self, pipeline_details: DataBrickPipelineDetails
|
||||
) -> Optional[str]:
|
||||
try:
|
||||
return pipeline_details.settings.name
|
||||
if pipeline_details.pipeline_id:
|
||||
return pipeline_details.name
|
||||
return pipeline_details.settings.name if pipeline_details.settings else None
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.error(f"Failed to get pipeline name due to : {exc}")
|
||||
@ -124,17 +145,35 @@ class DatabrickspipelineSource(PipelineServiceSource):
|
||||
) -> Iterable[Either[CreatePipelineRequest]]:
|
||||
"""Method to Get Pipeline Entity"""
|
||||
try:
|
||||
description = pipeline_details.settings.description
|
||||
if pipeline_details.pipeline_id:
|
||||
description = None
|
||||
display_name = pipeline_details.name
|
||||
entity_name = str(pipeline_details.pipeline_id)
|
||||
schedule_interval = None
|
||||
else:
|
||||
description = (
|
||||
pipeline_details.settings.description
|
||||
if pipeline_details.settings
|
||||
else None
|
||||
)
|
||||
display_name = (
|
||||
pipeline_details.settings.name
|
||||
if pipeline_details.settings
|
||||
else None
|
||||
)
|
||||
entity_name = str(pipeline_details.job_id)
|
||||
schedule_interval = (
|
||||
str(pipeline_details.settings.schedule.cron)
|
||||
if pipeline_details.settings and pipeline_details.settings.schedule
|
||||
else None
|
||||
)
|
||||
|
||||
pipeline_request = CreatePipelineRequest(
|
||||
name=EntityName(str(pipeline_details.job_id)),
|
||||
displayName=pipeline_details.settings.name,
|
||||
name=EntityName(entity_name),
|
||||
displayName=display_name,
|
||||
description=Markdown(description) if description else None,
|
||||
tasks=self.get_tasks(pipeline_details),
|
||||
scheduleInterval=(
|
||||
str(pipeline_details.settings.schedule.cron)
|
||||
if pipeline_details.settings.schedule
|
||||
else None
|
||||
),
|
||||
scheduleInterval=schedule_interval,
|
||||
service=FullyQualifiedEntityName(self.context.get().pipeline_service),
|
||||
)
|
||||
yield Either(right=pipeline_request)
|
||||
@ -170,6 +209,9 @@ class DatabrickspipelineSource(PipelineServiceSource):
|
||||
|
||||
def get_tasks(self, pipeline_details: DataBrickPipelineDetails) -> List[Task]:
|
||||
try:
|
||||
if not pipeline_details.job_id:
|
||||
return []
|
||||
|
||||
task_list = []
|
||||
for run in self.client.get_job_runs(job_id=pipeline_details.job_id) or []:
|
||||
run = DBRun(**run)
|
||||
@ -177,7 +219,11 @@ class DatabrickspipelineSource(PipelineServiceSource):
|
||||
[
|
||||
Task(
|
||||
name=str(task.name),
|
||||
taskType=pipeline_details.settings.task_type,
|
||||
taskType=(
|
||||
pipeline_details.settings.task_type
|
||||
if pipeline_details.settings
|
||||
else None
|
||||
),
|
||||
sourceUrl=(
|
||||
SourceUrl(run.run_page_url)
|
||||
if run.run_page_url
|
||||
@ -204,6 +250,9 @@ class DatabrickspipelineSource(PipelineServiceSource):
|
||||
self, pipeline_details: DataBrickPipelineDetails
|
||||
) -> Iterable[OMetaPipelineStatus]:
|
||||
try:
|
||||
if not pipeline_details.job_id:
|
||||
return
|
||||
|
||||
for run in self.client.get_job_runs(job_id=pipeline_details.job_id) or []:
|
||||
run = DBRun(**run)
|
||||
task_status = [
|
||||
@ -241,7 +290,7 @@ class DatabrickspipelineSource(PipelineServiceSource):
|
||||
except Exception as exc:
|
||||
yield Either(
|
||||
left=StackTraceError(
|
||||
name=pipeline_details.job_id,
|
||||
name=pipeline_details.id,
|
||||
error=f"Failed to yield pipeline status: {exc}",
|
||||
stackTrace=traceback.format_exc(),
|
||||
)
|
||||
@ -300,6 +349,314 @@ class DatabrickspipelineSource(PipelineServiceSource):
|
||||
)
|
||||
return processed_column_lineage or []
|
||||
|
||||
def _find_kafka_topic(self, topic_name: str) -> Optional[Topic]:
|
||||
"""
|
||||
Find Kafka topic in OpenMetadata using smart discovery
|
||||
|
||||
Strategy:
|
||||
1. If messagingServiceNames configured -> search only those (faster)
|
||||
2. Else -> search ALL messaging services using search API
|
||||
"""
|
||||
# Strategy 1: Search configured services (fast path)
|
||||
try:
|
||||
topic_fqn = fqn.build(
|
||||
metadata=self.metadata,
|
||||
entity_type=Topic,
|
||||
service_name=None,
|
||||
topic_name=topic_name,
|
||||
skip_es_search=False,
|
||||
)
|
||||
topic = self.metadata.get_by_name(entity=Topic, fqn=topic_fqn)
|
||||
if topic:
|
||||
logger.debug(f"Found topic {topic_name}")
|
||||
return topic
|
||||
except Exception as exc:
|
||||
logger.debug(f"Could not find topic {topic_name}: {exc}")
|
||||
logger.debug(f"Topic {topic_name} not found")
|
||||
return None
|
||||
|
||||
def _yield_kafka_lineage(
|
||||
self, pipeline_details: DataBrickPipelineDetails, pipeline_entity: Pipeline
|
||||
) -> Iterable[Either[AddLineageRequest]]:
|
||||
"""
|
||||
Extract and yield Kafka topic lineage from DLT pipeline source code
|
||||
Only processes DLT pipelines (with pipeline_id), not regular jobs
|
||||
Creates lineage: Kafka topic -> DLT table (with pipeline in lineageDetails)
|
||||
"""
|
||||
try:
|
||||
# Only process DLT pipelines - check for pipeline_id
|
||||
# For pure DLT pipelines, pipeline_id is set directly
|
||||
pipeline_id = pipeline_details.pipeline_id
|
||||
|
||||
# For jobs with DLT pipeline tasks, check settings
|
||||
if not pipeline_id and pipeline_details.settings:
|
||||
try:
|
||||
tasks = pipeline_details.settings.tasks
|
||||
logger.debug(
|
||||
f"Checking for DLT pipeline in job {pipeline_details.job_id}: "
|
||||
f"{len(tasks) if tasks else 0} tasks found"
|
||||
)
|
||||
|
||||
if tasks:
|
||||
for task in tasks:
|
||||
logger.debug(
|
||||
f"Task: {task.name}, has pipeline_task: {task.pipeline_task is not None}"
|
||||
)
|
||||
# Check for direct DLT pipeline task
|
||||
if task.pipeline_task and task.pipeline_task.pipeline_id:
|
||||
pipeline_id = task.pipeline_task.pipeline_id
|
||||
logger.info(
|
||||
f"Found DLT pipeline_id from job task: {pipeline_id} for job {pipeline_details.job_id}"
|
||||
)
|
||||
break
|
||||
except Exception as exc:
|
||||
logger.debug(f"Error checking for pipeline tasks: {exc}")
|
||||
logger.debug(traceback.format_exc())
|
||||
|
||||
# Only process if we have a DLT pipeline_id
|
||||
if not pipeline_id:
|
||||
logger.debug(
|
||||
f"No DLT pipeline_id found for {pipeline_details.job_id or pipeline_details.pipeline_id}, skipping Kafka lineage"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(f"Processing Kafka lineage for DLT pipeline: {pipeline_id}")
|
||||
|
||||
# Get pipeline configuration and extract target catalog/schema
|
||||
target_catalog = None
|
||||
target_schema = None
|
||||
notebook_paths = []
|
||||
try:
|
||||
pipeline_config = self.client.get_pipeline_details(pipeline_id)
|
||||
if not pipeline_config:
|
||||
logger.debug(f"Could not fetch pipeline config for {pipeline_id}")
|
||||
return
|
||||
|
||||
# Extract spec for detailed configuration
|
||||
spec = pipeline_config.get("spec", {})
|
||||
logger.info(
|
||||
f"Pipeline spec keys: {list(spec.keys()) if spec else 'None'}"
|
||||
)
|
||||
|
||||
# Extract target catalog and schema for DLT tables
|
||||
target_catalog = spec.get("catalog") if spec else None
|
||||
# Schema can be in 'target' or 'schema' field
|
||||
target_schema = (
|
||||
spec.get("target") or spec.get("schema") if spec else None
|
||||
)
|
||||
logger.debug(
|
||||
f"DLT pipeline target: catalog={target_catalog}, schema={target_schema}"
|
||||
)
|
||||
|
||||
# Extract notebook/file paths from libraries in spec
|
||||
notebook_paths = []
|
||||
if spec and "libraries" in spec:
|
||||
libraries = spec["libraries"]
|
||||
logger.info(f"Found {len(libraries)} libraries in spec")
|
||||
for lib in libraries:
|
||||
# Library can be dict or have different structures
|
||||
if isinstance(lib, dict):
|
||||
# Check for notebook path
|
||||
if "notebook" in lib and lib["notebook"]:
|
||||
notebook = lib["notebook"]
|
||||
if isinstance(notebook, dict):
|
||||
path = notebook.get("path")
|
||||
else:
|
||||
path = notebook
|
||||
if path:
|
||||
notebook_paths.append(path)
|
||||
logger.info(f"Found notebook in library: {path}")
|
||||
# Check for glob pattern
|
||||
elif "glob" in lib and lib["glob"]:
|
||||
glob_pattern = lib["glob"]
|
||||
if isinstance(glob_pattern, dict):
|
||||
include_pattern = glob_pattern.get("include")
|
||||
if include_pattern:
|
||||
# Convert glob pattern to directory path
|
||||
# e.g., "/path/**" -> "/path/"
|
||||
base_path = include_pattern.replace(
|
||||
"/**", "/"
|
||||
).replace("**", "")
|
||||
notebook_paths.append(base_path)
|
||||
logger.info(
|
||||
f"Found glob pattern, using base path: {base_path}"
|
||||
)
|
||||
|
||||
# Also check for source path in spec configuration
|
||||
if not notebook_paths and spec:
|
||||
source_path = None
|
||||
|
||||
# Check spec.configuration for source path
|
||||
if "configuration" in spec:
|
||||
config = spec["configuration"]
|
||||
source_path = config.get("source_path") or config.get("source")
|
||||
|
||||
# Check development settings
|
||||
if not source_path and "development" in spec:
|
||||
source_path = spec["development"].get("source_path")
|
||||
|
||||
if source_path:
|
||||
logger.info(
|
||||
f"Found source_path in pipeline spec: {source_path}"
|
||||
)
|
||||
notebook_paths.append(source_path)
|
||||
|
||||
logger.debug(
|
||||
f"Found {len(notebook_paths)} notebook paths for pipeline {pipeline_id}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
f"Failed to fetch pipeline config for {pipeline_id}: {exc}"
|
||||
)
|
||||
return
|
||||
|
||||
if not notebook_paths:
|
||||
logger.debug(f"No notebook paths found for pipeline {pipeline_id}")
|
||||
return
|
||||
|
||||
# Expand directories to individual notebook files
|
||||
expanded_paths = []
|
||||
for path in notebook_paths:
|
||||
# If path ends with /, it's a directory - list all notebooks in it
|
||||
if path.endswith("/"):
|
||||
try:
|
||||
# List workspace directory to get all notebooks
|
||||
objects = self.client.list_workspace_objects(path)
|
||||
if objects:
|
||||
for obj in objects:
|
||||
obj_type = obj.get("object_type")
|
||||
if obj_type in ("NOTEBOOK", "FILE"):
|
||||
notebook_path = obj.get("path")
|
||||
if notebook_path:
|
||||
expanded_paths.append(notebook_path)
|
||||
logger.info(
|
||||
f"Found {obj_type.lower()} in directory: {notebook_path}"
|
||||
)
|
||||
if not expanded_paths:
|
||||
logger.debug(f"No notebooks found in directory {path}")
|
||||
except Exception as exc:
|
||||
logger.debug(f"Could not list directory {path}: {exc}")
|
||||
else:
|
||||
expanded_paths.append(path)
|
||||
|
||||
logger.info(
|
||||
f"Processing {len(expanded_paths)} notebook(s) for pipeline {pipeline_id}"
|
||||
)
|
||||
|
||||
# Process each notebook to extract Kafka sources and DLT tables
|
||||
for lib_path in expanded_paths:
|
||||
try:
|
||||
source_code = self.client.export_notebook_source(lib_path)
|
||||
if not source_code:
|
||||
logger.debug(f"Could not export source for {lib_path}")
|
||||
continue
|
||||
|
||||
# Extract Kafka topics
|
||||
kafka_sources = extract_kafka_sources(source_code)
|
||||
if kafka_sources:
|
||||
topics_found = [t for ks in kafka_sources for t in ks.topics]
|
||||
logger.info(
|
||||
f"Found {len(kafka_sources)} Kafka sources with topics {topics_found} in {lib_path}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"No Kafka sources found in {lib_path}")
|
||||
|
||||
# Extract DLT table names
|
||||
dlt_table_names = extract_dlt_table_names(source_code)
|
||||
if dlt_table_names:
|
||||
logger.info(
|
||||
f"Found {len(dlt_table_names)} DLT tables in {lib_path}: {dlt_table_names}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"No DLT tables found in {lib_path}")
|
||||
|
||||
if not dlt_table_names or not kafka_sources:
|
||||
logger.debug(
|
||||
f"Skipping Kafka lineage for {lib_path} - need both Kafka sources and DLT tables"
|
||||
)
|
||||
continue
|
||||
|
||||
# Create lineage for each Kafka topic -> DLT table
|
||||
for kafka_config in kafka_sources:
|
||||
for topic_name in kafka_config.topics:
|
||||
try:
|
||||
# Use smart discovery to find topic
|
||||
kafka_topic = self._find_kafka_topic(topic_name)
|
||||
|
||||
if not kafka_topic:
|
||||
logger.debug(
|
||||
f"Kafka topic {topic_name} not found in any messaging service"
|
||||
)
|
||||
continue
|
||||
|
||||
# Create lineage to each DLT table in this notebook
|
||||
for table_name in dlt_table_names:
|
||||
# Build table FQN: catalog.schema.table
|
||||
for (
|
||||
dbservicename
|
||||
) in self.get_db_service_names() or ["*"]:
|
||||
target_table_fqn = fqn.build(
|
||||
metadata=self.metadata,
|
||||
entity_type=Table,
|
||||
table_name=table_name,
|
||||
database_name=target_catalog,
|
||||
schema_name=target_schema,
|
||||
service_name=dbservicename,
|
||||
)
|
||||
|
||||
target_table_entity = self.metadata.get_by_name(
|
||||
entity=Table, fqn=target_table_fqn
|
||||
)
|
||||
|
||||
if target_table_entity:
|
||||
logger.info(
|
||||
f"Creating Kafka lineage: {topic_name} -> {target_catalog}.{target_schema}.{table_name} (via pipeline {pipeline_id})"
|
||||
)
|
||||
|
||||
yield Either(
|
||||
right=AddLineageRequest(
|
||||
edge=EntitiesEdge(
|
||||
fromEntity=EntityReference(
|
||||
id=kafka_topic.id,
|
||||
type="topic",
|
||||
),
|
||||
toEntity=EntityReference(
|
||||
id=target_table_entity.id.root,
|
||||
type="table",
|
||||
),
|
||||
lineageDetails=LineageDetails(
|
||||
pipeline=EntityReference(
|
||||
id=pipeline_entity.id.root,
|
||||
type="pipeline",
|
||||
),
|
||||
source=LineageSource.PipelineLineage,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.debug(
|
||||
f"Target table not found in OpenMetadata: {target_table_fqn}"
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
f"Failed to process topic {topic_name}: {exc}"
|
||||
)
|
||||
continue
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
f"Failed to process library {lib_path}: {exc}. Continuing with next library."
|
||||
)
|
||||
continue
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f"Unexpected error in Kafka lineage extraction for job {pipeline_details.job_id}: {exc}"
|
||||
)
|
||||
logger.debug(traceback.format_exc())
|
||||
|
||||
def yield_pipeline_lineage_details(
|
||||
self, pipeline_details: DataBrickPipelineDetails
|
||||
) -> Iterable[Either[AddLineageRequest]]:
|
||||
@ -315,6 +672,13 @@ class DatabrickspipelineSource(PipelineServiceSource):
|
||||
entity=Pipeline, fqn=pipeline_fqn
|
||||
)
|
||||
|
||||
# Extract Kafka topic lineage from source code
|
||||
# Works automatically - no configuration required!
|
||||
yield from self._yield_kafka_lineage(pipeline_details, pipeline_entity)
|
||||
|
||||
if not pipeline_details.job_id:
|
||||
return
|
||||
|
||||
table_lineage_list = self.client.get_table_lineage(
|
||||
job_id=pipeline_details.job_id
|
||||
)
|
||||
@ -409,7 +773,7 @@ class DatabrickspipelineSource(PipelineServiceSource):
|
||||
except Exception as exc:
|
||||
yield Either(
|
||||
left=StackTraceError(
|
||||
name=pipeline_details.job_id,
|
||||
name=pipeline_details.id,
|
||||
error=f"Wild error ingesting pipeline lineage {pipeline_details} - {exc}",
|
||||
stackTrace=traceback.format_exc(),
|
||||
)
|
||||
|
||||
@ -13,7 +13,7 @@
|
||||
Databricks pipeline Source Model module
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -27,6 +27,20 @@ class DependentTask(BaseModel):
|
||||
name: Optional[str] = Field(None, alias="task_key")
|
||||
|
||||
|
||||
class PipelineTask(BaseModel):
|
||||
pipeline_id: Optional[str] = None
|
||||
full_refresh: Optional[bool] = None
|
||||
|
||||
|
||||
class DBJobTask(BaseModel):
|
||||
name: Optional[str] = Field(None, alias="task_key")
|
||||
description: Optional[str] = None
|
||||
depends_on: Optional[List[DependentTask]] = None
|
||||
pipeline_task: Optional[PipelineTask] = None
|
||||
notebook_task: Optional[Dict[str, Any]] = None
|
||||
spark_python_task: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class DBTasks(BaseModel):
|
||||
name: Optional[str] = Field(None, alias="task_key")
|
||||
description: Optional[str] = None
|
||||
@ -41,13 +55,21 @@ class DBSettings(BaseModel):
|
||||
description: Optional[str] = None
|
||||
schedule: Optional[DBRunSchedule] = None
|
||||
task_type: Optional[str] = Field(None, alias="format")
|
||||
tasks: Optional[List[DBJobTask]] = None
|
||||
|
||||
|
||||
class DataBrickPipelineDetails(BaseModel):
|
||||
job_id: int
|
||||
job_id: Optional[int] = None
|
||||
pipeline_id: Optional[str] = None
|
||||
creator_user_name: Optional[str] = None
|
||||
settings: Optional[DBSettings] = None
|
||||
created_time: int
|
||||
created_time: Optional[int] = None
|
||||
name: Optional[str] = None
|
||||
pipeline_type: Optional[str] = None
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return str(self.pipeline_id) if self.pipeline_id else str(self.job_id)
|
||||
|
||||
|
||||
class DBRunState(BaseModel):
|
||||
|
||||
@ -0,0 +1,380 @@
|
||||
# Copyright 2025 Collate
|
||||
# Licensed under the Collate Community License, Version 1.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Unit tests for Databricks Kafka parser
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from metadata.ingestion.source.pipeline.databrickspipeline.kafka_parser import (
|
||||
extract_kafka_sources,
|
||||
get_pipeline_libraries,
|
||||
)
|
||||
|
||||
|
||||
class TestKafkaParser(unittest.TestCase):
|
||||
"""Test cases for Kafka configuration parsing"""
|
||||
|
||||
def test_basic_kafka_readstream(self):
|
||||
"""Test basic Kafka readStream pattern"""
|
||||
source_code = """
|
||||
df = spark.readStream \\
|
||||
.format("kafka") \\
|
||||
.option("kafka.bootstrap.servers", "broker1:9092") \\
|
||||
.option("subscribe", "events_topic") \\
|
||||
.load()
|
||||
"""
|
||||
configs = extract_kafka_sources(source_code)
|
||||
self.assertEqual(len(configs), 1)
|
||||
self.assertEqual(configs[0].bootstrap_servers, "broker1:9092")
|
||||
self.assertEqual(configs[0].topics, ["events_topic"])
|
||||
|
||||
def test_multiple_topics(self):
|
||||
"""Test comma-separated topics"""
|
||||
source_code = """
|
||||
spark.readStream.format("kafka") \\
|
||||
.option("subscribe", "topic1,topic2,topic3") \\
|
||||
.load()
|
||||
"""
|
||||
configs = extract_kafka_sources(source_code)
|
||||
self.assertEqual(len(configs), 1)
|
||||
self.assertEqual(configs[0].topics, ["topic1", "topic2", "topic3"])
|
||||
|
||||
def test_topics_option(self):
|
||||
"""Test 'topics' option instead of 'subscribe'"""
|
||||
source_code = """
|
||||
df = spark.readStream.format("kafka") \\
|
||||
.option("topics", "single_topic") \\
|
||||
.load()
|
||||
"""
|
||||
configs = extract_kafka_sources(source_code)
|
||||
self.assertEqual(len(configs), 1)
|
||||
self.assertEqual(configs[0].topics, ["single_topic"])
|
||||
|
||||
def test_group_id_prefix(self):
|
||||
"""Test groupIdPrefix extraction"""
|
||||
source_code = """
|
||||
spark.readStream.format("kafka") \\
|
||||
.option("kafka.bootstrap.servers", "localhost:9092") \\
|
||||
.option("subscribe", "test_topic") \\
|
||||
.option("groupIdPrefix", "dlt-pipeline-123") \\
|
||||
.load()
|
||||
"""
|
||||
configs = extract_kafka_sources(source_code)
|
||||
self.assertEqual(len(configs), 1)
|
||||
self.assertEqual(configs[0].group_id_prefix, "dlt-pipeline-123")
|
||||
|
||||
def test_multiple_kafka_sources(self):
|
||||
"""Test multiple Kafka sources in same file"""
|
||||
source_code = """
|
||||
# First stream
|
||||
df1 = spark.readStream.format("kafka") \\
|
||||
.option("subscribe", "topic_a") \\
|
||||
.load()
|
||||
|
||||
# Second stream
|
||||
df2 = spark.readStream.format("kafka") \\
|
||||
.option("subscribe", "topic_b") \\
|
||||
.load()
|
||||
"""
|
||||
configs = extract_kafka_sources(source_code)
|
||||
self.assertEqual(len(configs), 2)
|
||||
topics = [c.topics[0] for c in configs]
|
||||
self.assertIn("topic_a", topics)
|
||||
self.assertIn("topic_b", topics)
|
||||
|
||||
def test_single_quotes(self):
|
||||
"""Test single quotes in options"""
|
||||
source_code = """
|
||||
df = spark.readStream.format('kafka') \\
|
||||
.option('kafka.bootstrap.servers', 'broker:9092') \\
|
||||
.option('subscribe', 'my_topic') \\
|
||||
.load()
|
||||
"""
|
||||
configs = extract_kafka_sources(source_code)
|
||||
self.assertEqual(len(configs), 1)
|
||||
self.assertEqual(configs[0].bootstrap_servers, "broker:9092")
|
||||
self.assertEqual(configs[0].topics, ["my_topic"])
|
||||
|
||||
def test_mixed_quotes(self):
|
||||
"""Test mixed single and double quotes"""
|
||||
source_code = """
|
||||
df = spark.readStream.format("kafka") \\
|
||||
.option('subscribe', "topic_mixed") \\
|
||||
.load()
|
||||
"""
|
||||
configs = extract_kafka_sources(source_code)
|
||||
self.assertEqual(len(configs), 1)
|
||||
self.assertEqual(configs[0].topics, ["topic_mixed"])
|
||||
|
||||
def test_compact_format(self):
|
||||
"""Test compact single-line format"""
|
||||
source_code = """
|
||||
df = spark.readStream.format("kafka").option("subscribe", "compact_topic").load()
|
||||
"""
|
||||
configs = extract_kafka_sources(source_code)
|
||||
self.assertEqual(len(configs), 1)
|
||||
self.assertEqual(configs[0].topics, ["compact_topic"])
|
||||
|
||||
def test_no_kafka_sources(self):
|
||||
"""Test code with no Kafka sources"""
|
||||
source_code = """
|
||||
df = spark.read.parquet("/data/path")
|
||||
df.write.format("delta").save("/output")
|
||||
"""
|
||||
configs = extract_kafka_sources(source_code)
|
||||
self.assertEqual(len(configs), 0)
|
||||
|
||||
def test_partial_kafka_config(self):
|
||||
"""Test Kafka source with only topics (no brokers)"""
|
||||
source_code = """
|
||||
df = spark.readStream.format("kafka") \\
|
||||
.option("subscribe", "topic_only") \\
|
||||
.load()
|
||||
"""
|
||||
configs = extract_kafka_sources(source_code)
|
||||
self.assertEqual(len(configs), 1)
|
||||
self.assertIsNone(configs[0].bootstrap_servers)
|
||||
self.assertEqual(configs[0].topics, ["topic_only"])
|
||||
|
||||
def test_malformed_kafka_incomplete(self):
|
||||
"""Test incomplete Kafka configuration doesn't crash"""
|
||||
source_code = """
|
||||
df = spark.readStream.format("kafka")
|
||||
# No .load() - malformed
|
||||
"""
|
||||
configs = extract_kafka_sources(source_code)
|
||||
# Should return empty list, not crash
|
||||
self.assertEqual(len(configs), 0)
|
||||
|
||||
def test_special_characters_in_topic(self):
|
||||
"""Test topics with special characters"""
|
||||
source_code = """
|
||||
df = spark.readStream.format("kafka") \\
|
||||
.option("subscribe", "topic-with-dashes_and_underscores.dots") \\
|
||||
.load()
|
||||
"""
|
||||
configs = extract_kafka_sources(source_code)
|
||||
self.assertEqual(len(configs), 1)
|
||||
self.assertEqual(configs[0].topics, ["topic-with-dashes_and_underscores.dots"])
|
||||
|
||||
def test_whitespace_variations(self):
|
||||
"""Test various whitespace patterns"""
|
||||
source_code = """
|
||||
df=spark.readStream.format( "kafka" ).option( "subscribe" , "topic" ).load( )
|
||||
"""
|
||||
configs = extract_kafka_sources(source_code)
|
||||
self.assertEqual(len(configs), 1)
|
||||
self.assertEqual(configs[0].topics, ["topic"])
|
||||
|
||||
def test_case_insensitive_format(self):
|
||||
"""Test case insensitive Kafka format"""
|
||||
source_code = """
|
||||
df = spark.readStream.format("KAFKA") \\
|
||||
.option("subscribe", "topic_upper") \\
|
||||
.load()
|
||||
"""
|
||||
configs = extract_kafka_sources(source_code)
|
||||
self.assertEqual(len(configs), 1)
|
||||
self.assertEqual(configs[0].topics, ["topic_upper"])
|
||||
|
||||
def test_dlt_decorator_pattern(self):
|
||||
"""Test DLT table decorator pattern"""
|
||||
source_code = """
|
||||
import dlt
|
||||
|
||||
@dlt.table
|
||||
def bronze_events():
|
||||
return spark.readStream \\
|
||||
.format("kafka") \\
|
||||
.option("kafka.bootstrap.servers", "kafka:9092") \\
|
||||
.option("subscribe", "raw_events") \\
|
||||
.load()
|
||||
"""
|
||||
configs = extract_kafka_sources(source_code)
|
||||
self.assertEqual(len(configs), 1)
|
||||
self.assertEqual(configs[0].topics, ["raw_events"])
|
||||
|
||||
def test_multiline_with_comments(self):
|
||||
"""Test code with inline comments"""
|
||||
source_code = """
|
||||
df = (spark.readStream
|
||||
.format("kafka") # Using Kafka source
|
||||
.option("kafka.bootstrap.servers", "broker:9092") # Broker config
|
||||
.option("subscribe", "commented_topic") # Topic name
|
||||
.load()) # Load the stream
|
||||
"""
|
||||
configs = extract_kafka_sources(source_code)
|
||||
self.assertEqual(len(configs), 1)
|
||||
self.assertEqual(configs[0].topics, ["commented_topic"])
|
||||
|
||||
def test_empty_source_code(self):
|
||||
"""Test empty source code"""
|
||||
configs = extract_kafka_sources("")
|
||||
self.assertEqual(len(configs), 0)
|
||||
|
||||
def test_null_source_code(self):
|
||||
"""Test None source code doesn't crash"""
|
||||
configs = extract_kafka_sources(None)
|
||||
self.assertEqual(len(configs), 0)
|
||||
|
||||
def test_topics_with_whitespace(self):
|
||||
"""Test topics with surrounding whitespace are trimmed"""
|
||||
source_code = """
|
||||
df = spark.readStream.format("kafka") \\
|
||||
.option("subscribe", " topic1 , topic2 , topic3 ") \\
|
||||
.load()
|
||||
"""
|
||||
configs = extract_kafka_sources(source_code)
|
||||
self.assertEqual(len(configs), 1)
|
||||
self.assertEqual(configs[0].topics, ["topic1", "topic2", "topic3"])
|
||||
|
||||
def test_variable_topic_reference(self):
|
||||
"""Test Kafka config with variable reference for topic"""
|
||||
source_code = """
|
||||
TOPIC = "events_topic"
|
||||
df = spark.readStream.format("kafka") \\
|
||||
.option("subscribe", TOPIC) \\
|
||||
.load()
|
||||
"""
|
||||
configs = extract_kafka_sources(source_code)
|
||||
self.assertEqual(len(configs), 1)
|
||||
self.assertEqual(configs[0].topics, ["events_topic"])
|
||||
|
||||
def test_real_world_dlt_pattern(self):
|
||||
"""Test real-world DLT pattern with variables"""
|
||||
source_code = """
|
||||
import dlt
|
||||
from pyspark.sql.functions import *
|
||||
|
||||
TOPIC = "tracker-events"
|
||||
KAFKA_BROKER = spark.conf.get("KAFKA_SERVER")
|
||||
|
||||
raw_kafka_events = (spark.readStream
|
||||
.format("kafka")
|
||||
.option("subscribe", TOPIC)
|
||||
.option("kafka.bootstrap.servers", KAFKA_BROKER)
|
||||
.option("startingOffsets", "earliest")
|
||||
.load()
|
||||
)
|
||||
|
||||
@dlt.table(table_properties={"pipelines.reset.allowed":"false"})
|
||||
def kafka_bronze():
|
||||
return raw_kafka_events
|
||||
"""
|
||||
configs = extract_kafka_sources(source_code)
|
||||
self.assertEqual(len(configs), 1)
|
||||
self.assertEqual(configs[0].topics, ["tracker-events"])
|
||||
|
||||
def test_multiple_variable_topics(self):
|
||||
"""Test multiple topics defined as variables"""
|
||||
source_code = """
|
||||
TOPIC_A = "orders"
|
||||
TOPIC_B = "payments"
|
||||
|
||||
df = spark.readStream.format("kafka") \\
|
||||
.option("subscribe", TOPIC_A) \\
|
||||
.load()
|
||||
|
||||
df2 = spark.readStream.format("kafka") \\
|
||||
.option("topics", TOPIC_B) \\
|
||||
.load()
|
||||
"""
|
||||
configs = extract_kafka_sources(source_code)
|
||||
self.assertEqual(len(configs), 2)
|
||||
topics = [c.topics[0] for c in configs]
|
||||
self.assertIn("orders", topics)
|
||||
self.assertIn("payments", topics)
|
||||
|
||||
def test_variable_not_defined(self):
|
||||
"""Test variable reference without definition"""
|
||||
source_code = """
|
||||
df = spark.readStream.format("kafka") \\
|
||||
.option("subscribe", UNDEFINED_TOPIC) \\
|
||||
.load()
|
||||
"""
|
||||
configs = extract_kafka_sources(source_code)
|
||||
# Should still find Kafka source but with empty topics
|
||||
self.assertEqual(len(configs), 0)
|
||||
|
||||
|
||||
class TestPipelineLibraries(unittest.TestCase):
|
||||
"""Test cases for pipeline library extraction"""
|
||||
|
||||
def test_notebook_library(self):
|
||||
"""Test notebook library extraction"""
|
||||
pipeline_config = {
|
||||
"libraries": [{"notebook": {"path": "/Workspace/dlt/bronze_pipeline"}}]
|
||||
}
|
||||
libraries = get_pipeline_libraries(pipeline_config)
|
||||
self.assertEqual(len(libraries), 1)
|
||||
self.assertEqual(libraries[0], "/Workspace/dlt/bronze_pipeline")
|
||||
|
||||
def test_file_library(self):
|
||||
"""Test file library extraction"""
|
||||
pipeline_config = {
|
||||
"libraries": [{"file": {"path": "/Workspace/scripts/etl.py"}}]
|
||||
}
|
||||
libraries = get_pipeline_libraries(pipeline_config)
|
||||
self.assertEqual(len(libraries), 1)
|
||||
self.assertEqual(libraries[0], "/Workspace/scripts/etl.py")
|
||||
|
||||
def test_mixed_libraries(self):
|
||||
"""Test mixed notebook and file libraries"""
|
||||
pipeline_config = {
|
||||
"libraries": [
|
||||
{"notebook": {"path": "/nb1"}},
|
||||
{"file": {"path": "/file1.py"}},
|
||||
{"notebook": {"path": "/nb2"}},
|
||||
]
|
||||
}
|
||||
libraries = get_pipeline_libraries(pipeline_config)
|
||||
self.assertEqual(len(libraries), 3)
|
||||
self.assertIn("/nb1", libraries)
|
||||
self.assertIn("/file1.py", libraries)
|
||||
self.assertIn("/nb2", libraries)
|
||||
|
||||
def test_empty_libraries(self):
|
||||
"""Test empty libraries list"""
|
||||
pipeline_config = {"libraries": []}
|
||||
libraries = get_pipeline_libraries(pipeline_config)
|
||||
self.assertEqual(len(libraries), 0)
|
||||
|
||||
def test_missing_libraries_key(self):
|
||||
"""Test missing libraries key"""
|
||||
pipeline_config = {}
|
||||
libraries = get_pipeline_libraries(pipeline_config)
|
||||
self.assertEqual(len(libraries), 0)
|
||||
|
||||
def test_library_with_no_path(self):
|
||||
"""Test library entry with no path"""
|
||||
pipeline_config = {"libraries": [{"notebook": {}}, {"file": {}}]}
|
||||
libraries = get_pipeline_libraries(pipeline_config)
|
||||
# Should skip entries without paths
|
||||
self.assertEqual(len(libraries), 0)
|
||||
|
||||
def test_unsupported_library_type(self):
|
||||
"""Test unsupported library types are skipped"""
|
||||
pipeline_config = {
|
||||
"libraries": [
|
||||
{"jar": {"path": "/lib.jar"}}, # Not supported
|
||||
{"notebook": {"path": "/nb"}}, # Supported
|
||||
{"whl": {"path": "/wheel.whl"}}, # Not supported
|
||||
]
|
||||
}
|
||||
libraries = get_pipeline_libraries(pipeline_config)
|
||||
self.assertEqual(len(libraries), 1)
|
||||
self.assertEqual(libraries[0], "/nb")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
x
Reference in New Issue
Block a user