Log Sum of Exponentials for Robust Sums on the Log Scale

This is a public service announcement in the interest of more robust numerical calculations.

Like matrix inverse, exponentiation is bad news. It’s prone to overflow or underflow. Just try this in R:

> exp(-800)
> exp(800)

That’s not rounding error you see. The first one evaluates to zero (underflows) and the second to infinity (overflows).

A log density of -800 is not unusual with the log likelihood of even a modestly sized data set. So what do we do? Work on the log scale, of course. It turns products into sums, and sums are much less prone to overflow or underflow.

log(a * b) = log(a) + log(b)

But what do we do if we need the log of a sum, not a product? We turn to a mainstay of statistical computing, the log sum of exponentials function.

log(a + b) = log(exp(log(a)) + exp(log(b))

           = log_sum_exp(log(a), log(b))

We use a little algebraic trick to prevent overflow and underflow while preserving as many accurate leading digits in the result as possible:

log_sum_exp(u, v) = max(u, v) + log(exp(u - max(u, v)) + exp(v - max(u, v)))

The leading digits are preserved by pulling the maximum outside. The arithmetic is robust becuase subtracting the maximum on the inside makes sure that only negative numbers or zero are ever exponentiated, so there can be no overflow on those calculations. If there is underflow, we know the leading digits have already been returned as part of the max term on the outside.

Mixtures

We use log-sum-of-exponentials extensively in the internal C++ code for Stan, and it also pops up in user programs when there is a need to marginalize out discrete parameters (as in mixture models or state-space models). For instance, if we have a normal log density function, we can compute the mixture density with mixing proportion lambda as

log_sum_exp(log(lambda) + normal_log(y, mu[1], sigma[1]),
            log1m(lambda) + normal_log(y, mu[2], sigma[2]));

The function log1m is used for robustness; it’s value is defined algebraically by

log1m(u) = log(1 - u)

But unlike the naive calculation, it won’t underflow to 0 when u is close to 1 and 1 – u overflows to 1. Try this in R:

log(1 - 10e-20)
log1p(-10e-20)

log1m isn’t built in, but log1p is and negation doesn’t lose us any digits. The subtraction in the first expression overflows to 1 so the log returns 0 (thus the overall expression underflows). But the second case returns the correct non-zero result.

What goes on under the hood is that different approximations are used to the log function depending on the value of u, typically using lower-order series expansions when standard algorithms are prone to underflow or overflow.