mirror of
				https://github.com/datahub-project/datahub.git
				synced 2025-10-31 02:37:05 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			93 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			93 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import threading
 | |
| import time
 | |
| 
 | |
| from datahub.utilities.perf_timer import PerfTimer
 | |
| from datahub.utilities.serialized_lru_cache import serialized_lru_cache
 | |
| 
 | |
| 
 | |
| def test_cache_hit() -> None:
 | |
|     @serialized_lru_cache(maxsize=2)
 | |
|     def fetch_data(x):
 | |
|         return x * 2
 | |
| 
 | |
|     assert fetch_data(1) == 2  # Cache miss
 | |
|     assert fetch_data(1) == 2  # Cache hit
 | |
|     assert fetch_data.cache_info().hits == 1  # type: ignore
 | |
|     assert fetch_data.cache_info().misses == 1  # type: ignore
 | |
| 
 | |
| 
 | |
| def test_cache_eviction() -> None:
 | |
|     @serialized_lru_cache(maxsize=2)
 | |
|     def compute(x):
 | |
|         return x * 2
 | |
| 
 | |
|     compute(1)
 | |
|     compute(2)
 | |
|     compute(3)  # Should evict the first entry (1)
 | |
|     assert compute.cache_info().currsize == 2  # type: ignore
 | |
|     assert compute.cache_info().misses == 3  # type: ignore
 | |
|     assert compute(1) == 2  # Cache miss, since it was evicted
 | |
|     assert compute.cache_info().misses == 4  # type: ignore
 | |
| 
 | |
| 
 | |
| def test_thread_safety() -> None:
 | |
|     @serialized_lru_cache(maxsize=5)
 | |
|     def compute(x):
 | |
|         time.sleep(0.2)  # Simulate some delay
 | |
|         return x * 2
 | |
| 
 | |
|     threads = []
 | |
|     results = [None] * 10
 | |
| 
 | |
|     def thread_func(index, arg):
 | |
|         results[index] = compute(arg)
 | |
| 
 | |
|     with PerfTimer() as timer:
 | |
|         for i in range(10):
 | |
|             thread = threading.Thread(target=thread_func, args=(i, i % 5))
 | |
|             threads.append(thread)
 | |
|             thread.start()
 | |
| 
 | |
|         for thread in threads:
 | |
|             thread.join()
 | |
| 
 | |
|     assert len(set(results)) == 5  # Only 5 unique results should be there
 | |
|     assert compute.cache_info().currsize <= 5  # type: ignore
 | |
|     # Only 5 unique calls should miss the cache
 | |
|     assert compute.cache_info().misses == 5  # type: ignore
 | |
| 
 | |
|     # Should take less than 1 second. If not, it means all calls were run serially.
 | |
|     assert timer.elapsed_seconds() < 1
 | |
| 
 | |
| 
 | |
| def test_concurrent_access_to_same_key() -> None:
 | |
|     @serialized_lru_cache(maxsize=3)
 | |
|     def compute(x: int) -> int:
 | |
|         time.sleep(0.2)  # Simulate some delay
 | |
|         return x * 2
 | |
| 
 | |
|     threads = []
 | |
|     results = []
 | |
| 
 | |
|     def thread_func():
 | |
|         results.append(compute(1))
 | |
| 
 | |
|     with PerfTimer() as timer:
 | |
|         for _ in range(10):
 | |
|             thread = threading.Thread(target=thread_func)
 | |
|             threads.append(thread)
 | |
|             thread.start()
 | |
| 
 | |
|         for thread in threads:
 | |
|             thread.join()
 | |
| 
 | |
|     assert all(result == 2 for result in results)  # All should compute the same result
 | |
| 
 | |
|     # 9 hits, as the first one is a miss
 | |
|     assert compute.cache_info().hits == 9  # type: ignore
 | |
|     # Only the first call is a miss
 | |
|     assert compute.cache_info().misses == 1  # type: ignore
 | |
| 
 | |
|     # Should take less than 1 second. If not, it means all calls were run serially.
 | |
|     assert timer.elapsed_seconds() < 1
 | 
