fix(cache): update search cache when skipped, but enabled (#7936)

This commit is contained in:
RyanHolstien 2023-05-08 17:50:50 -05:00 committed by GitHub
parent ef1ada118d
commit ee5480bcbd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 79 additions and 17 deletions

View File

@ -90,20 +90,25 @@ public class CacheableSearcher<K> {
try (Timer.Context ignored = MetricUtils.timer(this.getClass(), "getBatch").time()) {
QueryPagination batch = getBatchQuerySize(batchId);
SearchResult result;
if (enableCache()) {
try (Timer.Context ignored2 = MetricUtils.timer(this.getClass(), "getBatch_cache").time()) {
Timer.Context cacheAccess = MetricUtils.timer(this.getClass(), "getBatch_cache_access").time();
K cacheKey = cacheKeyGenerator.apply(batch);
String json = cache.get(cacheKey, String.class);
result = json != null ? toRecordTemplate(SearchResult.class, json) : null;
cacheAccess.stop();
if (result == null) {
Timer.Context cacheMiss = MetricUtils.timer(this.getClass(), "getBatch_cache_miss").time();
result = searcher.apply(batch);
cache.put(cacheKey, toJsonString(result));
cacheMiss.stop();
MetricUtils.counter(this.getClass(), "getBatch_cache_miss_count").inc();
if (enableCache) {
K cacheKey = cacheKeyGenerator.apply(batch);
if ((searchFlags == null || !searchFlags.isSkipCache())) {
try (Timer.Context ignored2 = MetricUtils.timer(this.getClass(), "getBatch_cache").time()) {
Timer.Context cacheAccess = MetricUtils.timer(this.getClass(), "getBatch_cache_access").time();
String json = cache.get(cacheKey, String.class);
result = json != null ? toRecordTemplate(SearchResult.class, json) : null;
cacheAccess.stop();
if (result == null) {
Timer.Context cacheMiss = MetricUtils.timer(this.getClass(), "getBatch_cache_miss").time();
result = searcher.apply(batch);
cache.put(cacheKey, toJsonString(result));
cacheMiss.stop();
MetricUtils.counter(this.getClass(), "getBatch_cache_miss_count").inc();
}
}
} else {
result = searcher.apply(batch);
cache.put(cacheKey, toJsonString(result));
}
} else {
result = searcher.apply(batch);
@ -111,8 +116,4 @@ public class CacheableSearcher<K> {
return result;
}
}
private boolean enableCache() {
return enableCache && (searchFlags == null || !searchFlags.isSkipCache());
}
}

View File

@ -3,6 +3,7 @@ package com.linkedin.metadata.search.cache;
import com.google.common.collect.Streams;
import com.linkedin.common.urn.TestEntityUrn;
import com.linkedin.common.urn.Urn;
import com.linkedin.metadata.query.SearchFlags;
import com.linkedin.metadata.search.AggregationMetadataArray;
import com.linkedin.metadata.search.SearchEntity;
import com.linkedin.metadata.search.SearchEntityArray;
@ -11,6 +12,8 @@ import com.linkedin.metadata.search.SearchResultMetadata;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.mockito.Mockito;
import org.springframework.cache.Cache;
import org.springframework.cache.CacheManager;
import org.springframework.cache.concurrent.ConcurrentMapCacheManager;
import org.testng.annotations.Test;
@ -86,6 +89,64 @@ public class CacheableSearcherTest {
getUrns(0, 40).stream(), getUrns(0, 5).stream()).collect(Collectors.toList()));
}
@Test
public void testCacheableSearcherEnabled() {
// Verify cache is not interacted with when cache disabled
Cache mockCache = Mockito.mock(Cache.class);
CacheableSearcher<Integer> cacheDisabled =
new CacheableSearcher<>(mockCache, 10,
qs -> getSearchResult(qs, qs.getFrom() + qs.getSize()), CacheableSearcher.QueryPagination::getFrom, null,
false);
SearchResult result = cacheDisabled.getSearchResults(0, 10);
assertEquals(result.getNumEntities().intValue(), 1000);
assertEquals(result.getEntities().size(), 10);
assertEquals(result.getEntities().stream().map(SearchEntity::getEntity).collect(Collectors.toList()),
getUrns(0, 10));
Mockito.verifyNoInteractions(mockCache);
Mockito.reset(mockCache);
// Verify cache is updated when cache enabled, but skip cache passed through
CacheableSearcher<Integer> skipCache =
new CacheableSearcher<>(mockCache, 10,
qs -> getSearchResult(qs, qs.getFrom() + qs.getSize()), CacheableSearcher.QueryPagination::getFrom,
new SearchFlags().setSkipCache(true), true);
result = skipCache.getSearchResults(0, 10);
assertEquals(result.getNumEntities().intValue(), 1000);
assertEquals(result.getEntities().size(), 10);
assertEquals(result.getEntities().stream().map(SearchEntity::getEntity).collect(Collectors.toList()),
getUrns(0, 10));
Mockito.verify(mockCache, Mockito.times(1)).put(Mockito.any(), Mockito.any());
Mockito.verify(mockCache, Mockito.times(0)).get(Mockito.any(), Mockito.any(Class.class));
Mockito.reset(mockCache);
// Test cache hit when searchFlags is null
CacheableSearcher<Integer> nullFlags =
new CacheableSearcher<>(mockCache, 10,
qs -> getSearchResult(qs, qs.getFrom() + qs.getSize()), CacheableSearcher.QueryPagination::getFrom,
null, true);
result = nullFlags.getSearchResults(0, 10);
assertEquals(result.getNumEntities().intValue(), 1000);
assertEquals(result.getEntities().size(), 10);
assertEquals(result.getEntities().stream().map(SearchEntity::getEntity).collect(Collectors.toList()),
getUrns(0, 10));
Mockito.verify(mockCache, Mockito.times(1)).put(Mockito.any(), Mockito.any());
Mockito.verify(mockCache, Mockito.times(1)).get(Mockito.any(), Mockito.any(Class.class));
Mockito.reset(mockCache);
// Test cache hit when skipCache is false
CacheableSearcher<Integer> useCache =
new CacheableSearcher<>(mockCache, 10,
qs -> getSearchResult(qs, qs.getFrom() + qs.getSize()), CacheableSearcher.QueryPagination::getFrom,
new SearchFlags().setSkipCache(false), true);
result = useCache.getSearchResults(0, 10);
assertEquals(result.getNumEntities().intValue(), 1000);
assertEquals(result.getEntities().size(), 10);
assertEquals(result.getEntities().stream().map(SearchEntity::getEntity).collect(Collectors.toList()),
getUrns(0, 10));
Mockito.verify(mockCache, Mockito.times(1)).put(Mockito.any(), Mockito.any());
Mockito.verify(mockCache, Mockito.times(1)).get(Mockito.any(), Mockito.any(Class.class));
}
private SearchResult getEmptySearchResult(CacheableSearcher.QueryPagination queryPagination) {
return new SearchResult().setEntities(new SearchEntityArray())
.setNumEntities(0)