feat(ingest): add parallelism to looker source and datahub rest sink (#3431)

This commit is contained in:
Swaroop Jagadish 2021-10-21 11:27:27 -07:00 committed by GitHub
parent 7dbd10f072
commit 6704765d41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 230 additions and 99 deletions

View File

@ -38,6 +38,7 @@ Note that a `.` is used to denote nested fields in the YAML recipe.
| `timeout_sec` | | 30 | Per-HTTP request timeout. |
| `token` | | | Bearer token used for authentication. |
| `extra_headers` | | | Extra headers which will be added to the request. |
| `max_threads` | | `1` | Experimental: Max parallelism for REST API calls |
## DataHub Kafka

View File

@ -91,6 +91,7 @@ Note that a `.` is used to denote nested fields in the YAML recipe.
| `view_browse_pattern` | | `/{env}/{platform}/{project}/views/{name}` | Pattern for providing browse paths to views. Allowed variables are `{project}`, `{model}`, `{name}`, `{platform}` and `{env}` |
| `explore_naming_pattern` | | `{model}.explore.{name}` | Pattern for providing dataset names to explores. Allowed variables are `{project}`, `{model}`, `{name}` |
| `explore_browse_pattern` | | `/{env}/{platform}/{project}/explores/{model}.{name}` | Pattern for providing browse paths to explores. Allowed variables are `{project}`, `{model}`, `{name}`, `{platform}` and `{env}` |
| `max_threads` | | `os.cpuCount or 40` | Max parallelism for Looker API calls |
## Compatibility

View File

@ -1,11 +1,13 @@
import datetime
import itertools
import json
import logging
import shlex
from json.decoder import JSONDecodeError
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union
import requests
import requests.adapters
from requests.exceptions import HTTPError, RequestException
from datahub import __package_name__
@ -66,6 +68,10 @@ class DatahubRestEmitter:
self._token = token
self._session = requests.Session()
adapter = requests.adapters.HTTPAdapter(pool_connections=100, pool_maxsize=100)
self._session.mount("http://", adapter)
self._session.mount("https://", adapter)
self._session.headers.update(
{
"X-RestLi-Protocol-Version": "2.0.0",
@ -106,13 +112,15 @@ class DatahubRestEmitter:
MetadataChangeProposalWrapper,
UsageAggregation,
],
) -> None:
) -> Tuple[datetime.datetime, datetime.datetime]:
start_time = datetime.datetime.now()
if isinstance(item, UsageAggregation):
return self.emit_usage(item)
self.emit_usage(item)
elif isinstance(item, (MetadataChangeProposal, MetadataChangeProposalWrapper)):
return self.emit_mcp(item)
self.emit_mcp(item)
else:
return self.emit_mce(item)
self.emit_mce(item)
return start_time, datetime.datetime.now()
def emit_mce(self, mce: MetadataChangeEvent) -> None:
url = f"{self._gms_server}/entities?action=ingest"

View File

@ -1,6 +1,7 @@
import datetime
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass, field
from typing import Any, List
from typing import Any, List, Optional
from datahub.ingestion.api.closeable import Closeable
from datahub.ingestion.api.common import PipelineContext, RecordEnvelope, WorkUnit
@ -12,6 +13,9 @@ class SinkReport(Report):
records_written: int = 0
warnings: List[Any] = field(default_factory=list)
failures: List[Any] = field(default_factory=list)
downstream_start_time: Optional[datetime.datetime] = None
downstream_end_time: Optional[datetime.datetime] = None
downstream_total_latency_in_seconds: Optional[float] = None
def report_record_written(self, record_envelope: RecordEnvelope) -> None:
self.records_written += 1
@ -22,6 +26,20 @@ class SinkReport(Report):
def report_failure(self, info: Any) -> None:
self.failures.append(info)
def report_downstream_latency(
self, start_time: datetime.datetime, end_time: datetime.datetime
) -> None:
if (
self.downstream_start_time is None
or self.downstream_start_time > start_time
):
self.downstream_start_time = start_time
if self.downstream_end_time is None or self.downstream_end_time < end_time:
self.downstream_end_time = end_time
self.downstream_total_latency_in_seconds = (
self.downstream_end_time - self.downstream_start_time
).total_seconds()
class WriteCallback(metaclass=ABCMeta):
@abstractmethod

View File

@ -1,3 +1,5 @@
import concurrent.futures
import functools
import logging
from dataclasses import dataclass
from typing import Dict, Optional, Union, cast
@ -24,6 +26,7 @@ class DatahubRestSinkConfig(ConfigModel):
token: Optional[str]
timeout_sec: Optional[int]
extra_headers: Optional[Dict[str, str]]
max_threads: int = 1
@dataclass
@ -45,6 +48,9 @@ class DatahubRestSink(Sink):
extra_headers=self.config.extra_headers,
)
self.emitter.test_connection()
self.executor = concurrent.futures.ThreadPoolExecutor(
max_workers=self.config.max_threads
)
@classmethod
def create(cls, config_dict: dict, ctx: PipelineContext) -> "DatahubRestSink":
@ -60,6 +66,53 @@ class DatahubRestSink(Sink):
def handle_work_unit_end(self, workunit: WorkUnit) -> None:
pass
def _write_done_callback(
self,
record_envelope: RecordEnvelope,
write_callback: WriteCallback,
future: concurrent.futures.Future,
) -> None:
if future.cancelled():
self.report.report_failure({"error": "future was cancelled"})
write_callback.on_failure(
record_envelope, OperationalError("future was cancelled"), {}
)
elif future.done():
e = future.exception()
if not e:
self.report.report_record_written(record_envelope)
start_time, end_time = future.result()
self.report.report_downstream_latency(start_time, end_time)
write_callback.on_success(record_envelope, {})
elif isinstance(e, OperationalError):
# only OperationalErrors should be ignored
if not self.treat_errors_as_warnings:
self.report.report_failure({"error": e.message, "info": e.info})
else:
# trim exception stacktraces when reporting warnings
if "stackTrace" in e.info:
try:
e.info["stackTrace"] = "\n".join(
e.info["stackTrace"].split("\n")[0:2]
)
except Exception:
# ignore failures in trimming
pass
record = record_envelope.record
if isinstance(record, MetadataChangeProposalWrapper):
# include information about the entity that failed
entity_id = cast(
MetadataChangeProposalWrapper, record
).entityUrn
e.info["id"] = entity_id
else:
entity_id = None
self.report.report_warning({"warning": e.message, "info": e.info})
write_callback.on_failure(record_envelope, e, e.info)
else:
self.report.report_failure({"e": e})
write_callback.on_failure(record_envelope, Exception(e), {})
def write_record_async(
self,
record_envelope: RecordEnvelope[
@ -74,38 +127,15 @@ class DatahubRestSink(Sink):
) -> None:
record = record_envelope.record
try:
self.emitter.emit(record)
self.report.report_record_written(record_envelope)
write_callback.on_success(record_envelope, {})
except OperationalError as e:
# only OperationalErrors should be ignored
if not self.treat_errors_as_warnings:
self.report.report_failure({"error": e.message, "info": e.info})
else:
# trim exception stacktraces when reporting warnings
if "stackTrace" in e.info:
try:
e.info["stackTrace"] = "\n".join(
e.info["stackTrace"].split("\n")[0:2]
)
except Exception:
# ignore failures in trimming
pass
if isinstance(record, MetadataChangeProposalWrapper):
# include information about the entity that failed
entity_id = cast(MetadataChangeProposalWrapper, record).entityUrn
e.info["id"] = entity_id
else:
entity_id = None
self.report.report_warning({"warning": e.message, "info": e.info})
write_callback.on_failure(record_envelope, e, e.info)
except Exception as e:
self.report.report_failure({"e": e})
write_callback.on_failure(record_envelope, e, {})
write_future = self.executor.submit(self.emitter.emit, record)
write_future.add_done_callback(
functools.partial(
self._write_done_callback, record_envelope, write_callback
)
)
def get_report(self) -> SinkReport:
return self.report
def close(self):
pass
self.executor.shutdown(wait=True)

View File

@ -1,3 +1,4 @@
import concurrent.futures
import datetime
import json
import logging
@ -76,6 +77,7 @@ class LookerAPI:
os.environ["LOOKERSDK_BASE_URL"] = config.base_url
self.client = looker_sdk.init31()
# try authenticating current user to check connectivity
# (since it's possible to initialize an invalid client without any complaints)
try:
@ -99,6 +101,7 @@ class LookerDashboardSourceConfig(LookerAPIConfig, LookerCommonConfig):
extract_owners: bool = True
strip_user_ids_from_email: bool = False
skip_personal_folders: bool = False
max_threads: int = os.cpu_count() or 40
@dataclass
@ -107,6 +110,9 @@ class LookerDashboardSourceReport(SourceReport):
charts_scanned: int = 0
filtered_dashboards: List[str] = dataclass_field(default_factory=list)
filtered_charts: List[str] = dataclass_field(default_factory=list)
upstream_start_time: Optional[datetime.datetime] = None
upstream_end_time: Optional[datetime.datetime] = None
upstream_total_latency_in_seconds: Optional[float] = None
def report_dashboards_scanned(self) -> None:
self.dashboards_scanned += 1
@ -120,6 +126,17 @@ class LookerDashboardSourceReport(SourceReport):
def report_charts_dropped(self, view: str) -> None:
self.filtered_charts.append(view)
def report_upstream_latency(
self, start_time: datetime.datetime, end_time: datetime.datetime
) -> None:
if self.upstream_start_time is None or self.upstream_start_time > start_time:
self.upstream_start_time = start_time
if self.upstream_end_time is None or self.upstream_end_time < end_time:
self.upstream_end_time = end_time
self.upstream_total_latency_in_seconds = (
self.upstream_end_time - self.upstream_start_time
).total_seconds()
@dataclass
class LookerDashboardElement:
@ -539,18 +556,45 @@ class LookerDashboardSource(Source):
explore_events: List[
Union[MetadataChangeEvent, MetadataChangeProposalWrapper]
] = []
for (model, explore) in self.explore_set:
logger.info("Will process model: {}, explore: {}".format(model, explore))
looker_explore = LookerExplore.from_api(
model, explore, self.client, self.reporter
)
if looker_explore is not None:
events = looker_explore._to_metadata_events(
with concurrent.futures.ThreadPoolExecutor(
max_workers=self.source_config.max_threads
) as async_executor:
explore_futures = [
async_executor.submit(self.fetch_one_explore, model, explore)
for (model, explore) in self.explore_set
]
for future in concurrent.futures.as_completed(explore_futures):
events, explore_id, start_time, end_time = future.result()
explore_events.extend(events)
self.reporter.report_upstream_latency(start_time, end_time)
logger.info(
f"Running time of fetch_one_explore for {explore_id}: {(end_time-start_time).total_seconds()}"
)
return explore_events
def fetch_one_explore(
self, model: str, explore: str
) -> Tuple[
List[Union[MetadataChangeEvent, MetadataChangeProposalWrapper]],
str,
datetime.datetime,
datetime.datetime,
]:
start_time = datetime.datetime.now()
events: List[Union[MetadataChangeEvent, MetadataChangeProposalWrapper]] = []
looker_explore = LookerExplore.from_api(
model, explore, self.client, self.reporter
)
if looker_explore is not None:
events = (
looker_explore._to_metadata_events(
self.source_config, self.reporter, self.source_config.base_url
)
if events is not None:
explore_events.extend(events)
return explore_events
or events
)
return events, f"{model}:{explore}", start_time, datetime.datetime.now()
def _make_dashboard_and_chart_mces(
self, looker_dashboard: LookerDashboard
@ -613,13 +657,17 @@ class LookerDashboardSource(Source):
return ownership
return None
folder_path_cache: Dict[str, str] = {}
def _get_folder_path(self, folder: FolderBase, client: Looker31SDK) -> str:
assert folder.id is not None
ancestors = [
ancestor.name
for ancestor in client.folder_ancestors(folder.id, fields="name")
]
return "/".join(ancestors + [folder.name])
assert folder.id
if not self.folder_path_cache.get(folder.id):
ancestors = [
ancestor.name
for ancestor in client.folder_ancestors(folder.id, fields="name")
]
self.folder_path_cache[folder.id] = "/".join(ancestors + [folder.name])
return self.folder_path_cache[folder.id]
def _get_looker_dashboard(
self, dashboard: Dashboard, client: Looker31SDK
@ -677,6 +725,57 @@ class LookerDashboardSource(Source):
)
return looker_dashboard
def process_dashboard(
self, dashboard_id: str
) -> Tuple[List[MetadataWorkUnit], str, datetime.datetime, datetime.datetime]:
start_time = datetime.datetime.now()
assert dashboard_id is not None
self.reporter.report_dashboards_scanned()
if not self.source_config.dashboard_pattern.allowed(dashboard_id):
self.reporter.report_dashboards_dropped(dashboard_id)
return [], dashboard_id, start_time, datetime.datetime.now()
try:
fields = [
"id",
"title",
"dashboard_elements",
"dashboard_filters",
"deleted",
"description",
"folder",
"user_id",
]
dashboard_object = self.client.dashboard(
dashboard_id=dashboard_id, fields=",".join(fields)
)
except SDKError:
# A looker dashboard could be deleted in between the list and the get
self.reporter.report_warning(
dashboard_id,
f"Error occurred while loading dashboard {dashboard_id}. Skipping.",
)
return [], dashboard_id, start_time, datetime.datetime.now()
if self.source_config.skip_personal_folders:
if dashboard_object.folder is not None and (
dashboard_object.folder.is_personal
or dashboard_object.folder.is_personal_descendant
):
self.reporter.report_warning(
dashboard_id, "Dropped due to being a personal folder"
)
self.reporter.report_dashboards_dropped(dashboard_id)
return [], dashboard_id, start_time, datetime.datetime.now()
looker_dashboard = self._get_looker_dashboard(dashboard_object, self.client)
mces = self._make_dashboard_and_chart_mces(looker_dashboard)
# for mce in mces:
workunits = [
MetadataWorkUnit(id=f"looker-{mce.proposedSnapshot.urn}", mce=mce)
for mce in mces
]
return workunits, dashboard_id, start_time, datetime.datetime.now()
def get_workunits(self) -> Iterable[MetadataWorkUnit]:
dashboards = self.client.all_dashboards(fields="id")
deleted_dashboards = (
@ -692,53 +791,22 @@ class LookerDashboardSource(Source):
[deleted_dashboard.id for deleted_dashboard in deleted_dashboards]
)
for dashboard_id in dashboard_ids:
assert dashboard_id is not None
self.reporter.report_dashboards_scanned()
if not self.source_config.dashboard_pattern.allowed(dashboard_id):
self.reporter.report_dashboards_dropped(dashboard_id)
continue
try:
fields = [
"id",
"title",
"dashboard_elements",
"dashboard_filters",
"deleted",
"description",
"folder",
"user_id",
]
dashboard_object = self.client.dashboard(
dashboard_id=dashboard_id, fields=",".join(fields)
with concurrent.futures.ThreadPoolExecutor(
max_workers=self.source_config.max_threads
) as async_executor:
async_workunits = [
async_executor.submit(self.process_dashboard, dashboard_id)
for dashboard_id in dashboard_ids
]
for async_workunit in concurrent.futures.as_completed(async_workunits):
work_units, dashboard_id, start_time, end_time = async_workunit.result()
logger.info(
f"Running time of process_dashboard for {dashboard_id} = {(end_time-start_time).total_seconds()}"
)
except SDKError:
# A looker dashboard could be deleted in between the list and the get
self.reporter.report_warning(
dashboard_id,
f"Error occurred while loading dashboard {dashboard_id}. Skipping.",
)
continue
if self.source_config.skip_personal_folders:
if dashboard_object.folder is not None and (
dashboard_object.folder.is_personal
or dashboard_object.folder.is_personal_descendant
):
self.reporter.report_warning(
dashboard_id, "Dropped due to being a personal folder"
)
self.reporter.report_dashboards_dropped(dashboard_id)
continue
looker_dashboard = self._get_looker_dashboard(dashboard_object, self.client)
mces = self._make_dashboard_and_chart_mces(looker_dashboard)
for mce in mces:
workunit = MetadataWorkUnit(
id=f"looker-{mce.proposedSnapshot.urn}", mce=mce
)
self.reporter.report_workunit(workunit)
yield workunit
self.reporter.report_upstream_latency(start_time, end_time)
for mwu in work_units:
yield mwu
self.reporter.report_workunit(mwu)
if (
self.source_config.extract_owners

View File

@ -575,12 +575,17 @@ class LookerExplore:
upstream_views=list(views),
source_file=explore.source_file,
)
except SDKError:
except SDKError as e:
logger.warn(
"Failed to extract explore {} from model {}.".format(
explore_name, model
)
)
logger.debug(
"Failed to extract explore {} from model {} with {}".format(
explore_name, model, e
)
)
except AssertionError:
reporter.report_warning(
key="chart-",