Feat: Add kafka lineage support in databricks pipelines (#23813)

* Add dlt pipeline support

* Fix code style

* Add variable parsing

* Fix kafka lineage

---------

Co-authored-by: Sriharsha Chintalapani <harsha@getcollate.io>
This commit is contained in:
Mayur Singal 2025-10-09 20:12:08 +05:30 committed by GitHub
parent 509295ed39
commit 05f064787f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 1145 additions and 16 deletions

View File

@ -11,6 +11,7 @@
"""
Client to interact with databricks apis
"""
import base64
import json
import traceback
from datetime import timedelta
@ -372,3 +373,122 @@ class DatabricksClient:
continue
self._job_column_lineage_executed = True
logger.debug("Table and column lineage caching completed.")
def get_pipeline_details(self, pipeline_id: str) -> Optional[dict]:
"""
Get DLT pipeline configuration including libraries and notebooks
"""
try:
url = f"{self.base_url}/pipelines/{pipeline_id}"
response = self.client.get(
url,
headers=self.headers,
timeout=self.api_timeout,
)
if response.status_code == 200:
return response.json()
logger.warning(
f"Failed to get pipeline details for {pipeline_id}: {response.status_code}"
)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(f"Error getting pipeline details for {pipeline_id}: {exc}")
return None
def list_pipelines(self) -> Iterable[dict]:
"""
List all DLT (Delta Live Tables) pipelines in the workspace
Uses the Pipelines API (/api/2.0/pipelines)
"""
try:
url = f"{self.base_url}/pipelines"
params = {"max_results": PAGE_SIZE}
response = self.client.get(
url,
params=params,
headers=self.headers,
timeout=self.api_timeout,
)
if response.status_code == 200:
data = response.json()
pipelines = data.get("statuses", [])
logger.info(f"Found {len(pipelines)} DLT pipelines")
yield from pipelines
# Handle pagination if there's a next_page_token
while data.get("next_page_token"):
params["page_token"] = data["next_page_token"]
response = self.client.get(
url,
params=params,
headers=self.headers,
timeout=self.api_timeout,
)
if response.status_code == 200:
data = response.json()
yield from data.get("statuses", [])
else:
break
else:
logger.warning(
f"Failed to list pipelines: {response.status_code} - {response.text}"
)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(f"Error listing DLT pipelines: {exc}")
def list_workspace_objects(self, path: str) -> List[dict]:
"""
List objects in a Databricks workspace directory
"""
try:
url = f"{self.base_url}/workspace/list"
params = {"path": path}
response = self.client.get(
url,
params=params,
headers=self.headers,
timeout=self.api_timeout,
)
if response.status_code == 200:
return response.json().get("objects", [])
else:
logger.warning(
f"Failed to list workspace directory {path}: {response.text}"
)
return []
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(f"Error listing workspace directory {path}: {exc}")
return []
def export_notebook_source(self, notebook_path: str) -> Optional[str]:
"""
Export notebook source code from Databricks workspace
"""
try:
url = f"{self.base_url}/workspace/export"
params = {"path": notebook_path, "format": "SOURCE"}
response = self.client.get(
url,
params=params,
headers=self.headers,
timeout=self.api_timeout,
)
if response.status_code == 200:
content = response.json().get("content")
if content:
return base64.b64decode(content).decode("utf-8")
logger.warning(
f"Failed to export notebook {notebook_path}: {response.status_code}"
)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(f"Error exporting notebook {notebook_path}: {exc}")
return None

View File

@ -0,0 +1,243 @@
# Copyright 2025 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Kafka configuration parser for Databricks DLT pipelines
"""
import re
from dataclasses import dataclass, field
from typing import List, Optional
from metadata.utils.logger import ingestion_logger
logger = ingestion_logger()
# Compile regex patterns at module level for performance
KAFKA_STREAM_PATTERN = re.compile(
r'\.format\s*\(\s*["\']kafka["\']\s*\)(.*?)\.load\s*\(\s*\)',
re.DOTALL | re.IGNORECASE,
)
# Pattern to extract variable assignments like: TOPIC = "tracker-events"
VARIABLE_ASSIGNMENT_PATTERN = re.compile(
r'^\s*([A-Z_][A-Z0-9_]*)\s*=\s*["\']([^"\']+)["\']\s*$',
re.MULTILINE,
)
# Pattern to extract DLT table decorators: @dlt.table(name="table_name", ...)
DLT_TABLE_PATTERN = re.compile(
r'@dlt\.table\s*\(\s*(?:.*?name\s*=\s*["\']([^"\']+)["\'])?',
re.DOTALL | re.IGNORECASE,
)
@dataclass
class KafkaSourceConfig:
"""Model for Kafka source configuration extracted from DLT code"""
bootstrap_servers: Optional[str] = None
topics: List[str] = field(default_factory=list)
group_id_prefix: Optional[str] = None
def _extract_variables(source_code: str) -> dict:
"""
Extract variable assignments from source code
Examples:
TOPIC = "events"
KAFKA_BROKER = "localhost:9092"
Returns dict like: {"TOPIC": "events", "KAFKA_BROKER": "localhost:9092"}
"""
variables = {}
try:
for match in VARIABLE_ASSIGNMENT_PATTERN.finditer(source_code):
var_name = match.group(1)
var_value = match.group(2)
variables[var_name] = var_value
logger.debug(f"Found variable: {var_name} = {var_value}")
except Exception as exc:
logger.debug(f"Error extracting variables: {exc}")
return variables
def extract_kafka_sources(source_code: str) -> List[KafkaSourceConfig]:
"""
Extract Kafka topic configurations from DLT source code
Parses patterns like:
- spark.readStream.format("kafka").option("subscribe", "topic1,topic2")
- .option("kafka.bootstrap.servers", "broker:9092")
- .option("groupIdPrefix", "dlt-pipeline")
Also supports variable references:
- TOPIC = "events"
- .option("subscribe", TOPIC)
Returns empty list if parsing fails or no sources found
"""
kafka_configs = []
try:
if not source_code:
logger.debug("Empty or None source code provided")
return kafka_configs
# Extract variable assignments for resolution
variables = _extract_variables(source_code)
for match in KAFKA_STREAM_PATTERN.finditer(source_code):
try:
config_block = match.group(1)
bootstrap_servers = _extract_option(
config_block, r"kafka\.bootstrap\.servers", variables
)
subscribe_topics = _extract_option(
config_block, r"subscribe", variables
)
topics = _extract_option(config_block, r"topics", variables)
group_id_prefix = _extract_option(
config_block, r"groupIdPrefix", variables
)
topic_list = []
if subscribe_topics:
topic_list = [
t.strip() for t in subscribe_topics.split(",") if t.strip()
]
elif topics:
topic_list = [t.strip() for t in topics.split(",") if t.strip()]
if bootstrap_servers or topic_list:
kafka_config = KafkaSourceConfig(
bootstrap_servers=bootstrap_servers,
topics=topic_list,
group_id_prefix=group_id_prefix,
)
kafka_configs.append(kafka_config)
logger.debug(
f"Extracted Kafka config: brokers={bootstrap_servers}, "
f"topics={topic_list}, group_prefix={group_id_prefix}"
)
except Exception as exc:
logger.warning(f"Failed to parse individual Kafka config block: {exc}")
continue
except Exception as exc:
logger.warning(f"Error parsing Kafka sources from code: {exc}")
return kafka_configs
def _extract_option(
config_block: str, option_name: str, variables: dict = None
) -> Optional[str]:
"""
Extract a single option value from Kafka configuration block
Supports both string literals and variable references
Safely handles any parsing errors
"""
if variables is None:
variables = {}
try:
# Try matching quoted string literal: .option("subscribe", "topic")
pattern_literal = (
rf'\.option\s*\(\s*["\']({option_name})["\']\s*,\s*["\']([^"\']+)["\']\s*\)'
)
match = re.search(pattern_literal, config_block, re.IGNORECASE)
if match:
return match.group(2)
# Try matching variable reference: .option("subscribe", TOPIC)
pattern_variable = (
rf'\.option\s*\(\s*["\']({option_name})["\']\s*,\s*([A-Z_][A-Z0-9_]*)\s*\)'
)
match = re.search(pattern_variable, config_block, re.IGNORECASE)
if match:
var_name = match.group(2)
# Resolve variable
if var_name in variables:
logger.debug(
f"Resolved variable {var_name} = {variables[var_name]} for option {option_name}"
)
return variables[var_name]
else:
logger.debug(
f"Variable {var_name} referenced but not found in source code"
)
except Exception as exc:
logger.debug(f"Failed to extract option {option_name}: {exc}")
return None
def extract_dlt_table_names(source_code: str) -> List[str]:
"""
Extract DLT table names from @dlt.table decorators
Parses patterns like:
- @dlt.table(name="user_events_bronze_pl", ...)
- @dlt.table(comment="...", name="my_table")
Returns list of table names found in decorators
"""
table_names = []
try:
if not source_code:
logger.debug("Empty or None source code provided")
return table_names
for match in DLT_TABLE_PATTERN.finditer(source_code):
table_name = match.group(1)
if table_name:
table_names.append(table_name)
logger.debug(f"Found DLT table: {table_name}")
except Exception as exc:
logger.warning(f"Error parsing DLT table names from code: {exc}")
return table_names
def get_pipeline_libraries(pipeline_config: dict, client=None) -> List[str]:
"""
Extract notebook and file paths from pipeline configuration
Safely handles missing or malformed configuration
"""
libraries = []
try:
if not pipeline_config:
return libraries
for lib in pipeline_config.get("libraries", []):
try:
if "notebook" in lib:
notebook_path = lib["notebook"].get("path")
if notebook_path:
libraries.append(notebook_path)
elif "file" in lib:
file_path = lib["file"].get("path")
if file_path:
libraries.append(file_path)
except Exception as exc:
logger.debug(f"Failed to process library entry {lib}: {exc}")
continue
except Exception as exc:
logger.warning(f"Error extracting pipeline libraries: {exc}")
return libraries

View File

@ -28,6 +28,7 @@ from metadata.generated.schema.entity.data.pipeline import (
TaskStatus,
)
from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.entity.data.topic import Topic
from metadata.generated.schema.entity.services.connections.pipeline.databricksPipelineConnection import (
DatabricksPipelineConnection,
)
@ -56,6 +57,10 @@ from metadata.ingestion.api.steps import InvalidSourceException
from metadata.ingestion.lineage.sql_lineage import get_column_fqn
from metadata.ingestion.models.pipeline_status import OMetaPipelineStatus
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.pipeline.databrickspipeline.kafka_parser import (
extract_dlt_table_names,
extract_kafka_sources,
)
from metadata.ingestion.source.pipeline.databrickspipeline.models import (
DataBrickPipelineDetails,
DBRun,
@ -105,14 +110,30 @@ class DatabrickspipelineSource(PipelineServiceSource):
yield DataBrickPipelineDetails(**workflow)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.error(f"Failed to get pipeline list due to : {exc}")
logger.error(f"Failed to get jobs list due to : {exc}")
# Fetch DLT pipelines directly (new)
try:
for pipeline in self.client.list_pipelines() or []:
try:
yield DataBrickPipelineDetails(**pipeline)
except Exception as exc:
logger.debug(f"Error creating DLT pipeline details: {exc}")
logger.debug(traceback.format_exc())
continue
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(f"Failed to get DLT pipelines list due to : {exc}")
return None
def get_pipeline_name(
self, pipeline_details: DataBrickPipelineDetails
) -> Optional[str]:
try:
return pipeline_details.settings.name
if pipeline_details.pipeline_id:
return pipeline_details.name
return pipeline_details.settings.name if pipeline_details.settings else None
except Exception as exc:
logger.debug(traceback.format_exc())
logger.error(f"Failed to get pipeline name due to : {exc}")
@ -124,17 +145,35 @@ class DatabrickspipelineSource(PipelineServiceSource):
) -> Iterable[Either[CreatePipelineRequest]]:
"""Method to Get Pipeline Entity"""
try:
description = pipeline_details.settings.description
if pipeline_details.pipeline_id:
description = None
display_name = pipeline_details.name
entity_name = str(pipeline_details.pipeline_id)
schedule_interval = None
else:
description = (
pipeline_details.settings.description
if pipeline_details.settings
else None
)
display_name = (
pipeline_details.settings.name
if pipeline_details.settings
else None
)
entity_name = str(pipeline_details.job_id)
schedule_interval = (
str(pipeline_details.settings.schedule.cron)
if pipeline_details.settings and pipeline_details.settings.schedule
else None
)
pipeline_request = CreatePipelineRequest(
name=EntityName(str(pipeline_details.job_id)),
displayName=pipeline_details.settings.name,
name=EntityName(entity_name),
displayName=display_name,
description=Markdown(description) if description else None,
tasks=self.get_tasks(pipeline_details),
scheduleInterval=(
str(pipeline_details.settings.schedule.cron)
if pipeline_details.settings.schedule
else None
),
scheduleInterval=schedule_interval,
service=FullyQualifiedEntityName(self.context.get().pipeline_service),
)
yield Either(right=pipeline_request)
@ -170,6 +209,9 @@ class DatabrickspipelineSource(PipelineServiceSource):
def get_tasks(self, pipeline_details: DataBrickPipelineDetails) -> List[Task]:
try:
if not pipeline_details.job_id:
return []
task_list = []
for run in self.client.get_job_runs(job_id=pipeline_details.job_id) or []:
run = DBRun(**run)
@ -177,7 +219,11 @@ class DatabrickspipelineSource(PipelineServiceSource):
[
Task(
name=str(task.name),
taskType=pipeline_details.settings.task_type,
taskType=(
pipeline_details.settings.task_type
if pipeline_details.settings
else None
),
sourceUrl=(
SourceUrl(run.run_page_url)
if run.run_page_url
@ -204,6 +250,9 @@ class DatabrickspipelineSource(PipelineServiceSource):
self, pipeline_details: DataBrickPipelineDetails
) -> Iterable[OMetaPipelineStatus]:
try:
if not pipeline_details.job_id:
return
for run in self.client.get_job_runs(job_id=pipeline_details.job_id) or []:
run = DBRun(**run)
task_status = [
@ -241,7 +290,7 @@ class DatabrickspipelineSource(PipelineServiceSource):
except Exception as exc:
yield Either(
left=StackTraceError(
name=pipeline_details.job_id,
name=pipeline_details.id,
error=f"Failed to yield pipeline status: {exc}",
stackTrace=traceback.format_exc(),
)
@ -300,6 +349,314 @@ class DatabrickspipelineSource(PipelineServiceSource):
)
return processed_column_lineage or []
def _find_kafka_topic(self, topic_name: str) -> Optional[Topic]:
"""
Find Kafka topic in OpenMetadata using smart discovery
Strategy:
1. If messagingServiceNames configured -> search only those (faster)
2. Else -> search ALL messaging services using search API
"""
# Strategy 1: Search configured services (fast path)
try:
topic_fqn = fqn.build(
metadata=self.metadata,
entity_type=Topic,
service_name=None,
topic_name=topic_name,
skip_es_search=False,
)
topic = self.metadata.get_by_name(entity=Topic, fqn=topic_fqn)
if topic:
logger.debug(f"Found topic {topic_name}")
return topic
except Exception as exc:
logger.debug(f"Could not find topic {topic_name}: {exc}")
logger.debug(f"Topic {topic_name} not found")
return None
def _yield_kafka_lineage(
self, pipeline_details: DataBrickPipelineDetails, pipeline_entity: Pipeline
) -> Iterable[Either[AddLineageRequest]]:
"""
Extract and yield Kafka topic lineage from DLT pipeline source code
Only processes DLT pipelines (with pipeline_id), not regular jobs
Creates lineage: Kafka topic -> DLT table (with pipeline in lineageDetails)
"""
try:
# Only process DLT pipelines - check for pipeline_id
# For pure DLT pipelines, pipeline_id is set directly
pipeline_id = pipeline_details.pipeline_id
# For jobs with DLT pipeline tasks, check settings
if not pipeline_id and pipeline_details.settings:
try:
tasks = pipeline_details.settings.tasks
logger.debug(
f"Checking for DLT pipeline in job {pipeline_details.job_id}: "
f"{len(tasks) if tasks else 0} tasks found"
)
if tasks:
for task in tasks:
logger.debug(
f"Task: {task.name}, has pipeline_task: {task.pipeline_task is not None}"
)
# Check for direct DLT pipeline task
if task.pipeline_task and task.pipeline_task.pipeline_id:
pipeline_id = task.pipeline_task.pipeline_id
logger.info(
f"Found DLT pipeline_id from job task: {pipeline_id} for job {pipeline_details.job_id}"
)
break
except Exception as exc:
logger.debug(f"Error checking for pipeline tasks: {exc}")
logger.debug(traceback.format_exc())
# Only process if we have a DLT pipeline_id
if not pipeline_id:
logger.debug(
f"No DLT pipeline_id found for {pipeline_details.job_id or pipeline_details.pipeline_id}, skipping Kafka lineage"
)
return
logger.info(f"Processing Kafka lineage for DLT pipeline: {pipeline_id}")
# Get pipeline configuration and extract target catalog/schema
target_catalog = None
target_schema = None
notebook_paths = []
try:
pipeline_config = self.client.get_pipeline_details(pipeline_id)
if not pipeline_config:
logger.debug(f"Could not fetch pipeline config for {pipeline_id}")
return
# Extract spec for detailed configuration
spec = pipeline_config.get("spec", {})
logger.info(
f"Pipeline spec keys: {list(spec.keys()) if spec else 'None'}"
)
# Extract target catalog and schema for DLT tables
target_catalog = spec.get("catalog") if spec else None
# Schema can be in 'target' or 'schema' field
target_schema = (
spec.get("target") or spec.get("schema") if spec else None
)
logger.debug(
f"DLT pipeline target: catalog={target_catalog}, schema={target_schema}"
)
# Extract notebook/file paths from libraries in spec
notebook_paths = []
if spec and "libraries" in spec:
libraries = spec["libraries"]
logger.info(f"Found {len(libraries)} libraries in spec")
for lib in libraries:
# Library can be dict or have different structures
if isinstance(lib, dict):
# Check for notebook path
if "notebook" in lib and lib["notebook"]:
notebook = lib["notebook"]
if isinstance(notebook, dict):
path = notebook.get("path")
else:
path = notebook
if path:
notebook_paths.append(path)
logger.info(f"Found notebook in library: {path}")
# Check for glob pattern
elif "glob" in lib and lib["glob"]:
glob_pattern = lib["glob"]
if isinstance(glob_pattern, dict):
include_pattern = glob_pattern.get("include")
if include_pattern:
# Convert glob pattern to directory path
# e.g., "/path/**" -> "/path/"
base_path = include_pattern.replace(
"/**", "/"
).replace("**", "")
notebook_paths.append(base_path)
logger.info(
f"Found glob pattern, using base path: {base_path}"
)
# Also check for source path in spec configuration
if not notebook_paths and spec:
source_path = None
# Check spec.configuration for source path
if "configuration" in spec:
config = spec["configuration"]
source_path = config.get("source_path") or config.get("source")
# Check development settings
if not source_path and "development" in spec:
source_path = spec["development"].get("source_path")
if source_path:
logger.info(
f"Found source_path in pipeline spec: {source_path}"
)
notebook_paths.append(source_path)
logger.debug(
f"Found {len(notebook_paths)} notebook paths for pipeline {pipeline_id}"
)
except Exception as exc:
logger.warning(
f"Failed to fetch pipeline config for {pipeline_id}: {exc}"
)
return
if not notebook_paths:
logger.debug(f"No notebook paths found for pipeline {pipeline_id}")
return
# Expand directories to individual notebook files
expanded_paths = []
for path in notebook_paths:
# If path ends with /, it's a directory - list all notebooks in it
if path.endswith("/"):
try:
# List workspace directory to get all notebooks
objects = self.client.list_workspace_objects(path)
if objects:
for obj in objects:
obj_type = obj.get("object_type")
if obj_type in ("NOTEBOOK", "FILE"):
notebook_path = obj.get("path")
if notebook_path:
expanded_paths.append(notebook_path)
logger.info(
f"Found {obj_type.lower()} in directory: {notebook_path}"
)
if not expanded_paths:
logger.debug(f"No notebooks found in directory {path}")
except Exception as exc:
logger.debug(f"Could not list directory {path}: {exc}")
else:
expanded_paths.append(path)
logger.info(
f"Processing {len(expanded_paths)} notebook(s) for pipeline {pipeline_id}"
)
# Process each notebook to extract Kafka sources and DLT tables
for lib_path in expanded_paths:
try:
source_code = self.client.export_notebook_source(lib_path)
if not source_code:
logger.debug(f"Could not export source for {lib_path}")
continue
# Extract Kafka topics
kafka_sources = extract_kafka_sources(source_code)
if kafka_sources:
topics_found = [t for ks in kafka_sources for t in ks.topics]
logger.info(
f"Found {len(kafka_sources)} Kafka sources with topics {topics_found} in {lib_path}"
)
else:
logger.debug(f"No Kafka sources found in {lib_path}")
# Extract DLT table names
dlt_table_names = extract_dlt_table_names(source_code)
if dlt_table_names:
logger.info(
f"Found {len(dlt_table_names)} DLT tables in {lib_path}: {dlt_table_names}"
)
else:
logger.debug(f"No DLT tables found in {lib_path}")
if not dlt_table_names or not kafka_sources:
logger.debug(
f"Skipping Kafka lineage for {lib_path} - need both Kafka sources and DLT tables"
)
continue
# Create lineage for each Kafka topic -> DLT table
for kafka_config in kafka_sources:
for topic_name in kafka_config.topics:
try:
# Use smart discovery to find topic
kafka_topic = self._find_kafka_topic(topic_name)
if not kafka_topic:
logger.debug(
f"Kafka topic {topic_name} not found in any messaging service"
)
continue
# Create lineage to each DLT table in this notebook
for table_name in dlt_table_names:
# Build table FQN: catalog.schema.table
for (
dbservicename
) in self.get_db_service_names() or ["*"]:
target_table_fqn = fqn.build(
metadata=self.metadata,
entity_type=Table,
table_name=table_name,
database_name=target_catalog,
schema_name=target_schema,
service_name=dbservicename,
)
target_table_entity = self.metadata.get_by_name(
entity=Table, fqn=target_table_fqn
)
if target_table_entity:
logger.info(
f"Creating Kafka lineage: {topic_name} -> {target_catalog}.{target_schema}.{table_name} (via pipeline {pipeline_id})"
)
yield Either(
right=AddLineageRequest(
edge=EntitiesEdge(
fromEntity=EntityReference(
id=kafka_topic.id,
type="topic",
),
toEntity=EntityReference(
id=target_table_entity.id.root,
type="table",
),
lineageDetails=LineageDetails(
pipeline=EntityReference(
id=pipeline_entity.id.root,
type="pipeline",
),
source=LineageSource.PipelineLineage,
),
)
)
)
break
else:
logger.debug(
f"Target table not found in OpenMetadata: {target_table_fqn}"
)
except Exception as exc:
logger.warning(
f"Failed to process topic {topic_name}: {exc}"
)
continue
except Exception as exc:
logger.warning(
f"Failed to process library {lib_path}: {exc}. Continuing with next library."
)
continue
except Exception as exc:
logger.error(
f"Unexpected error in Kafka lineage extraction for job {pipeline_details.job_id}: {exc}"
)
logger.debug(traceback.format_exc())
def yield_pipeline_lineage_details(
self, pipeline_details: DataBrickPipelineDetails
) -> Iterable[Either[AddLineageRequest]]:
@ -315,6 +672,13 @@ class DatabrickspipelineSource(PipelineServiceSource):
entity=Pipeline, fqn=pipeline_fqn
)
# Extract Kafka topic lineage from source code
# Works automatically - no configuration required!
yield from self._yield_kafka_lineage(pipeline_details, pipeline_entity)
if not pipeline_details.job_id:
return
table_lineage_list = self.client.get_table_lineage(
job_id=pipeline_details.job_id
)
@ -409,7 +773,7 @@ class DatabrickspipelineSource(PipelineServiceSource):
except Exception as exc:
yield Either(
left=StackTraceError(
name=pipeline_details.job_id,
name=pipeline_details.id,
error=f"Wild error ingesting pipeline lineage {pipeline_details} - {exc}",
stackTrace=traceback.format_exc(),
)

View File

@ -13,7 +13,7 @@
Databricks pipeline Source Model module
"""
from typing import List, Optional
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
@ -27,6 +27,20 @@ class DependentTask(BaseModel):
name: Optional[str] = Field(None, alias="task_key")
class PipelineTask(BaseModel):
pipeline_id: Optional[str] = None
full_refresh: Optional[bool] = None
class DBJobTask(BaseModel):
name: Optional[str] = Field(None, alias="task_key")
description: Optional[str] = None
depends_on: Optional[List[DependentTask]] = None
pipeline_task: Optional[PipelineTask] = None
notebook_task: Optional[Dict[str, Any]] = None
spark_python_task: Optional[Dict[str, Any]] = None
class DBTasks(BaseModel):
name: Optional[str] = Field(None, alias="task_key")
description: Optional[str] = None
@ -41,13 +55,21 @@ class DBSettings(BaseModel):
description: Optional[str] = None
schedule: Optional[DBRunSchedule] = None
task_type: Optional[str] = Field(None, alias="format")
tasks: Optional[List[DBJobTask]] = None
class DataBrickPipelineDetails(BaseModel):
job_id: int
job_id: Optional[int] = None
pipeline_id: Optional[str] = None
creator_user_name: Optional[str] = None
settings: Optional[DBSettings] = None
created_time: int
created_time: Optional[int] = None
name: Optional[str] = None
pipeline_type: Optional[str] = None
@property
def id(self) -> str:
return str(self.pipeline_id) if self.pipeline_id else str(self.job_id)
class DBRunState(BaseModel):

View File

@ -0,0 +1,380 @@
# Copyright 2025 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Unit tests for Databricks Kafka parser
"""
import unittest
from metadata.ingestion.source.pipeline.databrickspipeline.kafka_parser import (
extract_kafka_sources,
get_pipeline_libraries,
)
class TestKafkaParser(unittest.TestCase):
"""Test cases for Kafka configuration parsing"""
def test_basic_kafka_readstream(self):
"""Test basic Kafka readStream pattern"""
source_code = """
df = spark.readStream \\
.format("kafka") \\
.option("kafka.bootstrap.servers", "broker1:9092") \\
.option("subscribe", "events_topic") \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].bootstrap_servers, "broker1:9092")
self.assertEqual(configs[0].topics, ["events_topic"])
def test_multiple_topics(self):
"""Test comma-separated topics"""
source_code = """
spark.readStream.format("kafka") \\
.option("subscribe", "topic1,topic2,topic3") \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].topics, ["topic1", "topic2", "topic3"])
def test_topics_option(self):
"""Test 'topics' option instead of 'subscribe'"""
source_code = """
df = spark.readStream.format("kafka") \\
.option("topics", "single_topic") \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].topics, ["single_topic"])
def test_group_id_prefix(self):
"""Test groupIdPrefix extraction"""
source_code = """
spark.readStream.format("kafka") \\
.option("kafka.bootstrap.servers", "localhost:9092") \\
.option("subscribe", "test_topic") \\
.option("groupIdPrefix", "dlt-pipeline-123") \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].group_id_prefix, "dlt-pipeline-123")
def test_multiple_kafka_sources(self):
"""Test multiple Kafka sources in same file"""
source_code = """
# First stream
df1 = spark.readStream.format("kafka") \\
.option("subscribe", "topic_a") \\
.load()
# Second stream
df2 = spark.readStream.format("kafka") \\
.option("subscribe", "topic_b") \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 2)
topics = [c.topics[0] for c in configs]
self.assertIn("topic_a", topics)
self.assertIn("topic_b", topics)
def test_single_quotes(self):
"""Test single quotes in options"""
source_code = """
df = spark.readStream.format('kafka') \\
.option('kafka.bootstrap.servers', 'broker:9092') \\
.option('subscribe', 'my_topic') \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].bootstrap_servers, "broker:9092")
self.assertEqual(configs[0].topics, ["my_topic"])
def test_mixed_quotes(self):
"""Test mixed single and double quotes"""
source_code = """
df = spark.readStream.format("kafka") \\
.option('subscribe', "topic_mixed") \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].topics, ["topic_mixed"])
def test_compact_format(self):
"""Test compact single-line format"""
source_code = """
df = spark.readStream.format("kafka").option("subscribe", "compact_topic").load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].topics, ["compact_topic"])
def test_no_kafka_sources(self):
"""Test code with no Kafka sources"""
source_code = """
df = spark.read.parquet("/data/path")
df.write.format("delta").save("/output")
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 0)
def test_partial_kafka_config(self):
"""Test Kafka source with only topics (no brokers)"""
source_code = """
df = spark.readStream.format("kafka") \\
.option("subscribe", "topic_only") \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertIsNone(configs[0].bootstrap_servers)
self.assertEqual(configs[0].topics, ["topic_only"])
def test_malformed_kafka_incomplete(self):
"""Test incomplete Kafka configuration doesn't crash"""
source_code = """
df = spark.readStream.format("kafka")
# No .load() - malformed
"""
configs = extract_kafka_sources(source_code)
# Should return empty list, not crash
self.assertEqual(len(configs), 0)
def test_special_characters_in_topic(self):
"""Test topics with special characters"""
source_code = """
df = spark.readStream.format("kafka") \\
.option("subscribe", "topic-with-dashes_and_underscores.dots") \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].topics, ["topic-with-dashes_and_underscores.dots"])
def test_whitespace_variations(self):
"""Test various whitespace patterns"""
source_code = """
df=spark.readStream.format( "kafka" ).option( "subscribe" , "topic" ).load( )
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].topics, ["topic"])
def test_case_insensitive_format(self):
"""Test case insensitive Kafka format"""
source_code = """
df = spark.readStream.format("KAFKA") \\
.option("subscribe", "topic_upper") \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].topics, ["topic_upper"])
def test_dlt_decorator_pattern(self):
"""Test DLT table decorator pattern"""
source_code = """
import dlt
@dlt.table
def bronze_events():
return spark.readStream \\
.format("kafka") \\
.option("kafka.bootstrap.servers", "kafka:9092") \\
.option("subscribe", "raw_events") \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].topics, ["raw_events"])
def test_multiline_with_comments(self):
"""Test code with inline comments"""
source_code = """
df = (spark.readStream
.format("kafka") # Using Kafka source
.option("kafka.bootstrap.servers", "broker:9092") # Broker config
.option("subscribe", "commented_topic") # Topic name
.load()) # Load the stream
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].topics, ["commented_topic"])
def test_empty_source_code(self):
"""Test empty source code"""
configs = extract_kafka_sources("")
self.assertEqual(len(configs), 0)
def test_null_source_code(self):
"""Test None source code doesn't crash"""
configs = extract_kafka_sources(None)
self.assertEqual(len(configs), 0)
def test_topics_with_whitespace(self):
"""Test topics with surrounding whitespace are trimmed"""
source_code = """
df = spark.readStream.format("kafka") \\
.option("subscribe", " topic1 , topic2 , topic3 ") \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].topics, ["topic1", "topic2", "topic3"])
def test_variable_topic_reference(self):
"""Test Kafka config with variable reference for topic"""
source_code = """
TOPIC = "events_topic"
df = spark.readStream.format("kafka") \\
.option("subscribe", TOPIC) \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].topics, ["events_topic"])
def test_real_world_dlt_pattern(self):
"""Test real-world DLT pattern with variables"""
source_code = """
import dlt
from pyspark.sql.functions import *
TOPIC = "tracker-events"
KAFKA_BROKER = spark.conf.get("KAFKA_SERVER")
raw_kafka_events = (spark.readStream
.format("kafka")
.option("subscribe", TOPIC)
.option("kafka.bootstrap.servers", KAFKA_BROKER)
.option("startingOffsets", "earliest")
.load()
)
@dlt.table(table_properties={"pipelines.reset.allowed":"false"})
def kafka_bronze():
return raw_kafka_events
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 1)
self.assertEqual(configs[0].topics, ["tracker-events"])
def test_multiple_variable_topics(self):
"""Test multiple topics defined as variables"""
source_code = """
TOPIC_A = "orders"
TOPIC_B = "payments"
df = spark.readStream.format("kafka") \\
.option("subscribe", TOPIC_A) \\
.load()
df2 = spark.readStream.format("kafka") \\
.option("topics", TOPIC_B) \\
.load()
"""
configs = extract_kafka_sources(source_code)
self.assertEqual(len(configs), 2)
topics = [c.topics[0] for c in configs]
self.assertIn("orders", topics)
self.assertIn("payments", topics)
def test_variable_not_defined(self):
"""Test variable reference without definition"""
source_code = """
df = spark.readStream.format("kafka") \\
.option("subscribe", UNDEFINED_TOPIC) \\
.load()
"""
configs = extract_kafka_sources(source_code)
# Should still find Kafka source but with empty topics
self.assertEqual(len(configs), 0)
class TestPipelineLibraries(unittest.TestCase):
"""Test cases for pipeline library extraction"""
def test_notebook_library(self):
"""Test notebook library extraction"""
pipeline_config = {
"libraries": [{"notebook": {"path": "/Workspace/dlt/bronze_pipeline"}}]
}
libraries = get_pipeline_libraries(pipeline_config)
self.assertEqual(len(libraries), 1)
self.assertEqual(libraries[0], "/Workspace/dlt/bronze_pipeline")
def test_file_library(self):
"""Test file library extraction"""
pipeline_config = {
"libraries": [{"file": {"path": "/Workspace/scripts/etl.py"}}]
}
libraries = get_pipeline_libraries(pipeline_config)
self.assertEqual(len(libraries), 1)
self.assertEqual(libraries[0], "/Workspace/scripts/etl.py")
def test_mixed_libraries(self):
"""Test mixed notebook and file libraries"""
pipeline_config = {
"libraries": [
{"notebook": {"path": "/nb1"}},
{"file": {"path": "/file1.py"}},
{"notebook": {"path": "/nb2"}},
]
}
libraries = get_pipeline_libraries(pipeline_config)
self.assertEqual(len(libraries), 3)
self.assertIn("/nb1", libraries)
self.assertIn("/file1.py", libraries)
self.assertIn("/nb2", libraries)
def test_empty_libraries(self):
"""Test empty libraries list"""
pipeline_config = {"libraries": []}
libraries = get_pipeline_libraries(pipeline_config)
self.assertEqual(len(libraries), 0)
def test_missing_libraries_key(self):
"""Test missing libraries key"""
pipeline_config = {}
libraries = get_pipeline_libraries(pipeline_config)
self.assertEqual(len(libraries), 0)
def test_library_with_no_path(self):
"""Test library entry with no path"""
pipeline_config = {"libraries": [{"notebook": {}}, {"file": {}}]}
libraries = get_pipeline_libraries(pipeline_config)
# Should skip entries without paths
self.assertEqual(len(libraries), 0)
def test_unsupported_library_type(self):
"""Test unsupported library types are skipped"""
pipeline_config = {
"libraries": [
{"jar": {"path": "/lib.jar"}}, # Not supported
{"notebook": {"path": "/nb"}}, # Supported
{"whl": {"path": "/wheel.whl"}}, # Not supported
]
}
libraries = get_pipeline_libraries(pipeline_config)
self.assertEqual(len(libraries), 1)
self.assertEqual(libraries[0], "/nb")
if __name__ == "__main__":
unittest.main()