Site Search:

Memoizing wrapper using FutureTask

<Back

Memoizer3 in this version redefines the backing Map for the value cache as a ConcurrentHashMap<A, Future<V>> instead of a ConcurrentHashMap<A, V>. 

Memoizer3 first call cache.get(arg) to check if the appropriate calculation has been started (as opposed to finished, as in Memoizer2). If not, it creates a FutureTask, call cache.put(arg, ft), and starts the computation; otherwise it calls f.get() to wait for the result of the existing computation. 

block is still a check then act, there is a small window for two threads to call compute with the same value at the same time, both see the cache does not contain the desired value, ad both start the computation. Also we should remove the future instance in case f.get() throws CancellationException.


An output from eclipse is:

average time per run: 135.82777777777778 miliseconds.

import java.math.BigInteger;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.FutureTask;

class Memoizer3 <A, V> implements Computable<A, V> {
    private final Map<A, Future<V>> cache = new ConcurrentHashMap<A, Future<V>>();
    private final Computable<A, V> c;

    public Memoizer3(Computable<A, V> c) {
        this.c = c;
    }

    public V compute(A arg) throws InterruptedException {
        Future<V> f = cache.get(arg);
        if(f == null) {
            Callable<V> eval = new Callable<V>() {
                public V call() throws InterruptedException {
                    return c.compute(arg);
                }
            };
            FutureTask<V> ft = new FutureTask<V>(eval);
            f = ft;
            cache.put(arg, ft);
            ft.run();
        }
        
        try {
            return f.get();
        } catch (ExecutionException e) {
            e.printStackTrace();
        }
        return null;
    }
}

interface Computable <A, V> {
    V compute(A arg) throws InterruptedException;
}

interface Servlet {
    //mock the javax.servlet.Servlet
    public void service(StringBuilder req, StringBuilder resp);
}
class CachedFactorizer implements Servlet {
    private final Computable<BigInteger, BigInteger[]> c =
            new Computable<BigInteger, BigInteger[]>() {
                public BigInteger[] compute(BigInteger arg) {
                    return factor(arg);
                }
            };
    private final Computable<BigInteger, BigInteger[]> cache
            = new Memoizer3<BigInteger, BigInteger[]>(c);

    public void service(StringBuilder req, StringBuilder resp) {
        BigInteger i = extractFromRequest(req);
        try {
            encodeIntoResponse(resp, cache.compute(i));
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
    void encodeIntoResponse(StringBuilder resp, BigInteger[] factors) {
        resp.append(factors[0]);
    }

    BigInteger extractFromRequest(StringBuilder req) {
        return new BigInteger(req.toString());
    }

    BigInteger[] factor(BigInteger i) {
        // Doesn't really factor
        for(int j = 0; j < 100000000; j++) {}
        return new BigInteger[]{i};
    }
}

public class CachedFactorizerTest implements Runnable{
    private static final int TOTALRUN = 200;
    private static final int TOTALTHREAD = 2000;
    private static final int DISCARDFIRSTFEW = 20;
    private static Random rand = new Random();
    private static CountDownLatch latch;
    private static CachedFactorizer cachedFactorizer = new CachedFactorizer();
    
    private void setCountDownLatch(CountDownLatch latch) {
        CachedFactorizerTest.latch = latch;
    }

    public static void main(String[] args) throws InterruptedException {
        long totalTime = 0;
        for (int i = 0; i < TOTALRUN; i++) {
            long oneRoundTime = runMultiThread();
            if (i >= DISCARDFIRSTFEW) {
                totalTime += oneRoundTime;
            }
        }
        System.out.println("average time per run: "
                + (double) totalTime / (double) (TOTALRUN - DISCARDFIRSTFEW)
                + " miliseconds.");

    }
    
    private static long runMultiThread()
            throws InterruptedException {
        long start = System.currentTimeMillis();
        CountDownLatch latch = new CountDownLatch(TOTALTHREAD);
        CachedFactorizerTest cachedFactorizerTest = new CachedFactorizerTest();
        cachedFactorizerTest.setCountDownLatch(latch);
        for (long i = 0; i < TOTALTHREAD; i++) {
            Thread thread = new Thread(cachedFactorizerTest);
            thread.start();
        }
        latch.await();
        return System.currentTimeMillis() - start;
    }

    @Override
    public void run() {
        StringBuilder req = new StringBuilder().append(rand.nextInt(20));
        StringBuilder resp = new StringBuilder();
        //System.out.println(Thread.currentThread().getName() + " req=" + req.toString() + " resp=" + resp.toString());
        cachedFactorizer.service(req, resp);
        if(!req.toString().contains(resp.toString())) {
            System.out.println("multi-threading race condition");
        }
        //System.out.println(Thread.currentThread().getName() + " req=" + req.toString() + " resp=" + resp.toString());
        latch.countDown();
    }

}