mirror of
				https://github.com/datahub-project/datahub.git
				synced 2025-10-31 18:59:23 +00:00 
			
		
		
		
	
		
			
	
	
		
			93 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			93 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | import threading | ||
|  | import time | ||
|  | 
 | ||
|  | from datahub.utilities.perf_timer import PerfTimer | ||
|  | from datahub.utilities.serialized_lru_cache import serialized_lru_cache | ||
|  | 
 | ||
|  | 
 | ||
|  | def test_cache_hit() -> None: | ||
|  |     @serialized_lru_cache(maxsize=2) | ||
|  |     def fetch_data(x): | ||
|  |         return x * 2 | ||
|  | 
 | ||
|  |     assert fetch_data(1) == 2  # Cache miss | ||
|  |     assert fetch_data(1) == 2  # Cache hit | ||
|  |     assert fetch_data.cache_info().hits == 1  # type: ignore | ||
|  |     assert fetch_data.cache_info().misses == 1  # type: ignore | ||
|  | 
 | ||
|  | 
 | ||
|  | def test_cache_eviction() -> None: | ||
|  |     @serialized_lru_cache(maxsize=2) | ||
|  |     def compute(x): | ||
|  |         return x * 2 | ||
|  | 
 | ||
|  |     compute(1) | ||
|  |     compute(2) | ||
|  |     compute(3)  # Should evict the first entry (1) | ||
|  |     assert compute.cache_info().currsize == 2  # type: ignore | ||
|  |     assert compute.cache_info().misses == 3  # type: ignore | ||
|  |     assert compute(1) == 2  # Cache miss, since it was evicted | ||
|  |     assert compute.cache_info().misses == 4  # type: ignore | ||
|  | 
 | ||
|  | 
 | ||
|  | def test_thread_safety() -> None: | ||
|  |     @serialized_lru_cache(maxsize=5) | ||
|  |     def compute(x): | ||
|  |         time.sleep(0.2)  # Simulate some delay | ||
|  |         return x * 2 | ||
|  | 
 | ||
|  |     threads = [] | ||
|  |     results = [None] * 10 | ||
|  | 
 | ||
|  |     def thread_func(index, arg): | ||
|  |         results[index] = compute(arg) | ||
|  | 
 | ||
|  |     with PerfTimer() as timer: | ||
|  |         for i in range(10): | ||
|  |             thread = threading.Thread(target=thread_func, args=(i, i % 5)) | ||
|  |             threads.append(thread) | ||
|  |             thread.start() | ||
|  | 
 | ||
|  |         for thread in threads: | ||
|  |             thread.join() | ||
|  | 
 | ||
|  |     assert len(set(results)) == 5  # Only 5 unique results should be there | ||
|  |     assert compute.cache_info().currsize <= 5  # type: ignore | ||
|  |     # Only 5 unique calls should miss the cache | ||
|  |     assert compute.cache_info().misses == 5  # type: ignore | ||
|  | 
 | ||
|  |     # Should take less than 1 second. If not, it means all calls were run serially. | ||
|  |     assert timer.elapsed_seconds() < 1 | ||
|  | 
 | ||
|  | 
 | ||
|  | def test_concurrent_access_to_same_key() -> None: | ||
|  |     @serialized_lru_cache(maxsize=3) | ||
|  |     def compute(x: int) -> int: | ||
|  |         time.sleep(0.2)  # Simulate some delay | ||
|  |         return x * 2 | ||
|  | 
 | ||
|  |     threads = [] | ||
|  |     results = [] | ||
|  | 
 | ||
|  |     def thread_func(): | ||
|  |         results.append(compute(1)) | ||
|  | 
 | ||
|  |     with PerfTimer() as timer: | ||
|  |         for _ in range(10): | ||
|  |             thread = threading.Thread(target=thread_func) | ||
|  |             threads.append(thread) | ||
|  |             thread.start() | ||
|  | 
 | ||
|  |         for thread in threads: | ||
|  |             thread.join() | ||
|  | 
 | ||
|  |     assert all(result == 2 for result in results)  # All should compute the same result | ||
|  | 
 | ||
|  |     # 9 hits, as the first one is a miss | ||
|  |     assert compute.cache_info().hits == 9  # type: ignore | ||
|  |     # Only the first call is a miss | ||
|  |     assert compute.cache_info().misses == 1  # type: ignore | ||
|  | 
 | ||
|  |     # Should take less than 1 second. If not, it means all calls were run serially. | ||
|  |     assert timer.elapsed_seconds() < 1 |