datahub/metadata-ingestion/tests/unit/utilities/test_serialized_lru_cache.py
2024-10-16 19:18:32 -07:00

93 lines
2.7 KiB
Python

import threading
import time
from datahub.utilities.perf_timer import PerfTimer
from datahub.utilities.serialized_lru_cache import serialized_lru_cache
def test_cache_hit() -> None:
@serialized_lru_cache(maxsize=2)
def fetch_data(x):
return x * 2
assert fetch_data(1) == 2 # Cache miss
assert fetch_data(1) == 2 # Cache hit
assert fetch_data.cache_info().hits == 1 # type: ignore
assert fetch_data.cache_info().misses == 1 # type: ignore
def test_cache_eviction() -> None:
@serialized_lru_cache(maxsize=2)
def compute(x):
return x * 2
compute(1)
compute(2)
compute(3) # Should evict the first entry (1)
assert compute.cache_info().currsize == 2 # type: ignore
assert compute.cache_info().misses == 3 # type: ignore
assert compute(1) == 2 # Cache miss, since it was evicted
assert compute.cache_info().misses == 4 # type: ignore
def test_thread_safety() -> None:
@serialized_lru_cache(maxsize=5)
def compute(x):
time.sleep(0.2) # Simulate some delay
return x * 2
threads = []
results = [None] * 10
def thread_func(index, arg):
results[index] = compute(arg)
with PerfTimer() as timer:
for i in range(10):
thread = threading.Thread(target=thread_func, args=(i, i % 5))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
assert len(set(results)) == 5 # Only 5 unique results should be there
assert compute.cache_info().currsize <= 5 # type: ignore
# Only 5 unique calls should miss the cache
assert compute.cache_info().misses == 5 # type: ignore
# Should take less than 1 second. If not, it means all calls were run serially.
assert timer.elapsed_seconds() < 1
def test_concurrent_access_to_same_key() -> None:
@serialized_lru_cache(maxsize=3)
def compute(x: int) -> int:
time.sleep(0.2) # Simulate some delay
return x * 2
threads = []
results = []
def thread_func():
results.append(compute(1))
with PerfTimer() as timer:
for _ in range(10):
thread = threading.Thread(target=thread_func)
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
assert all(result == 2 for result in results) # All should compute the same result
# 9 hits, as the first one is a miss
assert compute.cache_info().hits == 9 # type: ignore
# Only the first call is a miss
assert compute.cache_info().misses == 1 # type: ignore
# Should take less than 1 second. If not, it means all calls were run serially.
assert timer.elapsed_seconds() < 1