(this post is by Charles)
My colleagues Matt Hoffman, Pavel Sountsov, Lionel Riou-Durand, Aki Vehtari, Andrew Gelman, and I released a preprint titled “Nested R-hat: assessing the convergence of Markov chain Monte Carlo when running many short chains”. This is a revision of an earlier preprint. Here’s the abstract:
The growing availability of hardware accelerators such as GPUs has generated interest in Markov chains Monte Carlo (MCMC) workflows which run a large number of chains in parallel. Each chain still needs to forget its initial state but the subsequent sampling phase can be almost arbitrarily short. To determine if the resulting short chains are reliable, we need to assess how close the Markov chains are to convergence to their stationary distribution. The R-hat statistic is a battle-tested convergence diagnostic but unfortunately can require long chains to work well. We present a nested design to overcome this challenge, and introduce tuning parameters to control the reliability, bias, and variance of convergence diagnostics.
The paper is motivated by the possibility of running many Markov chains in parallel on modern hardware, such as GPU. Increasing the number of chains allows you to reduce the variance of your Monte Carlo estimator, which is what the sampling phase is for, but not the bias, which is what the warmup phase is for (that’s the short story). So you can trade length of the sampling phase for number of chains but you still need to achieve approximate convergence.
There’s more to be said about the many-short-chains regime but what I want to focus on is what we’ve learned about the more classic R-hat. The first step is to rewrite the condition, R-hat < 1.01, as a tolerance on the variance of the per chain Monte Carlo estimator. Intuitively, we’re running a stochastic algorithm to estimate an expectation value, which is a non-random quantity. Hence, different chains should, despite their different initialization and seed, still come to an “agreement”. This agreement is measured by the variance of the estimator produced by each chain.
Now here’s the paradox. The expected squared error of a per chain Monte Carlo estimator decomposes into a squared bias and a variance. When diagnosing convergence, we’re really interested in making sure the bias has decayed sufficiently (a common phrase is “has become negligible”, but I find it useful to think of MCMC as a biased algorithm). But, with R-hat, we’re really monitoring the variance, not the bias! So how can this be a useful diagnostic?
This paradox occurred to us when we rewrote R-hat to monitor the variance of Monte Carlo estimators constructed using groups of chains or superchains, rather than a single chain. The resulting nested R-hat decays to 1 provided we have enough chains, even if the individual chains are short (think a single iteration). But here’s the issue: regardless of wether the chains are close to convergence or not, R-hat can be made arbitrarily close to 1 by increasing the size of each superchain and thence decreasing the variance of their Monte Carlo estimator. Which goes back to my earlier point: you cannot monitor bias simply by looking at variance.
Or can you?
Here’s the twist: we now force all the chains within a superchain to start at the same point. I had this idea initially to deal with multimodal distributions. The chains within a group are no longer independent, though eventually they (hopefully) will forget about each other. In the mean time we have artificially increased the variance. Doing a standard variance decomposition:
total variance = variance of conditional expectation + expected conditional variance
Here we’re conditioning on the initial point. If the expected value of each chain no longer depends on the initialization, then the first term — variance of the conditional expectation — goes to 0. This is a measurement of “how well the chains forget their starting point”, and we call it the violation of stationarity. It is indifferent to the number of chains. The second term, on the other hand, persists even if your chains are stationary but it decays to 0 as you increase the number of chains. More generally, this persistent variance can be linked to the Effective Sample Size.
We argue that nested R-hat is a (scaled) measure of the violation of stationarity, biased by the persistent variance. How does this link to the squared bias? Well, both bias and violation decay as we warm up our chains, so one can be used as a “proxy clock” of the other. I don’t have a fully general theory for this but if you consider a Gaussian target and are willing to solve an SDE, you can show that the violation and the squared bias decay at the same rate. This is also gives us insight about how over-dispersed initializations should be (or not be) for nested R-hat to be reliable.
Now nested R-hat is a generalization of R-hat, meaning our analysis carries over! We moreover have a theory of what R-hat measures which does not assume stationarity. Part of the conceptual leap is to do an asymptotic analysis which considers an infinite number of finite (non-stationary) chains, rather than a single infinitely long (and hence stationary) chain.
Moving forward, I hope this idea of a proxy clock will help us identify cases where R-hat and its nested version are reliable, and how we might revise our MCMC processes to get more reliable diagnostics. Two examples discussed in the preprint: choice of initial variance and how to split a fixed total number of chains into superchains.