Using Reference Maps for Caches and Listeners

A while ago I wrote a blog post about the WeakHashMap. It then turned out that the WeakHashMap was not the optimal choice for that particular use case and I proposed a different solution. To make this a bit more seizable, I decided to implement a post a full code example. Let me again describe the use case.

Let's say you have a class wrapping some sort of event. Let's give the event class a name, write a Java interface and call it Auditable. Each Auditable subclass must implement two methods: validate and process. There is a invoker class called AuditableInvoker which receives a collection of Auditable's and invokes validate and process on each one of them. So far so good.

package javasplitter;
public interface Auditable {
void process();
void validate();
}
view raw auditable.java hosted with ❤ by GitHub

package javasplitter;
import java.util.Iterator;
import java.util.NavigableSet;
import java.util.TreeSet;
public class AuditableInvoker {
private final Auditable[] invokeables;
private final NavigableSet<AuditableLifecycleListener> listeners;
public AuditableInvoker(final Auditable... invokeables) {
this.invokeables = invokeables;
this.listeners = new TreeSet<AuditableLifecycleListener>();
}
public void addListener(final AuditableLifecycleListener listener) {
this.listeners.add(listener);
}
public void invoke() {
for (Auditable invokeable : invokeables) {
validate(invokeable);
process(invokeable);
}
}
private void validate(final Auditable auditable) {
for (final Iterator<AuditableLifecycleListener> iterator = this.listeners.descendingIterator();
iterator.hasNext(); ) {
final AuditableLifecycleListener listener = iterator.next();
listener.onValidationStart(auditable);
}
auditable.validate();
for (final AuditableLifecycleListener listener : this.listeners) {
listener.onValidationFinish(auditable);
}
}
private void process(final Auditable auditable) {
for (final Iterator<AuditableLifecycleListener> iterator = this.listeners.descendingIterator();
iterator.hasNext(); ) {
final AuditableLifecycleListener listener = iterator.next();
listener.onProcessStart(auditable);
}
auditable.process();
for (final AuditableLifecycleListener listener : this.listeners) {
listener.onProcessFinish(auditable);
}
}
}

As an example, let's implement two Auditable subclasses which are pretty stupid. SleepingAuditable will just hold the current Thread for a few milliseconds. IteratingAuditable will run a small loop in it's validate and process methods.

package javasplitter;
public class IteratingAuditable implements Auditable {
@Override
public void process() {
for (int i = 0; i < 5; i++) {
// do nothing
}
}
@Override
public void validate() {
for (int i = 0; i < 50; i++) {
// do nothing
}
}
}

package javasplitter;
public class SleepingAuditable implements Auditable {
@Override
public void process() {
try {
Thread.sleep(20L);
} catch (InterruptedException e) {
System.out.println("I got interrupted");
}
}
@Override
public void validate() {
try {
Thread.sleep(10L);
} catch (InterruptedException e) {
System.out.println("I got interrupted");
}
}
}

In addition to that, there is a requirement that you need to know the execution time of the validate and process method in each Auditable subclass. Fortunately you can add listeners to AuditableInvoker. So all you have to do is to write a listener that measures the execution times. The listener need to start a stop watch before validate or process is invoked and stop this very stop watch after process and validate are finished. Once they are finished, the execution time can be computed and kept in a helper class that we call the StatsCollector. To keep things simple, our UnboundedStatsCollector will only increment a counter, completely ignoring the execution times.

package javasplitter;
public interface AuditableLifecycleListener {
void onValidationStart(final Auditable auditable);
void onValidationFinish(final Auditable auditable);
void onProcessStart(final Auditable auditable);
void onProcessFinish(final Auditable auditable);
}

package javasplitter;
public interface StatsCollector {
void collectValidationStats(Class<? extends Auditable> clazz, long executionTime);
void collectProcessStats(Class<? extends Auditable> clazz, long executionTime);
long timesValidated(Class<? extends Auditable> clazz);
long timesProcessed(Class<? extends Auditable> clazz);
}

package javasplitter;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
public class UnboundedStatsCollector implements StatsCollector {
private final Map<Class<? extends Auditable>, AtomicInteger> validationStats =
new ConcurrentHashMap<Class<? extends Auditable>, AtomicInteger>();
private final Map<Class<? extends Auditable>, AtomicInteger> processStats =
new ConcurrentHashMap<Class<? extends Auditable>, AtomicInteger>();
public UnboundedStatsCollector() {
this.validationStats.put(SleepingAuditable.class, new AtomicInteger(0));
this.processStats.put(SleepingAuditable.class, new AtomicInteger(0));
this.validationStats.put(IteratingAuditable.class, new AtomicInteger(0));
this.processStats.put(IteratingAuditable.class, new AtomicInteger(0));
}
@Override
public void collectValidationStats(final Class<? extends Auditable> clazz, final long executionTime) {
this.validationStats.get(clazz).incrementAndGet();
}
@Override
public void collectProcessStats(final Class<? extends Auditable> clazz, final long executionTime) {
this.processStats.get(clazz).incrementAndGet();
}
@Override
public long timesValidated(final Class<? extends Auditable> clazz) {
return this.validationStats.get(clazz).longValue();
}
@Override
public long timesProcessed(final Class<? extends Auditable> clazz) {
return this.processStats.get(clazz).longValue();
}
}

The tricky part here is that you need to use the same stop watch before and after the invocations of an Auditable. A good use case for a map using weak referenced keys and object identity for comparison. Once an Auditable subclass has finished it's lifecycle and is no longer referenced somewhere else in the code, the garbage collection can collect the Auditable as well as the associated stop watch. This will prevent the Map from growing indefinitely. So here is a implementation using a ReferenceIdentityMap from the commons-collections project.

package javasplitter;
import org.apache.commons.collections15.map.AbstractReferenceMap;
import org.apache.commons.collections15.map.ReferenceIdentityMap;
import java.util.Collections;
import java.util.Map;
public class ExecutionTimingAuditableLifecycleListener implements AuditableLifecycleListener {
private final StatsCollector statsCollector;
private final Map<Auditable, Long> timedExecutions = Collections.synchronizedMap(
new ReferenceIdentityMap<Auditable, Long>(
AbstractReferenceMap.WEAK,
AbstractReferenceMap.HARD
)
);
public ExecutionTimingAuditableLifecycleListener(final StatsCollector statsCollector) {
this.statsCollector = statsCollector;
}
@Override
public void onValidationStart(final Auditable auditable) {
this.timedExecutions.put(auditable, System.currentTimeMillis());
}
@Override
public void onValidationFinish(final Auditable auditable) {
final Long startTime = this.timedExecutions.get(auditable);
if (startTime != null) {
final long validationTime = System.currentTimeMillis() - startTime;
this.statsCollector.collectValidationStats(auditable.getClass(),
validationTime);
} else {
System.out.println(
String.format(
"Unable to find validation start time for %s",
auditable.getClass().getSimpleName()));
}
}
@Override
public void onProcessStart(final Auditable auditable) {
this.timedExecutions.put(auditable, System.currentTimeMillis());
}
@Override
public void onProcessFinish(final Auditable auditable) {
final Long startTime = this.timedExecutions.get(auditable);
if (startTime != null) {
final long validationTime = System.currentTimeMillis() - startTime;
this.statsCollector.collectProcessStats(auditable.getClass(),
validationTime);
} else {
System.out.println(
String.format(
"Unable to find process start time for %s",
auditable.getClass().getSimpleName()));
}
}
}

To verify that we really see the expected behavior, I have written a unit test that is stressing the ExecutionTimingAuditableLifecycleListener using multiple Threads. In this unit test I am re-using a class called MultithreadedStressTester which I stole from Nat Pryze's book "Growing Object Oriented Software guided by Tests".

public final class MultithreadedStressTester {
/**
* The default number of threads to run concurrently.
*/
public static final int DEFAULT_THREAD_COUNT = 2;
private final ExecutorService executor;
private final int threadCount;
private final int iterationCount;
public MultithreadedStressTester(int threadCount, int iterationCount) {
this.threadCount = threadCount;
this.iterationCount = iterationCount;
this.executor = Executors.newCachedThreadPool();
}
public void stress(final Runnable action) throws InterruptedException {
spawnThreads(action).await();
}
private CountDownLatch spawnThreads(final Runnable action) {
final CountDownLatch finished = new CountDownLatch(threadCount);
for (int i = 0; i &lt; threadCount; i++) {
executor.execute(new Runnable() {
public void run() {
try {
repeat(action);
}
finally {
finished.countDown();
}
}
});
}
return finished;
}
private void repeat(Runnable action) {
for (int i = 0; i &lt; iterationCount; i++) {
action.run();
}
}
}

The ExecutionTimingAuditableLifecycleListenerTest uses the MultithreadedStressTester to send a bunch of Threads over to the ExecutionTimingAuditableLifecycleListener, verifying that each invocation is properly timed using the ReferenceIdentityMap under the hood.
package javasplitter;
import org.junit.Test;
import static org.junit.Assert.*;
public class ExecutionTimingAuditableLifecycleListenerTest {
@Test
public void testTimingExecutions() throws InterruptedException {
final int threads = 500;
final int iterations = 5000;
final int total = threads * iterations;
final StatsCollector statsCollector = new UnboundedStatsCollector();
final AuditableLifecycleListener listener =
new ExecutionTimingAuditableLifecycleListener(statsCollector);
final MultithreadedStressTester stressTester =
new MultithreadedStressTester(threads, iterations);
stressTester.stress(
new Runnable() {
@Override
public void run() {
final Auditable sleeping = new SleepingAuditable();
final Auditable iterating = new IteratingAuditable();
final AuditableInvoker invoker = new AuditableInvoker(sleeping, iterating);
invoker.addListener(listener);
invoker.invoke();
}
}
);
assertEquals(total, statsCollector.timesValidated(SleepingAuditable.class));
assertEquals(total, statsCollector.timesValidated(IteratingAuditable.class));
assertEquals(total, statsCollector.timesProcessed(SleepingAuditable.class));
assertEquals(total, statsCollector.timesProcessed(IteratingAuditable.class));
}
}

Finally, if you want to use google-guava instead of commons-collections, you can also use a LoadingCache with weak keys instead of the ReferenceIdentityMap. Here is a version of the ExecutionTimingAuditableLifecycleListener using google-guava.
package javasplitter;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import java.util.concurrent.ExecutionException;
public class ExecutionTimingAuditableLifecycleListener implements AuditableLifecycleListener {
private final StatsCollector statsCollector;
private final LoadingCache<Auditable, Long> timedExecutionsCache =
CacheBuilder.newBuilder()
.weakKeys()
.build(
new CacheLoader<Auditable, Long>() {
@Override
public Long load(final Auditable key) throws Exception {
return System.currentTimeMillis();
}
}
);
public ExecutionTimingAuditableLifecycleListener(
final StatsCollector statsCollector) {
this.statsCollector = statsCollector;
}
@Override
public void onValidationStart(final Auditable auditable) {
try {
this.timedExecutionsCache.get(auditable);
} catch (ExecutionException e) {
// not thrown
}
}
@Override
public void onValidationFinish(final Auditable auditable) {
try {
this.statsCollector.collectValidationStats(
auditable.getClass(),
System.currentTimeMillis() - this.timedExecutionsCache.get(auditable)
);
} catch (ExecutionException e) {
// not thrown
} finally {
this.timedExecutionsCache.invalidate(auditable);
}
}
@Override
public void onProcessStart(final Auditable auditable) {
try {
this.timedExecutionsCache.get(auditable);
} catch (ExecutionException e) {
// not thrown
}
}
@Override
public void onProcessFinish(final Auditable auditable) {
try {
this.statsCollector.collectProcessStats(
auditable.getClass(),
System.currentTimeMillis() - this.timedExecutionsCache.get(auditable)
);
} catch (ExecutionException e) {
// not thrown
} finally {
this.timedExecutionsCache.invalidate(auditable);
}
}
}