mirror of
https://github.com/datahub-project/datahub.git
synced 2025-07-04 23:57:03 +00:00
93 lines
2.7 KiB
Python
93 lines
2.7 KiB
Python
import threading
|
|
import time
|
|
|
|
from datahub.utilities.perf_timer import PerfTimer
|
|
from datahub.utilities.serialized_lru_cache import serialized_lru_cache
|
|
|
|
|
|
def test_cache_hit() -> None:
|
|
@serialized_lru_cache(maxsize=2)
|
|
def fetch_data(x):
|
|
return x * 2
|
|
|
|
assert fetch_data(1) == 2 # Cache miss
|
|
assert fetch_data(1) == 2 # Cache hit
|
|
assert fetch_data.cache_info().hits == 1 # type: ignore
|
|
assert fetch_data.cache_info().misses == 1 # type: ignore
|
|
|
|
|
|
def test_cache_eviction() -> None:
|
|
@serialized_lru_cache(maxsize=2)
|
|
def compute(x):
|
|
return x * 2
|
|
|
|
compute(1)
|
|
compute(2)
|
|
compute(3) # Should evict the first entry (1)
|
|
assert compute.cache_info().currsize == 2 # type: ignore
|
|
assert compute.cache_info().misses == 3 # type: ignore
|
|
assert compute(1) == 2 # Cache miss, since it was evicted
|
|
assert compute.cache_info().misses == 4 # type: ignore
|
|
|
|
|
|
def test_thread_safety() -> None:
|
|
@serialized_lru_cache(maxsize=5)
|
|
def compute(x):
|
|
time.sleep(0.2) # Simulate some delay
|
|
return x * 2
|
|
|
|
threads = []
|
|
results = [None] * 10
|
|
|
|
def thread_func(index, arg):
|
|
results[index] = compute(arg)
|
|
|
|
with PerfTimer() as timer:
|
|
for i in range(10):
|
|
thread = threading.Thread(target=thread_func, args=(i, i % 5))
|
|
threads.append(thread)
|
|
thread.start()
|
|
|
|
for thread in threads:
|
|
thread.join()
|
|
|
|
assert len(set(results)) == 5 # Only 5 unique results should be there
|
|
assert compute.cache_info().currsize <= 5 # type: ignore
|
|
# Only 5 unique calls should miss the cache
|
|
assert compute.cache_info().misses == 5 # type: ignore
|
|
|
|
# Should take less than 1 second. If not, it means all calls were run serially.
|
|
assert timer.elapsed_seconds() < 1
|
|
|
|
|
|
def test_concurrent_access_to_same_key() -> None:
|
|
@serialized_lru_cache(maxsize=3)
|
|
def compute(x: int) -> int:
|
|
time.sleep(0.2) # Simulate some delay
|
|
return x * 2
|
|
|
|
threads = []
|
|
results = []
|
|
|
|
def thread_func():
|
|
results.append(compute(1))
|
|
|
|
with PerfTimer() as timer:
|
|
for _ in range(10):
|
|
thread = threading.Thread(target=thread_func)
|
|
threads.append(thread)
|
|
thread.start()
|
|
|
|
for thread in threads:
|
|
thread.join()
|
|
|
|
assert all(result == 2 for result in results) # All should compute the same result
|
|
|
|
# 9 hits, as the first one is a miss
|
|
assert compute.cache_info().hits == 9 # type: ignore
|
|
# Only the first call is a miss
|
|
assert compute.cache_info().misses == 1 # type: ignore
|
|
|
|
# Should take less than 1 second. If not, it means all calls were run serially.
|
|
assert timer.elapsed_seconds() < 1
|