feat(redash): add parallelism support for ingestion (#5061)

This commit is contained in:
Aseem Bansal 2022-06-01 19:36:30 +05:30 committed by GitHub
parent 2dd826626d
commit f81ead366d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3,6 +3,7 @@ import logging
import math
import sys
from dataclasses import dataclass, field
from multiprocessing.pool import ThreadPool
from typing import Dict, Iterable, List, Optional, Set, Type
import dateutil.parser as dp
@ -37,6 +38,7 @@ from datahub.metadata.schema_classes import (
ChartTypeClass,
DashboardInfoClass,
)
from datahub.utilities.perf_timer import PerfTimer
from datahub.utilities.sql_parser import SQLParser
logger = logging.getLogger(__name__)
@ -253,6 +255,10 @@ class RedashConfig(ConfigModel):
default=sys.maxsize,
description="Limit on number of pages queried for ingesting dashboards and charts API during pagination.",
)
parallelism: int = Field(
default=1,
description="Parallelism to use while processing.",
)
parse_table_names_from_sql: bool = Field(
default=False, description="See note below."
)
@ -271,18 +277,19 @@ class RedashConfig(ConfigModel):
class RedashSourceReport(SourceReport):
items_scanned: int = 0
filtered: List[str] = field(default_factory=list)
queries_problem_parsing: Set[str] = field(default_factory=set)
queries_no_dataset: Set[str] = field(default_factory=set)
charts_no_input: Set[str] = field(default_factory=set)
total_queries: Optional[int] = field(
default=None,
)
page_reached_queries: Optional[int] = field(default=None)
max_page_queries: Optional[int] = field(default=None)
total_dashboards: Optional[int] = field(
default=None,
)
page_reached_dashboards: Optional[int] = field(default=None)
max_page_dashboards: Optional[int] = field(default=None)
api_page_limit: Optional[float] = field(default=None)
timing: Dict[str, int] = field(default_factory=dict)
def report_item_scanned(self) -> None:
self.items_scanned += 1
@ -293,7 +300,7 @@ class RedashSourceReport(SourceReport):
@platform_name("Redash")
@config_class(RedashConfig)
@support_status(SupportStatus.CERTIFIED)
@support_status(SupportStatus.INCUBATING)
class RedashSource(Source):
"""
This plugin extracts the following:
@ -348,6 +355,10 @@ class RedashSource(Source):
self.report.report_failure(key, reason)
log.error(f"{key} => {reason}")
def warn(self, log: logging.Logger, key: str, reason: str) -> None:
self.report.report_warning(key, reason)
log.warning(f"{key} => {reason}")
def test_connection(self) -> None:
test_response = self.client._get(f"{self.config.connect_uri}/api")
if test_response.status_code == 200:
@ -441,23 +452,33 @@ class RedashSource(Source):
# Getting table lineage from SQL parsing
if self.parse_table_names_from_sql and data_source_syntax == "sql":
dataset_urns = list()
try:
dataset_urns = list()
sql_table_names = self._get_sql_table_names(
query, self.sql_parser_path
)
for sql_table_name in sql_table_names:
except Exception as e:
self.report.queries_problem_parsing.add(str(query_id))
self.error(
logger,
"sql-parsing",
f"exception {e} in parsing query-{query_id}-datasource-{data_source_id}",
)
sql_table_names = []
for sql_table_name in sql_table_names:
try:
dataset_urns.append(
self._construct_datalineage_urn(
platform, database_name, sql_table_name
)
)
except Exception as e:
self.error(
logger,
f"sql-parsing-query-{query_id}-datasource-{data_source_id}",
f"exception {e} in parsing {query}",
)
except Exception:
self.report.queries_problem_parsing.add(str(query_id))
self.warn(
logger,
"data-urn-invalid",
f"Problem making URN for {sql_table_name} parsed from query {query_id}",
)
# make sure dataset_urns is not empty list
return dataset_urns if len(dataset_urns) > 0 else None
@ -549,47 +570,26 @@ class RedashSource(Source):
return dashboard_snapshot
def _emit_dashboard_mces(self) -> Iterable[MetadataWorkUnit]:
current_dashboards_page = 1
skip_draft = self.config.skip_draft
# Get total number of dashboards to calculate maximum page number
dashboards_response = self.client.dashboards(1, self.config.page_size)
total_dashboards = dashboards_response["count"]
max_page = math.ceil(total_dashboards / self.config.page_size)
logger.info(
f"/api/dashboards total count {total_dashboards} and max page {max_page}"
)
self.report.total_dashboards = total_dashboards
self.report.max_page_dashboards = max_page
while (
current_dashboards_page <= max_page
and current_dashboards_page <= self.api_page_limit
):
self.report.page_reached_dashboards = current_dashboards_page
def _process_dashboard_response(
self, current_page: int
) -> Iterable[MetadataWorkUnit]:
result: List[MetadataWorkUnit] = []
logger.info(f"Starting processing dashboard for page {current_page}")
if current_page > self.api_page_limit:
logger.info(f"{current_page} > {self.api_page_limit} so returning")
return result
with PerfTimer() as timer:
dashboards_response = self.client.dashboards(
page=current_dashboards_page, page_size=self.config.page_size
page=current_page, page_size=self.config.page_size
)
logger.info(
f"/api/dashboards on page {current_dashboards_page} / {max_page}"
)
current_dashboards_page += 1
for dashboard_response in dashboards_response["results"]:
dashboard_name = dashboard_response["name"]
self.report.report_item_scanned()
if (not self.config.dashboard_patterns.allowed(dashboard_name)) or (
skip_draft and dashboard_response["is_draft"]
self.config.skip_draft and dashboard_response["is_draft"]
):
self.report.report_dropped(dashboard_name)
continue
# Continue producing MCE
try:
# This is undocumented but checking the Redash source
@ -610,7 +610,29 @@ class RedashSource(Source):
wu = MetadataWorkUnit(id=dashboard_snapshot.urn, mce=mce)
self.report.report_workunit(wu)
yield wu
result.append(wu)
self.report.timing[f"dashboard-{current_page}"] = int(
timer.elapsed_seconds()
)
return result
def _emit_dashboard_mces(self) -> Iterable[MetadataWorkUnit]:
# Get total number of dashboards to calculate maximum page number
dashboards_response = self.client.dashboards(1, self.config.page_size)
total_dashboards = dashboards_response["count"]
max_page = math.ceil(total_dashboards / self.config.page_size)
logger.info(
f"/api/dashboards total count {total_dashboards} and max page {max_page}"
)
self.report.total_dashboards = total_dashboards
self.report.max_page_dashboards = max_page
dash_exec_pool = ThreadPool(self.config.parallelism)
for response in dash_exec_pool.imap_unordered(
self._process_dashboard_response, range(1, max_page + 1)
):
yield from response
def _get_chart_type_from_viz_data(self, viz_data: Dict) -> str:
"""
@ -674,10 +696,11 @@ class RedashSource(Source):
datasource_urns = self._get_datasource_urns(data_source, query_data)
if datasource_urns is None:
self.report.charts_no_input.add(chart_urn)
self.report.queries_no_dataset.add(str(query_id))
self.report.report_warning(
key=f"redash-chart-{viz_id}-query-{query_id}-datasource-{data_source_id}",
reason=f"data_source_type={data_source_type} not yet implemented. Setting inputs to None",
key="redash-chart-input-missing",
reason=f"For viz-id-{viz_id}-query-{query_id}-datasource-{data_source_id} data_source_type={data_source_type} no datasources found. Setting inputs to None",
)
chart_info = ChartInfoClass(
@ -692,43 +715,25 @@ class RedashSource(Source):
return chart_snapshot
def _emit_chart_mces(self) -> Iterable[MetadataWorkUnit]:
current_queries_page: int = 1
skip_draft = self.config.skip_draft
# Get total number of queries to calculate maximum page number
_queries_response = self.client.queries(1, self.config.page_size)
total_queries = _queries_response["count"]
max_page = math.ceil(total_queries / self.config.page_size)
logger.info(f"/api/queries total count {total_queries} and max page {max_page}")
self.report.total_queries = total_queries
self.report.max_page_queries = max_page
while (
current_queries_page <= max_page
and current_queries_page <= self.api_page_limit
):
self.report.page_reached_queries = current_queries_page
def _process_query_response(self, current_page: int) -> Iterable[MetadataWorkUnit]:
logger.info(f"Starting processing query for page {current_page}")
result: List[MetadataWorkUnit] = []
if current_page > self.api_page_limit:
logger.info(f"{current_page} > {self.api_page_limit} so returning")
return result
with PerfTimer() as timer:
queries_response = self.client.queries(
page=current_queries_page, page_size=self.config.page_size
page=current_page, page_size=self.config.page_size
)
logger.info(f"/api/queries on page {current_queries_page} / {max_page}")
current_queries_page += 1
for query_response in queries_response["results"]:
chart_name = query_response["name"]
self.report.report_item_scanned()
if (not self.config.chart_patterns.allowed(chart_name)) or (
skip_draft and query_response["is_draft"]
self.config.skip_draft and query_response["is_draft"]
):
self.report.report_dropped(chart_name)
continue
query_id = query_response["id"]
query_data = self.client._get(f"/api/queries/{query_id}").json()
logger.debug(query_data)
@ -740,7 +745,24 @@ class RedashSource(Source):
wu = MetadataWorkUnit(id=chart_snapshot.urn, mce=mce)
self.report.report_workunit(wu)
yield wu
result.append(wu)
self.report.timing[f"query-{current_page}"] = int(timer.elapsed_seconds())
logger.info(f"Ending processing query for {current_page}")
return result
def _emit_chart_mces(self) -> Iterable[MetadataWorkUnit]:
# Get total number of queries to calculate maximum page number
queries_response = self.client.queries(1, self.config.page_size)
total_queries = queries_response["count"]
max_page = math.ceil(total_queries / self.config.page_size)
logger.info(f"/api/queries total count {total_queries} and max page {max_page}")
self.report.total_queries = total_queries
self.report.max_page_queries = max_page
chart_exec_pool = ThreadPool(self.config.parallelism)
for response in chart_exec_pool.imap_unordered(
self._process_query_response, range(1, max_page + 1)
):
yield from response
def add_config_to_report(self) -> None:
self.report.api_page_limit = self.config.api_page_limit
@ -748,8 +770,12 @@ class RedashSource(Source):
def get_workunits(self) -> Iterable[MetadataWorkUnit]:
self.test_connection()
self.add_config_to_report()
yield from self._emit_chart_mces()
yield from self._emit_dashboard_mces()
with PerfTimer() as timer:
yield from self._emit_chart_mces()
self.report.timing["time-all-charts"] = int(timer.elapsed_seconds())
with PerfTimer() as timer:
yield from self._emit_dashboard_mces()
self.report.timing["time-all-dashboards"] = int(timer.elapsed_seconds())
def get_report(self) -> SourceReport:
return self.report