mirror of
https://github.com/datahub-project/datahub.git
synced 2026-01-06 06:46:41 +00:00
feat(ingest): add async batch mode to the rest sink (#10733)
This commit is contained in:
parent
0dc0bc5761
commit
724907b8f4
@ -214,7 +214,7 @@ class DataHubGraph(DatahubRestEmitter):
|
||||
def _make_rest_sink_config(self) -> "DatahubRestSinkConfig":
|
||||
from datahub.ingestion.sink.datahub_rest import (
|
||||
DatahubRestSinkConfig,
|
||||
SyncOrAsync,
|
||||
RestSinkMode,
|
||||
)
|
||||
|
||||
# This is a bit convoluted - this DataHubGraph class is a subclass of DatahubRestEmitter,
|
||||
@ -222,7 +222,7 @@ class DataHubGraph(DatahubRestEmitter):
|
||||
# TODO: We should refactor out the multithreading functionality of the sink
|
||||
# into a separate class that can be used by both the sink and the graph client
|
||||
# e.g. a DatahubBulkRestEmitter that both the sink and the graph client use.
|
||||
return DatahubRestSinkConfig(**self.config.dict(), mode=SyncOrAsync.ASYNC)
|
||||
return DatahubRestSinkConfig(**self.config.dict(), mode=RestSinkMode.ASYNC)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def make_rest_sink(
|
||||
@ -253,14 +253,10 @@ class DataHubGraph(DatahubRestEmitter):
|
||||
) -> None:
|
||||
"""Emit all items in the iterable using multiple threads."""
|
||||
|
||||
# The context manager also ensures that we raise an error if a failure occurs.
|
||||
with self.make_rest_sink(run_id=run_id) as sink:
|
||||
for item in items:
|
||||
sink.emit_async(item)
|
||||
if sink.report.failures:
|
||||
raise OperationalError(
|
||||
f"Failed to emit {len(sink.report.failures)} records",
|
||||
info=sink.report.as_obj(),
|
||||
)
|
||||
|
||||
def get_aspect(
|
||||
self,
|
||||
|
||||
@ -7,7 +7,7 @@ import os
|
||||
import threading
|
||||
import uuid
|
||||
from enum import auto
|
||||
from typing import Optional, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from datahub.cli.cli_utils import set_env_variables_override_config
|
||||
from datahub.configuration.common import (
|
||||
@ -16,6 +16,7 @@ from datahub.configuration.common import (
|
||||
OperationalError,
|
||||
)
|
||||
from datahub.emitter.mcp import MetadataChangeProposalWrapper
|
||||
from datahub.emitter.mcp_builder import mcps_from_mce
|
||||
from datahub.emitter.rest_emitter import DataHubRestEmitter
|
||||
from datahub.ingestion.api.common import RecordEnvelope, WorkUnit
|
||||
from datahub.ingestion.api.sink import (
|
||||
@ -30,7 +31,10 @@ from datahub.metadata.com.linkedin.pegasus2avro.mxe import (
|
||||
MetadataChangeEvent,
|
||||
MetadataChangeProposal,
|
||||
)
|
||||
from datahub.utilities.advanced_thread_executor import PartitionExecutor
|
||||
from datahub.utilities.partition_executor import (
|
||||
BatchPartitionExecutor,
|
||||
PartitionExecutor,
|
||||
)
|
||||
from datahub.utilities.perf_timer import PerfTimer
|
||||
from datahub.utilities.server_config_util import set_gms_config
|
||||
|
||||
@ -41,18 +45,26 @@ DEFAULT_REST_SINK_MAX_THREADS = int(
|
||||
)
|
||||
|
||||
|
||||
class SyncOrAsync(ConfigEnum):
|
||||
class RestSinkMode(ConfigEnum):
|
||||
SYNC = auto()
|
||||
ASYNC = auto()
|
||||
|
||||
# Uses the new ingestProposalBatch endpoint. Significantly more efficient than the other modes,
|
||||
# but requires a server version that supports it.
|
||||
# https://github.com/datahub-project/datahub/pull/10706
|
||||
ASYNC_BATCH = auto()
|
||||
|
||||
|
||||
class DatahubRestSinkConfig(DatahubClientConfig):
|
||||
mode: SyncOrAsync = SyncOrAsync.ASYNC
|
||||
mode: RestSinkMode = RestSinkMode.ASYNC
|
||||
|
||||
# These only apply in async mode.
|
||||
# These only apply in async modes.
|
||||
max_threads: int = DEFAULT_REST_SINK_MAX_THREADS
|
||||
max_pending_requests: int = 2000
|
||||
|
||||
# Only applies in async batch mode.
|
||||
max_per_batch: int = 100
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DataHubRestSinkReport(SinkReport):
|
||||
@ -111,10 +123,20 @@ class DatahubRestSink(Sink[DatahubRestSinkConfig, DataHubRestSinkReport]):
|
||||
set_env_variables_override_config(self.config.server, self.config.token)
|
||||
logger.debug("Setting gms config")
|
||||
set_gms_config(gms_config)
|
||||
self.executor = PartitionExecutor(
|
||||
max_workers=self.config.max_threads,
|
||||
max_pending=self.config.max_pending_requests,
|
||||
)
|
||||
|
||||
self.executor: Union[PartitionExecutor, BatchPartitionExecutor]
|
||||
if self.config.mode == RestSinkMode.ASYNC_BATCH:
|
||||
self.executor = BatchPartitionExecutor(
|
||||
max_workers=self.config.max_threads,
|
||||
max_pending=self.config.max_pending_requests,
|
||||
process_batch=self._emit_batch_wrapper,
|
||||
max_per_batch=self.config.max_per_batch,
|
||||
)
|
||||
else:
|
||||
self.executor = PartitionExecutor(
|
||||
max_workers=self.config.max_threads,
|
||||
max_pending=self.config.max_pending_requests,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _make_emitter(cls, config: DatahubRestSinkConfig) -> DataHubRestEmitter:
|
||||
@ -189,6 +211,7 @@ class DatahubRestSink(Sink[DatahubRestSinkConfig, DataHubRestSinkReport]):
|
||||
self.report.report_warning({"warning": e.message, "info": e.info})
|
||||
write_callback.on_failure(record_envelope, e, e.info)
|
||||
else:
|
||||
logger.exception(f"Failure: {e}", exc_info=e)
|
||||
self.report.report_failure({"e": e})
|
||||
write_callback.on_failure(record_envelope, Exception(e), {})
|
||||
|
||||
@ -203,6 +226,30 @@ class DatahubRestSink(Sink[DatahubRestSinkConfig, DataHubRestSinkReport]):
|
||||
# TODO: Add timing metrics
|
||||
self.emitter.emit(record)
|
||||
|
||||
def _emit_batch_wrapper(
|
||||
self,
|
||||
records: List[
|
||||
Tuple[
|
||||
Union[
|
||||
MetadataChangeEvent,
|
||||
MetadataChangeProposal,
|
||||
MetadataChangeProposalWrapper,
|
||||
],
|
||||
]
|
||||
],
|
||||
) -> None:
|
||||
events: List[Union[MetadataChangeProposal, MetadataChangeProposalWrapper]] = []
|
||||
for record in records:
|
||||
event = record[0]
|
||||
if isinstance(event, MetadataChangeEvent):
|
||||
# Unpack MCEs into MCPs.
|
||||
mcps = mcps_from_mce(event)
|
||||
events.extend(mcps)
|
||||
else:
|
||||
events.append(event)
|
||||
|
||||
self.emitter.emit_mcps(events)
|
||||
|
||||
def write_record_async(
|
||||
self,
|
||||
record_envelope: RecordEnvelope[
|
||||
@ -218,7 +265,8 @@ class DatahubRestSink(Sink[DatahubRestSinkConfig, DataHubRestSinkReport]):
|
||||
# should only have a high value if the sink is actually a bottleneck.
|
||||
with self.report.main_thread_blocking_timer:
|
||||
record = record_envelope.record
|
||||
if self.config.mode == SyncOrAsync.ASYNC:
|
||||
if self.config.mode == RestSinkMode.ASYNC:
|
||||
assert isinstance(self.executor, PartitionExecutor)
|
||||
partition_key = _get_partition_key(record_envelope)
|
||||
self.executor.submit(
|
||||
partition_key,
|
||||
@ -229,6 +277,17 @@ class DatahubRestSink(Sink[DatahubRestSinkConfig, DataHubRestSinkReport]):
|
||||
),
|
||||
)
|
||||
self.report.pending_requests += 1
|
||||
elif self.config.mode == RestSinkMode.ASYNC_BATCH:
|
||||
assert isinstance(self.executor, BatchPartitionExecutor)
|
||||
partition_key = _get_partition_key(record_envelope)
|
||||
self.executor.submit(
|
||||
partition_key,
|
||||
record,
|
||||
done_callback=functools.partial(
|
||||
self._write_done_callback, record_envelope, write_callback
|
||||
),
|
||||
)
|
||||
self.report.pending_requests += 1
|
||||
else:
|
||||
# execute synchronously
|
||||
try:
|
||||
@ -249,7 +308,8 @@ class DatahubRestSink(Sink[DatahubRestSinkConfig, DataHubRestSinkReport]):
|
||||
)
|
||||
|
||||
def close(self):
|
||||
self.executor.shutdown()
|
||||
with self.report.main_thread_blocking_timer:
|
||||
self.executor.shutdown()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.emitter.__repr__()
|
||||
|
||||
@ -102,7 +102,7 @@ from datahub.metadata.schema_classes import (
|
||||
OwnershipTypeClass,
|
||||
SubTypesClass,
|
||||
)
|
||||
from datahub.utilities.advanced_thread_executor import BackpressureAwareExecutor
|
||||
from datahub.utilities.backpressure_aware_executor import BackpressureAwareExecutor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -1,231 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from threading import BoundedSemaphore
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Deque,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from datahub.ingestion.api.closeable import Closeable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_R = TypeVar("_R")
|
||||
_PARTITION_EXECUTOR_FLUSH_SLEEP_INTERVAL = 0.05
|
||||
|
||||
|
||||
class PartitionExecutor(Closeable):
|
||||
def __init__(self, max_workers: int, max_pending: int) -> None:
|
||||
"""A thread pool executor with partitioning and a pending request bound.
|
||||
|
||||
It works similarly to a ThreadPoolExecutor, with the following changes:
|
||||
- At most one request per partition key will be executing at a time.
|
||||
- If the number of pending requests exceeds the threshold, the submit() call
|
||||
will block until the number of pending requests drops below the threshold.
|
||||
|
||||
Due to the interaction between max_workers and max_pending, it is possible
|
||||
for execution to effectively be serialized when there's a large influx of
|
||||
requests with the same key. This can be mitigated by setting a reasonably
|
||||
large max_pending value.
|
||||
|
||||
Args:
|
||||
max_workers: The maximum number of threads to use for executing requests.
|
||||
max_pending: The maximum number of pending (e.g. non-executing) requests to allow.
|
||||
"""
|
||||
self.max_workers = max_workers
|
||||
self.max_pending = max_pending
|
||||
|
||||
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
|
||||
# Each pending or executing request will acquire a permit from this semaphore.
|
||||
self._semaphore = BoundedSemaphore(max_pending + max_workers)
|
||||
|
||||
# A key existing in this dict means that there is a submitted request for that key.
|
||||
# Any entries in the key's value e.g. the deque are requests that are waiting
|
||||
# to be submitted once the current request for that key completes.
|
||||
self._pending_by_key: Dict[
|
||||
str, Deque[Tuple[Callable, tuple, dict, Optional[Callable[[Future], None]]]]
|
||||
] = {}
|
||||
|
||||
def submit(
|
||||
self,
|
||||
key: str,
|
||||
fn: Callable[..., _R],
|
||||
*args: Any,
|
||||
# Ideally, we would've used ParamSpec to annotate this method. However,
|
||||
# due to the limitations of PEP 612, we can't add a keyword argument here.
|
||||
# See https://peps.python.org/pep-0612/#concatenating-keyword-parameters
|
||||
# As such, we're using Any here, and won't validate the args to this method.
|
||||
# We might be able to work around it by moving the done_callback arg to be before
|
||||
# the *args, but that would mean making done_callback a required arg instead of
|
||||
# optional as it is now.
|
||||
done_callback: Optional[Callable[[Future], None]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""See concurrent.futures.Executor#submit"""
|
||||
|
||||
self._semaphore.acquire()
|
||||
|
||||
if key in self._pending_by_key:
|
||||
self._pending_by_key[key].append((fn, args, kwargs, done_callback))
|
||||
|
||||
else:
|
||||
self._pending_by_key[key] = collections.deque()
|
||||
self._submit_nowait(key, fn, args, kwargs, done_callback=done_callback)
|
||||
|
||||
def _submit_nowait(
|
||||
self,
|
||||
key: str,
|
||||
fn: Callable[..., _R],
|
||||
args: tuple,
|
||||
kwargs: dict,
|
||||
done_callback: Optional[Callable[[Future], None]],
|
||||
) -> Future:
|
||||
future = self._executor.submit(fn, *args, **kwargs)
|
||||
|
||||
def _system_done_callback(future: Future) -> None:
|
||||
self._semaphore.release()
|
||||
|
||||
# If there is another pending request for this key, submit it now.
|
||||
# The key must exist in the map.
|
||||
if self._pending_by_key[key]:
|
||||
fn, args, kwargs, user_done_callback = self._pending_by_key[
|
||||
key
|
||||
].popleft()
|
||||
|
||||
try:
|
||||
self._submit_nowait(key, fn, args, kwargs, user_done_callback)
|
||||
except RuntimeError as e:
|
||||
if self._executor._shutdown:
|
||||
# If we're in shutdown mode, then we can't submit any more requests.
|
||||
# That means we'll need to drop requests on the floor, which is to
|
||||
# be expected in shutdown mode.
|
||||
# The only reason we'd normally be in shutdown here is during
|
||||
# Python exit (e.g. KeyboardInterrupt), so this is reasonable.
|
||||
logger.debug("Dropping request due to shutdown")
|
||||
else:
|
||||
raise e
|
||||
|
||||
else:
|
||||
# If there are no pending requests for this key, mark the key
|
||||
# as no longer in progress.
|
||||
del self._pending_by_key[key]
|
||||
|
||||
if done_callback:
|
||||
future.add_done_callback(done_callback)
|
||||
future.add_done_callback(_system_done_callback)
|
||||
return future
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Wait for all pending requests to complete."""
|
||||
|
||||
# Acquire all the semaphore permits so that no more requests can be submitted.
|
||||
for _i in range(self.max_pending):
|
||||
self._semaphore.acquire()
|
||||
|
||||
# Now, wait for all the pending requests to complete.
|
||||
while len(self._pending_by_key) > 0:
|
||||
# TODO: There should be a better way to wait for all executor threads to be idle.
|
||||
# One option would be to just shutdown the existing executor and create a new one.
|
||||
time.sleep(_PARTITION_EXECUTOR_FLUSH_SLEEP_INTERVAL)
|
||||
|
||||
# Now allow new requests to be submitted.
|
||||
# TODO: With Python 3.9, release() can take a count argument.
|
||||
for _i in range(self.max_pending):
|
||||
self._semaphore.release()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""See concurrent.futures.Executor#shutdown. Behaves as if wait=True."""
|
||||
|
||||
self.flush()
|
||||
assert len(self._pending_by_key) == 0
|
||||
|
||||
# Technically, the wait=True here is redundant, since all the threads should
|
||||
# be idle now.
|
||||
self._executor.shutdown(wait=True)
|
||||
|
||||
def close(self) -> None:
|
||||
self.shutdown()
|
||||
|
||||
|
||||
class BackpressureAwareExecutor:
|
||||
# This couldn't be a real executor because the semantics of submit wouldn't really make sense.
|
||||
# In this variant, if we blocked on submit, then we would also be blocking the thread that
|
||||
# we expect to be consuming the results. As such, I made it accept the full list of args
|
||||
# up front, and that way the consumer can read results at its own pace.
|
||||
|
||||
@classmethod
|
||||
def map(
|
||||
cls,
|
||||
fn: Callable[..., _R],
|
||||
args_list: Iterable[Tuple[Any, ...]],
|
||||
max_workers: int,
|
||||
max_pending: Optional[int] = None,
|
||||
) -> Iterator[Future[_R]]:
|
||||
"""Similar to concurrent.futures.ThreadPoolExecutor#map, except that it won't run ahead of the consumer.
|
||||
|
||||
The main benefit is that the ThreadPoolExecutor isn't stuck holding a ton of result
|
||||
objects in memory if the consumer is slow. Instead, the consumer can read the results
|
||||
at its own pace and the executor threads will idle if they need to.
|
||||
|
||||
Args:
|
||||
fn: The function to apply to each input.
|
||||
args_list: The list of inputs. In contrast to the builtin map, this is a list
|
||||
of tuples, where each tuple is the arguments to fn.
|
||||
max_workers: The maximum number of threads to use.
|
||||
max_pending: The maximum number of pending results to keep in memory.
|
||||
If not set, it will be set to 2*max_workers.
|
||||
|
||||
Returns:
|
||||
An iterable of futures.
|
||||
|
||||
This differs from a traditional map because it returns futures
|
||||
instead of the actual results, so that the caller is required
|
||||
to handle exceptions.
|
||||
|
||||
Additionally, it does not maintain the order of the arguments.
|
||||
If you want to know which result corresponds to which input,
|
||||
the mapped function should return some form of an identifier.
|
||||
"""
|
||||
|
||||
if max_pending is None:
|
||||
max_pending = 2 * max_workers
|
||||
assert max_pending >= max_workers
|
||||
|
||||
pending_futures: Set[Future] = set()
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
for args in args_list:
|
||||
# If the pending list is full, wait until one is done.
|
||||
if len(pending_futures) >= max_pending:
|
||||
(done, _) = concurrent.futures.wait(
|
||||
pending_futures, return_when=concurrent.futures.FIRST_COMPLETED
|
||||
)
|
||||
for future in done:
|
||||
pending_futures.remove(future)
|
||||
|
||||
# We don't want to call result() here because we want the caller
|
||||
# to handle exceptions/cancellation.
|
||||
yield future
|
||||
|
||||
# Now that there's space in the pending list, enqueue the next task.
|
||||
pending_futures.add(executor.submit(fn, *args))
|
||||
|
||||
# Wait for all the remaining tasks to complete.
|
||||
for future in concurrent.futures.as_completed(pending_futures):
|
||||
pending_futures.remove(future)
|
||||
yield future
|
||||
|
||||
assert not pending_futures
|
||||
@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from typing import Any, Callable, Iterable, Iterator, Optional, Set, Tuple, TypeVar
|
||||
|
||||
_R = TypeVar("_R")
|
||||
|
||||
|
||||
class BackpressureAwareExecutor:
|
||||
# This couldn't be a real executor because the semantics of submit wouldn't really make sense.
|
||||
# In this variant, if we blocked on submit, then we would also be blocking the thread that
|
||||
# we expect to be consuming the results. As such, I made it accept the full list of args
|
||||
# up front, and that way the consumer can read results at its own pace.
|
||||
|
||||
@classmethod
|
||||
def map(
|
||||
cls,
|
||||
fn: Callable[..., _R],
|
||||
args_list: Iterable[Tuple[Any, ...]],
|
||||
max_workers: int,
|
||||
max_pending: Optional[int] = None,
|
||||
) -> Iterator[Future[_R]]:
|
||||
"""Similar to concurrent.futures.ThreadPoolExecutor#map, except that it won't run ahead of the consumer.
|
||||
|
||||
The main benefit is that the ThreadPoolExecutor isn't stuck holding a ton of result
|
||||
objects in memory if the consumer is slow. Instead, the consumer can read the results
|
||||
at its own pace and the executor threads will idle if they need to.
|
||||
|
||||
Args:
|
||||
fn: The function to apply to each input.
|
||||
args_list: The list of inputs. In contrast to the builtin map, this is a list
|
||||
of tuples, where each tuple is the arguments to fn.
|
||||
max_workers: The maximum number of threads to use.
|
||||
max_pending: The maximum number of pending results to keep in memory.
|
||||
If not set, it will be set to 2*max_workers.
|
||||
|
||||
Returns:
|
||||
An iterable of futures.
|
||||
|
||||
This differs from a traditional map because it returns futures
|
||||
instead of the actual results, so that the caller is required
|
||||
to handle exceptions.
|
||||
|
||||
Additionally, it does not maintain the order of the arguments.
|
||||
If you want to know which result corresponds to which input,
|
||||
the mapped function should return some form of an identifier.
|
||||
"""
|
||||
|
||||
if max_pending is None:
|
||||
max_pending = 2 * max_workers
|
||||
assert max_pending >= max_workers
|
||||
|
||||
pending_futures: Set[Future] = set()
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
for args in args_list:
|
||||
# If the pending list is full, wait until one is done.
|
||||
if len(pending_futures) >= max_pending:
|
||||
(done, _) = concurrent.futures.wait(
|
||||
pending_futures, return_when=concurrent.futures.FIRST_COMPLETED
|
||||
)
|
||||
for future in done:
|
||||
pending_futures.remove(future)
|
||||
|
||||
# We don't want to call result() here because we want the caller
|
||||
# to handle exceptions/cancellation.
|
||||
yield future
|
||||
|
||||
# Now that there's space in the pending list, enqueue the next task.
|
||||
pending_futures.add(executor.submit(fn, *args))
|
||||
|
||||
# Wait for all the remaining tasks to complete.
|
||||
for future in concurrent.futures.as_completed(pending_futures):
|
||||
pending_futures.remove(future)
|
||||
yield future
|
||||
|
||||
assert not pending_futures
|
||||
404
metadata-ingestion/src/datahub/utilities/partition_executor.py
Normal file
404
metadata-ingestion/src/datahub/utilities/partition_executor.py
Normal file
@ -0,0 +1,404 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from threading import BoundedSemaphore
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Deque,
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from datahub.ingestion.api.closeable import Closeable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_R = TypeVar("_R")
|
||||
_Args = TypeVar("_Args", bound=tuple)
|
||||
_PARTITION_EXECUTOR_FLUSH_SLEEP_INTERVAL = 0.05
|
||||
_DEFAULT_BATCHER_MIN_PROCESS_INTERVAL = timedelta(seconds=30)
|
||||
|
||||
|
||||
class PartitionExecutor(Closeable):
|
||||
def __init__(self, max_workers: int, max_pending: int) -> None:
|
||||
"""A thread pool executor with partitioning and a pending request bound.
|
||||
|
||||
It works similarly to a ThreadPoolExecutor, with the following changes:
|
||||
- At most one request per partition key will be executing at a time.
|
||||
- If the number of pending requests exceeds the threshold, the submit() call
|
||||
will block until the number of pending requests drops below the threshold.
|
||||
|
||||
Due to the interaction between max_workers and max_pending, it is possible
|
||||
for execution to effectively be serialized when there's a large influx of
|
||||
requests with the same key. This can be mitigated by setting a reasonably
|
||||
large max_pending value.
|
||||
|
||||
Args:
|
||||
max_workers: The maximum number of threads to use for executing requests.
|
||||
max_pending: The maximum number of pending (e.g. non-executing) requests to allow.
|
||||
"""
|
||||
self.max_workers = max_workers
|
||||
self.max_pending = max_pending
|
||||
|
||||
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
|
||||
# Each pending or executing request will acquire a permit from this semaphore.
|
||||
self._semaphore = BoundedSemaphore(max_pending + max_workers)
|
||||
|
||||
# A key existing in this dict means that there is a submitted request for that key.
|
||||
# Any entries in the key's value e.g. the deque are requests that are waiting
|
||||
# to be submitted once the current request for that key completes.
|
||||
self._pending_by_key: Dict[
|
||||
str, Deque[Tuple[Callable, tuple, dict, Optional[Callable[[Future], None]]]]
|
||||
] = {}
|
||||
|
||||
def submit(
|
||||
self,
|
||||
key: str,
|
||||
fn: Callable[..., _R],
|
||||
*args: Any,
|
||||
# Ideally, we would've used ParamSpec to annotate this method. However,
|
||||
# due to the limitations of PEP 612, we can't add a keyword argument here.
|
||||
# See https://peps.python.org/pep-0612/#concatenating-keyword-parameters
|
||||
# As such, we're using Any here, and won't validate the args to this method.
|
||||
# We might be able to work around it by moving the done_callback arg to be before
|
||||
# the *args, but that would mean making done_callback a required arg instead of
|
||||
# optional as it is now.
|
||||
done_callback: Optional[Callable[[Future], None]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""See concurrent.futures.Executor#submit"""
|
||||
|
||||
self._semaphore.acquire()
|
||||
|
||||
if key in self._pending_by_key:
|
||||
self._pending_by_key[key].append((fn, args, kwargs, done_callback))
|
||||
|
||||
else:
|
||||
self._pending_by_key[key] = collections.deque()
|
||||
self._submit_nowait(key, fn, args, kwargs, done_callback=done_callback)
|
||||
|
||||
def _submit_nowait(
|
||||
self,
|
||||
key: str,
|
||||
fn: Callable[..., _R],
|
||||
args: tuple,
|
||||
kwargs: dict,
|
||||
done_callback: Optional[Callable[[Future], None]],
|
||||
) -> Future:
|
||||
future = self._executor.submit(fn, *args, **kwargs)
|
||||
|
||||
def _system_done_callback(future: Future) -> None:
|
||||
self._semaphore.release()
|
||||
|
||||
# If there is another pending request for this key, submit it now.
|
||||
# The key must exist in the map.
|
||||
if self._pending_by_key[key]:
|
||||
fn, args, kwargs, user_done_callback = self._pending_by_key[
|
||||
key
|
||||
].popleft()
|
||||
|
||||
try:
|
||||
self._submit_nowait(key, fn, args, kwargs, user_done_callback)
|
||||
except RuntimeError as e:
|
||||
if self._executor._shutdown:
|
||||
# If we're in shutdown mode, then we can't submit any more requests.
|
||||
# That means we'll need to drop requests on the floor, which is to
|
||||
# be expected in shutdown mode.
|
||||
# The only reason we'd normally be in shutdown here is during
|
||||
# Python exit (e.g. KeyboardInterrupt), so this is reasonable.
|
||||
logger.debug("Dropping request due to shutdown")
|
||||
else:
|
||||
raise e
|
||||
|
||||
else:
|
||||
# If there are no pending requests for this key, mark the key
|
||||
# as no longer in progress.
|
||||
del self._pending_by_key[key]
|
||||
|
||||
if done_callback:
|
||||
future.add_done_callback(done_callback)
|
||||
future.add_done_callback(_system_done_callback)
|
||||
return future
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Wait for all pending requests to complete."""
|
||||
|
||||
# Acquire all the semaphore permits so that no more requests can be submitted.
|
||||
for _i in range(self.max_pending):
|
||||
self._semaphore.acquire()
|
||||
|
||||
# Now, wait for all the pending requests to complete.
|
||||
while len(self._pending_by_key) > 0:
|
||||
# TODO: There should be a better way to wait for all executor threads to be idle.
|
||||
# One option would be to just shutdown the existing executor and create a new one.
|
||||
time.sleep(_PARTITION_EXECUTOR_FLUSH_SLEEP_INTERVAL)
|
||||
|
||||
# Now allow new requests to be submitted.
|
||||
# TODO: With Python 3.9, release() can take a count argument.
|
||||
for _i in range(self.max_pending):
|
||||
self._semaphore.release()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""See concurrent.futures.Executor#shutdown. Behaves as if wait=True."""
|
||||
|
||||
self.flush()
|
||||
assert len(self._pending_by_key) == 0
|
||||
|
||||
self._executor.shutdown(wait=True)
|
||||
|
||||
def close(self) -> None:
|
||||
self.shutdown()
|
||||
|
||||
|
||||
class _BatchPartitionWorkItem(NamedTuple):
|
||||
key: str
|
||||
args: tuple
|
||||
done_callback: Optional[Callable[[Future], None]]
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(tz=timezone.utc)
|
||||
|
||||
|
||||
class BatchPartitionExecutor(Closeable):
|
||||
def __init__(
|
||||
self,
|
||||
max_workers: int,
|
||||
max_pending: int,
|
||||
# Due to limitations of Python's typing, we can't express the type of the list
|
||||
# effectively. Ideally we'd use ParamSpec here, but that's not allowed in a
|
||||
# class context like this.
|
||||
process_batch: Callable[[List], None],
|
||||
max_per_batch: int = 100,
|
||||
min_process_interval: timedelta = _DEFAULT_BATCHER_MIN_PROCESS_INTERVAL,
|
||||
) -> None:
|
||||
"""Similar to PartitionExecutor, but with batching.
|
||||
|
||||
This takes in the stream of requests, automatically segments them into partition-aware
|
||||
batches, and schedules them across a pool of worker threads.
|
||||
|
||||
It maintains the invariant that multiple requests with the same key will not be in
|
||||
flight concurrently, except when part of the same batch. Requests for a given key
|
||||
will also be executed in the order they were submitted.
|
||||
|
||||
Unlike the PartitionExecutor, this does not support return values or kwargs.
|
||||
|
||||
Args:
|
||||
max_workers: The maximum number of threads to use for executing requests.
|
||||
max_pending: The maximum number of pending (e.g. non-executing) requests to allow.
|
||||
max_per_batch: The maximum number of requests to include in a batch.
|
||||
min_process_interval: When requests are coming in slowly, we will wait at least this long
|
||||
before submitting a non-full batch.
|
||||
process_batch: A function that takes in a list of argument tuples.
|
||||
"""
|
||||
self.max_workers = max_workers
|
||||
self.max_pending = max_pending
|
||||
self.max_per_batch = max_per_batch
|
||||
self.process_batch = process_batch
|
||||
self.min_process_interval = min_process_interval
|
||||
assert self.max_workers > 1
|
||||
|
||||
# We add one here to account for the clearinghouse worker thread.
|
||||
self._executor = ThreadPoolExecutor(max_workers=max_workers + 1)
|
||||
self._clearinghouse_started = False
|
||||
|
||||
self._pending_count = BoundedSemaphore(max_pending)
|
||||
self._pending: "queue.Queue[Optional[_BatchPartitionWorkItem]]" = queue.Queue(
|
||||
maxsize=max_pending
|
||||
)
|
||||
|
||||
# If this is true, that means shutdown() has been called and self._pending is empty.
|
||||
self._queue_empty_for_shutdown = False
|
||||
|
||||
def _clearinghouse_worker(self) -> None: # noqa: C901
|
||||
# This worker will pull items off the queue, and submit them into the executor
|
||||
# in batches. Only this worker will submit process commands to the executor thread pool.
|
||||
|
||||
# The lock protects the function's internal state.
|
||||
clearinghouse_state_lock = threading.Lock()
|
||||
workers_available = self.max_workers
|
||||
keys_in_flight: Set[str] = set()
|
||||
keys_no_longer_in_flight: Set[str] = set()
|
||||
pending_key_completion: List[_BatchPartitionWorkItem] = []
|
||||
|
||||
last_submit_time = _now()
|
||||
|
||||
def _handle_batch_completion(
|
||||
batch: List[_BatchPartitionWorkItem], future: Future
|
||||
) -> None:
|
||||
with clearinghouse_state_lock:
|
||||
for item in batch:
|
||||
keys_no_longer_in_flight.add(item.key)
|
||||
self._pending_count.release()
|
||||
|
||||
# Separate from the above loop to avoid holding the lock while calling the callbacks.
|
||||
for item in batch:
|
||||
if item.done_callback:
|
||||
item.done_callback(future)
|
||||
|
||||
def _find_ready_items() -> List[_BatchPartitionWorkItem]:
|
||||
with clearinghouse_state_lock:
|
||||
# First, update the keys in flight.
|
||||
for key in keys_no_longer_in_flight:
|
||||
keys_in_flight.remove(key)
|
||||
keys_no_longer_in_flight.clear()
|
||||
|
||||
# Then, update the pending key completion and build the ready list.
|
||||
pending = pending_key_completion.copy()
|
||||
pending_key_completion.clear()
|
||||
|
||||
ready: List[_BatchPartitionWorkItem] = []
|
||||
for item in pending:
|
||||
if (
|
||||
len(ready) < self.max_per_batch
|
||||
and item.key not in keys_in_flight
|
||||
):
|
||||
ready.append(item)
|
||||
else:
|
||||
pending_key_completion.append(item)
|
||||
|
||||
return ready
|
||||
|
||||
def _build_batch() -> List[_BatchPartitionWorkItem]:
|
||||
next_batch = _find_ready_items()
|
||||
|
||||
while (
|
||||
not self._queue_empty_for_shutdown
|
||||
and len(next_batch) < self.max_per_batch
|
||||
):
|
||||
blocking = True
|
||||
if (
|
||||
next_batch
|
||||
and _now() - last_submit_time > self.min_process_interval
|
||||
and workers_available > 0
|
||||
):
|
||||
# If we're past the submit deadline, pull from the queue
|
||||
# in a non-blocking way, and submit the batch once the queue
|
||||
# is empty.
|
||||
blocking = False
|
||||
|
||||
try:
|
||||
next_item: Optional[_BatchPartitionWorkItem] = self._pending.get(
|
||||
block=blocking,
|
||||
timeout=self.min_process_interval.total_seconds(),
|
||||
)
|
||||
if next_item is None:
|
||||
self._queue_empty_for_shutdown = True
|
||||
break
|
||||
|
||||
with clearinghouse_state_lock:
|
||||
if next_item.key in keys_in_flight:
|
||||
pending_key_completion.append(next_item)
|
||||
else:
|
||||
next_batch.append(next_item)
|
||||
except queue.Empty:
|
||||
if not blocking:
|
||||
break
|
||||
|
||||
return next_batch
|
||||
|
||||
def _submit_batch(next_batch: List[_BatchPartitionWorkItem]) -> None:
|
||||
with clearinghouse_state_lock:
|
||||
for item in next_batch:
|
||||
keys_in_flight.add(item.key)
|
||||
|
||||
nonlocal workers_available
|
||||
workers_available -= 1
|
||||
|
||||
nonlocal last_submit_time
|
||||
last_submit_time = _now()
|
||||
|
||||
future = self._executor.submit(
|
||||
self.process_batch, [item.args for item in next_batch]
|
||||
)
|
||||
future.add_done_callback(
|
||||
functools.partial(_handle_batch_completion, next_batch)
|
||||
)
|
||||
|
||||
try:
|
||||
# Normal operation - submit batches as they become available.
|
||||
while not self._queue_empty_for_shutdown:
|
||||
next_batch = _build_batch()
|
||||
if next_batch:
|
||||
_submit_batch(next_batch)
|
||||
|
||||
# Shutdown time.
|
||||
# Invariant - at this point, we know self._pending is empty.
|
||||
# We just need to wait for the in-flight items to complete,
|
||||
# and submit any currently pending items once possible.
|
||||
while pending_key_completion:
|
||||
next_batch = _build_batch()
|
||||
if next_batch:
|
||||
_submit_batch(next_batch)
|
||||
time.sleep(_PARTITION_EXECUTOR_FLUSH_SLEEP_INTERVAL)
|
||||
|
||||
# At this point, there are no more things to submit.
|
||||
# We could wait for the in-flight items to complete,
|
||||
# but the executor will take care of waiting for them to complete.
|
||||
except Exception as e:
|
||||
# This represents a fatal error that makes the entire executor defunct.
|
||||
logger.exception(
|
||||
"Threaded executor's clearinghouse worker failed.", exc_info=e
|
||||
)
|
||||
finally:
|
||||
self._clearinghouse_started = False
|
||||
|
||||
def _ensure_clearinghouse_started(self) -> None:
|
||||
# Lazily start the clearinghouse worker.
|
||||
if not self._clearinghouse_started:
|
||||
self._clearinghouse_started = True
|
||||
self._executor.submit(self._clearinghouse_worker)
|
||||
|
||||
def submit(
|
||||
self,
|
||||
key: str,
|
||||
*args: Any,
|
||||
done_callback: Optional[Callable[[Future], None]] = None,
|
||||
) -> None:
|
||||
"""See concurrent.futures.Executor#submit"""
|
||||
|
||||
self._ensure_clearinghouse_started()
|
||||
|
||||
self._pending_count.acquire()
|
||||
self._pending.put(_BatchPartitionWorkItem(key, args, done_callback))
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if not self._clearinghouse_started:
|
||||
# This is required to make shutdown() idempotent, which is important
|
||||
# when it's called explicitly and then also by a context manager.
|
||||
logger.debug("Shutting down: clearinghouse not started")
|
||||
return
|
||||
|
||||
logger.debug(f"Shutting down {self.__class__.__name__}")
|
||||
|
||||
# Send the shutdown signal.
|
||||
self._pending.put(None)
|
||||
|
||||
# By acquiring all the permits, we ensure that no more tasks will be scheduled
|
||||
# and automatically wait until all existing tasks have completed.
|
||||
for _ in range(self.max_pending):
|
||||
self._pending_count.acquire()
|
||||
|
||||
# We must wait for the clearinghouse worker to exit before calling shutdown
|
||||
# on the thread pool. Without this, the clearinghouse worker might fail to
|
||||
# enqueue pending tasks into the pool.
|
||||
while self._clearinghouse_started:
|
||||
time.sleep(_PARTITION_EXECUTOR_FLUSH_SLEEP_INTERVAL)
|
||||
|
||||
self._executor.shutdown(wait=False)
|
||||
|
||||
def close(self) -> None:
|
||||
self.shutdown()
|
||||
@ -1,128 +0,0 @@
|
||||
import time
|
||||
from concurrent.futures import Future
|
||||
|
||||
from datahub.utilities.advanced_thread_executor import (
|
||||
BackpressureAwareExecutor,
|
||||
PartitionExecutor,
|
||||
)
|
||||
from datahub.utilities.perf_timer import PerfTimer
|
||||
|
||||
|
||||
def test_partitioned_executor():
|
||||
executing_tasks = set()
|
||||
done_tasks = set()
|
||||
|
||||
def task(key: str, id: str) -> None:
|
||||
executing_tasks.add((key, id))
|
||||
time.sleep(0.8)
|
||||
done_tasks.add(id)
|
||||
executing_tasks.remove((key, id))
|
||||
|
||||
with PartitionExecutor(max_workers=2, max_pending=10) as executor:
|
||||
# Submit tasks with the same key. They should be executed sequentially.
|
||||
executor.submit("key1", task, "key1", "task1")
|
||||
executor.submit("key1", task, "key1", "task2")
|
||||
executor.submit("key1", task, "key1", "task3")
|
||||
|
||||
# Submit a task with a different key. It should be executed in parallel.
|
||||
executor.submit("key2", task, "key2", "task4")
|
||||
|
||||
saw_keys_in_parallel = False
|
||||
while executing_tasks or not done_tasks:
|
||||
keys_executing = [key for key, _ in executing_tasks]
|
||||
assert list(sorted(keys_executing)) == list(
|
||||
sorted(set(keys_executing))
|
||||
), "partitioning not working"
|
||||
|
||||
if len(keys_executing) == 2:
|
||||
saw_keys_in_parallel = True
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
executor.flush()
|
||||
assert saw_keys_in_parallel
|
||||
assert not executing_tasks
|
||||
assert done_tasks == {"task1", "task2", "task3", "task4"}
|
||||
|
||||
|
||||
def test_partitioned_executor_bounding():
|
||||
task_duration = 0.5
|
||||
done_tasks = set()
|
||||
|
||||
def on_done(future: Future) -> None:
|
||||
done_tasks.add(future.result())
|
||||
|
||||
def task(id: str) -> str:
|
||||
time.sleep(task_duration)
|
||||
return id
|
||||
|
||||
with PartitionExecutor(
|
||||
max_workers=5, max_pending=10
|
||||
) as executor, PerfTimer() as timer:
|
||||
# The first 15 submits should be non-blocking.
|
||||
for i in range(15):
|
||||
executor.submit(f"key{i}", task, f"task{i}", done_callback=on_done)
|
||||
assert timer.elapsed_seconds() < task_duration
|
||||
|
||||
# This submit should block.
|
||||
executor.submit("key-blocking", task, "task-blocking", done_callback=on_done)
|
||||
assert timer.elapsed_seconds() > task_duration
|
||||
|
||||
# Wait for everything to finish.
|
||||
executor.flush()
|
||||
assert len(done_tasks) == 16
|
||||
|
||||
|
||||
def test_backpressure_aware_executor_simple():
|
||||
def task(i):
|
||||
return i
|
||||
|
||||
assert {
|
||||
res.result()
|
||||
for res in BackpressureAwareExecutor.map(
|
||||
task, ((i,) for i in range(10)), max_workers=2
|
||||
)
|
||||
} == set(range(10))
|
||||
|
||||
|
||||
def test_backpressure_aware_executor_advanced():
|
||||
task_duration = 0.5
|
||||
started = set()
|
||||
executed = set()
|
||||
|
||||
def task(x, y):
|
||||
assert x + 1 == y
|
||||
started.add(x)
|
||||
time.sleep(task_duration)
|
||||
executed.add(x)
|
||||
return x
|
||||
|
||||
args_list = [(i, i + 1) for i in range(10)]
|
||||
|
||||
with PerfTimer() as timer:
|
||||
results = BackpressureAwareExecutor.map(
|
||||
task, args_list, max_workers=2, max_pending=4
|
||||
)
|
||||
assert timer.elapsed_seconds() < task_duration
|
||||
|
||||
# No tasks should have completed yet.
|
||||
assert len(executed) == 0
|
||||
|
||||
# Consume the first result.
|
||||
first_result = next(results)
|
||||
assert 0 <= first_result.result() < 4
|
||||
assert timer.elapsed_seconds() > task_duration
|
||||
|
||||
# By now, the first four tasks should have started.
|
||||
time.sleep(task_duration)
|
||||
assert {0, 1, 2, 3}.issubset(started)
|
||||
assert 2 <= len(executed) <= 4
|
||||
|
||||
# Finally, consume the rest of the results.
|
||||
assert {r.result() for r in results} == {
|
||||
i for i in range(10) if i != first_result.result()
|
||||
}
|
||||
|
||||
# Validate that the entire process took about 5-10x the task duration.
|
||||
# That's because we have 2 workers and 10 tasks.
|
||||
assert 5 * task_duration < timer.elapsed_seconds() < 10 * task_duration
|
||||
@ -0,0 +1,59 @@
|
||||
import time
|
||||
|
||||
from datahub.utilities.backpressure_aware_executor import BackpressureAwareExecutor
|
||||
from datahub.utilities.perf_timer import PerfTimer
|
||||
|
||||
|
||||
def test_backpressure_aware_executor_simple():
|
||||
def task(i):
|
||||
return i
|
||||
|
||||
assert {
|
||||
res.result()
|
||||
for res in BackpressureAwareExecutor.map(
|
||||
task, ((i,) for i in range(10)), max_workers=2
|
||||
)
|
||||
} == set(range(10))
|
||||
|
||||
|
||||
def test_backpressure_aware_executor_advanced():
|
||||
task_duration = 0.5
|
||||
started = set()
|
||||
executed = set()
|
||||
|
||||
def task(x, y):
|
||||
assert x + 1 == y
|
||||
started.add(x)
|
||||
time.sleep(task_duration)
|
||||
executed.add(x)
|
||||
return x
|
||||
|
||||
args_list = [(i, i + 1) for i in range(10)]
|
||||
|
||||
with PerfTimer() as timer:
|
||||
results = BackpressureAwareExecutor.map(
|
||||
task, args_list, max_workers=2, max_pending=4
|
||||
)
|
||||
assert timer.elapsed_seconds() < task_duration
|
||||
|
||||
# No tasks should have completed yet.
|
||||
assert len(executed) == 0
|
||||
|
||||
# Consume the first result.
|
||||
first_result = next(results)
|
||||
assert 0 <= first_result.result() < 4
|
||||
assert timer.elapsed_seconds() > task_duration
|
||||
|
||||
# By now, the first four tasks should have started.
|
||||
time.sleep(task_duration)
|
||||
assert {0, 1, 2, 3}.issubset(started)
|
||||
assert 2 <= len(executed) <= 4
|
||||
|
||||
# Finally, consume the rest of the results.
|
||||
assert {r.result() for r in results} == {
|
||||
i for i in range(10) if i != first_result.result()
|
||||
}
|
||||
|
||||
# Validate that the entire process took about 5-10x the task duration.
|
||||
# That's because we have 2 workers and 10 tasks.
|
||||
assert 5 * task_duration < timer.elapsed_seconds() < 10 * task_duration
|
||||
@ -0,0 +1,150 @@
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import Future
|
||||
|
||||
from datahub.utilities.partition_executor import (
|
||||
BatchPartitionExecutor,
|
||||
PartitionExecutor,
|
||||
)
|
||||
from datahub.utilities.perf_timer import PerfTimer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def test_partitioned_executor():
|
||||
executing_tasks = set()
|
||||
done_tasks = set()
|
||||
|
||||
def task(key: str, id: str) -> None:
|
||||
executing_tasks.add((key, id))
|
||||
time.sleep(0.8)
|
||||
done_tasks.add(id)
|
||||
executing_tasks.remove((key, id))
|
||||
|
||||
with PartitionExecutor(max_workers=2, max_pending=10) as executor:
|
||||
# Submit tasks with the same key. They should be executed sequentially.
|
||||
executor.submit("key1", task, "key1", "task1")
|
||||
executor.submit("key1", task, "key1", "task2")
|
||||
executor.submit("key1", task, "key1", "task3")
|
||||
|
||||
# Submit a task with a different key. It should be executed in parallel.
|
||||
executor.submit("key2", task, "key2", "task4")
|
||||
|
||||
saw_keys_in_parallel = False
|
||||
while executing_tasks or not done_tasks:
|
||||
keys_executing = [key for key, _ in executing_tasks]
|
||||
assert list(sorted(keys_executing)) == list(
|
||||
sorted(set(keys_executing))
|
||||
), "partitioning not working"
|
||||
|
||||
if len(keys_executing) == 2:
|
||||
saw_keys_in_parallel = True
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
executor.flush()
|
||||
assert saw_keys_in_parallel
|
||||
assert not executing_tasks
|
||||
assert done_tasks == {"task1", "task2", "task3", "task4"}
|
||||
|
||||
|
||||
def test_partitioned_executor_bounding():
|
||||
task_duration = 0.5
|
||||
done_tasks = set()
|
||||
|
||||
def on_done(future: Future) -> None:
|
||||
done_tasks.add(future.result())
|
||||
|
||||
def task(id: str) -> str:
|
||||
time.sleep(task_duration)
|
||||
return id
|
||||
|
||||
with PartitionExecutor(
|
||||
max_workers=5, max_pending=10
|
||||
) as executor, PerfTimer() as timer:
|
||||
# The first 15 submits should be non-blocking.
|
||||
for i in range(15):
|
||||
executor.submit(f"key{i}", task, f"task{i}", done_callback=on_done)
|
||||
assert timer.elapsed_seconds() < task_duration
|
||||
|
||||
# This submit should block.
|
||||
executor.submit("key-blocking", task, "task-blocking", done_callback=on_done)
|
||||
assert timer.elapsed_seconds() > task_duration
|
||||
|
||||
# Wait for everything to finish.
|
||||
executor.flush()
|
||||
assert len(done_tasks) == 16
|
||||
|
||||
|
||||
def test_batch_partition_executor_sequential_key_execution():
|
||||
executing_tasks = set()
|
||||
done_tasks = set()
|
||||
done_task_batches = set()
|
||||
|
||||
def process_batch(batch):
|
||||
for key, id in batch:
|
||||
assert (key, id) not in executing_tasks, "Task is already executing"
|
||||
executing_tasks.add((key, id))
|
||||
|
||||
time.sleep(0.5) # Simulate work
|
||||
|
||||
for key, id in batch:
|
||||
executing_tasks.remove((key, id))
|
||||
done_tasks.add(id)
|
||||
|
||||
done_task_batches.add(tuple(id for _, id in batch))
|
||||
|
||||
with BatchPartitionExecutor(
|
||||
max_workers=2,
|
||||
max_pending=10,
|
||||
max_per_batch=2,
|
||||
process_batch=process_batch,
|
||||
) as executor:
|
||||
# Submit tasks with the same key. The first two should get batched together.
|
||||
executor.submit("key1", "key1", "task1")
|
||||
executor.submit("key1", "key1", "task2")
|
||||
executor.submit("key1", "key1", "task3")
|
||||
|
||||
# Submit tasks with a different key. These should get their own batch.
|
||||
executor.submit("key2", "key2", "task4")
|
||||
executor.submit("key2", "key2", "task5")
|
||||
|
||||
# Test idempotency of shutdown().
|
||||
executor.shutdown()
|
||||
|
||||
# Check if all tasks were executed and completed.
|
||||
assert done_tasks == {
|
||||
"task1",
|
||||
"task2",
|
||||
"task3",
|
||||
"task4",
|
||||
"task5",
|
||||
}, "Not all tasks completed"
|
||||
|
||||
# Check the batching configuration.
|
||||
assert done_task_batches == {
|
||||
("task1", "task2"),
|
||||
("task4", "task5"),
|
||||
("task3",),
|
||||
}
|
||||
|
||||
|
||||
def test_batch_partition_executor_max_batch_size():
|
||||
batches_processed = []
|
||||
|
||||
def process_batch(batch):
|
||||
batches_processed.append(batch)
|
||||
time.sleep(0.1) # Simulate batch processing time
|
||||
|
||||
with BatchPartitionExecutor(
|
||||
max_workers=5, max_pending=20, process_batch=process_batch, max_per_batch=2
|
||||
) as executor:
|
||||
# Submit more tasks than the max_per_batch to test batching limits.
|
||||
for i in range(5):
|
||||
executor.submit("key3", "key3", f"task{i}")
|
||||
|
||||
# Check the batches.
|
||||
logger.info(f"batches_processed: {batches_processed}")
|
||||
assert len(batches_processed) == 3
|
||||
for batch in batches_processed:
|
||||
assert len(batch) <= 2, "Batch size exceeded max_per_batch limit"
|
||||
Loading…
x
Reference in New Issue
Block a user