What Nested R-hat teaches us about the classical R-hat

(this post is by Charles)

My colleagues Matt Hoffman, Pavel Sountsov, Lionel Riou-Durand, Aki Vehtari, Andrew Gelman, and I released a preprint titled “Nested R-hat: assessing the convergence of Markov chain Monte Carlo when running many short chains”. This is a revision of an earlier preprint. Here’s the abstract:

The growing availability of hardware accelerators such as GPUs has generated interest in Markov chains Monte Carlo (MCMC) workflows which run a large number of chains in parallel. Each chain still needs to forget its initial state but the subsequent sampling phase can be almost arbitrarily short. To determine if the resulting short chains are reliable, we need to assess how close the Markov chains are to convergence to their stationary distribution. The R-hat statistic is a battle-tested convergence diagnostic but unfortunately can require long chains to work well. We present a nested design to overcome this challenge, and introduce tuning parameters to control the reliability, bias, and variance of convergence diagnostics.

The paper is motivated by the possibility of running many Markov chains in parallel on modern hardware, such as GPU. Increasing the number of chains allows you to reduce the variance of your Monte Carlo estimator, which is what the sampling phase is for, but not the bias, which is what the warmup phase is for (that’s the short story).  So you can trade length of the sampling phase for number of chains but you still need to achieve approximate convergence.

There’s more to be said about the many-short-chains regime but what I want to focus on is what we’ve learned about the more classic R-hat. The first step is to rewrite the condition, R-hat < 1.01, as a tolerance on the variance of the per chain Monte Carlo estimator. Intuitively, we’re running a stochastic algorithm to estimate an expectation value, which is a non-random quantity. Hence, different chains should, despite their different initialization and seed, still come to an “agreement”. This agreement is measured by the variance of the estimator produced by each chain.

Now here’s the paradox. The expected squared error of a per chain Monte Carlo estimator decomposes into a squared bias and a variance. When diagnosing convergence, we’re really interested in making sure the bias has decayed sufficiently (a common phrase is “has become negligible”, but I find it useful to think of MCMC as a biased algorithm). But, with R-hat, we’re really monitoring the variance, not the bias! So how can this be a useful diagnostic?

This paradox occurred to us when we rewrote R-hat to monitor the variance of Monte Carlo estimators constructed using groups of chains or superchains, rather than a single chain. The resulting nested R-hat decays to 1 provided we have enough chains, even if the individual chains are short (think a single iteration). But here’s the issue: regardless of wether the chains are close to convergence or not, R-hat can be made arbitrarily close to 1 by increasing the size of each superchain and thence decreasing the variance of their Monte Carlo estimator. Which goes back to my earlier point: you cannot monitor bias simply by looking at variance.

Or can you?

Here’s the twist: we now force all the chains within a superchain to start at the same point. I had this idea initially to deal with multimodal distributions. The chains within a group are no longer independent, though eventually they (hopefully) will forget about each other. In the mean time we have artificially increased the variance. Doing a standard variance decomposition:

total variance = variance of conditional expectation + expected conditional variance

Here we’re conditioning on the initial point. If the expected value of each chain no longer depends on the initialization, then the first term — variance of the conditional expectation — goes to 0. This is a measurement of “how well the chains forget their starting point”, and we call it the violation of stationarity. It is indifferent to the number of chains. The second term, on the other hand, persists even if your chains are stationary but it decays to 0 as you increase the number of chains. More generally, this persistent variance can be linked to the Effective Sample Size.

We argue that nested R-hat is a (scaled) measure of the violation of stationarity, biased by the persistent variance. How does this link to the squared bias? Well, both bias and violation decay as we warm up our chains, so one can be used as a “proxy clock” of the other. I don’t have a fully general theory for this but if you consider a Gaussian target and are willing to solve an SDE, you can show that the violation and the squared bias decay at the same rate. This is also gives us insight about how over-dispersed initializations should be (or not be) for nested R-hat to be reliable.

Now nested R-hat is a generalization of R-hat, meaning our analysis carries over! We moreover have a theory of what R-hat measures which does not assume stationarity. Part of the conceptual leap is to do an asymptotic analysis which considers an infinite number of finite (non-stationary) chains, rather than a single infinitely long (and hence stationary) chain.

Moving forward, I hope this idea of a proxy clock will help us identify cases where R-hat and its nested version are reliable, and how we might revise our MCMC processes to get more reliable diagnostics. Two examples discussed in the preprint: choice of initial variance and how to split a fixed total number of chains into superchains.

7 thoughts on “What Nested R-hat teaches us about the classical R-hat

  1. Sorry to derail the conversation into general parallel HMC talk, but I don’t have any meaningful thoughts on nested R hat yet

    The paper is motivated by the possibility of running many Markov chains in parallel on modern hardware, such as GPU. Increasing the number of chains allows you to reduce the variance of your Monte Carlo estimator, which is what the sampling phase is for, but not the bias, which is what the warmup phase is for (that’s the short story). So you can trade length of the sampling phase for number of chains but you still need to achieve approximate convergence.

    Wouldn’t HMC on a GPU get you sublinear speedup with the number of cores? My understanding is that GPUs are efficient for highly SIMD workloads. But with each chain running warmup and adaptation, I’d expect chains to get slightly different (epsilon, L) pairs, the different step counts resulting in instructions between chains being slightly out of sync. I’d also expect the different choices of integration path lengths from the no-u-turn criterion to make HMC non-data-parallel across chains.

    • Running many chains in parallel on a GPU is indeed not a trivial task but there have been a few recent breakthroughs to tackle this problem and address some of the issues you raise for NUTS. In our experiments we do not use NUTS, rather ChEES-HMC (https://proceedings.mlr.press/v130/hoffman21a.html), which uses a cross-chain adaptation and the same leapfrog step size and trajectory length for all chains. My co-authors Matt Hoffman and Pavel Sountsov have worked quite a bit on the topic, developing ChEES-HMC, and more recently the SNAPER and MEADS algorithms. And there is more ongoing work! It should be noted that sharing information between chains during the warmup can lead to faster adaptation and (I conjecture) faster bias decay; this needs to be studied more formally but it is a way in which many chains could also reduce runtime during the warmup.

      With all this said, let me bring up a few persistent challenges. This cross-chain adaptation works well empirically on a diverse palette of targets but early sharing of the tuning parameters during the warmup phase can be problematic. This is because the behavior of the chains during the warmup may be heterogeneous, depending on where each chain starts, i.e. the tuning parameter required for a chain to make progress depends on where the chain is. I uncovered this problem in a less conventional pharmacometrics model, with an ODE-based likelihood.

      On a somewhat related note, we may worry about heterogenous runtimes when evaluating the gradient of the log likelihood. Again consider an ODE-based likelihood where the speed at which you solve the ODE depends on the parameter values. We have a simple demo of this in the Bayesian Workflow paper (https://arxiv.org/abs/2011.01808, Section 11). Heterogenous runtime is a problem because we’re always waiting for the slowest operation, so runtime is distributed according to an ordered statistics, rather than an average — although we might not have to always wait for the slowest operation. This is another problem we’re working on.

  2. This probably isn’t helpful, but you might want to check out the folks who come up with the idea of using Monte Carlo tree search for Go and then their and others later work on how to make it actually work in the context of a game where tactics make a big difference. It _might_ have some similar concerns with what you are doing here. Maybe. (Or maybe the GPU usage in Go is only pattern matching and the tree search is all done on the CPU and I’m wasting your time….)

  3. Nice post Charles. I have been routinely using R-hat for multiple chains. Haven’t tried your new stuff (can’t wait until I teach again so I can dig into it). But, R-hat<1.01 seems to be only applicable to parametric convergence problems. With nonparametric Bayes, you will almost never see R-hat that small so I still use R-hat<1.1 for that. I don't know if there is a theory here, but if others have some thoughts that would be interesting to hear.

    • hmmm… I’d be intrigued to work out in details what makes the nonparameteric case different and whether it would make sense to change the threshold. A big part of the preprint is providing new interpretation and in some cases a principled choice for the threshold we put on R-hat. I think the (scaled) tolerance on the violation of stationarity can more or less be the same across problems. On the other hand, the persistent variance changes, so we want to be mindful of that. Our last example in Section 5 suggests 1.01 is not a good idea, while something like 1.004 works better; in other cases, we show something larger than 1.01 is appropriate. Again it’s not about being more or less conservative, rather adjusting for the persistent variance.

      • I’d be intrigued to work out in details what makes the nonparameteric case different

        Not an expert or even a practitioner of nonparametric bayes, but I think nonparametric models often refer to overparameterized or infinite parameter (not in stan) models, and my guess would be that their likelihoods have multiple interior modes and long lines of very low curvature associated with weak identifiability. Multimodality is typically suppressed by good choice of priors, but it makes sense that their between chain estimates of parameters would diverge, even where posterior predictive quantities match.

        • I think we have a winner. Yes, identifiability is in the eye of the beholder. These models are extremely flexible which is why they are so handy. But, the devil is in the details.

Leave a Reply

Your email address will not be published. Required fields are marked *