I’m reading a really dense and beautifully written survey of Monte Carlo gradient estimation for machine learning by Shakir Mohamed, Mihaela Rosca, Michael Figurnov, and Andriy Mnih. There are great explanations of everything including variance reduction techniques like coupling, control variates, and Rao-Blackwellization. The latter’s the topic of today’s post, as it relates directly to current Stan practices.
Expecations of interest
In Bayesian inference, parameter estimates and event probabilities and predictions can all be formulated as expectations of functions of parameters conditioned on observed data. In symbols, that’s
for a model with parameter vector
and data
In this post and most writing about probability theory, random variables are capitalized and bound variables are not.
Partitioning variables
Suppose we have two random variables
and want to compute an expectation
In the Bayesian setting, this means splitting our parameters
into two groups and suppressing the conditioning on
in the notation.
Full sampling-based estimate of expectations
There are two unbiased approaches to computing the expectation
using sampling. This first one is traditional, with all random variables in the expectation being sampled.
Draw
for
and estimate the expectation as
![\displaystyle\mathbb{E}[f(A, B)] \approx \frac{1}{M} \sum_{m=1}^M f(a^{(m)}, b^{(m)}). \displaystyle\mathbb{E}[f(A, B)] \approx \frac{1}{M} \sum_{m=1}^M f(a^{(m)}, b^{(m)}).](https://s0.wp.com/latex.php?latex=%5Cdisplaystyle%5Cmathbb%7BE%7D%5Bf%28A%2C+B%29%5D+%5Capprox+%5Cfrac%7B1%7D%7BM%7D+%5Csum_%7Bm%3D1%7D%5EM+f%28a%5E%7B%28m%29%7D%2C+b%5E%7B%28m%29%7D%29.&bg=ffffff&fg=000000&s=0)
Marginalized sampling-based estimate of expectations
The so-called Rao-Blackwellized estimator of the expectation involves marginalizing
and sampling
for
. The expectation is then estimated as
![\displaystyle \mathbb{E}[f(A, B)] \approx \frac{1}{M} \sum_{m=1}^M \mathbb{E}[f(A, b^{(m)})] \displaystyle \mathbb{E}[f(A, B)] \approx \frac{1}{M} \sum_{m=1}^M \mathbb{E}[f(A, b^{(m)})]](https://s0.wp.com/latex.php?latex=%5Cdisplaystyle+%5Cmathbb%7BE%7D%5Bf%28A%2C+B%29%5D+%5Capprox+%5Cfrac%7B1%7D%7BM%7D+%5Csum_%7Bm%3D1%7D%5EM+%5Cmathbb%7BE%7D%5Bf%28A%2C+b%5E%7B%28m%29%7D%29%5D&bg=ffffff&fg=000000&s=0)
For this estimator to be efficiently computatable, the nested expectation must be efficiently computable,
![\displaystyle \mathbb{E}[f(A, b^{(m)})] = \int f(a, b^{(m)}) \cdot p(a \mid b^{(m)}) \, \textrm{d}a. \displaystyle \mathbb{E}[f(A, b^{(m)})] = \int f(a, b^{(m)}) \cdot p(a \mid b^{(m)}) \, \textrm{d}a.](https://s0.wp.com/latex.php?latex=%5Cdisplaystyle+%5Cmathbb%7BE%7D%5Bf%28A%2C+b%5E%7B%28m%29%7D%29%5D+%3D+%5Cint+f%28a%2C+b%5E%7B%28m%29%7D%29+%5Ccdot+p%28a+%5Cmid+b%5E%7B%28m%29%7D%29+%5C%2C+%5Ctextrm%7Bd%7Da.&bg=ffffff&fg=000000&s=0)
The Rao-Blackwell theorem
The Rao-Blackwell theorem states that the marginalization approach has variance less than or equal to the direct approach. In practice, this difference can be enormous. It will be based on how efficiently we could estimate
by sampling 
![\displaystyle \mathbb{E}[f(A, b^{(m)})] \approx \frac{1}{N} \sum_{n = 1}^N f(a^{(n)}, b^{(m)}) \displaystyle \mathbb{E}[f(A, b^{(m)})] \approx \frac{1}{N} \sum_{n = 1}^N f(a^{(n)}, b^{(m)})](https://s0.wp.com/latex.php?latex=%5Cdisplaystyle+%5Cmathbb%7BE%7D%5Bf%28A%2C+b%5E%7B%28m%29%7D%29%5D+%5Capprox+%5Cfrac%7B1%7D%7BN%7D+%5Csum_%7Bn+%3D+1%7D%5EN+f%28a%5E%7B%28n%29%7D%2C+b%5E%7B%28m%29%7D%29&bg=ffffff&fg=000000&s=0)
Discrete variables in Stan
Stan does not have a sampler for discrete variables. Instead, Rao-Blackwellized estimators must be used, which essentially means marginalizing out the discrete parameters. Thus if
is the vector of discrete parameters in a model,
the vector of continuous parameters, and
the vector of observed data, then the model posterior is
With a sampler that can efficiently make Gibbs draws (e.g., BUGS or PyMC3), it is tempting to try to compute posterior expectations by sampling,
where 
This is almost always a bad idea if it possible to efficiently calculate the inner Rao-Blackwellizization expectation,
With discrete variables, the formula is just
![\displaystyle \mathbb{E}[f(A, b^{(m)})] = \sum_{a \in A} p(a \mid b^{(m)}) \cdot f(a, b^{(m)}). \displaystyle \mathbb{E}[f(A, b^{(m)})] = \sum_{a \in A} p(a \mid b^{(m)}) \cdot f(a, b^{(m)}).](https://s0.wp.com/latex.php?latex=%5Cdisplaystyle+%5Cmathbb%7BE%7D%5Bf%28A%2C+b%5E%7B%28m%29%7D%29%5D+%3D+%5Csum_%7Ba+%5Cin+A%7D+p%28a+%5Cmid+b%5E%7B%28m%29%7D%29+%5Ccdot+f%28a%2C+b%5E%7B%28m%29%7D%29.&bg=ffffff&fg=000000&s=0)
Usually the summation can be done efficiently in models like mixture models where the discrete variables are tied to individual data points or in state-space models like HMMs where the discrete parameters can be marginalized using the forward algorithm. Where this is not so easy is with missing count data or variable selection problems where the posterior combinatorics are intractable.
Gains from marginalizing discrete parameters
The gains to be had from marginalizing discrete parameters are enormous. This is even true of models coded in BUGS or PyMC3. Cole Monnahan, James Thorson, and Trevor Branch wrote a nice survey of the advantages of marginalization for some ecology models that compares marginalized HMC with Stan to JAGS with discrete sampling and JAGS with marginalization. The takeway here isn’t that HMC is faster than JAGS, but that JAGS with marginalization is a lot faster than JAGS without.
The other place to see the effects of marginalization are in the Stan User’s Guide chapter on latent discrete parameters. The first choice-point example shows how much more efficient the marginalization is by comparing it directly with estimated generated from exact sampling of the discrete parameters conditioned on the continuous ones. This is particularly true of the tail statistics, which can’t be estimated at all with MCMC because too many draws would be required. I had the same experience in coding the Dawid-Skene model of noisy data coding, which was my gateway to Bayesian inference—I had coded it with discrete sampling in BUGS, but BUGS took forever (24 hours compared to 20m for Stan for my real data) and kept crashing on trivial examples during my tutorials.
Marginalization calculations can be found in the MLE literature
The other place marginalization of discrete parameters comes up is in maximum likelihood estimation. For example, Dawid and Skene’s original approach to their coding model used the expectation maximization (EM) algorithm for maximum marginal likelihood estimation. The E-step does the marginalization and it’s exactly the same marginalization as required in Stan for discrete parameters. You can find the marginalization for HMMs in the literature on calculating maximum likelihood estiates of HMMs (in computer science, electrical engineering, etc.) and in the ecology literature for the Cormack-Jolly-Seber model. And they’re in the Stan user’s guide.
Nothing’s lost, really
[edit: added last section explaining how to deal with posterior inference for the discrete parameters]
It’s convenient to do posterior inference with samples. Even with a Rao-Blackwellized estimator, it’s possible to sample
in the generated quantities block of a Stan program and then proceed from there with full posterior draws
of both the discrete and continuous parameters.
As tempting as that is because of simplicitly, the marginalization is worth the coding effort, because the gain in efficiency from working in expectation with the Rao-Blackwellized estimator is enormous for discrete parameters. It can often take problems from infeasible to straightforward computationally.
For example, to estimate the posterior distribution of a discrete parameter, we need the expectation
for all values
that
might take. This is a trival computation with MCMC (assuming the number of values is not too large) and carried out in Stan by defining an indicator variable and setting it. In contrast, estimating such a variable by sampling
is very inefficient and increasingly so as the probability
being estimated is small.
Examples of both forms of inference are shown in the user’s guide chapter on latent discrete parameters.