MINOR: Improve threading for lineage (#20668)

This commit is contained in:
Mayur Singal 2025-04-07 18:31:52 +05:30 committed by GitHub
parent f7c4cc54f4
commit b7d43e7ee2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 129 additions and 56 deletions

View File

@ -447,8 +447,14 @@ class LineageParser:
if lr_sqlfluff:
return lr_sqlfluff
lr_sqlparser = LineageRunner(query)
@timeout(seconds=timeout_seconds)
def get_sqlparser_lineage_runner(qry: str) -> LineageRunner:
lr_sqlparser = LineageRunner(qry)
lr_sqlparser.get_column_lineage()
return lr_sqlparser
try:
lr_sqlparser = get_sqlparser_lineage_runner(query)
_ = len(lr_sqlparser.get_column_lineage()) + len(
set(lr_sqlparser.source_tables).union(
set(lr_sqlparser.target_tables).union(
@ -456,6 +462,13 @@ class LineageParser:
)
)
)
except TimeoutError:
self.query_parsing_success = False
self.query_parsing_failure_reason = (
f"Lineage with SqlParser failed for the [{dialect.value}]. "
f"Parser has been running for more than {timeout_seconds} seconds."
)
return None
except Exception:
# if both runner have failed we return the usual one
logger.debug(f"Failed to parse query with sqlparse & sqlfluff: {query}")

View File

@ -14,6 +14,7 @@ Mixin class containing Lineage specific methods
To be used by OpenMetadata class
"""
import functools
import json
import traceback
from copy import deepcopy
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
@ -407,22 +408,26 @@ class OMetaLineageMixin(Generic[T]):
f"Error while adding lineage: {lineage_request.left.error}"
)
@functools.lru_cache(maxsize=LRU_CACHE_SIZE)
def patch_lineage_processed_flag(
self,
entity: Type[T],
fqn: str,
) -> None:
"""
Patch the processed lineage flag for an entity
"""
try:
original_entity = self.get_by_name(entity=entity, fqn=fqn)
if not original_entity:
return
updated_entity = original_entity.model_copy(deep=True)
updated_entity.processedLineage = True
self.patch(
entity=entity, source=original_entity, destination=updated_entity
patch = [
{
"op": "add",
"path": "/processedLineage",
"value": True,
}
]
self.client.patch(
path=f"{self.get_suffix(entity)}/name/{fqn}",
data=json.dumps(patch),
)
except Exception as exc:
logger.debug(f"Error while patching lineage processed flag: {exc}")

View File

@ -13,10 +13,10 @@ Lineage Source Module
"""
import csv
import os
import threading
import time
import traceback
from abc import ABC
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Any, Callable, Iterable, Iterator, List, Optional, Union
@ -54,9 +54,9 @@ from metadata.utils.logger import ingestion_logger
logger = ingestion_logger()
CHUNK_SIZE = 200
CHUNK_SIZE = 100
THREAD_TIMEOUT = 600
THREAD_TIMEOUT = 3 * 60 * 10 # 30 minutes in seconds
class LineageSource(QueryParserSource, ABC):
@ -119,61 +119,95 @@ class LineageSource(QueryParserSource, ABC):
)
yield from self.yield_table_query()
@staticmethod
def generate_lineage_in_thread(
self,
producer_fn: Callable[[], Iterable[Any]],
processor_fn: Callable[[Any], Iterable[Any]],
processor_fn: Callable[[Any, Queue], None],
chunk_size: int = CHUNK_SIZE,
thread_timeout: int = THREAD_TIMEOUT,
max_threads: int = 10, # Default maximum number of concurrent threads
):
"""
Optimized multithreaded lineage generation with improved error handling and performance.
Process data in separate daemon threads with timeout control.
Args:
producer_fn: Function that yields input items
processor_fn: Function to process each input item
chunk_size: Optional batching to reduce thread creation overhead
producer_fn: Function that yields data chunks
processor_fn: Function that processes data and adds results to the queue
chunk_size: Size of chunks to process
thread_timeout: Maximum time in seconds to wait for a processor thread
max_threads: Maximum number of concurrent threads to run
"""
def chunk_generator():
temp_chunk = []
for chunk in producer_fn():
temp_chunk.append(chunk)
if len(temp_chunk) >= chunk_size:
yield temp_chunk
temp_chunk = []
if temp_chunk:
yield temp_chunk
thread_pool = ThreadPoolExecutor(max_workers=self.source_config.threads)
queue = Queue()
active_threads = []
def process_chunk(chunk):
"""Process a chunk of data in a thread."""
try:
processor_fn(chunk, queue)
except Exception as e:
logger.error(f"Error processing chunk: {e}")
logger.debug(traceback.format_exc())
# Create an iterator for the chunks but don't consume it all at once
chunk_iterator = iter(chunk_generator(producer_fn, chunk_size))
# Process results from the queue and check for timed-out threads
chunk_processed = False # Flag to track if all chunks have been processed
ignored_threads = 0
futures = [
thread_pool.submit(
processor_fn,
chunk,
queue,
)
for chunk in chunk_generator()
]
while True:
# Start new threads until we reach the max_threads limit
while (
len(active_threads) + ignored_threads
) < max_threads and not chunk_processed:
try:
# Only fetch a new chunk when we're ready to create a thread
chunk = next(chunk_iterator)
thread = threading.Thread(target=process_chunk, args=(chunk,))
thread.start_time = time.time() # Track when the thread started
thread.daemon = True
active_threads.append(thread)
thread.start()
except StopIteration:
# No more chunks to process
chunk_processed = True
break
if ignored_threads == max_threads:
logger.warning(f"Max threads reached, skipping remaining threads")
break
# Process any available results
if queue.has_tasks():
yield from queue.process()
else:
if not futures:
break
# Check for completed or timed-out threads
still_active = []
for thread in active_threads:
if thread.is_alive():
# Check if the thread has timed out
if time.time() - thread.start_time > thread_timeout:
logger.warning(
f"Thread {thread.name} timed out after {thread_timeout}s"
)
ignored_threads += 1
else:
still_active.append(thread)
# If thread is not alive, it has completed normally
for i, future in enumerate(futures):
if future.done():
try:
future.result(timeout=THREAD_TIMEOUT)
except Exception as e:
logger.debug(f"Error in future: {e}")
logger.debug(traceback.format_exc())
futures.pop(i)
active_threads = still_active
time.sleep(0.01)
# Exit conditions: no more active threads and no more chunks to process
if not active_threads and chunk_processed:
break
# Small pause to prevent CPU spinning
if active_threads:
time.sleep(0.1)
# Final check for any remaining results
while queue.has_tasks():
yield from queue.process()
def yield_table_query(self) -> Iterator[TableQuery]:
"""
@ -269,7 +303,9 @@ class LineageSource(QueryParserSource, ABC):
producer_fn = self.get_table_query
processor_fn = self.query_lineage_generator
yield from self.generate_lineage_in_thread(
producer_fn, processor_fn, CHUNK_SIZE
producer_fn,
processor_fn,
max_threads=self.source_config.threads,
)
def view_lineage_generator(
@ -318,7 +354,9 @@ class LineageSource(QueryParserSource, ABC):
self.source_config.incrementalLineageProcessing,
)
processor_fn = self.view_lineage_generator
yield from self.generate_lineage_in_thread(producer_fn, processor_fn)
yield from self.generate_lineage_in_thread(
producer_fn, processor_fn, max_threads=self.source_config.threads
)
def yield_procedure_lineage(
self,
@ -412,3 +450,18 @@ class LineageSource(QueryParserSource, ABC):
and self.source_config.crossDatabaseServiceNames
):
yield from self.yield_cross_database_lineage() or []
def chunk_generator(producer_fn, chunk_size):
"""
Group items from producer into chunks of specified size.
This is a separate function to allow for better lazy evaluation.
"""
temp_chunk = []
for item in producer_fn():
temp_chunk.append(item)
if len(temp_chunk) >= chunk_size:
yield temp_chunk
temp_chunk = []
if temp_chunk:
yield temp_chunk

View File

@ -304,4 +304,6 @@ class StoredProcedureLineageMixin(ABC):
logger.info("Processing Lineage for Stored Procedures")
producer_fn = self.procedure_lineage_generator
processor_fn = self.procedure_lineage_processor
yield from self.generate_lineage_in_thread(producer_fn, processor_fn)
yield from self.generate_lineage_in_thread(
producer_fn, processor_fn, max_threads=self.source_config.threads
)