feat(ingest): add async batch mode to the rest sink (#10733)

This commit is contained in:
Harshal Sheth 2024-06-25 15:49:00 -07:00 committed by GitHub
parent 0dc0bc5761
commit 724907b8f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 766 additions and 378 deletions

View File

@ -214,7 +214,7 @@ class DataHubGraph(DatahubRestEmitter):
def _make_rest_sink_config(self) -> "DatahubRestSinkConfig":
from datahub.ingestion.sink.datahub_rest import (
DatahubRestSinkConfig,
SyncOrAsync,
RestSinkMode,
)
# This is a bit convoluted - this DataHubGraph class is a subclass of DatahubRestEmitter,
@ -222,7 +222,7 @@ class DataHubGraph(DatahubRestEmitter):
# TODO: We should refactor out the multithreading functionality of the sink
# into a separate class that can be used by both the sink and the graph client
# e.g. a DatahubBulkRestEmitter that both the sink and the graph client use.
return DatahubRestSinkConfig(**self.config.dict(), mode=SyncOrAsync.ASYNC)
return DatahubRestSinkConfig(**self.config.dict(), mode=RestSinkMode.ASYNC)
@contextlib.contextmanager
def make_rest_sink(
@ -253,14 +253,10 @@ class DataHubGraph(DatahubRestEmitter):
) -> None:
"""Emit all items in the iterable using multiple threads."""
# The context manager also ensures that we raise an error if a failure occurs.
with self.make_rest_sink(run_id=run_id) as sink:
for item in items:
sink.emit_async(item)
if sink.report.failures:
raise OperationalError(
f"Failed to emit {len(sink.report.failures)} records",
info=sink.report.as_obj(),
)
def get_aspect(
self,

View File

@ -7,7 +7,7 @@ import os
import threading
import uuid
from enum import auto
from typing import Optional, Union
from typing import List, Optional, Tuple, Union
from datahub.cli.cli_utils import set_env_variables_override_config
from datahub.configuration.common import (
@ -16,6 +16,7 @@ from datahub.configuration.common import (
OperationalError,
)
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.emitter.mcp_builder import mcps_from_mce
from datahub.emitter.rest_emitter import DataHubRestEmitter
from datahub.ingestion.api.common import RecordEnvelope, WorkUnit
from datahub.ingestion.api.sink import (
@ -30,7 +31,10 @@ from datahub.metadata.com.linkedin.pegasus2avro.mxe import (
MetadataChangeEvent,
MetadataChangeProposal,
)
from datahub.utilities.advanced_thread_executor import PartitionExecutor
from datahub.utilities.partition_executor import (
BatchPartitionExecutor,
PartitionExecutor,
)
from datahub.utilities.perf_timer import PerfTimer
from datahub.utilities.server_config_util import set_gms_config
@ -41,18 +45,26 @@ DEFAULT_REST_SINK_MAX_THREADS = int(
)
class SyncOrAsync(ConfigEnum):
class RestSinkMode(ConfigEnum):
SYNC = auto()
ASYNC = auto()
# Uses the new ingestProposalBatch endpoint. Significantly more efficient than the other modes,
# but requires a server version that supports it.
# https://github.com/datahub-project/datahub/pull/10706
ASYNC_BATCH = auto()
class DatahubRestSinkConfig(DatahubClientConfig):
mode: SyncOrAsync = SyncOrAsync.ASYNC
mode: RestSinkMode = RestSinkMode.ASYNC
# These only apply in async mode.
# These only apply in async modes.
max_threads: int = DEFAULT_REST_SINK_MAX_THREADS
max_pending_requests: int = 2000
# Only applies in async batch mode.
max_per_batch: int = 100
@dataclasses.dataclass
class DataHubRestSinkReport(SinkReport):
@ -111,10 +123,20 @@ class DatahubRestSink(Sink[DatahubRestSinkConfig, DataHubRestSinkReport]):
set_env_variables_override_config(self.config.server, self.config.token)
logger.debug("Setting gms config")
set_gms_config(gms_config)
self.executor = PartitionExecutor(
max_workers=self.config.max_threads,
max_pending=self.config.max_pending_requests,
)
self.executor: Union[PartitionExecutor, BatchPartitionExecutor]
if self.config.mode == RestSinkMode.ASYNC_BATCH:
self.executor = BatchPartitionExecutor(
max_workers=self.config.max_threads,
max_pending=self.config.max_pending_requests,
process_batch=self._emit_batch_wrapper,
max_per_batch=self.config.max_per_batch,
)
else:
self.executor = PartitionExecutor(
max_workers=self.config.max_threads,
max_pending=self.config.max_pending_requests,
)
@classmethod
def _make_emitter(cls, config: DatahubRestSinkConfig) -> DataHubRestEmitter:
@ -189,6 +211,7 @@ class DatahubRestSink(Sink[DatahubRestSinkConfig, DataHubRestSinkReport]):
self.report.report_warning({"warning": e.message, "info": e.info})
write_callback.on_failure(record_envelope, e, e.info)
else:
logger.exception(f"Failure: {e}", exc_info=e)
self.report.report_failure({"e": e})
write_callback.on_failure(record_envelope, Exception(e), {})
@ -203,6 +226,30 @@ class DatahubRestSink(Sink[DatahubRestSinkConfig, DataHubRestSinkReport]):
# TODO: Add timing metrics
self.emitter.emit(record)
def _emit_batch_wrapper(
self,
records: List[
Tuple[
Union[
MetadataChangeEvent,
MetadataChangeProposal,
MetadataChangeProposalWrapper,
],
]
],
) -> None:
events: List[Union[MetadataChangeProposal, MetadataChangeProposalWrapper]] = []
for record in records:
event = record[0]
if isinstance(event, MetadataChangeEvent):
# Unpack MCEs into MCPs.
mcps = mcps_from_mce(event)
events.extend(mcps)
else:
events.append(event)
self.emitter.emit_mcps(events)
def write_record_async(
self,
record_envelope: RecordEnvelope[
@ -218,7 +265,8 @@ class DatahubRestSink(Sink[DatahubRestSinkConfig, DataHubRestSinkReport]):
# should only have a high value if the sink is actually a bottleneck.
with self.report.main_thread_blocking_timer:
record = record_envelope.record
if self.config.mode == SyncOrAsync.ASYNC:
if self.config.mode == RestSinkMode.ASYNC:
assert isinstance(self.executor, PartitionExecutor)
partition_key = _get_partition_key(record_envelope)
self.executor.submit(
partition_key,
@ -229,6 +277,17 @@ class DatahubRestSink(Sink[DatahubRestSinkConfig, DataHubRestSinkReport]):
),
)
self.report.pending_requests += 1
elif self.config.mode == RestSinkMode.ASYNC_BATCH:
assert isinstance(self.executor, BatchPartitionExecutor)
partition_key = _get_partition_key(record_envelope)
self.executor.submit(
partition_key,
record,
done_callback=functools.partial(
self._write_done_callback, record_envelope, write_callback
),
)
self.report.pending_requests += 1
else:
# execute synchronously
try:
@ -249,7 +308,8 @@ class DatahubRestSink(Sink[DatahubRestSinkConfig, DataHubRestSinkReport]):
)
def close(self):
self.executor.shutdown()
with self.report.main_thread_blocking_timer:
self.executor.shutdown()
def __repr__(self) -> str:
return self.emitter.__repr__()

View File

@ -102,7 +102,7 @@ from datahub.metadata.schema_classes import (
OwnershipTypeClass,
SubTypesClass,
)
from datahub.utilities.advanced_thread_executor import BackpressureAwareExecutor
from datahub.utilities.backpressure_aware_executor import BackpressureAwareExecutor
logger = logging.getLogger(__name__)

View File

@ -1,231 +0,0 @@
from __future__ import annotations
import collections
import concurrent.futures
import logging
import time
from concurrent.futures import Future, ThreadPoolExecutor
from threading import BoundedSemaphore
from typing import (
Any,
Callable,
Deque,
Dict,
Iterable,
Iterator,
Optional,
Set,
Tuple,
TypeVar,
)
from datahub.ingestion.api.closeable import Closeable
logger = logging.getLogger(__name__)
_R = TypeVar("_R")
_PARTITION_EXECUTOR_FLUSH_SLEEP_INTERVAL = 0.05
class PartitionExecutor(Closeable):
def __init__(self, max_workers: int, max_pending: int) -> None:
"""A thread pool executor with partitioning and a pending request bound.
It works similarly to a ThreadPoolExecutor, with the following changes:
- At most one request per partition key will be executing at a time.
- If the number of pending requests exceeds the threshold, the submit() call
will block until the number of pending requests drops below the threshold.
Due to the interaction between max_workers and max_pending, it is possible
for execution to effectively be serialized when there's a large influx of
requests with the same key. This can be mitigated by setting a reasonably
large max_pending value.
Args:
max_workers: The maximum number of threads to use for executing requests.
max_pending: The maximum number of pending (e.g. non-executing) requests to allow.
"""
self.max_workers = max_workers
self.max_pending = max_pending
self._executor = ThreadPoolExecutor(max_workers=max_workers)
# Each pending or executing request will acquire a permit from this semaphore.
self._semaphore = BoundedSemaphore(max_pending + max_workers)
# A key existing in this dict means that there is a submitted request for that key.
# Any entries in the key's value e.g. the deque are requests that are waiting
# to be submitted once the current request for that key completes.
self._pending_by_key: Dict[
str, Deque[Tuple[Callable, tuple, dict, Optional[Callable[[Future], None]]]]
] = {}
def submit(
self,
key: str,
fn: Callable[..., _R],
*args: Any,
# Ideally, we would've used ParamSpec to annotate this method. However,
# due to the limitations of PEP 612, we can't add a keyword argument here.
# See https://peps.python.org/pep-0612/#concatenating-keyword-parameters
# As such, we're using Any here, and won't validate the args to this method.
# We might be able to work around it by moving the done_callback arg to be before
# the *args, but that would mean making done_callback a required arg instead of
# optional as it is now.
done_callback: Optional[Callable[[Future], None]] = None,
**kwargs: Any,
) -> None:
"""See concurrent.futures.Executor#submit"""
self._semaphore.acquire()
if key in self._pending_by_key:
self._pending_by_key[key].append((fn, args, kwargs, done_callback))
else:
self._pending_by_key[key] = collections.deque()
self._submit_nowait(key, fn, args, kwargs, done_callback=done_callback)
def _submit_nowait(
self,
key: str,
fn: Callable[..., _R],
args: tuple,
kwargs: dict,
done_callback: Optional[Callable[[Future], None]],
) -> Future:
future = self._executor.submit(fn, *args, **kwargs)
def _system_done_callback(future: Future) -> None:
self._semaphore.release()
# If there is another pending request for this key, submit it now.
# The key must exist in the map.
if self._pending_by_key[key]:
fn, args, kwargs, user_done_callback = self._pending_by_key[
key
].popleft()
try:
self._submit_nowait(key, fn, args, kwargs, user_done_callback)
except RuntimeError as e:
if self._executor._shutdown:
# If we're in shutdown mode, then we can't submit any more requests.
# That means we'll need to drop requests on the floor, which is to
# be expected in shutdown mode.
# The only reason we'd normally be in shutdown here is during
# Python exit (e.g. KeyboardInterrupt), so this is reasonable.
logger.debug("Dropping request due to shutdown")
else:
raise e
else:
# If there are no pending requests for this key, mark the key
# as no longer in progress.
del self._pending_by_key[key]
if done_callback:
future.add_done_callback(done_callback)
future.add_done_callback(_system_done_callback)
return future
def flush(self) -> None:
"""Wait for all pending requests to complete."""
# Acquire all the semaphore permits so that no more requests can be submitted.
for _i in range(self.max_pending):
self._semaphore.acquire()
# Now, wait for all the pending requests to complete.
while len(self._pending_by_key) > 0:
# TODO: There should be a better way to wait for all executor threads to be idle.
# One option would be to just shutdown the existing executor and create a new one.
time.sleep(_PARTITION_EXECUTOR_FLUSH_SLEEP_INTERVAL)
# Now allow new requests to be submitted.
# TODO: With Python 3.9, release() can take a count argument.
for _i in range(self.max_pending):
self._semaphore.release()
def shutdown(self) -> None:
"""See concurrent.futures.Executor#shutdown. Behaves as if wait=True."""
self.flush()
assert len(self._pending_by_key) == 0
# Technically, the wait=True here is redundant, since all the threads should
# be idle now.
self._executor.shutdown(wait=True)
def close(self) -> None:
self.shutdown()
class BackpressureAwareExecutor:
# This couldn't be a real executor because the semantics of submit wouldn't really make sense.
# In this variant, if we blocked on submit, then we would also be blocking the thread that
# we expect to be consuming the results. As such, I made it accept the full list of args
# up front, and that way the consumer can read results at its own pace.
@classmethod
def map(
cls,
fn: Callable[..., _R],
args_list: Iterable[Tuple[Any, ...]],
max_workers: int,
max_pending: Optional[int] = None,
) -> Iterator[Future[_R]]:
"""Similar to concurrent.futures.ThreadPoolExecutor#map, except that it won't run ahead of the consumer.
The main benefit is that the ThreadPoolExecutor isn't stuck holding a ton of result
objects in memory if the consumer is slow. Instead, the consumer can read the results
at its own pace and the executor threads will idle if they need to.
Args:
fn: The function to apply to each input.
args_list: The list of inputs. In contrast to the builtin map, this is a list
of tuples, where each tuple is the arguments to fn.
max_workers: The maximum number of threads to use.
max_pending: The maximum number of pending results to keep in memory.
If not set, it will be set to 2*max_workers.
Returns:
An iterable of futures.
This differs from a traditional map because it returns futures
instead of the actual results, so that the caller is required
to handle exceptions.
Additionally, it does not maintain the order of the arguments.
If you want to know which result corresponds to which input,
the mapped function should return some form of an identifier.
"""
if max_pending is None:
max_pending = 2 * max_workers
assert max_pending >= max_workers
pending_futures: Set[Future] = set()
with ThreadPoolExecutor(max_workers=max_workers) as executor:
for args in args_list:
# If the pending list is full, wait until one is done.
if len(pending_futures) >= max_pending:
(done, _) = concurrent.futures.wait(
pending_futures, return_when=concurrent.futures.FIRST_COMPLETED
)
for future in done:
pending_futures.remove(future)
# We don't want to call result() here because we want the caller
# to handle exceptions/cancellation.
yield future
# Now that there's space in the pending list, enqueue the next task.
pending_futures.add(executor.submit(fn, *args))
# Wait for all the remaining tasks to complete.
for future in concurrent.futures.as_completed(pending_futures):
pending_futures.remove(future)
yield future
assert not pending_futures

View File

@ -0,0 +1,78 @@
from __future__ import annotations
import concurrent.futures
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Any, Callable, Iterable, Iterator, Optional, Set, Tuple, TypeVar
_R = TypeVar("_R")
class BackpressureAwareExecutor:
# This couldn't be a real executor because the semantics of submit wouldn't really make sense.
# In this variant, if we blocked on submit, then we would also be blocking the thread that
# we expect to be consuming the results. As such, I made it accept the full list of args
# up front, and that way the consumer can read results at its own pace.
@classmethod
def map(
cls,
fn: Callable[..., _R],
args_list: Iterable[Tuple[Any, ...]],
max_workers: int,
max_pending: Optional[int] = None,
) -> Iterator[Future[_R]]:
"""Similar to concurrent.futures.ThreadPoolExecutor#map, except that it won't run ahead of the consumer.
The main benefit is that the ThreadPoolExecutor isn't stuck holding a ton of result
objects in memory if the consumer is slow. Instead, the consumer can read the results
at its own pace and the executor threads will idle if they need to.
Args:
fn: The function to apply to each input.
args_list: The list of inputs. In contrast to the builtin map, this is a list
of tuples, where each tuple is the arguments to fn.
max_workers: The maximum number of threads to use.
max_pending: The maximum number of pending results to keep in memory.
If not set, it will be set to 2*max_workers.
Returns:
An iterable of futures.
This differs from a traditional map because it returns futures
instead of the actual results, so that the caller is required
to handle exceptions.
Additionally, it does not maintain the order of the arguments.
If you want to know which result corresponds to which input,
the mapped function should return some form of an identifier.
"""
if max_pending is None:
max_pending = 2 * max_workers
assert max_pending >= max_workers
pending_futures: Set[Future] = set()
with ThreadPoolExecutor(max_workers=max_workers) as executor:
for args in args_list:
# If the pending list is full, wait until one is done.
if len(pending_futures) >= max_pending:
(done, _) = concurrent.futures.wait(
pending_futures, return_when=concurrent.futures.FIRST_COMPLETED
)
for future in done:
pending_futures.remove(future)
# We don't want to call result() here because we want the caller
# to handle exceptions/cancellation.
yield future
# Now that there's space in the pending list, enqueue the next task.
pending_futures.add(executor.submit(fn, *args))
# Wait for all the remaining tasks to complete.
for future in concurrent.futures.as_completed(pending_futures):
pending_futures.remove(future)
yield future
assert not pending_futures

View File

@ -0,0 +1,404 @@
from __future__ import annotations
import collections
import functools
import logging
import queue
import threading
import time
from concurrent.futures import Future, ThreadPoolExecutor
from datetime import datetime, timedelta, timezone
from threading import BoundedSemaphore
from typing import (
Any,
Callable,
Deque,
Dict,
List,
NamedTuple,
Optional,
Set,
Tuple,
TypeVar,
)
from datahub.ingestion.api.closeable import Closeable
logger = logging.getLogger(__name__)
_R = TypeVar("_R")
_Args = TypeVar("_Args", bound=tuple)
_PARTITION_EXECUTOR_FLUSH_SLEEP_INTERVAL = 0.05
_DEFAULT_BATCHER_MIN_PROCESS_INTERVAL = timedelta(seconds=30)
class PartitionExecutor(Closeable):
def __init__(self, max_workers: int, max_pending: int) -> None:
"""A thread pool executor with partitioning and a pending request bound.
It works similarly to a ThreadPoolExecutor, with the following changes:
- At most one request per partition key will be executing at a time.
- If the number of pending requests exceeds the threshold, the submit() call
will block until the number of pending requests drops below the threshold.
Due to the interaction between max_workers and max_pending, it is possible
for execution to effectively be serialized when there's a large influx of
requests with the same key. This can be mitigated by setting a reasonably
large max_pending value.
Args:
max_workers: The maximum number of threads to use for executing requests.
max_pending: The maximum number of pending (e.g. non-executing) requests to allow.
"""
self.max_workers = max_workers
self.max_pending = max_pending
self._executor = ThreadPoolExecutor(max_workers=max_workers)
# Each pending or executing request will acquire a permit from this semaphore.
self._semaphore = BoundedSemaphore(max_pending + max_workers)
# A key existing in this dict means that there is a submitted request for that key.
# Any entries in the key's value e.g. the deque are requests that are waiting
# to be submitted once the current request for that key completes.
self._pending_by_key: Dict[
str, Deque[Tuple[Callable, tuple, dict, Optional[Callable[[Future], None]]]]
] = {}
def submit(
self,
key: str,
fn: Callable[..., _R],
*args: Any,
# Ideally, we would've used ParamSpec to annotate this method. However,
# due to the limitations of PEP 612, we can't add a keyword argument here.
# See https://peps.python.org/pep-0612/#concatenating-keyword-parameters
# As such, we're using Any here, and won't validate the args to this method.
# We might be able to work around it by moving the done_callback arg to be before
# the *args, but that would mean making done_callback a required arg instead of
# optional as it is now.
done_callback: Optional[Callable[[Future], None]] = None,
**kwargs: Any,
) -> None:
"""See concurrent.futures.Executor#submit"""
self._semaphore.acquire()
if key in self._pending_by_key:
self._pending_by_key[key].append((fn, args, kwargs, done_callback))
else:
self._pending_by_key[key] = collections.deque()
self._submit_nowait(key, fn, args, kwargs, done_callback=done_callback)
def _submit_nowait(
self,
key: str,
fn: Callable[..., _R],
args: tuple,
kwargs: dict,
done_callback: Optional[Callable[[Future], None]],
) -> Future:
future = self._executor.submit(fn, *args, **kwargs)
def _system_done_callback(future: Future) -> None:
self._semaphore.release()
# If there is another pending request for this key, submit it now.
# The key must exist in the map.
if self._pending_by_key[key]:
fn, args, kwargs, user_done_callback = self._pending_by_key[
key
].popleft()
try:
self._submit_nowait(key, fn, args, kwargs, user_done_callback)
except RuntimeError as e:
if self._executor._shutdown:
# If we're in shutdown mode, then we can't submit any more requests.
# That means we'll need to drop requests on the floor, which is to
# be expected in shutdown mode.
# The only reason we'd normally be in shutdown here is during
# Python exit (e.g. KeyboardInterrupt), so this is reasonable.
logger.debug("Dropping request due to shutdown")
else:
raise e
else:
# If there are no pending requests for this key, mark the key
# as no longer in progress.
del self._pending_by_key[key]
if done_callback:
future.add_done_callback(done_callback)
future.add_done_callback(_system_done_callback)
return future
def flush(self) -> None:
"""Wait for all pending requests to complete."""
# Acquire all the semaphore permits so that no more requests can be submitted.
for _i in range(self.max_pending):
self._semaphore.acquire()
# Now, wait for all the pending requests to complete.
while len(self._pending_by_key) > 0:
# TODO: There should be a better way to wait for all executor threads to be idle.
# One option would be to just shutdown the existing executor and create a new one.
time.sleep(_PARTITION_EXECUTOR_FLUSH_SLEEP_INTERVAL)
# Now allow new requests to be submitted.
# TODO: With Python 3.9, release() can take a count argument.
for _i in range(self.max_pending):
self._semaphore.release()
def shutdown(self) -> None:
"""See concurrent.futures.Executor#shutdown. Behaves as if wait=True."""
self.flush()
assert len(self._pending_by_key) == 0
self._executor.shutdown(wait=True)
def close(self) -> None:
self.shutdown()
class _BatchPartitionWorkItem(NamedTuple):
key: str
args: tuple
done_callback: Optional[Callable[[Future], None]]
def _now() -> datetime:
return datetime.now(tz=timezone.utc)
class BatchPartitionExecutor(Closeable):
def __init__(
self,
max_workers: int,
max_pending: int,
# Due to limitations of Python's typing, we can't express the type of the list
# effectively. Ideally we'd use ParamSpec here, but that's not allowed in a
# class context like this.
process_batch: Callable[[List], None],
max_per_batch: int = 100,
min_process_interval: timedelta = _DEFAULT_BATCHER_MIN_PROCESS_INTERVAL,
) -> None:
"""Similar to PartitionExecutor, but with batching.
This takes in the stream of requests, automatically segments them into partition-aware
batches, and schedules them across a pool of worker threads.
It maintains the invariant that multiple requests with the same key will not be in
flight concurrently, except when part of the same batch. Requests for a given key
will also be executed in the order they were submitted.
Unlike the PartitionExecutor, this does not support return values or kwargs.
Args:
max_workers: The maximum number of threads to use for executing requests.
max_pending: The maximum number of pending (e.g. non-executing) requests to allow.
max_per_batch: The maximum number of requests to include in a batch.
min_process_interval: When requests are coming in slowly, we will wait at least this long
before submitting a non-full batch.
process_batch: A function that takes in a list of argument tuples.
"""
self.max_workers = max_workers
self.max_pending = max_pending
self.max_per_batch = max_per_batch
self.process_batch = process_batch
self.min_process_interval = min_process_interval
assert self.max_workers > 1
# We add one here to account for the clearinghouse worker thread.
self._executor = ThreadPoolExecutor(max_workers=max_workers + 1)
self._clearinghouse_started = False
self._pending_count = BoundedSemaphore(max_pending)
self._pending: "queue.Queue[Optional[_BatchPartitionWorkItem]]" = queue.Queue(
maxsize=max_pending
)
# If this is true, that means shutdown() has been called and self._pending is empty.
self._queue_empty_for_shutdown = False
def _clearinghouse_worker(self) -> None: # noqa: C901
# This worker will pull items off the queue, and submit them into the executor
# in batches. Only this worker will submit process commands to the executor thread pool.
# The lock protects the function's internal state.
clearinghouse_state_lock = threading.Lock()
workers_available = self.max_workers
keys_in_flight: Set[str] = set()
keys_no_longer_in_flight: Set[str] = set()
pending_key_completion: List[_BatchPartitionWorkItem] = []
last_submit_time = _now()
def _handle_batch_completion(
batch: List[_BatchPartitionWorkItem], future: Future
) -> None:
with clearinghouse_state_lock:
for item in batch:
keys_no_longer_in_flight.add(item.key)
self._pending_count.release()
# Separate from the above loop to avoid holding the lock while calling the callbacks.
for item in batch:
if item.done_callback:
item.done_callback(future)
def _find_ready_items() -> List[_BatchPartitionWorkItem]:
with clearinghouse_state_lock:
# First, update the keys in flight.
for key in keys_no_longer_in_flight:
keys_in_flight.remove(key)
keys_no_longer_in_flight.clear()
# Then, update the pending key completion and build the ready list.
pending = pending_key_completion.copy()
pending_key_completion.clear()
ready: List[_BatchPartitionWorkItem] = []
for item in pending:
if (
len(ready) < self.max_per_batch
and item.key not in keys_in_flight
):
ready.append(item)
else:
pending_key_completion.append(item)
return ready
def _build_batch() -> List[_BatchPartitionWorkItem]:
next_batch = _find_ready_items()
while (
not self._queue_empty_for_shutdown
and len(next_batch) < self.max_per_batch
):
blocking = True
if (
next_batch
and _now() - last_submit_time > self.min_process_interval
and workers_available > 0
):
# If we're past the submit deadline, pull from the queue
# in a non-blocking way, and submit the batch once the queue
# is empty.
blocking = False
try:
next_item: Optional[_BatchPartitionWorkItem] = self._pending.get(
block=blocking,
timeout=self.min_process_interval.total_seconds(),
)
if next_item is None:
self._queue_empty_for_shutdown = True
break
with clearinghouse_state_lock:
if next_item.key in keys_in_flight:
pending_key_completion.append(next_item)
else:
next_batch.append(next_item)
except queue.Empty:
if not blocking:
break
return next_batch
def _submit_batch(next_batch: List[_BatchPartitionWorkItem]) -> None:
with clearinghouse_state_lock:
for item in next_batch:
keys_in_flight.add(item.key)
nonlocal workers_available
workers_available -= 1
nonlocal last_submit_time
last_submit_time = _now()
future = self._executor.submit(
self.process_batch, [item.args for item in next_batch]
)
future.add_done_callback(
functools.partial(_handle_batch_completion, next_batch)
)
try:
# Normal operation - submit batches as they become available.
while not self._queue_empty_for_shutdown:
next_batch = _build_batch()
if next_batch:
_submit_batch(next_batch)
# Shutdown time.
# Invariant - at this point, we know self._pending is empty.
# We just need to wait for the in-flight items to complete,
# and submit any currently pending items once possible.
while pending_key_completion:
next_batch = _build_batch()
if next_batch:
_submit_batch(next_batch)
time.sleep(_PARTITION_EXECUTOR_FLUSH_SLEEP_INTERVAL)
# At this point, there are no more things to submit.
# We could wait for the in-flight items to complete,
# but the executor will take care of waiting for them to complete.
except Exception as e:
# This represents a fatal error that makes the entire executor defunct.
logger.exception(
"Threaded executor's clearinghouse worker failed.", exc_info=e
)
finally:
self._clearinghouse_started = False
def _ensure_clearinghouse_started(self) -> None:
# Lazily start the clearinghouse worker.
if not self._clearinghouse_started:
self._clearinghouse_started = True
self._executor.submit(self._clearinghouse_worker)
def submit(
self,
key: str,
*args: Any,
done_callback: Optional[Callable[[Future], None]] = None,
) -> None:
"""See concurrent.futures.Executor#submit"""
self._ensure_clearinghouse_started()
self._pending_count.acquire()
self._pending.put(_BatchPartitionWorkItem(key, args, done_callback))
def shutdown(self) -> None:
if not self._clearinghouse_started:
# This is required to make shutdown() idempotent, which is important
# when it's called explicitly and then also by a context manager.
logger.debug("Shutting down: clearinghouse not started")
return
logger.debug(f"Shutting down {self.__class__.__name__}")
# Send the shutdown signal.
self._pending.put(None)
# By acquiring all the permits, we ensure that no more tasks will be scheduled
# and automatically wait until all existing tasks have completed.
for _ in range(self.max_pending):
self._pending_count.acquire()
# We must wait for the clearinghouse worker to exit before calling shutdown
# on the thread pool. Without this, the clearinghouse worker might fail to
# enqueue pending tasks into the pool.
while self._clearinghouse_started:
time.sleep(_PARTITION_EXECUTOR_FLUSH_SLEEP_INTERVAL)
self._executor.shutdown(wait=False)
def close(self) -> None:
self.shutdown()

View File

@ -1,128 +0,0 @@
import time
from concurrent.futures import Future
from datahub.utilities.advanced_thread_executor import (
BackpressureAwareExecutor,
PartitionExecutor,
)
from datahub.utilities.perf_timer import PerfTimer
def test_partitioned_executor():
executing_tasks = set()
done_tasks = set()
def task(key: str, id: str) -> None:
executing_tasks.add((key, id))
time.sleep(0.8)
done_tasks.add(id)
executing_tasks.remove((key, id))
with PartitionExecutor(max_workers=2, max_pending=10) as executor:
# Submit tasks with the same key. They should be executed sequentially.
executor.submit("key1", task, "key1", "task1")
executor.submit("key1", task, "key1", "task2")
executor.submit("key1", task, "key1", "task3")
# Submit a task with a different key. It should be executed in parallel.
executor.submit("key2", task, "key2", "task4")
saw_keys_in_parallel = False
while executing_tasks or not done_tasks:
keys_executing = [key for key, _ in executing_tasks]
assert list(sorted(keys_executing)) == list(
sorted(set(keys_executing))
), "partitioning not working"
if len(keys_executing) == 2:
saw_keys_in_parallel = True
time.sleep(0.1)
executor.flush()
assert saw_keys_in_parallel
assert not executing_tasks
assert done_tasks == {"task1", "task2", "task3", "task4"}
def test_partitioned_executor_bounding():
task_duration = 0.5
done_tasks = set()
def on_done(future: Future) -> None:
done_tasks.add(future.result())
def task(id: str) -> str:
time.sleep(task_duration)
return id
with PartitionExecutor(
max_workers=5, max_pending=10
) as executor, PerfTimer() as timer:
# The first 15 submits should be non-blocking.
for i in range(15):
executor.submit(f"key{i}", task, f"task{i}", done_callback=on_done)
assert timer.elapsed_seconds() < task_duration
# This submit should block.
executor.submit("key-blocking", task, "task-blocking", done_callback=on_done)
assert timer.elapsed_seconds() > task_duration
# Wait for everything to finish.
executor.flush()
assert len(done_tasks) == 16
def test_backpressure_aware_executor_simple():
def task(i):
return i
assert {
res.result()
for res in BackpressureAwareExecutor.map(
task, ((i,) for i in range(10)), max_workers=2
)
} == set(range(10))
def test_backpressure_aware_executor_advanced():
task_duration = 0.5
started = set()
executed = set()
def task(x, y):
assert x + 1 == y
started.add(x)
time.sleep(task_duration)
executed.add(x)
return x
args_list = [(i, i + 1) for i in range(10)]
with PerfTimer() as timer:
results = BackpressureAwareExecutor.map(
task, args_list, max_workers=2, max_pending=4
)
assert timer.elapsed_seconds() < task_duration
# No tasks should have completed yet.
assert len(executed) == 0
# Consume the first result.
first_result = next(results)
assert 0 <= first_result.result() < 4
assert timer.elapsed_seconds() > task_duration
# By now, the first four tasks should have started.
time.sleep(task_duration)
assert {0, 1, 2, 3}.issubset(started)
assert 2 <= len(executed) <= 4
# Finally, consume the rest of the results.
assert {r.result() for r in results} == {
i for i in range(10) if i != first_result.result()
}
# Validate that the entire process took about 5-10x the task duration.
# That's because we have 2 workers and 10 tasks.
assert 5 * task_duration < timer.elapsed_seconds() < 10 * task_duration

View File

@ -0,0 +1,59 @@
import time
from datahub.utilities.backpressure_aware_executor import BackpressureAwareExecutor
from datahub.utilities.perf_timer import PerfTimer
def test_backpressure_aware_executor_simple():
def task(i):
return i
assert {
res.result()
for res in BackpressureAwareExecutor.map(
task, ((i,) for i in range(10)), max_workers=2
)
} == set(range(10))
def test_backpressure_aware_executor_advanced():
task_duration = 0.5
started = set()
executed = set()
def task(x, y):
assert x + 1 == y
started.add(x)
time.sleep(task_duration)
executed.add(x)
return x
args_list = [(i, i + 1) for i in range(10)]
with PerfTimer() as timer:
results = BackpressureAwareExecutor.map(
task, args_list, max_workers=2, max_pending=4
)
assert timer.elapsed_seconds() < task_duration
# No tasks should have completed yet.
assert len(executed) == 0
# Consume the first result.
first_result = next(results)
assert 0 <= first_result.result() < 4
assert timer.elapsed_seconds() > task_duration
# By now, the first four tasks should have started.
time.sleep(task_duration)
assert {0, 1, 2, 3}.issubset(started)
assert 2 <= len(executed) <= 4
# Finally, consume the rest of the results.
assert {r.result() for r in results} == {
i for i in range(10) if i != first_result.result()
}
# Validate that the entire process took about 5-10x the task duration.
# That's because we have 2 workers and 10 tasks.
assert 5 * task_duration < timer.elapsed_seconds() < 10 * task_duration

View File

@ -0,0 +1,150 @@
import logging
import time
from concurrent.futures import Future
from datahub.utilities.partition_executor import (
BatchPartitionExecutor,
PartitionExecutor,
)
from datahub.utilities.perf_timer import PerfTimer
logger = logging.getLogger(__name__)
def test_partitioned_executor():
executing_tasks = set()
done_tasks = set()
def task(key: str, id: str) -> None:
executing_tasks.add((key, id))
time.sleep(0.8)
done_tasks.add(id)
executing_tasks.remove((key, id))
with PartitionExecutor(max_workers=2, max_pending=10) as executor:
# Submit tasks with the same key. They should be executed sequentially.
executor.submit("key1", task, "key1", "task1")
executor.submit("key1", task, "key1", "task2")
executor.submit("key1", task, "key1", "task3")
# Submit a task with a different key. It should be executed in parallel.
executor.submit("key2", task, "key2", "task4")
saw_keys_in_parallel = False
while executing_tasks or not done_tasks:
keys_executing = [key for key, _ in executing_tasks]
assert list(sorted(keys_executing)) == list(
sorted(set(keys_executing))
), "partitioning not working"
if len(keys_executing) == 2:
saw_keys_in_parallel = True
time.sleep(0.1)
executor.flush()
assert saw_keys_in_parallel
assert not executing_tasks
assert done_tasks == {"task1", "task2", "task3", "task4"}
def test_partitioned_executor_bounding():
task_duration = 0.5
done_tasks = set()
def on_done(future: Future) -> None:
done_tasks.add(future.result())
def task(id: str) -> str:
time.sleep(task_duration)
return id
with PartitionExecutor(
max_workers=5, max_pending=10
) as executor, PerfTimer() as timer:
# The first 15 submits should be non-blocking.
for i in range(15):
executor.submit(f"key{i}", task, f"task{i}", done_callback=on_done)
assert timer.elapsed_seconds() < task_duration
# This submit should block.
executor.submit("key-blocking", task, "task-blocking", done_callback=on_done)
assert timer.elapsed_seconds() > task_duration
# Wait for everything to finish.
executor.flush()
assert len(done_tasks) == 16
def test_batch_partition_executor_sequential_key_execution():
executing_tasks = set()
done_tasks = set()
done_task_batches = set()
def process_batch(batch):
for key, id in batch:
assert (key, id) not in executing_tasks, "Task is already executing"
executing_tasks.add((key, id))
time.sleep(0.5) # Simulate work
for key, id in batch:
executing_tasks.remove((key, id))
done_tasks.add(id)
done_task_batches.add(tuple(id for _, id in batch))
with BatchPartitionExecutor(
max_workers=2,
max_pending=10,
max_per_batch=2,
process_batch=process_batch,
) as executor:
# Submit tasks with the same key. The first two should get batched together.
executor.submit("key1", "key1", "task1")
executor.submit("key1", "key1", "task2")
executor.submit("key1", "key1", "task3")
# Submit tasks with a different key. These should get their own batch.
executor.submit("key2", "key2", "task4")
executor.submit("key2", "key2", "task5")
# Test idempotency of shutdown().
executor.shutdown()
# Check if all tasks were executed and completed.
assert done_tasks == {
"task1",
"task2",
"task3",
"task4",
"task5",
}, "Not all tasks completed"
# Check the batching configuration.
assert done_task_batches == {
("task1", "task2"),
("task4", "task5"),
("task3",),
}
def test_batch_partition_executor_max_batch_size():
batches_processed = []
def process_batch(batch):
batches_processed.append(batch)
time.sleep(0.1) # Simulate batch processing time
with BatchPartitionExecutor(
max_workers=5, max_pending=20, process_batch=process_batch, max_per_batch=2
) as executor:
# Submit more tasks than the max_per_batch to test batching limits.
for i in range(5):
executor.submit("key3", "key3", f"task{i}")
# Check the batches.
logger.info(f"batches_processed: {batches_processed}")
assert len(batches_processed) == 3
for batch in batches_processed:
assert len(batch) <= 2, "Batch size exceeded max_per_batch limit"