Site Search:

Final implementation of Memoizer

<Back

Memoizer uses ConcurrentMap's atomic putIfAbsent() to close the window that could cause Memoizer3 to calculate the same value twice. It removes the Future from the cache if a computation is cancelled, it kept looping in case of ExecutionException hoping the code can pass the next time.

An output from eclipse is:

average time per run: 139.87222222222223 miliseconds.

import java.math.BigInteger;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.FutureTask;

class Memoizer <A, V> implements Computable<A, V> {
    private final ConcurrentMap<A, Future<V>> cache
            = new ConcurrentHashMap<A, Future<V>>();
    private final Computable<A, V> c;

    public Memoizer(Computable<A, V> c) {
        this.c = c;
    }

    public V compute(final A arg) throws InterruptedException {
        while (true) {
            Future<V> f = cache.get(arg);
            if (f == null) {
                Callable<V> eval = new Callable<V>() {
                    public V call() throws InterruptedException {
                        return c.compute(arg);
                    }
                };
                FutureTask<V> ft = new FutureTask<V>(eval);
                f = cache.putIfAbsent(arg, ft);
                if (f == null) {
                    f = ft;
                    ft.run();
                }
            }
            try {
                return f.get();
            } catch (CancellationException e) {
                cache.remove(arg, f);
            } catch (ExecutionException e) {
                e.printStackTrace();
            }
        }
    }
}

interface Computable <A, V> {
    V compute(A arg) throws InterruptedException;
}

interface Servlet {
    //mock the javax.servlet.Servlet
    public void service(StringBuilder req, StringBuilder resp);
}
class CachedFactorizer implements Servlet {
    private final Computable<BigInteger, BigInteger[]> c =
            new Computable<BigInteger, BigInteger[]>() {
                public BigInteger[] compute(BigInteger arg) {
                    return factor(arg);
                }
            };
    private final Computable<BigInteger, BigInteger[]> cache
            = new Memoizer<BigInteger, BigInteger[]>(c);

    public void service(StringBuilder req, StringBuilder resp) {
        BigInteger i = extractFromRequest(req);
        try {
            encodeIntoResponse(resp, cache.compute(i));
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
    void encodeIntoResponse(StringBuilder resp, BigInteger[] factors) {
        resp.append(factors[0]);
    }

    BigInteger extractFromRequest(StringBuilder req) {
        return new BigInteger(req.toString());
    }

    BigInteger[] factor(BigInteger i) {
        // Doesn't really factor
        for(int j = 0; j < 100000000; j++) {}
        return new BigInteger[]{i};
    }
}

public class CachedFactorizerTest implements Runnable{
    private static final int TOTALRUN = 200;
    private static final int TOTALTHREAD = 2000;
    private static final int DISCARDFIRSTFEW = 20;
    private static Random rand = new Random();
    private static CountDownLatch latch;
    private static CachedFactorizer cachedFactorizer = new CachedFactorizer();
    
    private void setCountDownLatch(CountDownLatch latch) {
        CachedFactorizerTest.latch = latch;
    }

    public static void main(String[] args) throws InterruptedException {
        long totalTime = 0;
        for (int i = 0; i < TOTALRUN; i++) {
            long oneRoundTime = runMultiThread();
            if (i >= DISCARDFIRSTFEW) {
                totalTime += oneRoundTime;
            }
        }
        System.out.println("average time per run: "
                + (double) totalTime / (double) (TOTALRUN - DISCARDFIRSTFEW)
                + " miliseconds.");

    }
    
    private static long runMultiThread()
            throws InterruptedException {
        long start = System.currentTimeMillis();
        CountDownLatch latch = new CountDownLatch(TOTALTHREAD);
        CachedFactorizerTest cachedFactorizerTest = new CachedFactorizerTest();
        cachedFactorizerTest.setCountDownLatch(latch);
        for (long i = 0; i < TOTALTHREAD; i++) {
            Thread thread = new Thread(cachedFactorizerTest);
            thread.start();
        }
        latch.await();
        return System.currentTimeMillis() - start;
    }

    @Override
    public void run() {
        StringBuilder req = new StringBuilder().append(rand.nextInt(20));
        StringBuilder resp = new StringBuilder();
        //System.out.println(Thread.currentThread().getName() + " req=" + req.toString() + " resp=" + resp.toString());
        cachedFactorizer.service(req, resp);
        if(!req.toString().contains(resp.toString())) {
            System.out.println("multi-threading race condition");
        }
        //System.out.println(Thread.currentThread().getName() + " req=" + req.toString() + " resp=" + resp.toString());
        latch.countDown();
    }

}