Fix #19489: Optimise multithreading for lineage (#19524)

This commit is contained in:
Mayur Singal 2025-01-27 18:15:58 +05:30 committed by GitHub
parent 8117586c57
commit d2dc7bd038
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 229 additions and 172 deletions

View File

@ -14,8 +14,7 @@ Query masking utilities
import traceback
import sqlparse
from sqlfluff.core import Linter
from collate_sqllineage.runner import SQLPARSE_DIALECT, LineageRunner
from sqlparse.sql import Comparison
from sqlparse.tokens import Literal, Number, String
@ -24,6 +23,7 @@ from metadata.ingestion.lineage.models import Dialect
MASK_TOKEN = "?"
# pylint: disable=protected-access
def get_logger():
# pylint: disable=import-outside-toplevel
from metadata.utils.logger import utils_logger
@ -31,18 +31,14 @@ def get_logger():
return utils_logger()
def mask_literals_with_sqlparse(query: str):
def mask_literals_with_sqlparse(query: str, parser: LineageRunner):
"""
Mask literals in a query using sqlparse.
"""
logger = get_logger()
try:
parsed = sqlparse.parse(query) # Parse the query
if not parsed:
return query
parsed = parsed[0]
parsed = parser._parsed_result
def mask_token(token):
# Mask all literals: strings, numbers, or other literal values
@ -79,17 +75,16 @@ def mask_literals_with_sqlparse(query: str):
return query
def mask_literals_with_sqlfluff(query: str, dialect: str = Dialect.ANSI.value) -> str:
def mask_literals_with_sqlfluff(query: str, parser: LineageRunner) -> str:
"""
Mask literals in a query using SQLFluff.
"""
logger = get_logger()
try:
# Initialize SQLFluff linter
linter = Linter(dialect=dialect)
if not parser._evaluated:
parser._eval()
# Parse the query
parsed = linter.parse_string(query)
parsed = parser._parsed_result
def replace_literals(segment):
"""Recursively replace literals with placeholders."""
@ -114,17 +109,21 @@ def mask_literals_with_sqlfluff(query: str, dialect: str = Dialect.ANSI.value) -
return query
def mask_query(query: str, dialect: str = Dialect.ANSI.value) -> str:
def mask_query(
query: str, dialect: str = Dialect.ANSI.value, parser: LineageRunner = None
) -> str:
logger = get_logger()
try:
sqlfluff_masked_query = mask_literals_with_sqlfluff(query, dialect)
sqlparse_masked_query = mask_literals_with_sqlparse(query)
# compare both masked queries and return the one with more masked tokens
if sqlfluff_masked_query.count(MASK_TOKEN) >= sqlparse_masked_query.count(
MASK_TOKEN
):
return sqlfluff_masked_query
return sqlparse_masked_query
if not parser:
try:
parser = LineageRunner(query, dialect=dialect)
len(parser.source_tables)
except Exception:
parser = LineageRunner(query)
len(parser.source_tables)
if parser._dialect == SQLPARSE_DIALECT:
return mask_literals_with_sqlparse(query, parser)
return mask_literals_with_sqlfluff(query, parser)
except Exception as exc:
logger.debug(f"Failed to mask query with sqlfluff: {exc}")
logger.debug(traceback.format_exc())

View File

@ -71,12 +71,13 @@ class LineageParser:
self.query_parsing_success = True
self.query_parsing_failure_reason = None
self.dialect = dialect
self._masked_query = mask_query(self.query, dialect.value)
self.masked_query = None
self._clean_query = self.clean_raw_query(query)
self._masked_clean_query = mask_query(self._clean_query, dialect.value)
self.parser = self._evaluate_best_parser(
self._clean_query, dialect=dialect, timeout_seconds=timeout_seconds
)
if self.masked_query is None:
self.masked_query = mask_query(self._clean_query, parser=self.parser)
@cached_property
def involved_tables(self) -> Optional[List[Table]]:
@ -95,7 +96,7 @@ class LineageParser:
except SQLLineageException as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Cannot extract source table information from query [{self._masked_query}]: {exc}"
f"Cannot extract source table information from query [{self.masked_query or self.query}]: {exc}"
)
return None
@ -334,12 +335,10 @@ class LineageParser:
)
if not table_left or not table_right:
logger.warning(
logger.debug(
f"Can't extract table names when parsing JOIN information from {comparison}"
)
logger.debug(
f"Query: {mask_query(sql_statement, self.dialect.value)}"
)
logger.debug(f"Query: {self.masked_query}")
continue
left_table_column = TableColumn(table=table_left, column=column_left)
@ -422,10 +421,9 @@ class LineageParser:
lr_dialect.get_column_lineage()
return lr_dialect
sqlfluff_count = 0
try:
lr_sqlfluff = get_sqlfluff_lineage_runner(query, dialect.value)
sqlfluff_count = len(lr_sqlfluff.get_column_lineage()) + len(
_ = len(lr_sqlfluff.get_column_lineage()) + len(
set(lr_sqlfluff.source_tables).union(
set(lr_sqlfluff.target_tables).union(
set(lr_sqlfluff.intermediate_tables)
@ -438,23 +436,20 @@ class LineageParser:
f"Lineage with SqlFluff failed for the [{dialect.value}]. "
f"Parser has been running for more than {timeout_seconds} seconds."
)
logger.debug(
f"{self.query_parsing_failure_reason}] query: [{self._masked_clean_query}]"
)
lr_sqlfluff = None
except Exception:
self.query_parsing_success = False
self.query_parsing_failure_reason = (
f"Lineage with SqlFluff failed for the [{dialect.value}]"
)
logger.debug(
f"{self.query_parsing_failure_reason} query: [{self._masked_clean_query}]"
)
lr_sqlfluff = None
if lr_sqlfluff:
return lr_sqlfluff
lr_sqlparser = LineageRunner(query)
try:
sqlparser_count = len(lr_sqlparser.get_column_lineage()) + len(
_ = len(lr_sqlparser.get_column_lineage()) + len(
set(lr_sqlparser.source_tables).union(
set(lr_sqlparser.target_tables).union(
set(lr_sqlparser.intermediate_tables)
@ -463,21 +458,13 @@ class LineageParser:
)
except Exception:
# if both runner have failed we return the usual one
logger.debug(f"Failed to parse query with sqlparse & sqlfluff: {query}")
return lr_sqlfluff if lr_sqlfluff else lr_sqlparser
if lr_sqlfluff:
# if sqlparser retrieve more lineage info that sqlfluff
if sqlparser_count > sqlfluff_count:
self.query_parsing_success = False
self.query_parsing_failure_reason = (
"Lineage computed with SqlFluff did not perform as expected "
f"for the [{dialect.value}]"
)
logger.debug(
f"{self.query_parsing_failure_reason} query: [{self._masked_clean_query}]"
)
return lr_sqlparser
return lr_sqlfluff
self.masked_query = mask_query(self._clean_query, parser=lr_sqlparser)
logger.debug(
f"Using sqlparse for lineage parsing for query: {self.masked_query}"
)
return lr_sqlparser
@staticmethod

View File

@ -37,7 +37,6 @@ from metadata.generated.schema.type.entityLineage import (
from metadata.generated.schema.type.entityLineage import Source as LineageSource
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.api.models import Either
from metadata.ingestion.lineage.masker import mask_query
from metadata.ingestion.lineage.models import (
Dialect,
QueryParsingError,
@ -614,11 +613,11 @@ def get_lineage_by_query(
"""
column_lineage = {}
query_parsing_failures = QueryParsingFailures()
masked_query = mask_query(query, dialect.value)
try:
logger.debug(f"Running lineage with query: {masked_query}")
lineage_parser = LineageParser(query, dialect, timeout_seconds=timeout_seconds)
masked_query = lineage_parser.masked_query or query
logger.debug(f"Running lineage with query: {masked_query}")
raw_column_lineage = lineage_parser.column_lineage
column_lineage.update(populate_column_lineage_map(raw_column_lineage))
@ -715,11 +714,11 @@ def get_lineage_via_table_entity(
"""Get lineage from table entity"""
column_lineage = {}
query_parsing_failures = QueryParsingFailures()
masked_query = mask_query(query, dialect.value)
try:
logger.debug(f"Getting lineage via table entity using query: {masked_query}")
lineage_parser = LineageParser(query, dialect, timeout_seconds=timeout_seconds)
masked_query = lineage_parser.masked_query or query
logger.debug(f"Getting lineage via table entity using query: {masked_query}")
to_table_name = table_entity.name.root
for from_table_name in lineage_parser.source_tables:

View File

@ -344,7 +344,7 @@ class ESMixin(Generic[T]):
# Get next page
last_hit = response.hits.hits[-1] if response.hits.hits else None
if not last_hit or not last_hit.sort:
logger.info("No more pages to fetch")
logger.debug("No more pages to fetch")
break
after = ",".join(last_hit.sort)
@ -429,10 +429,11 @@ class ESMixin(Generic[T]):
_, database_name, schema_name, table_name = fqn.split(
hit.source["fullyQualifiedName"]
)
yield TableView(
view_definition=hit.source["schemaDefinition"],
service_name=service_name,
db_name=database_name,
schema_name=schema_name,
table_name=table_name,
)
if hit.source.get("schemaDefinition"):
yield TableView(
view_definition=hit.source["schemaDefinition"],
service_name=service_name,
db_name=database_name,
schema_name=schema_name,
table_name=table_name,
)

View File

@ -63,7 +63,7 @@ class OMetaLineageMixin(Generic[T]):
)
for column in updated or []:
if not isinstance(column, dict):
data = column.dict()
data = column.model_dump()
else:
data = column
if data.get("toColumn") and data.get("fromColumns"):

View File

@ -13,11 +13,12 @@ Lineage Source Module
"""
import csv
import os
import time
import traceback
from abc import ABC
from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Callable, Iterable, Iterator, List, Optional, Union
from typing import Any, Callable, Iterable, Iterator, List, Optional, Union
from metadata.generated.schema.api.data.createQuery import CreateQueryRequest
from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest
@ -39,6 +40,7 @@ from metadata.ingestion.api.models import Either
from metadata.ingestion.lineage.models import ConnectionTypeDialectMapper, Dialect
from metadata.ingestion.lineage.sql_lineage import get_column_fqn, get_lineage_by_query
from metadata.ingestion.models.ometa_lineage import OMetaLineageRequest
from metadata.ingestion.models.topology import Queue
from metadata.ingestion.source.database.query_parser_source import QueryParserSource
from metadata.ingestion.source.models import TableView
from metadata.utils import fqn
@ -48,6 +50,9 @@ from metadata.utils.logger import ingestion_logger
logger = ingestion_logger()
CHUNK_SIZE = 200
class LineageSource(QueryParserSource, ABC):
"""
This is the base source to handle Lineage-only ingestion.
@ -108,27 +113,57 @@ class LineageSource(QueryParserSource, ABC):
)
yield from self.yield_table_query()
def generate_lineage_in_thread(self, producer_fn: Callable, processor_fn: Callable):
with ThreadPoolExecutor(max_workers=self.source_config.threads) as executor:
futures = []
def generate_lineage_in_thread(
self,
producer_fn: Callable[[], Iterable[Any]],
processor_fn: Callable[[Any], Iterable[Any]],
chunk_size: int = CHUNK_SIZE,
):
"""
Optimized multithreaded lineage generation with improved error handling and performance.
for produced_input in producer_fn():
futures.append(executor.submit(processor_fn, produced_input))
Args:
producer_fn: Function that yields input items
processor_fn: Function to process each input item
chunk_size: Optional batching to reduce thread creation overhead
"""
# Handle remaining futures after the loop
for future in as_completed(
futures, timeout=self.source_config.parsingTimeoutLimit
):
try:
results = future.result(
timeout=self.source_config.parsingTimeoutLimit
)
yield from results
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Error processing result for {produced_input}: {exc}"
)
def chunk_generator():
temp_chunk = []
for chunk in producer_fn():
temp_chunk.append(chunk)
if len(temp_chunk) >= chunk_size:
yield temp_chunk
temp_chunk = []
if temp_chunk:
yield temp_chunk
thread_pool = ThreadPoolExecutor(max_workers=self.source_config.threads)
queue = Queue()
futures = [
thread_pool.submit(
processor_fn,
chunk,
queue,
)
for chunk in chunk_generator()
]
while True:
if queue.has_tasks():
yield from queue.process()
else:
if not futures:
break
for i, future in enumerate(futures):
if future.done():
future.result()
futures.pop(i)
time.sleep(0.01)
def yield_table_query(self) -> Iterator[TableQuery]:
"""
@ -170,33 +205,38 @@ class LineageSource(QueryParserSource, ABC):
return fqn.get_query_checksum(table_query.query) in checksums or {}
def query_lineage_generator(
self, table_query: TableQuery
self, table_queries: List[TableQuery], queue: Queue
) -> Iterable[Either[Union[AddLineageRequest, CreateQueryRequest]]]:
if not self._query_already_processed(table_query):
lineages: Iterable[Either[AddLineageRequest]] = get_lineage_by_query(
self.metadata,
query=table_query.query,
service_name=table_query.serviceName,
database_name=table_query.databaseName,
schema_name=table_query.databaseSchema,
dialect=self.dialect,
timeout_seconds=self.source_config.parsingTimeoutLimit,
)
for table_query in table_queries or []:
if not self._query_already_processed(table_query):
lineages: Iterable[Either[AddLineageRequest]] = get_lineage_by_query(
self.metadata,
query=table_query.query,
service_name=table_query.serviceName,
database_name=table_query.databaseName,
schema_name=table_query.databaseSchema,
dialect=self.dialect,
timeout_seconds=self.source_config.parsingTimeoutLimit,
)
for lineage_request in lineages or []:
yield lineage_request
for lineage_request in lineages or []:
queue.put(lineage_request)
# If we identified lineage properly, ingest the original query
if lineage_request.right:
yield Either(
right=CreateQueryRequest(
query=SqlQuery(table_query.query),
query_type=table_query.query_type,
duration=table_query.duration,
processedLineage=True,
service=FullyQualifiedEntityName(self.config.serviceName),
# If we identified lineage properly, ingest the original query
if lineage_request.right:
queue.put(
Either(
right=CreateQueryRequest(
query=SqlQuery(table_query.query),
query_type=table_query.query_type,
duration=table_query.duration,
processedLineage=True,
service=FullyQualifiedEntityName(
self.config.serviceName
),
)
)
)
)
def yield_query_lineage(
self,
@ -209,28 +249,33 @@ class LineageSource(QueryParserSource, ABC):
self.dialect = ConnectionTypeDialectMapper.dialect_of(connection_type)
producer_fn = self.get_table_query
processor_fn = self.query_lineage_generator
yield from self.generate_lineage_in_thread(producer_fn, processor_fn)
yield from self.generate_lineage_in_thread(
producer_fn, processor_fn, CHUNK_SIZE
)
def view_lineage_generator(
self, view: TableView
self, views: List[TableView], queue: Queue
) -> Iterable[Either[AddLineageRequest]]:
try:
for lineage in get_view_lineage(
view=view,
metadata=self.metadata,
service_name=self.config.serviceName,
connection_type=self.service_connection.type.value,
timeout_seconds=self.source_config.parsingTimeoutLimit,
):
if lineage.right is not None:
yield Either(
right=OMetaLineageRequest(
lineage_request=lineage.right,
override_lineage=self.source_config.overrideViewLineage,
for view in views:
for lineage in get_view_lineage(
view=view,
metadata=self.metadata,
service_name=self.config.serviceName,
connection_type=self.service_connection.type.value,
timeout_seconds=self.source_config.parsingTimeoutLimit,
):
if lineage.right is not None:
queue.put(
Either(
right=OMetaLineageRequest(
lineage_request=lineage.right,
override_lineage=self.source_config.overrideViewLineage,
)
)
)
)
else:
yield lineage
else:
queue.put(lineage)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(f"Error processing view {view}: {exc}")

View File

@ -38,10 +38,11 @@ from metadata.ingestion.api.models import Either
from metadata.ingestion.api.status import Status
from metadata.ingestion.lineage.models import ConnectionTypeDialectMapper
from metadata.ingestion.lineage.sql_lineage import get_lineage_by_query
from metadata.ingestion.models.topology import Queue
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.utils.logger import ingestion_logger
from metadata.utils.stored_procedures import get_procedure_name_from_call
from metadata.utils.time_utils import convert_timestamp_to_milliseconds
from metadata.utils.time_utils import datetime_to_timestamp
logger = ingestion_logger()
@ -176,8 +177,6 @@ class StoredProcedureLineageMixin(ABC):
timeout_seconds=self.source_config.parsingTimeoutLimit,
lineage_source=LineageSource.QueryLineage,
):
print("&& " * 100)
print(either_lineage)
if (
either_lineage.left is None
and either_lineage.right.edge.lineageDetails
@ -200,8 +199,8 @@ class StoredProcedureLineageMixin(ABC):
query_type=query_by_procedure.query_type,
duration=query_by_procedure.query_duration,
queryDate=Timestamp(
root=convert_timestamp_to_milliseconds(
int(query_by_procedure.query_start_time.timestamp())
root=datetime_to_timestamp(
query_by_procedure.query_start_time, True
)
),
triggeredBy=EntityReference(
@ -214,29 +213,31 @@ class StoredProcedureLineageMixin(ABC):
)
def procedure_lineage_processor(
self, procedure_and_query: ProcedureAndQuery
self, procedure_and_queries: List[ProcedureAndQuery], queue: Queue
) -> Iterable[Either[Union[AddLineageRequest, CreateQueryRequest]]]:
try:
yield from self._yield_procedure_lineage(
query_by_procedure=procedure_and_query.query_by_procedure,
procedure=procedure_and_query.procedure,
)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Could not get lineage for store procedure '{procedure_and_query.procedure.fullyQualifiedName}' due to [{exc}]."
)
try:
yield from self.yield_procedure_query(
query_by_procedure=procedure_and_query.query_by_procedure,
procedure=procedure_and_query.procedure,
)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Could not get query for store procedure '{procedure_and_query.procedure.fullyQualifiedName}' due to [{exc}]."
)
for procedure_and_query in procedure_and_queries:
try:
for lineage in self._yield_procedure_lineage(
query_by_procedure=procedure_and_query.query_by_procedure,
procedure=procedure_and_query.procedure,
):
queue.put(lineage)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Could not get lineage for store procedure '{procedure_and_query.procedure.fullyQualifiedName}' due to [{exc}]."
)
try:
for lineage in self.yield_procedure_query(
query_by_procedure=procedure_and_query.query_by_procedure,
procedure=procedure_and_query.procedure,
):
queue.put(lineage)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Could not get query for store procedure '{procedure_and_query.procedure.fullyQualifiedName}' due to [{exc}]."
)
def procedure_lineage_generator(self) -> Iterable[ProcedureAndQuery]:
query = {
@ -256,7 +257,9 @@ class StoredProcedureLineageMixin(ABC):
queries_dict = self.get_stored_procedure_queries_dict()
# Then for each procedure, iterate over all its queries
for procedure in (
self.metadata.paginate_es(entity=StoredProcedure, query_filter=query_filter)
self.metadata.paginate_es(
entity=StoredProcedure, query_filter=query_filter, size=10
)
or []
):
if procedure:

View File

@ -69,6 +69,10 @@ def get_view_lineage(
fqn=table_fqn,
)
if not view_definition:
logger.warning(f"View definition for view {table_fqn} not available")
return
try:
connection_type = str(connection_type)
dialect = ConnectionTypeDialectMapper.dialect_of(connection_type)

View File

@ -28,7 +28,6 @@ from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest
from metadata.generated.schema.type.queryParserData import QueryParserData
from metadata.generated.schema.type.tableQuery import TableQueries
from metadata.ingestion.api.models import Entity
from metadata.ingestion.lineage.masker import mask_query
from metadata.ingestion.models.delete_entity import DeleteEntity
from metadata.ingestion.models.life_cycle import OMetaLifeCycleData
from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification
@ -284,19 +283,13 @@ def _(record: PatchRequest) -> str:
@get_log_name.register
def _(record: TableQueries) -> str:
"""Get the log of the TableQuery"""
queries = "\n------\n".join(
mask_query(query.query, query.dialect) for query in record.queries
)
return f"Table Queries [{queries}]"
return f"Table Queries [{len(record.queries)}]"
@get_log_name.register
def _(record: QueryParserData) -> str:
"""Get the log of the ParsedData"""
queries = "\n------\n".join(
mask_query(query.sql, query.dialect) for query in record.parsedData
)
return f"Usage ParsedData [{queries}]"
return f"Usage ParsedData [{len(record.parsedData)}]"
def redacted_config(config: Dict[str, Union[str, dict]]) -> Dict[str, Union[str, dict]]:

View File

@ -229,14 +229,39 @@ class SqlLineageTest(TestCase):
def test_query_masker(self):
query_list = [
"""SELECT * FROM user WHERE id=1234 AND name='Alice' AND birthdate=DATE '2023-01-01';""",
"""insert into user values ('mayur',123,'my random address 1'), ('mayur',123,'my random address 1');""",
"""SELECT * FROM user WHERE address = '5th street' and name = 'john';""",
"""INSERT INTO user VALUE ('John', '19', '5TH Street');""",
"""SELECT CASE address WHEN '5th Street' THEN 'CEO' ELSE 'Unknown' END AS person FROM user;""",
"""with test as (SELECT CASE address WHEN '5th Street' THEN 'CEO' ELSE 'Unknown' END AS person FROM user) select * from test;""",
"""select * from (select * from (SELECT CASE address WHEN '5th Street' THEN 'CEO' ELSE 'Unknown' END AS person FROM user));""",
"""select * from users where id > 2 and name <> 'pere';""",
(
"""SELECT * FROM user WHERE id=1234 AND name='Alice' AND birthdate=DATE '2023-01-01';""",
Dialect.MYSQL.value,
),
(
"""insert into user values ('mayur',123,'my random address 1'), ('mayur',123,'my random address 1');""",
Dialect.ANSI.value,
),
(
"""SELECT * FROM user WHERE address = '5th street' and name = 'john';""",
Dialect.ANSI.value,
),
(
"""INSERT INTO user VALUE ('John', '19', '5TH Street');""",
Dialect.ANSI.value,
),
(
"""SELECT CASE address WHEN '5th Street' THEN 'CEO' ELSE 'Unknown' END AS person FROM user;""",
Dialect.ANSI.value,
),
(
"""with test as (SELECT CASE address WHEN '5th Street' THEN 'CEO' ELSE 'Unknown' END AS person FROM user) select * from test;""",
Dialect.ANSI.value,
),
(
"""select * from (select * from (SELECT CASE address WHEN '5th Street' THEN 'CEO' ELSE 'Unknown' END AS person FROM user));""",
Dialect.ANSI.value,
),
(
"""select * from users where id > 2 and name <> 'pere';""",
Dialect.ANSI.value,
),
("""select * from users where id > 2 and name <> 'pere';""", "random"),
]
expected_query_list = [
@ -248,7 +273,8 @@ class SqlLineageTest(TestCase):
"""with test as (SELECT CASE address WHEN ? THEN ? ELSE ? END AS person FROM user) select * from test;""",
"""select * from (select * from (SELECT CASE address WHEN ? THEN ? ELSE ? END AS person FROM user));""",
"""select * from users where id > ? and name <> ?;""",
"""select * from users where id > ? and name <> ?;""",
]
for i, query in enumerate(query_list):
self.assertEqual(mask_query(query), expected_query_list[i])
self.assertEqual(mask_query(query[0], query[1]), expected_query_list[i])