Memoizer uses ConcurrentMap's atomic putIfAbsent() to close the window that could cause Memoizer3 to calculate the same value twice. It removes the Future from the cache if a computation is cancelled, it kept looping in case of ExecutionException hoping the code can pass the next time.
An output from eclipse is:
average time per run: 139.87222222222223 miliseconds.
import java.math.BigInteger;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.FutureTask;
class Memoizer <A, V> implements Computable<A, V> {
private final ConcurrentMap<A, Future<V>> cache
= new ConcurrentHashMap<A, Future<V>>();
private final Computable<A, V> c;
public Memoizer(Computable<A, V> c) {
this.c = c;
}
public V compute(final A arg) throws InterruptedException {
while (true) {
Future<V> f = cache.get(arg);
if (f == null) {
Callable<V> eval = new Callable<V>() {
public V call() throws InterruptedException {
return c.compute(arg);
}
};
FutureTask<V> ft = new FutureTask<V>(eval);
f = cache.putIfAbsent(arg, ft);
if (f == null) {
f = ft;
ft.run();
}
}
try {
return f.get();
} catch (CancellationException e) {
cache.remove(arg, f);
} catch (ExecutionException e) {
e.printStackTrace();
}
}
}
}
interface Computable <A, V> {
V compute(A arg) throws InterruptedException;
}
interface Servlet {
//mock the javax.servlet.Servlet
public void service(StringBuilder req, StringBuilder resp);
}
class CachedFactorizer implements Servlet {
private final Computable<BigInteger, BigInteger[]> c =
new Computable<BigInteger, BigInteger[]>() {
public BigInteger[] compute(BigInteger arg) {
return factor(arg);
}
};
private final Computable<BigInteger, BigInteger[]> cache
= new Memoizer<BigInteger, BigInteger[]>(c);
public void service(StringBuilder req, StringBuilder resp) {
BigInteger i = extractFromRequest(req);
try {
encodeIntoResponse(resp, cache.compute(i));
} catch (InterruptedException e) {
e.printStackTrace();
}
}
void encodeIntoResponse(StringBuilder resp, BigInteger[] factors) {
resp.append(factors[0]);
}
BigInteger extractFromRequest(StringBuilder req) {
return new BigInteger(req.toString());
}
BigInteger[] factor(BigInteger i) {
// Doesn't really factor
for(int j = 0; j < 100000000; j++) {}
return new BigInteger[]{i};
}
}
public class CachedFactorizerTest implements Runnable{
private static final int TOTALRUN = 200;
private static final int TOTALTHREAD = 2000;
private static final int DISCARDFIRSTFEW = 20;
private static Random rand = new Random();
private static CountDownLatch latch;
private static CachedFactorizer cachedFactorizer = new CachedFactorizer();
private void setCountDownLatch(CountDownLatch latch) {
CachedFactorizerTest.latch = latch;
}
public static void main(String[] args) throws InterruptedException {
long totalTime = 0;
for (int i = 0; i < TOTALRUN; i++) {
long oneRoundTime = runMultiThread();
if (i >= DISCARDFIRSTFEW) {
totalTime += oneRoundTime;
}
}
System.out.println("average time per run: "
+ (double) totalTime / (double) (TOTALRUN - DISCARDFIRSTFEW)
+ " miliseconds.");
}
private static long runMultiThread()
throws InterruptedException {
long start = System.currentTimeMillis();
CountDownLatch latch = new CountDownLatch(TOTALTHREAD);
CachedFactorizerTest cachedFactorizerTest = new CachedFactorizerTest();
cachedFactorizerTest.setCountDownLatch(latch);
for (long i = 0; i < TOTALTHREAD; i++) {
Thread thread = new Thread(cachedFactorizerTest);
thread.start();
}
latch.await();
return System.currentTimeMillis() - start;
}
@Override
public void run() {
StringBuilder req = new StringBuilder().append(rand.nextInt(20));
StringBuilder resp = new StringBuilder();
//System.out.println(Thread.currentThread().getName() + " req=" + req.toString() + " resp=" + resp.toString());
cachedFactorizer.service(req, resp);
if(!req.toString().contains(resp.toString())) {
System.out.println("multi-threading race condition");
}
//System.out.println(Thread.currentThread().getName() + " req=" + req.toString() + " resp=" + resp.toString());
latch.countDown();
}
}