mirror of
				https://github.com/allenai/olmocr.git
				synced 2025-10-31 18:15:44 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			100 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			100 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import unittest
 | |
| import time
 | |
| import concurrent.futures
 | |
| from concurrent.futures import TimeoutError
 | |
| 
 | |
| # Assuming the CappedProcessPoolExecutor code is in a module named 'capped_executor'
 | |
| from pdelfin.cappedpool import CappedProcessPoolExecutor
 | |
| 
 | |
| # Define functions at the top level to ensure they are picklable by multiprocessing
 | |
| 
 | |
| def square(x):
 | |
|     return x * x
 | |
| 
 | |
| def raise_exception():
 | |
|     raise ValueError("Test exception")
 | |
| 
 | |
| def sleep_and_return(x, sleep_time):
 | |
|     time.sleep(sleep_time)
 | |
|     return x
 | |
| 
 | |
| def task(counter, max_counter, counter_lock):
 | |
|     with counter_lock:
 | |
|         counter.value += 1
 | |
|         print(f"Task incrementing counter to {counter.value}")
 | |
|         if counter.value > max_counter.value:
 | |
|             max_counter.value = counter.value
 | |
|     time.sleep(0.5)
 | |
|     with counter_lock:
 | |
|         counter.value -= 1
 | |
|     return True
 | |
| 
 | |
| class TestCappedProcessPoolExecutor(unittest.TestCase):
 | |
| 
 | |
|     def test_basic_functionality(self):
 | |
|         """Test that tasks are executed and results are correct."""
 | |
|         with CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4) as executor:
 | |
|             futures = [executor.submit(square, i) for i in range(10)]
 | |
|             results = [f.result() for f in futures]
 | |
|             expected = [i * i for i in range(10)]
 | |
|             self.assertEqual(results, expected)
 | |
| 
 | |
|     def test_exception_handling(self):
 | |
|         """Test that exceptions in tasks are properly raised."""
 | |
|         with CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4) as executor:
 | |
|             future = executor.submit(raise_exception)
 | |
|             with self.assertRaises(ValueError):
 | |
|                 future.result()
 | |
| 
 | |
|     def test_cancellation(self):
 | |
|         """Test that tasks can be cancelled before execution."""
 | |
|         with CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4) as executor:
 | |
|             future = executor.submit(time.sleep, 5)
 | |
|             # Try to cancel immediately
 | |
|             cancelled = future.cancel()
 | |
|             self.assertTrue(cancelled)
 | |
|             self.assertTrue(future.cancelled())
 | |
|             # Attempt to get result; should raise CancelledError
 | |
|             with self.assertRaises(concurrent.futures.CancelledError):
 | |
|                 future.result()
 | |
| 
 | |
|     def test_shutdown(self):
 | |
|         """Test that the executor shuts down properly and does not accept new tasks."""
 | |
|         executor = CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4)
 | |
|         future = executor.submit(time.sleep, 1)
 | |
|         executor.shutdown(wait=True)
 | |
|         with self.assertRaises(RuntimeError):
 | |
|             executor.submit(time.sleep, 1)
 | |
| 
 | |
|     def test_capping_behavior(self):
 | |
|         """Test that the number of concurrent tasks does not exceed max_unprocessed."""
 | |
|         max_unprocessed = 3
 | |
|         with CappedProcessPoolExecutor(max_unprocessed=max_unprocessed, max_workers=10) as executor:
 | |
|             from multiprocessing import Manager
 | |
| 
 | |
|             manager = Manager()
 | |
|             counter = manager.Value('i', 0)
 | |
|             max_counter = manager.Value('i', 0)
 | |
|             counter_lock = manager.Lock()
 | |
| 
 | |
|             futures = [executor.submit(task, counter, max_counter, counter_lock) for _ in range(10)]
 | |
| 
 | |
|             for index, f in enumerate(futures):
 | |
|                 print(f"Future {index} returned {f.result()}")
 | |
| 
 | |
|                 time.sleep(1)
 | |
| 
 | |
|             print(max_counter.value)
 | |
|             self.assertLessEqual(max_counter.value, max_unprocessed)
 | |
| 
 | |
|     def test_submit_after_shutdown(self):
 | |
|         """Test that submitting tasks after shutdown raises an error."""
 | |
|         executor = CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4)
 | |
|         executor.shutdown(wait=True)
 | |
|         with self.assertRaises(RuntimeError):
 | |
|             executor.submit(square, 2)
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     unittest.main()
 | 
