fix(ingest/partition-executor): Fix deadlock by recomputing ready items (#11853)

This commit is contained in:
Andrew Sikowitz 2024-11-13 23:48:30 -08:00 committed by GitHub
parent 383a70ac0a
commit 5ff6295b0f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 31 additions and 12 deletions

View File

@ -573,6 +573,7 @@ mypy_stubs = {
test_api_requirements = {
"pytest>=6.2.2",
"pytest-timeout",
# Missing numpy requirement in 8.0.0
"deepdiff!=8.0.0",
"PyYAML",

View File

@ -237,6 +237,11 @@ class BatchPartitionExecutor(Closeable):
process_batch: Callable[[List], None],
max_per_batch: int = 100,
min_process_interval: timedelta = _DEFAULT_BATCHER_MIN_PROCESS_INTERVAL,
# Why 3 seconds? It's somewhat arbitrary.
# We don't want it to be too high, since then liveness suffers,
# particularly during a dirty shutdown. If it's too low, then we'll
# waste CPU cycles rechecking the timer, only to call get again.
read_from_pending_interval: timedelta = timedelta(seconds=3),
) -> None:
"""Similar to PartitionExecutor, but with batching.
@ -262,8 +267,10 @@ class BatchPartitionExecutor(Closeable):
self.max_per_batch = max_per_batch
self.process_batch = process_batch
self.min_process_interval = min_process_interval
self.read_from_pending_interval = read_from_pending_interval
assert self.max_workers > 1
self.state_lock = threading.Lock()
self._executor = ThreadPoolExecutor(
# We add one here to account for the clearinghouse worker thread.
max_workers=max_workers + 1,
@ -362,12 +369,8 @@ class BatchPartitionExecutor(Closeable):
if not blocking:
next_item = self._pending.get_nowait()
else:
# Why 3 seconds? It's somewhat arbitrary.
# We don't want it to be too high, since then liveness suffers,
# particularly during a dirty shutdown. If it's too low, then we'll
# waste CPU cycles rechecking the timer, only to call get again.
next_item = self._pending.get(
timeout=3, # seconds
timeout=self.read_from_pending_interval.total_seconds(),
)
if next_item is None: # None is the shutdown signal
@ -379,6 +382,9 @@ class BatchPartitionExecutor(Closeable):
pending_key_completion.append(next_item)
else:
next_batch.append(next_item)
if not next_batch:
next_batch = _find_ready_items()
except queue.Empty:
if not blocking:
break
@ -452,10 +458,11 @@ class BatchPartitionExecutor(Closeable):
f"{self.__class__.__name__} is shutting down; cannot submit new work items."
)
# Lazily start the clearinghouse worker.
if not self._clearinghouse_started:
self._clearinghouse_started = True
self._executor.submit(self._clearinghouse_worker)
with self.state_lock:
# Lazily start the clearinghouse worker.
if not self._clearinghouse_started:
self._clearinghouse_started = True
self._executor.submit(self._clearinghouse_worker)
def submit(
self,

View File

@ -1,7 +1,11 @@
import logging
import math
import time
from concurrent.futures import Future
import pytest
from pydantic.schema import timedelta
from datahub.utilities.partition_executor import (
BatchPartitionExecutor,
PartitionExecutor,
@ -129,7 +133,9 @@ def test_batch_partition_executor_sequential_key_execution():
}
@pytest.mark.timeout(10)
def test_batch_partition_executor_max_batch_size():
n = 20 # Exceed max_pending to test for deadlocks when max_pending exceeded
batches_processed = []
def process_batch(batch):
@ -137,15 +143,20 @@ def test_batch_partition_executor_max_batch_size():
time.sleep(0.1) # Simulate batch processing time
with BatchPartitionExecutor(
max_workers=5, max_pending=20, process_batch=process_batch, max_per_batch=2
max_workers=5,
max_pending=10,
process_batch=process_batch,
max_per_batch=2,
min_process_interval=timedelta(seconds=1),
read_from_pending_interval=timedelta(seconds=1),
) as executor:
# Submit more tasks than the max_per_batch to test batching limits.
for i in range(5):
for i in range(n):
executor.submit("key3", "key3", f"task{i}")
# Check the batches.
logger.info(f"batches_processed: {batches_processed}")
assert len(batches_processed) == 3
assert len(batches_processed) == math.ceil(n / 2), "Incorrect number of batches"
for batch in batches_processed:
assert len(batch) <= 2, "Batch size exceeded max_per_batch limit"