I’m not a great programmer. The story that follows is not intended to represent best programming practice or even good programming practice. It’s just something that happened to me, and it’s the kind of thing that’s happened to me many times before, so I’m sharing it with you here.
The problem
For a research project I needed to fit a regression model with an error term that is a mixture of three normals. I googled *mixture model Stan* and came to this page with some code:
data {
int K; // number of mixture components
int N; // number of data points
array[N] real y; // observations
}
parameters {
simplex[K] theta; // mixing proportions
ordered[K] mu; // locations of mixture components
vector[K] sigma; // scales of mixture components
}
model {
vector[K] log_theta = log(theta); // cache log calculation
sigma ~ lognormal(0, 2);
mu ~ normal(0, 10);
for (n in 1:N) {
vector[K] lps = log_theta;
for (k in 1:K) {
lps[k] += normal_lpdf(y[n] | mu[k], sigma[k]);
}
target += log_sum_exp(lps);
}
}
I should’ve just started by using this code as is, but instead I altered it in a couple of ways:
data {
int M;
int N;
int K;
vector[N] v;
matrix[N,K] X;
}
parameters {
vector[K] beta;
simplex[M] lambda;
ordered[M] mu;
vector[M] sigma;
}
model {
lambda ~ lognormal(0, 2);
mu ~ normal(0, 10);
sigma ~ lognormal(0, 2);
for (n in 1:N){
vector[M] lps = log(lambda);
for (m in 1:M){
lps[m] += normal_lpdf(v[n] | X*beta + mu[m], sigma[m]);
}
target += log_sum_exp(lps);
}
}
The main thing was adding the linear predictor, X*beta, but I also renamed a couple of variables to make the code line up with the notation in the paper I was writing, also I added a prior on the mixture component sizes and I removed some of the code from the Stan User’s Guide that increased computational efficiency but which seemed to me to make the code harder to read to newcomers.
I fit this model setting M=3, and . . . it was really slow! I mean, stupendously slow. My example had about 1000 data points and it was taking, oh, I dunno, close to an hour to run?
This just made no sense. It’s not just that it was slow; also, the slowness was just not right, which made me concerned that something else was going wrong.
Also there was poor convergence. Some of this was understandable, as I was fitting the model to data that had been simulated from a linear regression with normal errors, so there weren’t actually three components. But, still, the model has priors, and I don’t think the no-U-turn sampler (NUTS) algorithm used by Stan should have so much trouble traversing this space.
Playing with priors
It was time to debug. My first thought was that the slowness was caused by poor geometry: if NUTS is moving poorly, it can take up to 1024 steps per iteration, and then each iteration will take a long time. Poor mixing and slow convergence go together.
A natural way to fix this problem is to make the priors stronger. With strong priors, the parameters are restricted to fall in a small, controlled zone, and the geometry should be better.
I tried a few things with priors, the most interesting of which was to set up a hierarchical model for the scales of the mixture components. I added this line to the parameters block:
real log_sigma_0;
And then, in the model block, I replaced “sigma ~ lognormal(0, 2);” with:
sigma ~ lognormal(log_sigma_0, 1);
One thing that made the priors relatively easy to set up here was that the data are on unit scale: they’re the logarithms of sampling weights, and the sampling weights were normalized to have a sample mean of 1. Also, we’re not gonna have weights of 1000, so they have a limited range on the log scale.
In any case, these steps of tinkering with the prior weren’t helping. The model was still running ridiculously slowly.
Starting from scratch
OK, what’s going on? I decided to start from the other direction: Instead of starting with my desired model and trying to clean it up, I started with something simple and built up.
The starting point is the code in the Stan User’s Guide, given above. I simulated some fake data and fitted it:
library("rstanarm")
set.seed(123)
# Simulate data
N <- 100
K <- 3
lambda <- c(0.5, 0.3, 0.2)
mu <- c(-2, 0, 2)
sigma <- c(1, 1, 1)
z <- sample(1:K, n, replace=TRUE, prob=lambda)
v <- rnorm(N, mu[z], sigma[z])
# Fit model
mixture <- cmdstan_model("mixture_2.stan")
mixture_data <- list(y=v, N=N, K=3)
v_mixture_fit <- mixture$sample(data=mixture_data, seed=123, chains=4, parallel_chains=4)
print(v_mixture_fit)
And it worked just fine, ran fast, no convergence problems. It also worked fine with N=1000; it just made sense to try N=100 first in case any issues came up.
Then I changed the Stan code to my notation, using M rather than K and a couple other things, and still no problems.
Then I decided to make it a bit harder by setting the true means of the mixture components to be identical, changing "mu <- c(-2, 0, 2)" in the above code to
mu <- c(0, 0, 0)
Convergence was a bit worse, but it was still basically ok. So the problem didn't seem to be the geometry.
Next step was to add predictors to the model. I added them to the R code and the Stan code . . . and then the problem returned.
So I'd isolated the problem. It was with the regression predictors. But what was going on? One problem could be nonidentification of the constant term in the regression with the location parameters mu in the mixture model---but I'd been careful not to include a constant term in my regression, so it wasn't that.
Finding the bug
I stared at the code harder and found the problem! It was in this line of the Stan program:
lps[m] += normal_lpdf(v[n] | X*beta + mu[m], sigma[m]);
The problem is that the code is doing one data point at a time, but "X*beta" has all the data together! So I fixed it. I changed the above line to:
lps[m] += normal_lpdf(v[n] | Xbeta[n] + mu[m], sigma[m]);
and added the following line to the beginning of the model block:
vector[N] Xbeta = X*beta;
Now it all works. I found the bug.
What next?
The code runs and does what it is supposed to do. Great. Now I have to go back to the larger analysis and see whether everything makes sense.
Here was the output of the simple normal linear regression fit to the data I'd simulated:
Median MAD_SD
(Intercept) 0.57 0.10
x -0.16 0.01
Auxiliary parameter(s):
Median MAD_SD
sigma 1.02 0.02
And here was the result of fitting the regression model in Stan with error term being a mixture of 3 normals:
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
lp__ -1340.21 -1339.91 2.85 2.77 -1345.21 -1335.96 1.00 681 458
beta[1] -0.16 -0.16 0.01 0.01 -0.18 -0.14 1.00 2139 2357
lambda[1] 0.32 0.14 0.34 0.19 0.01 0.94 1.02 419 977
lambda[2] 0.39 0.27 0.35 0.36 0.01 0.95 1.01 506 880
lambda[3] 0.28 0.11 0.32 0.15 0.01 0.92 1.01 543 879
mu[1] -0.14 -0.02 0.61 0.57 -1.43 0.57 1.02 385 281
mu[2] 0.58 0.56 0.51 0.33 -0.25 1.56 1.01 474 536
mu[3] 1.53 1.37 1.56 0.67 0.60 2.42 1.01 562 1025
sigma[1] 0.77 0.85 0.30 0.22 0.21 1.09 1.01 522 481
sigma[2] 0.85 0.93 0.32 0.16 0.27 1.16 1.00 644 1028
sigma[3] 0.77 0.79 0.50 0.29 0.25 1.11 1.00 910 1056
log_sigma_0 -0.35 -0.34 0.64 0.65 -1.42 0.68 1.00 1474 1605
The estimate for the slope parameter, beta, seems fine, but it's hard to judge mu and sigma. Ummm, we can take a weighted average for mu, 0.32*(-0.14) + 0.39*0.58 + 0.28*1.53 = 0.61, which seems kind of ok although a bit off from the 0.57 we got from the linear regression. What about sigma? It's harder to tell. We can compute the weighted variance of the mu's plus the weighted average of the sigma^2's, but it's getting kinda messy, also really we want to account for the posterior uncertainty---as you can see from above, these uncertainty intervals are really wide, which makes sense given that we're fitting a mixture of 3 normals to data that were simulated from a single normal distribution . . . Ok, this is getting messy. Let's just do it right.
I added a generated quantities block to the Stan program to compute the total mean and standard deviation of the error term:
generated quantities {
real mu_total = sum(lambda.*mu)/sum(lambda);
real sigma_total = sqrt(sum(lambda.*((mu - mu_total)^2 + sigma^2))/sum(lambda));
}
And here's what came out:
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
mu_total 0.56 0.56 0.11 0.11 0.39 0.74 1.00 2172 2399
sigma_total 1.02 1.02 0.03 0.02 0.98 1.06 1.00 3116 2052
Check. Very satisfying.
The next step is to continue on with the research on the problem that motivated all this. Fitting a regression with mixture model for errors, that was just a little technical thing I needed to do. It's annoying that it took many hours (not even counting the hour it took to write this post!) and even more annoying that I can't do much with this---it's just a stupid little bug, nothing that we can even put in our workflow book, I'm sorry to say---but now it's time to move on.
Sometimes, cleaning the code, or getting the model to converge, or finding a statistically-significant result, or whatever, takes so much effort that when we reach that ledge, we just want to stop and declare victory right there. But we can't! Or, at least, we shouldn't! Getting the code to do what we want is just a means to an end, not an end in itself.
P.S. I cleaned the code some more, soft-constraining mu_total to 0 so I could include the intercept back into the model. Here's the updated Stan program:
data {
int M;
int N;
int K;
vector[N] v;
matrix[N,K] X;
}
parameters {
vector[K] beta;
simplex[M] lambda;
ordered[M] mu;
vector[M] sigma;
real log_sigma_0;
}
model {
vector[N] Xbeta = X*beta;
lambda ~ lognormal(log(1./M), 1);
mu ~ normal(0, 10);
sigma ~ lognormal(log_sigma_0, 1);
sum(lambda.*mu) ~ normal(0, 0.1);
for (n in 1:N){
vector[M] lps = log(lambda);
for (m in 1:M){
lps[m] += normal_lpdf(v[n] | Xbeta[n] + mu[m], sigma[m]);
}
target += log_sum_exp(lps);
}
}
generated quantities {
real mu_total = sum(lambda.*mu);
real sigma_total = sqrt(sum(lambda.*((mu - mu_total)^2 + sigma^2)));
}