RedisCache serializing value loading in get(key, valueLoader)
This is follow up to https://github.com/spring-projects/spring-data-redis/issues/2079
I have case where valueLoader for populating RedisCache is calling external service with long response time. All value loads for same cache are synchronized, so N different values loaded in parallel takes N x external service response time for last one. What are possible risks if value loading is done concurrently?
I do not feel ready to submit PR (concurrent programming is hard), but I have created example (java 11) with something which seems to run concurrently :) I'm using private key as discussed here: https://github.com/spring-projects/spring-data-redis/issues/2079#issuecomment-858730434
package example;
import java.time.Duration;
import java.time.LocalTime;
import java.util.ArrayList;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.FutureTask;
import java.util.concurrent.RunnableFuture;
// see also https://github.com/spring-projects/spring-data-redis/issues/2079
public class RedisCacheSimulation {
private static final String[] KEYS = {"key 1", "key 2", "key 3", "key 4"};
private static final int THREAD_COUNT = KEYS.length;
private static final long LOADER_WAIT_TIME_MS = 2000;
private static final RedisCache CACHE = new RedisCache();
private static final RedisCache CONCURRENT_CACHE = new ConcurrentRedisCache();
private static CountDownLatch BARRIER;
static void runSimulation(RedisCache cache) {
BARRIER = new CountDownLatch(1);
var workers = new ArrayList<Worker>();
for (var i = 0; i < THREAD_COUNT; i++) {
var worker = new Worker(i, KEYS[i % KEYS.length], cache);
worker.start();
workers.add(worker);
}
var start = LocalTime.now();
System.out.println(start + " started " + THREAD_COUNT + " threads");
BARRIER.countDown();
for (var worker : workers) {
try {
worker.join();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
var end = LocalTime.now();
System.out.println(end + " all done");
System.out.println("Took " + Duration.between(start, end).getSeconds() + " seconds");
}
public static void main(String[] args) {
runSimulation(CACHE);
System.out.println("\nConcurrent implementation\n");
runSimulation(CONCURRENT_CACHE);
}
static class RedisCache {
private final Map<Object, ValueWrapper> backend = new ConcurrentHashMap<>();
@SuppressWarnings("unchecked")
public <T> T get(Object key, Callable<T> valueLoader) {
ValueWrapper result = get(key);
if (result != null) {
return (T) result.get();
}
return getSynchronized(key, valueLoader);
}
@SuppressWarnings("unchecked")
private synchronized <T> T getSynchronized(Object key, Callable<T> valueLoader) {
ValueWrapper result = get(key);
if (result != null) {
return (T) result.get();
}
T value;
try {
value = valueLoader.call();
} catch (Exception e) {
throw new ValueRetrievalException(key, valueLoader, e);
}
put(key, value);
return value;
}
protected ValueWrapper get(Object key) {
return backend.get(key);
}
protected void put(Object key, Object value) {
backend.put(key, new ValueWrapper(value));
}
}
static class ConcurrentRedisCache extends RedisCache {
private final Map<String, RunnableFuture<?>> runningLoaders = new ConcurrentHashMap<>();
@SuppressWarnings("unchecked")
public <T> T get(Object key, Callable<T> valueLoader) {
ValueWrapper result = get(key);
if (result != null) {
return (T) result.get();
}
return getConcurrent(key, valueLoader);
}
@SuppressWarnings("unchecked")
private <T> T getConcurrent(Object key, Callable<T> valueLoader) {
boolean mustRunLoader = false;
RunnableFuture<?> loader;
String privateKey = convertKey(key);
synchronized (runningLoaders) {
ValueWrapper result = get(key);
if (result != null) {
return (T) result.get();
}
loader = runningLoaders.get(privateKey);
if (loader == null) {
loader = new FutureTask<>(valueLoader) {
protected void setException(Throwable t) {
// everyone who's waiting concurrently will get same exception
super.setException(new ValueRetrievalException(key, valueLoader, t));
}
};
runningLoaders.put(privateKey, loader);
mustRunLoader = true;
}
}
if (mustRunLoader) {
loader.run();
}
try {
T loaderResult;
try {
loaderResult = (T) loader.get();
} catch (InterruptedException e) {
throw new RuntimeException(e);
} catch (ExecutionException e) {
var cause = e.getCause();
if (cause instanceof RuntimeException) {
throw (RuntimeException) cause;
}
throw new RuntimeException(cause);
}
if (mustRunLoader) {
put(key, loaderResult);
}
return loaderResult;
} finally {
if (mustRunLoader) {
runningLoaders.remove(privateKey);
}
}
}
protected String convertKey(Object key) {
// minimal version for testing
return key.toString();
}
}
static class Worker extends Thread {
private final String threadId;
private final String key;
private final RedisCache cache;
Worker(int id, String key, RedisCache cache) {
this.threadId = " thread " + id;
this.key = key;
this.cache = cache;
}
@Override
public void run() {
try {
BARRIER.await();
System.out.println(LocalTime.now() + threadId + " aquiring value for " + key);
cache.get(key, () -> {
// long-running value loader
System.out.println(LocalTime.now() + " calculating value for " + key);
Thread.sleep(LOADER_WAIT_TIME_MS);
return key;
});
System.out.println(LocalTime.now() + threadId + " got value for " + key);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}
static class ValueWrapper {
private final Object value;
ValueWrapper(Object v) {
value = v;
}
public Object get() {
return value;
}
}
static class ValueRetrievalException extends RuntimeException {
public ValueRetrievalException(Object key, Callable<?> valueLoader, Throwable t) {
}
}
}
Results: 23:57:24.568315 started 4 threads 23:57:24.569955 thread 1 acquiring value for key 2 23:57:24.569982 thread 0 acquiring value for key 1 23:57:24.569977 thread 3 acquiring value for key 4 23:57:24.569963 thread 2 acquiring value for key 3 23:57:24.570787 calculating value for key 1 23:57:26.571701 calculating value for key 3 23:57:26.571693 thread 0 got value for key 1 23:57:28.571964 thread 2 got value for key 3 23:57:28.571970 calculating value for key 2 23:57:30.572547 thread 1 got value for key 2 23:57:30.572562 calculating value for key 4 23:57:32.573019 thread 3 got value for key 4 23:57:32.573315 all done Took 8 seconds
Concurrent implementation
23:57:32.574645 started 4 threads 23:57:32.574852 thread 0 acquiring value for key 1 23:57:32.574869 thread 1 acquiring value for key 2 23:57:32.574888 thread 2 acquiring value for key 3 23:57:32.574916 thread 3 acquiring value for key 4 23:57:32.577271 calculating value for key 1 23:57:32.577307 calculating value for key 3 23:57:32.577292 calculating value for key 2 23:57:32.577290 calculating value for key 4 23:57:34.578093 thread 0 got value for key 1 23:57:34.578093 thread 3 got value for key 4 23:57:34.578119 thread 1 got value for key 2 23:57:34.578098 thread 2 got value for key 3 23:57:34.578778 all done Took 2 seconds