Rao-Blackwellization and discrete parameters in Stan

I’m reading a really dense and beautifully written survey of Monte Carlo gradient estimation for machine learning by Shakir Mohamed, Mihaela Rosca, Michael Figurnov, and Andriy Mnih. There are great explanations of everything including variance reduction techniques like coupling, control variates, and Rao-Blackwellization. The latter’s the topic of today’s post, as it relates directly to current Stan practices.

Expecations of interest

In Bayesian inference, parameter estimates and event probabilities and predictions can all be formulated as expectations of functions of parameters conditioned on observed data. In symbols, that’s

$latex \displaystyle \mathbb{E}[f(\Theta) \mid Y = y] = \int f(\theta) \cdot p(\theta \mid y) \, \textrm{d}\theta$

for a model with parameter vector $latex \Theta$ and data $latex Y = y.$ In this post and most writing about probability theory, random variables are capitalized and bound variables are not.

Partitioning variables

Suppose we have two random variables $latex A, B$ and want to compute an expectation $latex \mathbb{E}[f(A, B)].$ In the Bayesian setting, this means splitting our parameters $latex \Theta = (A, B)$ into two groups and suppressing the conditioning on $latex Y = y$ in the notation.

Full sampling-based estimate of expectations

There are two unbiased approaches to computing the expectation $latex \mathbb{E}[f(A, B)]$ using sampling. This first one is traditional, with all random variables in the expectation being sampled.

Draw $latex (a^{(m)}, b^{(m)}) \sim p_{A,B}(a, b)$ for $latex m \in 1:M$ and estimate the expectation as

$latex \displaystyle\mathbb{E}[f(A, B)] \approx \frac{1}{M} \sum_{m=1}^M f(a^{(m)}, b^{(m)}).$

Marginalized sampling-based estimate of expectations

The so-called Rao-Blackwellized estimator of the expectation involves marginalizing $latex p_{A,B}(a, b)$ and sampling $latex b^{(m)} \sim p_{B}(b)$ for $latex m \in 1:M$. The expectation is then estimated as

$latex \displaystyle \mathbb{E}[f(A, B)] \approx \frac{1}{M} \sum_{m=1}^M \mathbb{E}[f(A, b^{(m)})]$

For this estimator to be efficiently computatable, the nested expectation must be efficiently computable,

$latex \displaystyle \mathbb{E}[f(A, b^{(m)})] = \int f(a, b^{(m)}) \cdot p(a \mid b^{(m)}) \, \textrm{d}a.$

The Rao-Blackwell theorem

The Rao-Blackwell theorem states that the marginalization approach has variance less than or equal to the direct approach. In practice, this difference can be enormous. It will be based on how efficiently we could estimate $latex \mathbb{E}[f(A, b^{(m)})]$ by sampling $latex a^{(n)} \sim p_{A \mid B}(a \mid b^{(m)}),$

$latex \displaystyle \mathbb{E}[f(A, b^{(m)})] \approx \frac{1}{N} \sum_{n = 1}^N f(a^{(n)}, b^{(m)})$

Discrete variables in Stan

Stan does not have a sampler for discrete variables. Instead, Rao-Blackwellized estimators must be used, which essentially means marginalizing out the discrete parameters. Thus if $latex A$ is the vector of discrete parameters in a model, $latex B$ the vector of continuous parameters, and $latex y$ the vector of observed data, then the model posterior is $latex p_{A, B \mid Y}(a, b \mid y).$

With a sampler that can efficiently make Gibbs draws (e.g., BUGS or PyMC3), it is tempting to try to compute posterior expectations by sampling,

$latex \displaystyle \mathbb{E}[f(A, B) \mid Y = y] \approx \frac{1}{M} \sum_{m=1}^M f(a^{(m)}, b^{(m)})$ where $latex (a^{(m)}, b^{(m)}) \sim p_{A,B}(a, b \mid y).$

This is almost always a bad idea if it possible to efficiently calculate the inner Rao-Blackwellizization expectation, $latex \mathbb{E}[f(A, b^{(m)})].$ With discrete variables, the formula is just

$latex \displaystyle \mathbb{E}[f(A, b^{(m)})] = \sum_{a \in A} p(a \mid b^{(m)}) \cdot f(a, b^{(m)}).$

Usually the summation can be done efficiently in models like mixture models where the discrete variables are tied to individual data points or in state-space models like HMMs where the discrete parameters can be marginalized using the forward algorithm. Where this is not so easy is with missing count data or variable selection problems where the posterior combinatorics are intractable.

Gains from marginalizing discrete parameters

The gains to be had from marginalizing discrete parameters are enormous. This is even true of models coded in BUGS or PyMC3. Cole Monnahan, James Thorson, and Trevor Branch wrote a nice survey of the advantages of marginalization for some ecology models that compares marginalized HMC with Stan to JAGS with discrete sampling and JAGS with marginalization. The takeway here isn’t that HMC is faster than JAGS, but that JAGS with marginalization is a lot faster than JAGS without.

The other place to see the effects of marginalization are in the Stan User’s Guide chapter on latent discrete parameters. The first choice-point example shows how much more efficient the marginalization is by comparing it directly with estimated generated from exact sampling of the discrete parameters conditioned on the continuous ones. This is particularly true of the tail statistics, which can’t be estimated at all with MCMC because too many draws would be required. I had the same experience in coding the Dawid-Skene model of noisy data coding, which was my gateway to Bayesian inference—I had coded it with discrete sampling in BUGS, but BUGS took forever (24 hours compared to 20m for Stan for my real data) and kept crashing on trivial examples during my tutorials.

Marginalization calculations can be found in the MLE literature

The other place marginalization of discrete parameters comes up is in maximum likelihood estimation. For example, Dawid and Skene’s original approach to their coding model used the expectation maximization (EM) algorithm for maximum marginal likelihood estimation. The E-step does the marginalization and it’s exactly the same marginalization as required in Stan for discrete parameters. You can find the marginalization for HMMs in the literature on calculating maximum likelihood estiates of HMMs (in computer science, electrical engineering, etc.) and in the ecology literature for the Cormack-Jolly-Seber model. And they’re in the Stan user’s guide.

Nothing’s lost, really

[edit: added last section explaining how to deal with posterior inference for the discrete parameters]

It’s convenient to do posterior inference with samples. Even with a Rao-Blackwellized estimator, it’s possible to sample $latex a^{(m)} \sim p(a \mid b^{(m)})$ in the generated quantities block of a Stan program and then proceed from there with full posterior draws $latex (a^{(m)}, b^{(m)})$ of both the discrete and continuous parameters.

As tempting as that is because of simplicitly, the marginalization is worth the coding effort, because the gain in efficiency from working in expectation with the Rao-Blackwellized estimator is enormous for discrete parameters. It can often take problems from infeasible to straightforward computationally.

For example, to estimate the posterior distribution of a discrete parameter, we need the expectation

$latex \displaystyle \mbox{Pr}[A_n = k] = \mathbb{E}[\textrm{I}[A_n = k]].$

for all values $latex k$ that $latex A_n$ might take. This is a trival computation with MCMC (assuming the number of values is not too large) and carried out in Stan by defining an indicator variable and setting it. In contrast, estimating such a variable by sampling $latex a^{(m)} \sim p(a \mid b^{(m)})$ is very inefficient and increasingly so as the probability $latex \mbox{Pr}[A_n = k]$ being estimated is small.

Examples of both forms of inference are shown in the user’s guide chapter on latent discrete parameters.

12 thoughts on “Rao-Blackwellization and discrete parameters in Stan

  1. You write:

    “Cole Monnahan, James Thorson, and Trevor Branch wrote a nice survey of the advantages for some ecology models that compares marginalized HMC with Stan to JAGS with discrete sampling and JAGS with marginalization.”

    Link or Reference? Otherwise I guess I could write one of the authors. Fellow NOAA people.

  2. Marginalization can be infeasible if the number of categories is large or infinity, whereas a direct sampling (not in stan, but say a Gibbs sampling) on discrete variables may often encounter dimensional curse. It seems also possible to combine both techniques (marginalize out a few categories and do importance sampling on remainings). See https://arxiv.org/pdf/1810.04777.pdf for a similar application.

    • Right. I mentioned in the post (sorry it was so long) that there were cases that were intractable to marginalize. The place that comes up in practice is usually with variable selection in regression (non-zero probability of true zero value) or with something like Ising models. Or in missing count data or even with large-K categorical models. And while we can marginalize the discrete parameters out of models like general K-means or LDA, we still can’t sample the full posterior (though Andrew says you two are working on a stacking approach to that).

  3. > In Bayesian inference, parameter estimates and event probabilities and predictions can all be formulated as expectations of functions of parameters conditioned on observed data

    What about eg the median?

    • Over a year later: quantiles like the median can be defined as the inverse of expectation of an indicator function. More specifically, the value v of the q-th quantile is defined by the relation

      $q = Integral density(x) (1 if x < v else 1) dx$

  4. I wanted to reread this post and wanted to call out that the LaTeX on it, and likely many other similar posts, is currently broken. Maybe the new wordpress doesn’t support markdown, or otherwise maybe the $ were accidentally escaped in the migration?

Leave a Reply to Steve Shulman-Laniel Cancel reply

Your email address will not be published. Required fields are marked *