spring-data-redis icon indicating copy to clipboard operation
spring-data-redis copied to clipboard

RedisCache serializing value loading in get(key, valueLoader)

Open andruskutt opened this issue 3 years ago • 0 comments

This is follow up to https://github.com/spring-projects/spring-data-redis/issues/2079

I have case where valueLoader for populating RedisCache is calling external service with long response time. All value loads for same cache are synchronized, so N different values loaded in parallel takes N x external service response time for last one. What are possible risks if value loading is done concurrently?

I do not feel ready to submit PR (concurrent programming is hard), but I have created example (java 11) with something which seems to run concurrently :) I'm using private key as discussed here: https://github.com/spring-projects/spring-data-redis/issues/2079#issuecomment-858730434

package example;

import java.time.Duration;
import java.time.LocalTime;
import java.util.ArrayList;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.FutureTask;
import java.util.concurrent.RunnableFuture;

// see also https://github.com/spring-projects/spring-data-redis/issues/2079
public class RedisCacheSimulation {

    private static final String[] KEYS = {"key 1", "key 2", "key 3", "key 4"};
    private static final int THREAD_COUNT = KEYS.length;
    private static final long LOADER_WAIT_TIME_MS = 2000;
    private static final RedisCache CACHE = new RedisCache();
    private static final RedisCache CONCURRENT_CACHE = new ConcurrentRedisCache();
    private static CountDownLatch BARRIER;

    static void runSimulation(RedisCache cache) {
        BARRIER = new CountDownLatch(1);

        var workers = new ArrayList<Worker>();
        for (var i = 0; i < THREAD_COUNT; i++) {
            var worker = new Worker(i, KEYS[i % KEYS.length], cache);
            worker.start();
            workers.add(worker);
        }

        var start = LocalTime.now();
        System.out.println(start + " started " + THREAD_COUNT + " threads");

        BARRIER.countDown();

        for (var worker : workers) {
            try {
                worker.join();
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        }

        var end = LocalTime.now();
        System.out.println(end + " all done");
        System.out.println("Took " +  Duration.between(start, end).getSeconds() + " seconds");
    }

    public static void main(String[] args) {
        runSimulation(CACHE);
        System.out.println("\nConcurrent implementation\n");
        runSimulation(CONCURRENT_CACHE);
    }
 
    static class RedisCache {
        private final Map<Object, ValueWrapper> backend = new ConcurrentHashMap<>();

        @SuppressWarnings("unchecked")
        public <T> T get(Object key, Callable<T> valueLoader) {
            ValueWrapper result = get(key);

            if (result != null) {
                return (T) result.get();
            }

            return getSynchronized(key, valueLoader);
        }

        @SuppressWarnings("unchecked")
        private synchronized <T> T getSynchronized(Object key, Callable<T> valueLoader) {
            ValueWrapper result = get(key);

            if (result != null) {
                return (T) result.get();
            }

            T value;
            try {
                value = valueLoader.call();
            } catch (Exception e) {
                throw new ValueRetrievalException(key, valueLoader, e);
            }
            put(key, value);
            return value;
        }

        protected ValueWrapper get(Object key) {
            return backend.get(key);
        }

        protected void put(Object key, Object value) {
            backend.put(key, new ValueWrapper(value));
        }
    }

    static class ConcurrentRedisCache extends RedisCache {
        private final Map<String, RunnableFuture<?>> runningLoaders = new ConcurrentHashMap<>();

        @SuppressWarnings("unchecked")
        public <T> T get(Object key, Callable<T> valueLoader) {
            ValueWrapper result = get(key);

            if (result != null) {
                return (T) result.get();
            }

            return getConcurrent(key, valueLoader);
        }

        @SuppressWarnings("unchecked")
        private <T> T getConcurrent(Object key, Callable<T> valueLoader) {
            boolean mustRunLoader = false;
            RunnableFuture<?> loader;
            String privateKey = convertKey(key);

            synchronized (runningLoaders) {
                ValueWrapper result = get(key);

                if (result != null) {
                    return (T) result.get();
                }

                loader = runningLoaders.get(privateKey);
                if (loader == null) {
                    loader = new FutureTask<>(valueLoader) {
                        protected void setException(Throwable t) {
                            // everyone who's waiting concurrently will get same exception
                            super.setException(new ValueRetrievalException(key, valueLoader, t));
                        }
                    };
                    runningLoaders.put(privateKey, loader);
                    mustRunLoader = true;
                }
            }

            if (mustRunLoader) {
                loader.run();
            }

            try {
                T loaderResult;
                try {
                    loaderResult = (T) loader.get();
                } catch (InterruptedException e) {
                    throw new RuntimeException(e);
                } catch (ExecutionException e) {
                    var cause = e.getCause();
                    if (cause instanceof RuntimeException) {
                        throw (RuntimeException) cause;
                    }
                    throw new RuntimeException(cause);
                }
                if (mustRunLoader) {
                    put(key, loaderResult);
                }
                return loaderResult;
            } finally {
                if (mustRunLoader) {
                    runningLoaders.remove(privateKey);
                }
            }
        }

        protected String convertKey(Object key) {
            // minimal version for testing
            return key.toString();
        }
    }

    static class Worker extends Thread {
        private final String threadId;
        private final String key;
        private final RedisCache cache;

        Worker(int id, String key, RedisCache cache) {
            this.threadId = " thread " + id;
            this.key = key;
            this.cache = cache;
        }

        @Override
        public void run() {
            try {
                BARRIER.await();

                System.out.println(LocalTime.now() + threadId + " aquiring value for " + key);

                cache.get(key, () -> {
                    // long-running value loader
                    System.out.println(LocalTime.now() + " calculating value for " + key);
                    Thread.sleep(LOADER_WAIT_TIME_MS);
                    return key;
                });

                System.out.println(LocalTime.now() + threadId  + " got value for " + key);
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        }
    }

    static class ValueWrapper {
        private final Object value;

        ValueWrapper(Object v) {
            value = v;
        }

        public Object get() {
            return value;
        }
    }

    static class ValueRetrievalException extends RuntimeException {

        public ValueRetrievalException(Object key, Callable<?> valueLoader, Throwable t) {
        }
    }
}

Results: 23:57:24.568315 started 4 threads 23:57:24.569955 thread 1 acquiring value for key 2 23:57:24.569982 thread 0 acquiring value for key 1 23:57:24.569977 thread 3 acquiring value for key 4 23:57:24.569963 thread 2 acquiring value for key 3 23:57:24.570787 calculating value for key 1 23:57:26.571701 calculating value for key 3 23:57:26.571693 thread 0 got value for key 1 23:57:28.571964 thread 2 got value for key 3 23:57:28.571970 calculating value for key 2 23:57:30.572547 thread 1 got value for key 2 23:57:30.572562 calculating value for key 4 23:57:32.573019 thread 3 got value for key 4 23:57:32.573315 all done Took 8 seconds

Concurrent implementation

23:57:32.574645 started 4 threads 23:57:32.574852 thread 0 acquiring value for key 1 23:57:32.574869 thread 1 acquiring value for key 2 23:57:32.574888 thread 2 acquiring value for key 3 23:57:32.574916 thread 3 acquiring value for key 4 23:57:32.577271 calculating value for key 1 23:57:32.577307 calculating value for key 3 23:57:32.577292 calculating value for key 2 23:57:32.577290 calculating value for key 4 23:57:34.578093 thread 0 got value for key 1 23:57:34.578093 thread 3 got value for key 4 23:57:34.578119 thread 1 got value for key 2 23:57:34.578098 thread 2 got value for key 3 23:57:34.578778 all done Took 2 seconds

andruskutt avatar Sep 03 '22 20:09 andruskutt