Site Search:

Replacing HashMap with ConcurrentHashMap

<Back

Memoizer2 in this version improves on the concurrent behavior of Memoizer1 by replacing the HashMap with a ConcurrentHashMap. Since concurrentHashMap is thread-safe, ther is no additional synchronization is needed.

Memoizer2 allows multiple threads to run compute concurrently,  however, if(result == null) {} block is still check then act, two threads could run the if block twice, resulting in waste of cpu cycle.

An output when running in eclipse is:

average time per run: 149.17222222222222 miliseconds.

import java.math.BigInteger;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;

class Memoizer2 <A, V> implements Computable<A, V> {
    private final Map<A, V> cache = new ConcurrentHashMap<A, V>();
    private final Computable<A, V> c;

    public Memoizer2(Computable<A, V> c) {
        this.c = c;
    }

    public V compute(A arg) throws InterruptedException {
        V result = cache.get(arg);
        if (result == null) {
            result = c.compute(arg);
            cache.put(arg, result);
        }
        return result;
    }
}

interface Computable <A, V> {
    V compute(A arg) throws InterruptedException;
}

interface Servlet {
    //mock the javax.servlet.Servlet
    public void service(StringBuilder req, StringBuilder resp);
}
class CachedFactorizer implements Servlet {
    private final Computable<BigInteger, BigInteger[]> c =
            new Computable<BigInteger, BigInteger[]>() {
                public BigInteger[] compute(BigInteger arg) {
                    return factor(arg);
                }
            };
    private final Computable<BigInteger, BigInteger[]> cache
            = new Memoizer2<BigInteger, BigInteger[]>(c);

    public void service(StringBuilder req, StringBuilder resp) {
        BigInteger i = extractFromRequest(req);
        try {
            encodeIntoResponse(resp, cache.compute(i));
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
    void encodeIntoResponse(StringBuilder resp, BigInteger[] factors) {
        resp.append(factors[0]);
    }

    BigInteger extractFromRequest(StringBuilder req) {
        return new BigInteger(req.toString());
    }

    BigInteger[] factor(BigInteger i) {
        // Doesn't really factor
        for(int j = 0; j < 100000000; j++) {}
        return new BigInteger[]{i};
    }
}

public class CachedFactorizerTest implements Runnable{
    private static final int TOTALRUN = 200;
    private static final int TOTALTHREAD = 2000;
    private static final int DISCARDFIRSTFEW = 20;
    private static Random rand = new Random();
    private static CountDownLatch latch;
    private static CachedFactorizer cachedFactorizer = new CachedFactorizer();
    
    private void setCountDownLatch(CountDownLatch latch) {
        CachedFactorizerTest.latch = latch;
    }

    public static void main(String[] args) throws InterruptedException {
        long totalTime = 0;
        for (int i = 0; i < TOTALRUN; i++) {
            long oneRoundTime = runMultiThread();
            if (i >= DISCARDFIRSTFEW) {
                totalTime += oneRoundTime;
            }
        }
        System.out.println("average time per run: "
                + (double) totalTime / (double) (TOTALRUN - DISCARDFIRSTFEW)
                + " miliseconds.");

    }
    
    private static long runMultiThread()
            throws InterruptedException {
        long start = System.currentTimeMillis();
        CountDownLatch latch = new CountDownLatch(TOTALTHREAD);
        CachedFactorizerTest cachedFactorizerTest = new CachedFactorizerTest();
        cachedFactorizerTest.setCountDownLatch(latch);
        for (long i = 0; i < TOTALTHREAD; i++) {
            Thread thread = new Thread(cachedFactorizerTest);
            thread.start();
        }
        latch.await();
        return System.currentTimeMillis() - start;
    }

    @Override
    public void run() {
        StringBuilder req = new StringBuilder().append(rand.nextInt(20));
        StringBuilder resp = new StringBuilder();
        //System.out.println(Thread.currentThread().getName() + " req=" + req.toString() + " resp=" + resp.toString());
        cachedFactorizer.service(req, resp);
        if(!req.toString().contains(resp.toString())) {
            System.out.println("multi-threading race condition");
        }
        //System.out.println(Thread.currentThread().getName() + " req=" + req.toString() + " resp=" + resp.toString());
        latch.countDown();
    }

}