feat(redash): add parallelism support for ingestion (#5061)

This commit is contained in:
Aseem Bansal 2022-06-01 19:36:30 +05:30 committed by GitHub
parent 2dd826626d
commit f81ead366d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3,6 +3,7 @@ import logging
import math import math
import sys import sys
from dataclasses import dataclass, field from dataclasses import dataclass, field
from multiprocessing.pool import ThreadPool
from typing import Dict, Iterable, List, Optional, Set, Type from typing import Dict, Iterable, List, Optional, Set, Type
import dateutil.parser as dp import dateutil.parser as dp
@ -37,6 +38,7 @@ from datahub.metadata.schema_classes import (
ChartTypeClass, ChartTypeClass,
DashboardInfoClass, DashboardInfoClass,
) )
from datahub.utilities.perf_timer import PerfTimer
from datahub.utilities.sql_parser import SQLParser from datahub.utilities.sql_parser import SQLParser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -253,6 +255,10 @@ class RedashConfig(ConfigModel):
default=sys.maxsize, default=sys.maxsize,
description="Limit on number of pages queried for ingesting dashboards and charts API during pagination.", description="Limit on number of pages queried for ingesting dashboards and charts API during pagination.",
) )
parallelism: int = Field(
default=1,
description="Parallelism to use while processing.",
)
parse_table_names_from_sql: bool = Field( parse_table_names_from_sql: bool = Field(
default=False, description="See note below." default=False, description="See note below."
) )
@ -271,18 +277,19 @@ class RedashConfig(ConfigModel):
class RedashSourceReport(SourceReport): class RedashSourceReport(SourceReport):
items_scanned: int = 0 items_scanned: int = 0
filtered: List[str] = field(default_factory=list) filtered: List[str] = field(default_factory=list)
queries_problem_parsing: Set[str] = field(default_factory=set)
queries_no_dataset: Set[str] = field(default_factory=set) queries_no_dataset: Set[str] = field(default_factory=set)
charts_no_input: Set[str] = field(default_factory=set)
total_queries: Optional[int] = field( total_queries: Optional[int] = field(
default=None, default=None,
) )
page_reached_queries: Optional[int] = field(default=None)
max_page_queries: Optional[int] = field(default=None) max_page_queries: Optional[int] = field(default=None)
total_dashboards: Optional[int] = field( total_dashboards: Optional[int] = field(
default=None, default=None,
) )
page_reached_dashboards: Optional[int] = field(default=None)
max_page_dashboards: Optional[int] = field(default=None) max_page_dashboards: Optional[int] = field(default=None)
api_page_limit: Optional[float] = field(default=None) api_page_limit: Optional[float] = field(default=None)
timing: Dict[str, int] = field(default_factory=dict)
def report_item_scanned(self) -> None: def report_item_scanned(self) -> None:
self.items_scanned += 1 self.items_scanned += 1
@ -293,7 +300,7 @@ class RedashSourceReport(SourceReport):
@platform_name("Redash") @platform_name("Redash")
@config_class(RedashConfig) @config_class(RedashConfig)
@support_status(SupportStatus.CERTIFIED) @support_status(SupportStatus.INCUBATING)
class RedashSource(Source): class RedashSource(Source):
""" """
This plugin extracts the following: This plugin extracts the following:
@ -348,6 +355,10 @@ class RedashSource(Source):
self.report.report_failure(key, reason) self.report.report_failure(key, reason)
log.error(f"{key} => {reason}") log.error(f"{key} => {reason}")
def warn(self, log: logging.Logger, key: str, reason: str) -> None:
self.report.report_warning(key, reason)
log.warning(f"{key} => {reason}")
def test_connection(self) -> None: def test_connection(self) -> None:
test_response = self.client._get(f"{self.config.connect_uri}/api") test_response = self.client._get(f"{self.config.connect_uri}/api")
if test_response.status_code == 200: if test_response.status_code == 200:
@ -441,23 +452,33 @@ class RedashSource(Source):
# Getting table lineage from SQL parsing # Getting table lineage from SQL parsing
if self.parse_table_names_from_sql and data_source_syntax == "sql": if self.parse_table_names_from_sql and data_source_syntax == "sql":
dataset_urns = list()
try: try:
dataset_urns = list()
sql_table_names = self._get_sql_table_names( sql_table_names = self._get_sql_table_names(
query, self.sql_parser_path query, self.sql_parser_path
) )
for sql_table_name in sql_table_names: except Exception as e:
self.report.queries_problem_parsing.add(str(query_id))
self.error(
logger,
"sql-parsing",
f"exception {e} in parsing query-{query_id}-datasource-{data_source_id}",
)
sql_table_names = []
for sql_table_name in sql_table_names:
try:
dataset_urns.append( dataset_urns.append(
self._construct_datalineage_urn( self._construct_datalineage_urn(
platform, database_name, sql_table_name platform, database_name, sql_table_name
) )
) )
except Exception as e: except Exception:
self.error( self.report.queries_problem_parsing.add(str(query_id))
logger, self.warn(
f"sql-parsing-query-{query_id}-datasource-{data_source_id}", logger,
f"exception {e} in parsing {query}", "data-urn-invalid",
) f"Problem making URN for {sql_table_name} parsed from query {query_id}",
)
# make sure dataset_urns is not empty list # make sure dataset_urns is not empty list
return dataset_urns if len(dataset_urns) > 0 else None return dataset_urns if len(dataset_urns) > 0 else None
@ -549,47 +570,26 @@ class RedashSource(Source):
return dashboard_snapshot return dashboard_snapshot
def _emit_dashboard_mces(self) -> Iterable[MetadataWorkUnit]: def _process_dashboard_response(
current_dashboards_page = 1 self, current_page: int
skip_draft = self.config.skip_draft ) -> Iterable[MetadataWorkUnit]:
result: List[MetadataWorkUnit] = []
# Get total number of dashboards to calculate maximum page number logger.info(f"Starting processing dashboard for page {current_page}")
dashboards_response = self.client.dashboards(1, self.config.page_size) if current_page > self.api_page_limit:
total_dashboards = dashboards_response["count"] logger.info(f"{current_page} > {self.api_page_limit} so returning")
max_page = math.ceil(total_dashboards / self.config.page_size) return result
logger.info( with PerfTimer() as timer:
f"/api/dashboards total count {total_dashboards} and max page {max_page}"
)
self.report.total_dashboards = total_dashboards
self.report.max_page_dashboards = max_page
while (
current_dashboards_page <= max_page
and current_dashboards_page <= self.api_page_limit
):
self.report.page_reached_dashboards = current_dashboards_page
dashboards_response = self.client.dashboards( dashboards_response = self.client.dashboards(
page=current_dashboards_page, page_size=self.config.page_size page=current_page, page_size=self.config.page_size
) )
logger.info(
f"/api/dashboards on page {current_dashboards_page} / {max_page}"
)
current_dashboards_page += 1
for dashboard_response in dashboards_response["results"]: for dashboard_response in dashboards_response["results"]:
dashboard_name = dashboard_response["name"] dashboard_name = dashboard_response["name"]
self.report.report_item_scanned() self.report.report_item_scanned()
if (not self.config.dashboard_patterns.allowed(dashboard_name)) or ( if (not self.config.dashboard_patterns.allowed(dashboard_name)) or (
skip_draft and dashboard_response["is_draft"] self.config.skip_draft and dashboard_response["is_draft"]
): ):
self.report.report_dropped(dashboard_name) self.report.report_dropped(dashboard_name)
continue continue
# Continue producing MCE # Continue producing MCE
try: try:
# This is undocumented but checking the Redash source # This is undocumented but checking the Redash source
@ -610,7 +610,29 @@ class RedashSource(Source):
wu = MetadataWorkUnit(id=dashboard_snapshot.urn, mce=mce) wu = MetadataWorkUnit(id=dashboard_snapshot.urn, mce=mce)
self.report.report_workunit(wu) self.report.report_workunit(wu)
yield wu result.append(wu)
self.report.timing[f"dashboard-{current_page}"] = int(
timer.elapsed_seconds()
)
return result
def _emit_dashboard_mces(self) -> Iterable[MetadataWorkUnit]:
# Get total number of dashboards to calculate maximum page number
dashboards_response = self.client.dashboards(1, self.config.page_size)
total_dashboards = dashboards_response["count"]
max_page = math.ceil(total_dashboards / self.config.page_size)
logger.info(
f"/api/dashboards total count {total_dashboards} and max page {max_page}"
)
self.report.total_dashboards = total_dashboards
self.report.max_page_dashboards = max_page
dash_exec_pool = ThreadPool(self.config.parallelism)
for response in dash_exec_pool.imap_unordered(
self._process_dashboard_response, range(1, max_page + 1)
):
yield from response
def _get_chart_type_from_viz_data(self, viz_data: Dict) -> str: def _get_chart_type_from_viz_data(self, viz_data: Dict) -> str:
""" """
@ -674,10 +696,11 @@ class RedashSource(Source):
datasource_urns = self._get_datasource_urns(data_source, query_data) datasource_urns = self._get_datasource_urns(data_source, query_data)
if datasource_urns is None: if datasource_urns is None:
self.report.charts_no_input.add(chart_urn)
self.report.queries_no_dataset.add(str(query_id)) self.report.queries_no_dataset.add(str(query_id))
self.report.report_warning( self.report.report_warning(
key=f"redash-chart-{viz_id}-query-{query_id}-datasource-{data_source_id}", key="redash-chart-input-missing",
reason=f"data_source_type={data_source_type} not yet implemented. Setting inputs to None", reason=f"For viz-id-{viz_id}-query-{query_id}-datasource-{data_source_id} data_source_type={data_source_type} no datasources found. Setting inputs to None",
) )
chart_info = ChartInfoClass( chart_info = ChartInfoClass(
@ -692,43 +715,25 @@ class RedashSource(Source):
return chart_snapshot return chart_snapshot
def _emit_chart_mces(self) -> Iterable[MetadataWorkUnit]: def _process_query_response(self, current_page: int) -> Iterable[MetadataWorkUnit]:
current_queries_page: int = 1 logger.info(f"Starting processing query for page {current_page}")
skip_draft = self.config.skip_draft result: List[MetadataWorkUnit] = []
if current_page > self.api_page_limit:
# Get total number of queries to calculate maximum page number logger.info(f"{current_page} > {self.api_page_limit} so returning")
_queries_response = self.client.queries(1, self.config.page_size) return result
total_queries = _queries_response["count"] with PerfTimer() as timer:
max_page = math.ceil(total_queries / self.config.page_size)
logger.info(f"/api/queries total count {total_queries} and max page {max_page}")
self.report.total_queries = total_queries
self.report.max_page_queries = max_page
while (
current_queries_page <= max_page
and current_queries_page <= self.api_page_limit
):
self.report.page_reached_queries = current_queries_page
queries_response = self.client.queries( queries_response = self.client.queries(
page=current_queries_page, page_size=self.config.page_size page=current_page, page_size=self.config.page_size
) )
logger.info(f"/api/queries on page {current_queries_page} / {max_page}")
current_queries_page += 1
for query_response in queries_response["results"]: for query_response in queries_response["results"]:
chart_name = query_response["name"] chart_name = query_response["name"]
self.report.report_item_scanned() self.report.report_item_scanned()
if (not self.config.chart_patterns.allowed(chart_name)) or ( if (not self.config.chart_patterns.allowed(chart_name)) or (
skip_draft and query_response["is_draft"] self.config.skip_draft and query_response["is_draft"]
): ):
self.report.report_dropped(chart_name) self.report.report_dropped(chart_name)
continue continue
query_id = query_response["id"] query_id = query_response["id"]
query_data = self.client._get(f"/api/queries/{query_id}").json() query_data = self.client._get(f"/api/queries/{query_id}").json()
logger.debug(query_data) logger.debug(query_data)
@ -740,7 +745,24 @@ class RedashSource(Source):
wu = MetadataWorkUnit(id=chart_snapshot.urn, mce=mce) wu = MetadataWorkUnit(id=chart_snapshot.urn, mce=mce)
self.report.report_workunit(wu) self.report.report_workunit(wu)
yield wu result.append(wu)
self.report.timing[f"query-{current_page}"] = int(timer.elapsed_seconds())
logger.info(f"Ending processing query for {current_page}")
return result
def _emit_chart_mces(self) -> Iterable[MetadataWorkUnit]:
# Get total number of queries to calculate maximum page number
queries_response = self.client.queries(1, self.config.page_size)
total_queries = queries_response["count"]
max_page = math.ceil(total_queries / self.config.page_size)
logger.info(f"/api/queries total count {total_queries} and max page {max_page}")
self.report.total_queries = total_queries
self.report.max_page_queries = max_page
chart_exec_pool = ThreadPool(self.config.parallelism)
for response in chart_exec_pool.imap_unordered(
self._process_query_response, range(1, max_page + 1)
):
yield from response
def add_config_to_report(self) -> None: def add_config_to_report(self) -> None:
self.report.api_page_limit = self.config.api_page_limit self.report.api_page_limit = self.config.api_page_limit
@ -748,8 +770,12 @@ class RedashSource(Source):
def get_workunits(self) -> Iterable[MetadataWorkUnit]: def get_workunits(self) -> Iterable[MetadataWorkUnit]:
self.test_connection() self.test_connection()
self.add_config_to_report() self.add_config_to_report()
yield from self._emit_chart_mces() with PerfTimer() as timer:
yield from self._emit_dashboard_mces() yield from self._emit_chart_mces()
self.report.timing["time-all-charts"] = int(timer.elapsed_seconds())
with PerfTimer() as timer:
yield from self._emit_dashboard_mces()
self.report.timing["time-all-dashboards"] = int(timer.elapsed_seconds())
def get_report(self) -> SourceReport: def get_report(self) -> SourceReport:
return self.report return self.report