mirror of
https://github.com/allenai/olmocr.git
synced 2025-08-14 11:52:03 +00:00
Some crazy idea I had to simplify futures and memory limits
This commit is contained in:
parent
f6ac591fe9
commit
a1a4798ce7
116
pdelfin/cappedpool.py
Normal file
116
pdelfin/cappedpool.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
import concurrent.futures
|
||||||
|
import threading
|
||||||
|
import queue
|
||||||
|
|
||||||
|
class CappedFuture(concurrent.futures.Future):
|
||||||
|
def __init__(self, semaphore):
|
||||||
|
super().__init__()
|
||||||
|
self._semaphore = semaphore
|
||||||
|
self._result_retrieved = False
|
||||||
|
self._underlying_future = None
|
||||||
|
self._condition = threading.Condition()
|
||||||
|
|
||||||
|
def set_underlying_future(self, underlying_future):
|
||||||
|
with self._condition:
|
||||||
|
self._underlying_future = underlying_future
|
||||||
|
# Transfer the result when the underlying future completes
|
||||||
|
underlying_future.add_done_callback(self._transfer_result)
|
||||||
|
|
||||||
|
def _transfer_result(self, underlying_future):
|
||||||
|
if underlying_future.cancelled():
|
||||||
|
self.set_cancelled()
|
||||||
|
elif underlying_future.exception() is not None:
|
||||||
|
self.set_exception(underlying_future.exception())
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
result = underlying_future.result()
|
||||||
|
self.set_result(result)
|
||||||
|
except Exception as e:
|
||||||
|
self.set_exception(e)
|
||||||
|
|
||||||
|
def result(self, timeout=None):
|
||||||
|
res = super().result(timeout)
|
||||||
|
self._release_semaphore()
|
||||||
|
return res
|
||||||
|
|
||||||
|
def exception(self, timeout=None):
|
||||||
|
exc = super().exception(timeout)
|
||||||
|
self._release_semaphore()
|
||||||
|
return exc
|
||||||
|
|
||||||
|
def _release_semaphore(self):
|
||||||
|
if not self._result_retrieved:
|
||||||
|
self._result_retrieved = True
|
||||||
|
self._semaphore.release()
|
||||||
|
|
||||||
|
def cancel(self):
|
||||||
|
with self._condition:
|
||||||
|
if self._underlying_future is not None:
|
||||||
|
cancelled = self._underlying_future.cancel()
|
||||||
|
if cancelled:
|
||||||
|
super().cancel()
|
||||||
|
return cancelled
|
||||||
|
else:
|
||||||
|
# Task has not been submitted yet; cancel directly
|
||||||
|
return super().cancel()
|
||||||
|
|
||||||
|
def cancelled(self):
|
||||||
|
return super().cancelled()
|
||||||
|
|
||||||
|
def running(self):
|
||||||
|
with self._condition:
|
||||||
|
if self._underlying_future is not None:
|
||||||
|
return self._underlying_future.running()
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def done(self):
|
||||||
|
return super().done()
|
||||||
|
|
||||||
|
class CappedProcessPoolExecutor(concurrent.futures.Executor):
|
||||||
|
def __init__(self, max_unprocessed=100, max_workers=None):
|
||||||
|
self._max_unprocessed = max_unprocessed
|
||||||
|
self._semaphore = threading.BoundedSemaphore(max_unprocessed)
|
||||||
|
self._task_queue = queue.Queue()
|
||||||
|
self._shutdown = threading.Event()
|
||||||
|
self._shutdown_lock = threading.Lock()
|
||||||
|
self._executor = concurrent.futures.ProcessPoolExecutor(max_workers=max_workers)
|
||||||
|
self._worker_thread = threading.Thread(target=self._worker)
|
||||||
|
self._worker_thread.daemon = True
|
||||||
|
self._worker_thread.start()
|
||||||
|
|
||||||
|
def submit(self, fn, *args, **kwargs):
|
||||||
|
if self._shutdown.is_set():
|
||||||
|
raise RuntimeError('Cannot submit new tasks after shutdown')
|
||||||
|
# Create a CappedFuture to return to the user
|
||||||
|
user_future = CappedFuture(self._semaphore)
|
||||||
|
# Put the task in the queue
|
||||||
|
self._task_queue.put((user_future, fn, args, kwargs))
|
||||||
|
return user_future
|
||||||
|
|
||||||
|
def _worker(self):
|
||||||
|
while True:
|
||||||
|
if self._shutdown.is_set() and self._task_queue.empty():
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
user_future, fn, args, kwargs = self._task_queue.get(timeout=0.1)
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
self._semaphore.acquire()
|
||||||
|
if user_future.cancelled():
|
||||||
|
self._semaphore.release()
|
||||||
|
continue
|
||||||
|
# Submit the task to the underlying executor
|
||||||
|
try:
|
||||||
|
underlying_future = self._executor.submit(fn, *args, **kwargs)
|
||||||
|
user_future.set_underlying_future(underlying_future)
|
||||||
|
except Exception as e:
|
||||||
|
user_future.set_exception(e)
|
||||||
|
self._semaphore.release()
|
||||||
|
continue
|
||||||
|
|
||||||
|
def shutdown(self, wait=True):
|
||||||
|
with self._shutdown_lock:
|
||||||
|
self._shutdown.set()
|
||||||
|
self._worker_thread.join()
|
||||||
|
self._executor.shutdown(wait=wait)
|
@ -99,13 +99,16 @@ def sample_mm_requests_qwen2vl(
|
|||||||
return_tensors="np",
|
return_tensors="np",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# print(inputs)
|
||||||
|
|
||||||
tokens = inputs["input_ids"][0]
|
tokens = inputs["input_ids"][0]
|
||||||
prompt_len = len(tokens)
|
prompt_len = len(tokens)
|
||||||
|
|
||||||
result.append((TokensPrompt(
|
result.append((TokensPrompt(
|
||||||
dict(
|
dict(
|
||||||
prompt_token_ids=tokens,
|
prompt_token_ids=tokens,
|
||||||
multi_modal_data=dict(image=main_image),
|
multi_modal_data=dict(image=dict(image_embeds=torch.randn(1036, 3584), image_grid_thw=torch.tensor([[1, 74, 56]]))),
|
||||||
|
# multi_modal_data=dict(image=main_image)
|
||||||
)
|
)
|
||||||
), prompt_len, fixed_output_len))
|
), prompt_len, fixed_output_len))
|
||||||
|
|
||||||
@ -467,7 +470,7 @@ def main(args: argparse.Namespace):
|
|||||||
else:
|
else:
|
||||||
# requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
|
# requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
|
||||||
# args.output_len)
|
# args.output_len)
|
||||||
requests = sample_mm_requests_molmo(args.dataset, args.num_prompts, tokenizer,
|
requests = sample_mm_requests_qwen2vl(args.dataset, args.num_prompts, tokenizer,
|
||||||
args.output_len)
|
args.output_len)
|
||||||
|
|
||||||
if args.backend == "vllm":
|
if args.backend == "vllm":
|
||||||
|
99
tests/test_cappedpool.py
Normal file
99
tests/test_cappedpool.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
import unittest
|
||||||
|
import time
|
||||||
|
import concurrent.futures
|
||||||
|
from concurrent.futures import TimeoutError
|
||||||
|
|
||||||
|
# Assuming the CappedProcessPoolExecutor code is in a module named 'capped_executor'
|
||||||
|
from pdelfin.cappedpool import CappedProcessPoolExecutor
|
||||||
|
|
||||||
|
# Define functions at the top level to ensure they are picklable by multiprocessing
|
||||||
|
|
||||||
|
def square(x):
|
||||||
|
return x * x
|
||||||
|
|
||||||
|
def raise_exception():
|
||||||
|
raise ValueError("Test exception")
|
||||||
|
|
||||||
|
def sleep_and_return(x, sleep_time):
|
||||||
|
time.sleep(sleep_time)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def task(counter, max_counter, counter_lock):
|
||||||
|
with counter_lock:
|
||||||
|
counter.value += 1
|
||||||
|
print(f"Task incrementing counter to {counter.value}")
|
||||||
|
if counter.value > max_counter.value:
|
||||||
|
max_counter.value = counter.value
|
||||||
|
time.sleep(0.5)
|
||||||
|
with counter_lock:
|
||||||
|
counter.value -= 1
|
||||||
|
return True
|
||||||
|
|
||||||
|
class TestCappedProcessPoolExecutor(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_basic_functionality(self):
|
||||||
|
"""Test that tasks are executed and results are correct."""
|
||||||
|
with CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4) as executor:
|
||||||
|
futures = [executor.submit(square, i) for i in range(10)]
|
||||||
|
results = [f.result() for f in futures]
|
||||||
|
expected = [i * i for i in range(10)]
|
||||||
|
self.assertEqual(results, expected)
|
||||||
|
|
||||||
|
def test_exception_handling(self):
|
||||||
|
"""Test that exceptions in tasks are properly raised."""
|
||||||
|
with CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4) as executor:
|
||||||
|
future = executor.submit(raise_exception)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
future.result()
|
||||||
|
|
||||||
|
def test_cancellation(self):
|
||||||
|
"""Test that tasks can be cancelled before execution."""
|
||||||
|
with CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4) as executor:
|
||||||
|
future = executor.submit(time.sleep, 5)
|
||||||
|
# Try to cancel immediately
|
||||||
|
cancelled = future.cancel()
|
||||||
|
self.assertTrue(cancelled)
|
||||||
|
self.assertTrue(future.cancelled())
|
||||||
|
# Attempt to get result; should raise CancelledError
|
||||||
|
with self.assertRaises(concurrent.futures.CancelledError):
|
||||||
|
future.result()
|
||||||
|
|
||||||
|
def test_shutdown(self):
|
||||||
|
"""Test that the executor shuts down properly and does not accept new tasks."""
|
||||||
|
executor = CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4)
|
||||||
|
future = executor.submit(time.sleep, 1)
|
||||||
|
executor.shutdown(wait=True)
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
executor.submit(time.sleep, 1)
|
||||||
|
|
||||||
|
def test_capping_behavior(self):
|
||||||
|
"""Test that the number of concurrent tasks does not exceed max_unprocessed."""
|
||||||
|
max_unprocessed = 3
|
||||||
|
with CappedProcessPoolExecutor(max_unprocessed=max_unprocessed, max_workers=10) as executor:
|
||||||
|
from multiprocessing import Manager
|
||||||
|
|
||||||
|
manager = Manager()
|
||||||
|
counter = manager.Value('i', 0)
|
||||||
|
max_counter = manager.Value('i', 0)
|
||||||
|
counter_lock = manager.Lock()
|
||||||
|
|
||||||
|
futures = [executor.submit(task, counter, max_counter, counter_lock) for _ in range(10)]
|
||||||
|
|
||||||
|
for index, f in enumerate(futures):
|
||||||
|
print(f"Future {index} returned {f.result()}")
|
||||||
|
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
print(max_counter.value)
|
||||||
|
self.assertLessEqual(max_counter.value, max_unprocessed)
|
||||||
|
|
||||||
|
def test_submit_after_shutdown(self):
|
||||||
|
"""Test that submitting tasks after shutdown raises an error."""
|
||||||
|
executor = CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4)
|
||||||
|
executor.shutdown(wait=True)
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
executor.submit(square, 2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
x
Reference in New Issue
Block a user