Davendra Seunarine Maharaj writes:
We are seeking to characterize the performance and potential bottlenecks of the latest fast MCMC samplers. I see that Stan is currently using Intel TBB to parallelize the no-U-turn sampler (NUTS) across multiple chains. Do you know of any research attempted to parallelize each sampler itself within one chain.
Matt Hoffman (the inventor of the NUTS algorithm) responded:
Our group at Google has been very interested in using parallel compute in HMC variants (including NUTS), particularly on accelerators (e.g., GPUs). We’ve been working in the deep-learning-oriented autodiff+accelerator software frameworks TensorFlow and JAX, both of which are supported by our TensorFlow Probability library. It turns out that building on top of these frameworks gives you parallelism both within-chain (mostly from data parallelism) and across chains (because multiple chains can be run in parallel) “for free” without thinking too much about it—we just let the substrate (i.e., TF or JAX) decide how best to use the available resources.
Some of our systems papers you may (or may not) find interesting:
https://arxiv.org/abs/2002.01184
https://arxiv.org/abs/2001.05035Some papers arguing theoretically and empirically that running many chains in parallel is a good thing to do if you have the resources:
https://proceedings.mlr.press/v130/hoffman21a.html
https://proceedings.mlr.press/v119/hoffman20a.htmlAnd one about convergence diagnostics if you’re going to run many chains in parallel and don’t want to wait for them all to generate large effective sample sizes (with Andrew, among others):
https://arxiv.org/abs/2110.13017
Mark Jeffrey followed up:
Our research group comprises computer architecture and PL folks, and we are always keen to improve our understanding of a diversity of application domains to improve computer system support for them. It is encouraging to hear that Google has a significant interest in this area to the point that TensorFlow has a team developing it.
Two more questions:
1. Are you aware of any HMC/NUTS/MCMC competitions akin to those organized at DIMACS, e.g., does StanCon run a competition? This would give us pointers to
potentially useful input models to work with.2. In your opinion, does TensorFlow’s support for HMC methods supersede Stan, or will both continue to coexist with different strengths? I expect the students in my group may be more productive hacking on the internals of Stan than TensorFlow, but I am open to suggestion.
I passed those questions over to Bob Carpenter, who replied:
Stan can use multiple threads to evaluate the log density and gradients within a single chain and it can use multiple threads to run multiple chains in parallel. We can also send some computations to GPU to evaluate the log density.
Google did some comparisons of Stan and TensorFlow and the answer is not surprising: Stan was faster on CPU and TensorFlow faster on GPU. Most of the models used in Stan aren’t easy to push onto CPU because they don’t involve SIMD calculations within a single log density eval. This is in contrast to deep neural nets, which involve a lot of dense
matrix operations, which are the bread and butter of GPUs.TensorFlow usually runs 32 bit and Stan always runs 64 bit.
HMC and NUTS are both instances of MCMC, so I don’t see the contrast.
No, we’re not running any bakeoffs, nor do I know of any plans to do such. They are notoriously difficult to organize.
Finally, I’d suggest looking at some of Matt Hoffman and Pavel Sountsov’s new algorithms, which are designed to take advantage of the SIMD capabilities of GPUs and TPUs. I’m personally very excited about their recent MEADS algorithm.
Matt also sent along his responses:
1. In HMC variants, the overhead (at least in FLOPs) of the model-independent code is generally quite small compared to the cost of computing gradients. So as long as you’re FLOPS-bound, there’s not much point in aggressively optimizing the implementation. (But in https://proceedings.mlr.press/v130/hoffman21a.html we find that, when running on a GPU, NUTS in particular has a lot of control-flow overhead when run on top of TF or JAX; cf. also https://github.com/tensorflow/probability/blob/main/discussion/technical_note_on_unrolled_nuts.md)
So I don’t think there’s been a lot of work put into DIMACS-style implementation challenges. (Or I might just not know about it.)
But there are some “model zoos” out there that are more oriented towards comparing algorithmic improvements. Stan’s example models repo has lots of great stuff, there’s posteriordb, and our own inference gym.
2. I certainly don’t think that TFP is going to make Stan obsolete anytime soon. For one thing, I’d say that Stan’s documentation, community support, and general ease of use are in most respects well ahead of TFP. Also, I think it’s safe to say that Stan has better support for some modeling features—ODEs are a big example that springs to mind. Also, if you need double precision then it can be kind of a pain to get in TF/JAX.
Some of TFP’s big advantages over Stan to my mind are:
A. We get accelerator (i.e., GPU and TPU) support “for free” from building on top of TF and JAX. That includes support for sharding computation across multiple accelerators, which lets us do useful things like run 8x as many chains relatively easily, as well as letting us do kind-of-absurd things like apply HMC to a large Bayesian neural network using 512 TPU cores.
B. It’s easier to use TFP’s HMC as part of a larger system. If I want to, say, use HMC as an inner loop in training a deep generative model, or use an invertible neural network to precondition my sampler, I think it’d be harder to do that using Stan than using TFP.
C. I strongly prefer doing algorithm development in Python/JAX (using, e.g., FunMC) to c++. Having a model zoo that’s defined in the same autodiff framework as the one I’m writing my algorithm in is very convenient.
So compared to Stan, I think TFP is more geared towards problems where either you want the increased flexibility (and lower user-friendliness) of its lower-level API, or you’ve really hit a computational bottleneck that can be solved with access to more FLOPs.
Whereas for a lot of data-analysis problems, I think TFP’s advantages over Stan aren’t that relevant—if you can run 6 chains for 2000 iterations on a laptop in under five minutes, using a well-supported, well-documented, easy-to-use system, and everything converges with no tuning, why fix what’s not broken? Even in situations where things don’t converge properly, often the solution is to fix the model specification (e.g., by noncentering) rather than to throw more compute or customized algorithms at it.
Interesting discussion, and it’s good to see all this work being done on practical algorithms. No method will solve all problems, so it makes sense that multiple systems are being developed.