SparseNUTS: Preconditioning hierarchical models in HMC with a sparse “Laplace approximation” at the marginal mode

This post is from Bob.

Cole Monnahan (who came up with the idea and did the heavy lifting), two of his colleagues from the Template Model Builder (TMB) project, Kasper Kristensen and James T. Thorson, and I just put a paper up on arXiv.

If you have any feedback on either, we’d love to hear it.

The method

It’s quite simple if you know how the components work. The tricky part is getting them all to play together nicely, efficiently, and robustly.

1. Take a max marginal mode of a hierarchical model, of the kind produced by the linear mixed effects package (lme4) or Template Model Builder (TMB). Both packages are implemented in C++ under the hood and distributed with an R interface.

2. Center a second-order Taylor series approximation at the marginal mode using a precision matrix (inverse covariance) rather than covariance matrix. This is like a Laplace approximation, but it’s not centered at a global mode.

3. Take the resulting sparse precision matrix and use it to precondition the target density. This is equivalent to using the precision as a mass matrix in HMC, as Radford Neal showed in his introduction in the MCMC Handbook. This approach is necessitated by Stan’s lack of sparse mass matrix support.

4. Use TMB’s R interface and Andrew Johnson’s StanEstimators package to make the model available to Stan’s samplers and other tooling.

There’s nothing Stan-specific about this technique. It could be rolled into PyMC and NumPyro with a lot of work if JAX’s experimental sparse library is up to the task of duplicating what TMB is doing.

Empirical evaluation

The paper is evaluated on 15+ realistic models, some of which can be scaled. The results show that with this preconditioning, Stan can scale to 10K+ parameter hierarchical models when there is sparseness and high correlations.

The paper demonstrate how much better SparseNUTS is than trying to use Stan’s built-in diagonal or dense mass matrix estimators. Obviously, the dense mass matrix won’t scale to 10K parameters as the mass matrix will have 100M entries and that becomes nearly impossible to estimate or manage computationally in Stan. Stan is very slow to estimate mass matrices in these cases (WALNUTS should be better, which we hope to roll out soon), but the real problem is that diagonal preconditioning is insufficient for hierarchical models.

Should you use it?

Yes! The form of hierarchical models addressed by SparseNUTS are widely used and Stan by itself simply cannot fit them. SparseNUTS also provides much more flexibility than INLA in writing custom likelihoods and priors.

Why not code this method in Stan?

We do not have the sparse Hessian tooling within Stan’s automatic differentiation library to implement this directly. Nor are we likely to roll out something for specific model classes, as we try to stick to black-box techniques in the core of Stan. This isn’t because we don’t like model-class specific methods. We are happy to implement them in special packages like brms. It’s just because one has to limit one’s scope somehow to keep a software project manageable with limited developer hours.

Some alternatives on the horizon

We are evaluating how far we’ll be able to scale Charles Margossian’s black-box approach to marginalization, which Steve Bronder and Charles have gotten into shape so it will be in the next Stan release. As INLA demonstrates, marginalization can be an efficient alternative to sampling the whole model if you can get the marginalization right.

We’re also evaluating how far we can get with Adrian Seyboldt’s Nutpie sampler’s approach to estimating low-rank plus diagonal mass matrices based on Fisher divergence, which could also potentially solve this problem as most of the structure in hierarchical models is low rank. Adrian and Eliot Carlsen (both of PyMC Labs) and I are writing that up now and we will release the paper soon. So far, Nutpie’s low-rank plus diagonal adaptation works incredibly well for some problems, but fails spectacularly on others for reasons we do not yet understand.

1 thought on “SparseNUTS: Preconditioning hierarchical models in HMC with a sparse “Laplace approximation” at the marginal mode

Leave a Reply

Your email address will not be published. Required fields are marked *