Stacking for Non-mixing Bayesian Computations: The Curse and Blessing of Multimodal Posteriors

Yuling, Aki, and I write:

When working with multimodal Bayesian posterior distributions, Markov chain Monte Carlo (MCMC) algorithms can have difficulty moving between modes, and default variational or mode-based approximate inferences will understate posterior uncertainty. And, even if the most important modes can be found, it is difficult to evaluate their relative weights in the posterior.

Here we propose an alternative approach, using parallel runs of MCMC, variational, or mode-based inference to hit as many modes or separated regions as possible, and then combining these using importance sampling based Bayesian stacking, a scalable method for constructing a weighted average of distributions so as to maximize cross-validated prediction utility. The result from stacking is not necessarily equivalent, even asymptotically, to fully Bayesian inference, but it serves many of the same goals. Under misspecified models, stacking can give better predictive performance than full Bayesian inference, hence the multimodality can be considered a blessing rather than a curse.

We explore with an example where the stacked inference approximates the true data generating process from the misspecified model, an example of inconsistent inference, and non-mixing samplers. We elaborate the practical implantation in the context of latent Dirichlet allocation, Gaussian process regression, hierarchical model, variational inference in horseshoe regression, and neural networks.

Poor mixing of MCMC (and other algorithms such as variational inference) is inevitable, either because of fundamental discreteness (multimodality) in the posterior distribution, or diverse geometry arising from the desire to fit a model that represents a multitude of explanations for data, or just because you want to work fast and in parallel. What, then, to do with all these snips from the posterior distribution? It turns out that Bayesian model averaging (giving each snip a weight corresponding to its estimated mass in the posterior distribution) doesn’t always work so well, in part for the same mathematical reasons that Bayes factors don’t work in an M-open world. We find that cross-validated model averaging (Bayesian stacking) works better.

Stacking of parallel chains can even be superefficient, outperforming full Bayes because it can catch model failures, in a way similar to the mixture model formulation of Kamary, Mengersen, Robert, and Rousseau. And you can flip the idea around and use the stacking average to check model fit.

We can implement stacking in Stan by computing the vector of the log posterior density in the generated quantities block, and then using Pareto smoothed importance sampling to compute leave-one-out cross validation without having to re-fit the model n times.

I think this is a big idea, both for throwing at difficult problems in Bayesian computation and for facilitating a faster workflow using parallel simulation.

7 thoughts on “Stacking for Non-mixing Bayesian Computations: The Curse and Blessing of Multimodal Posteriors

  1. > Pareto smoothed importance sampling to compute leave-one-out cross validation without having to re-fit the model n times

    When PSIS fails, you do need to re-fit the model right? Or is there justification to ignore these things in this context?

    I’d expect a chain stuck in a bad place in the model to have bad PSIS diagnostics.

    Something that strikes me looking at these plots is I’ve read through this stuff a bunch, and I never remember what the BMA weights are, and I don’t think I ever understood what pseudo-BMA weights were at all.

    Not to say the comparison isn’t worthwhile if there’s some historical context there, but the little example that starts “One of the benefits of stacking is that it manages well if there are many similar models” on this page: https://cran.r-project.org/web/packages/loo/vignettes/loo2-weights.html#example-oceanic-tool-complexity is the thing I always come back to in my head when I think BMA and it’s the reason I’m always happy to dismiss it without caring about it. -10 vs. 0 doesn’t mean as much.

    • Ben:

      For PSIS, the hope is that even when it has problems, it’s getting us in the right direction. When it fails, I guess a key question is how often it fails. For example if you have a million data points and the PSIS diagnostics show problems 10,000 times, what do you do about it? You’re not gonna re-fit the model 10,000 times. If it’s only messing up on 1% of data points, maybe it’s no big deal? Sometimes when PSIS has problems, Aki recommends K-fold cross validation with K = 5 or 10.

      To get BMA weights, you compute or approximate the posterior mass corresponding to each chain. The usual problem with BMA is strong dependence on aspects of the prior that have essentially no impact on the posterior distribution of the parameters in the model. One problem is that people with traditional Bayesian training often think that BMA is the right thing to do. In the setting of this new paper, we’re kind of off the hook on that one because we’re just approximating. One of the goals of the new paper is to understand how stacking can outperform approximate BMA.

  2. This is pretty close to what Christian Robert used to call population Monte Carlo. The resampling step imposes some fairly severe dimension limitations.

    • Steven, I think the starting point for population Monte Carlo methods is the unbiasedness at every iteration—with the goal to compute the ingegral with respect to the exact posterior density. But the chain-stacking appraoch is not the same, even asymptotically, as the exact posterior.

  3. I’m far from completely understanding the manuscript but I had a question about stacking of the complete-pooling and boundary-avoiding models. Would it be possible to stack the complete pooling model, the centered model, and non-centered model all together? Then if the data comes from the regime where centered is more efficient than non-centered it would give higher weights to those chains and otherwise give higher weights to the non-centered chains? Or am I misunderstanding the use?

    Finally, is the first sum $\sum_{n=1}{n_{val}}$ supposed to be $\sum_{i=1}{n_{val}}$ using the index $i$ rather than $n$?

      • Hi Eric, Yes it should be $i$. It makes sense to stack centered, non-centered together. But in most experiments I have run, I didn’t find too many improvements from additional inclusions. It also takes extra effort in practice to code both centered and non-centered parameterizations.

Leave a Reply to Eric Cancel reply

Your email address will not be published. Required fields are marked *