From a1a4798ce7c1aa8cc9dd22300b44d8d2c2a15e3c Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Wed, 23 Oct 2024 21:51:37 +0000 Subject: [PATCH] Some crazy idea I had to simplify futures and memory limits --- pdelfin/cappedpool.py | 116 ++++++++++++++++++++++++++++++++ scripts/benchmark_throughput.py | 7 +- tests/test_cappedpool.py | 99 +++++++++++++++++++++++++++ 3 files changed, 220 insertions(+), 2 deletions(-) create mode 100644 pdelfin/cappedpool.py create mode 100644 tests/test_cappedpool.py diff --git a/pdelfin/cappedpool.py b/pdelfin/cappedpool.py new file mode 100644 index 0000000..f6ffc27 --- /dev/null +++ b/pdelfin/cappedpool.py @@ -0,0 +1,116 @@ +import concurrent.futures +import threading +import queue + +class CappedFuture(concurrent.futures.Future): + def __init__(self, semaphore): + super().__init__() + self._semaphore = semaphore + self._result_retrieved = False + self._underlying_future = None + self._condition = threading.Condition() + + def set_underlying_future(self, underlying_future): + with self._condition: + self._underlying_future = underlying_future + # Transfer the result when the underlying future completes + underlying_future.add_done_callback(self._transfer_result) + + def _transfer_result(self, underlying_future): + if underlying_future.cancelled(): + self.set_cancelled() + elif underlying_future.exception() is not None: + self.set_exception(underlying_future.exception()) + else: + try: + result = underlying_future.result() + self.set_result(result) + except Exception as e: + self.set_exception(e) + + def result(self, timeout=None): + res = super().result(timeout) + self._release_semaphore() + return res + + def exception(self, timeout=None): + exc = super().exception(timeout) + self._release_semaphore() + return exc + + def _release_semaphore(self): + if not self._result_retrieved: + self._result_retrieved = True + self._semaphore.release() + + def cancel(self): + with self._condition: + if self._underlying_future is not None: + cancelled = self._underlying_future.cancel() + if cancelled: + super().cancel() + return cancelled + else: + # Task has not been submitted yet; cancel directly + return super().cancel() + + def cancelled(self): + return super().cancelled() + + def running(self): + with self._condition: + if self._underlying_future is not None: + return self._underlying_future.running() + else: + return False + + def done(self): + return super().done() + +class CappedProcessPoolExecutor(concurrent.futures.Executor): + def __init__(self, max_unprocessed=100, max_workers=None): + self._max_unprocessed = max_unprocessed + self._semaphore = threading.BoundedSemaphore(max_unprocessed) + self._task_queue = queue.Queue() + self._shutdown = threading.Event() + self._shutdown_lock = threading.Lock() + self._executor = concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) + self._worker_thread = threading.Thread(target=self._worker) + self._worker_thread.daemon = True + self._worker_thread.start() + + def submit(self, fn, *args, **kwargs): + if self._shutdown.is_set(): + raise RuntimeError('Cannot submit new tasks after shutdown') + # Create a CappedFuture to return to the user + user_future = CappedFuture(self._semaphore) + # Put the task in the queue + self._task_queue.put((user_future, fn, args, kwargs)) + return user_future + + def _worker(self): + while True: + if self._shutdown.is_set() and self._task_queue.empty(): + break + try: + user_future, fn, args, kwargs = self._task_queue.get(timeout=0.1) + except queue.Empty: + continue + self._semaphore.acquire() + if user_future.cancelled(): + self._semaphore.release() + continue + # Submit the task to the underlying executor + try: + underlying_future = self._executor.submit(fn, *args, **kwargs) + user_future.set_underlying_future(underlying_future) + except Exception as e: + user_future.set_exception(e) + self._semaphore.release() + continue + + def shutdown(self, wait=True): + with self._shutdown_lock: + self._shutdown.set() + self._worker_thread.join() + self._executor.shutdown(wait=wait) diff --git a/scripts/benchmark_throughput.py b/scripts/benchmark_throughput.py index a6bd54b..7ebb052 100644 --- a/scripts/benchmark_throughput.py +++ b/scripts/benchmark_throughput.py @@ -99,13 +99,16 @@ def sample_mm_requests_qwen2vl( return_tensors="np", ) + # print(inputs) + tokens = inputs["input_ids"][0] prompt_len = len(tokens) result.append((TokensPrompt( dict( prompt_token_ids=tokens, - multi_modal_data=dict(image=main_image), + multi_modal_data=dict(image=dict(image_embeds=torch.randn(1036, 3584), image_grid_thw=torch.tensor([[1, 74, 56]]))), + # multi_modal_data=dict(image=main_image) ) ), prompt_len, fixed_output_len)) @@ -467,7 +470,7 @@ def main(args: argparse.Namespace): else: # requests = sample_requests(args.dataset, args.num_prompts, tokenizer, # args.output_len) - requests = sample_mm_requests_molmo(args.dataset, args.num_prompts, tokenizer, + requests = sample_mm_requests_qwen2vl(args.dataset, args.num_prompts, tokenizer, args.output_len) if args.backend == "vllm": diff --git a/tests/test_cappedpool.py b/tests/test_cappedpool.py new file mode 100644 index 0000000..05c8b77 --- /dev/null +++ b/tests/test_cappedpool.py @@ -0,0 +1,99 @@ +import unittest +import time +import concurrent.futures +from concurrent.futures import TimeoutError + +# Assuming the CappedProcessPoolExecutor code is in a module named 'capped_executor' +from pdelfin.cappedpool import CappedProcessPoolExecutor + +# Define functions at the top level to ensure they are picklable by multiprocessing + +def square(x): + return x * x + +def raise_exception(): + raise ValueError("Test exception") + +def sleep_and_return(x, sleep_time): + time.sleep(sleep_time) + return x + +def task(counter, max_counter, counter_lock): + with counter_lock: + counter.value += 1 + print(f"Task incrementing counter to {counter.value}") + if counter.value > max_counter.value: + max_counter.value = counter.value + time.sleep(0.5) + with counter_lock: + counter.value -= 1 + return True + +class TestCappedProcessPoolExecutor(unittest.TestCase): + + def test_basic_functionality(self): + """Test that tasks are executed and results are correct.""" + with CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4) as executor: + futures = [executor.submit(square, i) for i in range(10)] + results = [f.result() for f in futures] + expected = [i * i for i in range(10)] + self.assertEqual(results, expected) + + def test_exception_handling(self): + """Test that exceptions in tasks are properly raised.""" + with CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4) as executor: + future = executor.submit(raise_exception) + with self.assertRaises(ValueError): + future.result() + + def test_cancellation(self): + """Test that tasks can be cancelled before execution.""" + with CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4) as executor: + future = executor.submit(time.sleep, 5) + # Try to cancel immediately + cancelled = future.cancel() + self.assertTrue(cancelled) + self.assertTrue(future.cancelled()) + # Attempt to get result; should raise CancelledError + with self.assertRaises(concurrent.futures.CancelledError): + future.result() + + def test_shutdown(self): + """Test that the executor shuts down properly and does not accept new tasks.""" + executor = CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4) + future = executor.submit(time.sleep, 1) + executor.shutdown(wait=True) + with self.assertRaises(RuntimeError): + executor.submit(time.sleep, 1) + + def test_capping_behavior(self): + """Test that the number of concurrent tasks does not exceed max_unprocessed.""" + max_unprocessed = 3 + with CappedProcessPoolExecutor(max_unprocessed=max_unprocessed, max_workers=10) as executor: + from multiprocessing import Manager + + manager = Manager() + counter = manager.Value('i', 0) + max_counter = manager.Value('i', 0) + counter_lock = manager.Lock() + + futures = [executor.submit(task, counter, max_counter, counter_lock) for _ in range(10)] + + for index, f in enumerate(futures): + print(f"Future {index} returned {f.result()}") + + time.sleep(1) + + print(max_counter.value) + self.assertLessEqual(max_counter.value, max_unprocessed) + + def test_submit_after_shutdown(self): + """Test that submitting tasks after shutdown raises an error.""" + executor = CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4) + executor.shutdown(wait=True) + with self.assertRaises(RuntimeError): + executor.submit(square, 2) + + +if __name__ == '__main__': + unittest.main()