Applications of (Bayesian) variational inference?

I’m curious about whether anyone’s using variational inference, and more specifically, using variational approximations to estimate posterior expectations for applied work. And if so, what kinds of reactions have you gotten from readers or reviewers?

I see a lot of talk in papers about how variational inference (VI) scales better than MCMC at the cost of only approximating the posterior. MCMC, which is often characterized as “approximate”, is technically asymptotically exact. MCMC’s approximation is not very many decimal places of accuracy rather than bias, at least in cases where MCMC can sample the posterior.

But I don’t recall ever seeing anyone use VI for inference in applied statistics. In particular, I’m curious if there are any Bayesian applications of VI, by which I mean applications where the variational approximation is used to estimate Bayesian posterior expectations in the usual way for an applied statistics problem of interest. That is, I’m wondering if anyone uses a variational approximation q(theta | phi), where phi is fixed as usual, to approximate a Bayesian posterior p(theta | y) and use it to estimate expectations as follows.

E[f(theta) | y] = INTEGRAL f(theta) q(theta | phi) d.theta.

This could be computed with Monte Carlo when it is possible to sample from q(theta | phi).

I’m using our Pathfinder variational inference system (now in Stan) to initialize MCMC, but I wouldn’t trust inference based on Pathfinder because of the very restrictive variational family (i.e., multivariate normal with low rank plus diagonal covariance). Similarly, most of the theoretical results I’ve been seeing around VI are for normal approximating families, particularly of the mean field (diagonal covariance) variety. Mean field approximations are easy to manipulate theoretically and computationally, but seem to make poor candidates for predictive inference, where there is often substantial posterior correlation and non-Gaussianity.

I know that there are ML applications to autoencoding that use variational inference, but I’m specifically asking about applied statistics applications that would be published in an applied journal, not a stats methodology or ML journal. I’ve seen some applications of point estimates from VI to “fit” latent Dirichlet allocation (LDA) models, but the ones I’ve seen don’t compute any expectations other than point estimates of parameters from a local mode among combinatorially many modes.

I’m curious about applications using ML techniques like normalizing flows as the variational family. I would expect those to be of more practical interest to applied statisticians than all the VI that has come before. I’ve seen cases where VI outperforms NUTS from Abhinav Agrawal and Justin Domke using a 10-layer deep, 20-ish neuron wide, real non-volume preserving (realNVP) flow touched up with importance sampling—their summary paper’s still under review and Abhinav’s thesis is being revised. But it requires a lot of compute, which isn’t cheap these days. The cases where realNVP outperforms include funnels, multimodal targets, bananas and other varying curvature models (like from an IRT 2PL posterior). I suspect the costs and accessibility of the equivalent of an NVIDIA H100 GPU will drop to a point where everyone will be able to use these methods in 10 years. It’s what I’m spending at least half my time on these days—JAX is fun and ChatGPT can (help) translate Stan programs pretty much line for line into JAX.

22 thoughts on “Applications of (Bayesian) variational inference?

  1. I don’t have much of what you’re specifically asking for, but I have an example application where we could really use a way to do something with a posterior that was less computational.

    A reader on the blog wrote to me for help with his project on migration within Germany. We formulated a model and have fit it to multiple age groups and now a decade or more of data.

    The data available is 160,000 data points per year per age group, so it’s incredibly tightly constrained posterior. The priors are all O(1) on parameters scaled to be O(1), so the posteriors should be about 1/sqrt(160,000) in size, or standard deviations along the lines of 0.0025 with obviously some variation and some dependency structure for the different parameters (on the order of 80 parameters I think).

    However, it’s multi-modal with lots of irrelevant modes, so it’s quite complicated to find the right mode, and then sample in that vicinity in any useful way… It doesn’t help that sometimes gradients seem to work poorly or give NaN in some dimensions (this is in Turing, there are multiple differentiation backends we can try etc).

    What worked really well was sampling from a tempered version of the posterior with temperatures really high. Like starting at 4000 and then tempering down to 1000 or 400 or whatever it was. This got us into the right vicinity. If you figure tempering a standard normal is like multiplying the standard deviation by sqrt(temperature) then at temperature 400 the standard deviation is instead of something like 0.0025 maybe something like .04

    A lot of my projects have wound up like this… the model fits really well, there’s a lot of data, and so the posterior is virtually impossible to sample because it’s like sampling a spike. You might say then “well, just do an optimization and take the optimal point, there’s no real uncertainty to care about” but the correlation structure is sometimes important, and sometimes some parameters have marginal distributions that are quite large because they correlate with other parameters. Not only that but sometimes just finding an optimal point is hard, getting stuck with lots of local modes in 80 dimensions.

  2. Bob:

    Aki should be able to answer this one fairly quickly, I think. In the meantime I can give a Yes answer for the simple reason that the EM algorithm is a special case of variational Bayes (it’s the special case where the approximating distribution is a delta function), and the EM algorithm has been used zillions of times in real applied problems.

    • the EM algorithm is a special case of variational Bayes (it’s the special
      case where the approximating distribution is a delta function)

      I’m not sure this is right. Missing data/latent variables in EM correspond to parameters of interest in VB (call them z) and in both cases we have a distribution for z rather than a delta function.

      As I understand it the relationship is

      – both methods maximize the same objective (ELBO), a function of an approximating distribution q = q(z) and some parameters
      – parameters θ in EM are for the complete data distribution p(y, z); parameters ϕ in VB are of of the approximating posterior q(z)
      – EM updates θ and q alternately, and at each step q is implicitly updated to exactly match the true θ-conditional posterior p(y | z, θ); VB doesn’t prescribe any specific way to update the ELBO
      – in EM typically the interest is in the MLE for θ while in VB the interest is typically in the posterior p(z | y)

        • Here is the relevant section from p. 337 (end of 13.7).

          EM as a special case of variational Bayes

          Variational inference proceeds in J steps, each time updating one conditional distribution gj, averaging over the other factors of g. EM has two steps (the E-step and the M-step), alternately estimating a parameter φ and averaging over the other parameters γ. EM can be seen as a special case of variational Bayes in which (a) the parameters are partitioned into two parts, φ and γ, (b) the approximating distribution for φ is required to be a point mass (thus, updating g(φ) is equivalent to updating the point estimate of φ), and (c) the approximating distribution for γ is unconstrained; thus g(γ) = p(γ|φ, y), conditional on the most recent update of φ.

          I find this confusing. Not the EM algorithm, part—I get how (a)–(c) says one can view EM as an approximate unnormalized maringal posterior over a subset of the parameters γ obtained by plugging in a point estimate for the remaining parametres φ. Just the connection to variational inference. For a start, it’s a very specific VI algorithm being cited in this section, whereas the discussion elsewhere was more inclusive of black-box methods like ADVI.

          I think of VI as fundamentally minimizing a variational objective. That’s general enough it can be connected to the MLE. If we take the approximating family q to be a single point mass, then KL[p || q], the KL divergence from the target density to the approximate density, is minimized by the q that puts that point mass at the maximum of p. If we think of p as the Bayesian posterior for a uniform (perhaps improper) prior for some likelihood, then this gives us a roundabout way of deriving a maximum likelihood estimator by treating the parameters as random.

          All of this VI-inclusiveness brings up the question of whether there is a practical inference algorithm that isn’t an instance of VI by this kind of reasoning? Presumably you’d include Laplace approximation because it’s approximating a density with a second point-estimated density. And you’d include MCMC because it approximates a target density with a uniform distribution over a sample.

          Maybe BDA4 will take a VI-forward approach to inference.

        • Bob:

          The meaning of “variational inference” has changed over the years. When we were writing BDA3, the algorithm ADVI did not exist. Indeed, I didn’t really understand variational inference at all until writing that chapter of the book. I knew the book had to include variational inference and expectation propagation because they were important approximate Bayesian computation methods (as you note, “approximate” in the sense of not being simulation-consistent), and I didn’t want to do the lazy textbook B.S. approach of reading up on them and then spitting out the description from other textbooks, so I read up on them and then programmed them from scratch on an example.

          As with many statistical procedures, variational inference can be defined in different levels of generality, and it could be that in BDA4 we will talk about the more general definition. The same is true of EM! The EM algorithm can also be considered as fundamentally minimizing a variational objective, just with the restriction that the distribution of φ is required to be a point mass. Indeed, that’s kind of the original insight of variational inference, that the estimated distribution can be framed as a minimum of an integral, as with the calculus of variations.

          Considering it as a mathematical method (rather than just part of VI), I’ve only used the calculus of variations once, but it was very satisfying. This was with my project with Meng on path sampling, where we wanted to figure out the optimal path between two distributions (see Figure 2 in our paper from 1998, although I think we did this particular derivation in 1994 during my stay at the University of Chicago). We knew that there was this method called “calculus of variations” that you could use to optimize over an entire continuous function, so we went to a textbook–I think we used Courant and Hilbert’s Methods of Mathematical Physics–, looked up the calculus of variations, and successfully applied it to our problem! Very satisfying.

    • Thanks, Kyurae and Leon: that was exactly the kind of thing I was looking for. Now I’m a bit embarassed for not having searched harder before posting, because the paper Leon cited has “Variational Inference” right in the title!

      Looking through the Raj et al. paper, it’s making a mean field assumption, but it discusses integrating over parameters to get the predictive inference, which is nice. I’ve never seen anyone do that before. I can’t quite parse Figure 1, which seems to be evaluating it against a maximum likelihood alternative.

  3. i am not the author, but i recently came across this very interesting economics paper published in a very good journal that uses variational inference

    https://direct.mit.edu/rest/article-abstract/doi/10.1162/rest_a_01449/120874/Estimating-Nursing-Home-Quality-with-Selection

    abstract:

    We use variational inference (VI), a technique from the machine learning literature, to estimate a mortality-based Bayesian model of nursing home quality accounting for selection. We demonstrate how one can use VI to quickly and flexibly estimate a high-dimensional economic model with large datasets. Using our facility quality estimates, we examine the correlates of quality and find that public report cards have near-zero correlation. We then show that in contrast to prior literature, higher quality nursing homes fared better during the pandemic: a one standard deviation increase in quality corresponds to 2.5% fewer Covid-19 cases.

    • Sam:

      That paper says, “Given the size of our data, conventional approaches, such as an MCMC sampler, are computationally infeasible.” That could be, but I’d like to see what would happen if they were to just try to fit their model in Stan. People have used Stan to fit some pretty big problems!

      • andrew:

        see the excerpts below from the paper where they give more detail on computational feasibility

        short-version: in a simulation using a sample 10% the size of their data and using gpu hardware to speed up computations for each, VI was 10 times faster than NUTS. because their data resides in an environment where hardware accelerators are not available, they conclude it would be prohibitively slow to use NUTS on their data.

        “We show in a simulation exercise that this procedure recovers the true values and offers a substantial speedup against conventional MCMC-based approaches.”

        “We do not report runtimes, but in all of our experiments we observed at least a 10-fold improvement using VI compared with NUTS, and often much larger. In our simulations both algorithms (VI and NUTS) benefited from GPU acceleration (using a single consumergrade GPU, i.e. Nvidia Geforce GTX 1660 Ti) resulting in approximately 12-fold speedup compared with the deployment on a 16-core CPU (Intel i7-1077H). Unfortunately, hardware accelerators are not available at the secure server where our data reside. Further, despite using the GPU, to keep runtime under two hours, we had to restrict the size of our simulated data set to less than 10% of the size of the biggest market in the real data. We therefore conclude that in our application even the state-of-the-art MCMC approach is prohibitively slow, while VI converges in a reasonable time.”

  4. As an experimental chemist applying statistical modelling (to the best of my ability!) to reasoning about noisy/contaminated analytical data, e.g. mass spectra using PyMC and more recently numpyro, my experience mirrors Daniel’s. With enough raw data there comes a point where the tight posterior constraints make MCMC sampling grind to a halt. I find myself reaching for VI in each case and so far it has rarely disappointed (might get the odd nan with a Dirichlet parameter but parameter constraints 0.0001 < alpha < 1000 take care of those). As in Daniel's case, the outcome tends to be more useful than a point estimate, and sampling from the posterior reveals interesting relationships between latent variables. Multimodal posteriors are interesting in the abstract but for us fully identifying each unique mode and its significance plus dealing with label switching (which happens in most cases of interest to us) hasn't been worth the effort.

    I haven't experimented with the MCMC temperature idea, which sounds promising.

    • We have been using Julia + Turing for our model, and I found that it was pretty trivial to implement tempering by taking the Turing model and wrapping it in a wrapper “TemperedModel”. then using the LogDensityProblems interface to simply evaluate the Turing log density and divide it by the temperature.

      In Stan, I don’t know how you’d do tempering.

  5. Here is another line of work using VI for biological problems from my past lab — I’m not the author, but participated in discussions for these. The VI is especially useful for such mixture models because it breaks the labelling symmetry that MCMC algorithms exhibit in this setting. If one cares about the cluster assignments, then the standard advice of marginalising over clusters for MCMC is inadequate.

    Kapourani, CA., Sanguinetti, G. Melissa: Bayesian clustering and imputation of single-cell methylomes. Genome Biol 20, 61 (2019). https://doi.org/10.1186/s13059-019-1665-8

    Kapourani, CA., Argelaguet, R., Sanguinetti, G. et al. scMET: Bayesian modeling of DNA methylation heterogeneity at single-cell resolution. Genome Biol 22, 114 (2021). https://doi.org/10.1186/s13059-021-02329-8

    And the underlying BRPMeth R package for clustered (mixture) generalised regression using VI.

    Kapourani CA, Sanguinetti G (2018). “BPRMeth: a flexible Bioconductor package for modelling methylation profiles.” Bioinformatics, 34, 2485-2486. doi:10.1093/bioinformatics/bty129.

    ~~MM

  6. Hey I have a flip side of a question – does anyone use stan for larger datasets? something like 100,000s observations and thousands of columns? Throw in a few levels of hierarchy and I usually give up and use tree based methods

    • For 100,000 observations…yes, I use it for that, but with something like 2-3 predictors, it would probably take hours or days, even with all the tricks like re-scaling variables, tighter priors, non-centered parameterization.

      But 1,000’s of columns? No way. Does that imply 1000’s or predictors or maybe I’m missing something here.

      • Yea 1000’s of predictors / explanatory variables. In the past few companies I’ve worked for there’d be thousands / hundreds of thousands of individuals with lots of different variables over time. Once you include aggregating them over different horizons and lags the number of explanatory variables can grow significantly.

  7. Here is an example of when I use variational Bayes in an applied setting.

    I conduct online and field experiments with one or more treatment groups and a control. Often, my key outcome is a series of Likert based scales. Usually around 6-7 questions. My goal is to use a hierarchical model to estimate–with some potential shrinkage–the treatment effect for all the Likert questions. Something like (treatment | question) in R. So, I have random slopes for the treatment group and varying intercepts for the questions.

    If I run this in rstanarm or brms, it will likely take many hours, sometimes overnight, and when it is finished, I get errors, usually divergent transitions. This necessitates running the model again, often many times, for the span of 1-3 days. I’ve even spun up the models on a powerful server, but it doesn’t really make much of a difference. I’ve tried running some of this stuff in parallel, especially when creating predictions with posterior_epred, with little success. Similar to what others have mentioned, I convert all continuous variables to a zscore, set adapt_delta to 0.999, and tightened some of the priors, etc.

    Using a frequentist approach—i.e., lme4—increases speed, but it is even more likely to throw an error, likely because all the varying intercepts were pulled towards 0.

    In these cases, a new-ish package has been a life saver: vglmer, which uses a few algorithms for variational Bayes. It runs in less than 10 mins, and rarely has any issues/errors. It doesn’t generate a fully posterior distribution, however. But if I’m working on multiple experiments, under comfortable but defined timelines, this is the only feasible approach.

Leave a Reply

Your email address will not be published. Required fields are marked *