This post is by Bob.
The title is based on the similarly named classic film.
“Big” models moving from Stan to JAX
Ever since the big ML frameworks PyTorch and TensorFlow were released, the Stan developers have been worried they’re going to put Stan out of business (we built Stan’s autodiff before those packages existed, but after Theano). While that hasn’t quite happened yet, I now believe our days are numbered. For high end applications, Stan is slowly, but surely, being replaced by JAX. Many places I go (don’t want Andrew to jump on a hyperbolic use of “everywhere”), I hear about people switching from Stan to JAX.
Here are four examples:
1. At StanCon in Oxford in 2024, Elizaveta Semenova started her talk by saying something to the effect of, “I’m sorry to say this here, but I don’t use Stan any more—I switched to JAX through NumPyro for scalability.”
2. Mitzi Morris just started working as a contractor for the U.S. Center for Disease Control (CDC) (!? as they say in chess). Their public GitHub repositories have old Stan code they used to use that has been replaced by JAX, for which they are building up a library of code. It’s very hard to build reusable code in Stan given its blocked structure and the limited form of includes; Sean Pinkney has gone further than I thought possible with his helpful Stan functions project. The CDC models are for wastewater-informed forecasting—here’s the project overview.
3. Andrew posted a job announcement from the L.A. Dodgers baseball team a week ago that said, “We have a soft spot for jax and numpyro but Stan and PyMC folks are obviously always of interest.” Like Andrew, they apparently don’t like using their shift key.
4. Matt Hoffman’s been saying this for years and backing it up with adaptive ensemble samplers, convergence monitoring, etc. He, Pavel Sountsov, and Colin Carroll wrote a draft chapter for the second edition of the MCMC Handbook, Running Markov Chain Monte Carlo on Modern Hardware and Software. It contains complete instructions for massively parallelizing HMC on a GPU using JAX.
But what about the hardware?
The biggest obstacle for people moving is finding the hardware on which to run JAX most efficiently—it’s really tailored for multiprocessing and GPU processing and I don’t believe most of the Stan users have access to this kind of hardware to fit their models. But I believe this is going to change over the next ten years. That, and I believe we’re going to get better and better Macs—the ARM chips are way faster than the Intel chips for the kind of random-access memory needed in Stan programs.
New samplers moving to JAX
New samplers like the micro-canonical HMC of Jakob Robnik and Uroš Seljak (and more recently Reuben Cohn-Gordon) are being coded only in JAX. Like many others, they added their package (see the previous link) to the Blackjax package. They even have a competitor for posteriordb in the form of Inference Gym.
A very nice feature of putting things up on Blackjax is that you can use them with any Python-defined log density function—it doesn’t even need to come from JAX. Brian Ward managed to plug Stan models into JAX (by which I mean having JAX call Stan’s C++, not generating JAX code from Stan).
Static vs. dynamic automatic differentiation
We built Stan with automatic differentiation before PyTorch, TensorFlow or JAX existed. We went with the same dynamic design as PyTorch eventually chose, despite Matt Hoffman and I knowing that the static TensorFlow/JAX approach could be more performant. The problem was that we didn’t have the people to implement enough derivatives to do it that way. Instead, we just started autodiffing through functions in the Eigen matrix library (like matrix multiplication and division) and in Boost (like the Runge-Kutta 4/5 ODE solver and many of the special functions). The static approach of XLA (which is the infrastructure under JAX and TensorFlow) does limit expressiveness of things like loops and conditionals to not condition on parameters, making it challenging, if not impossible to write iterative algorithms in JAX.
Graphical modeling
Tools like BUGS, PyMC, and NumPyro are all fundamentally based on the notion of a directed acyclic graphical model. That is, you have nodes representing random variables with each variable being conditionally independent given the nodes that point to it. You specify the distribution of each node given the nodes on which it depends. Transforms are represented by deterministic nodes. The upside to constraining oneself to graphical models is that everything has to remain clearly generative (assuming you avoid improper flat priors, that is). This lets tools like PyMC automate a lot of workflow in the same way that we can with brms in Stan. When you go outside that paradigm, as you can in PyMC by adding density statements, the built-in automation of workflow breaks. So while it’s possible, they generally don’t recommend it. This came up in an earlier blog post I wrote, What’s a generative model? PyMC and Stan edition.
Differentiable programming
Stan does not work on a graphical modeling base. You can write graphical models in Stan, but we just treat them as defining a log density (that was the leap that led to Stan—I thought about how to code JAGS to generate log densities rather than conditional samplers as they do in BUGS/JAGS). In Stan, we just declare constrained parameters and define a log density over them. That’s it (the Jacobian adjustment for the change of variables is kept under the hood). There are generated quantities, but that’s conceptually after sampling.
Like Stan, JAX is also a differentiable programming language. Unlike Stan, it’s wonderfully compositional and general.
Writing JAX models like Stan models
As much as people like to use NumPyro and sometimes even PyMC to generate JAX code, I think it may be easier in the end to just write JAX directly. That way, nothing gets between you and JAX and you don’t have to figure out how to filter JAX through middleware. When you do that, the models can be organized very much like in Stan.
Brian Ward and I took some time to work through what a simple linear regression would look like coded this way in JAX. I went over it a couple weeks ago with Andrew and he didn’t think it was too bad. Here’s the example.
GitHub Gist: linear regression in JAX.
In this example, we first do the constraining parameter transforms and extract the Jacobian, then define the model directly. Although we didn’t need it for this simple example, the Oryx library in JAX provides an extensive library of constraining transforms with Jacobians. It’s using the really cool PyTree features of JAX to move between structured log densities and array-based serialized log densities. This is sooo cool and the fact that it can all be compiled away is even cooler.
In JAX, there’s no distribution statement syntactic sugar, but then even Andrew thinks those were a mistake in Stan. I still like them, though I admit they’ve caused a lot of confusion in terms of people thinking about how Stan works. It’s odd to find myself on the more permissive side of language design discussion for once.
Generated quantities of the form used in Stan are trivial to code directly in JAX with vmap. Removing all these special constructs is super helpful for learnability, as is having the language embedded in Python (as much as Python is terrible for this kind of thing, much like R, because of its lack of static typing, its global interpreter lock, and it’s R-like scope, I believe it’s well on its way to becoming the lingua franca of numerical analysis.
Generating JAX from Stan?
People have asked if we were going to work on generating JAX code from Stan programs. I doubt it, given how easy it is to just define models directly in JAX and given how few dedicated developers we now have. The whole point of Stan was to provide a structured way to do derivatives for statistics models. We can just do that directly in JAX as the above gist shows.
Giving up working on Stan?
No, we’re not giving up on Stan. People still use BUGS! Stan’s going to keep being used for a long time if history is any indication. We have lots of strategies for making it faster, adding samplers that will work well on CPU but not GPU, etc. I don’t plan to be involved in coding for Stan any more. It’s just too complicated for me. My plan is to write standalone samplers like WALNUTS, following Adrian Seyboldt’s lead for Nutpie. If you’re OK with Python but haven’t tried Nutpie, I’d highly recommend it—it’s twice as fast as Stan and more robust due to its adaptation—I’m rolling that into the new WALNUTS code and maybe we’ll find the cycles to roll it into Stan itself after more testing.
Do you have recent or even inside information about the status of the JAX Metal project, aiming to bring JAX to the Mac GPUs? I tried it about a year ago and it was still buggy and missed fp64. Last week I took a glance (no reinstall, just the web page) and it didn’t seem like it had gone much further.
Elsewhere I’ve been using JAX on an old Nvidia GPU and have been happy.
Nope—I had no idea there even was a JAX project to accelerate for ARM (recent Mac CPU/GPU) architectures. We have good NVIDIA GPU clusters on which to run this kind of thing, so I haven’t been worried bout server level Macs. Having said that, I just got a new Mac Studio desktop to replace my 6 year old iMac Pro. The new ARM processors are great for multi-CPU jobs like running multiple Markov chains in Stan. They’re three or four times faster than similarly priced Intel machines for fitting Stan models and way faster than my old iMac.
I’ve just been assuming costs of GPU cycles will keep going down by a factor of 2 every eighteen months, as compute has been going ever since the 1980s. That should make them very affordable in ten years. I don’t actually know what the GPU compute horizon looks like and would love to hear more from someone who understands hardware evolution better than me.
I thought it was widely accepted in hardware circles that Moore’s law scaling on transistor counts ended a few years back (somewhere between 2015-2020) and that it’s been growing but sub exponential since.
You can find all kinds of articles about it. Some would argue it hasn’t quite ended yet, but I don’t think many people in hardware thinks it’s going to power its way through 2035. That’s even more true if we’re talking about worldwide political destabilization.
Costs per unit of compute can keep going down even if hardware performance improves sub-exponential.
At least in the consumer peecee world, GPU improvements have been uninspired of late. I’m using a 3080, and the two generation later 5080 is only 50% faster (for Go). (These are the second best GPUs in the corresponding generation/line.) While the 5090 (the high end model) is about exactly twice is fast as the 3080, it’s insanely expensive.
Ouch. And this is what the world looks like running on a now-ancient peecee purchased in January 2022. Ouch.
I was thinking about getting a new machine January 2026, a full four years later, but I doubt things will improve much by then.
But I am enthused about the 16th gen Intel chips: the i9 (high end) will have a riduculous number of processors, and it turns out that two efficiency cores are better than one performance core running two threads with hyperthreading, so in terms of generic CPU processing, things are looking up*. But consumer peecees with an i9 won’t be generally available until late 2026 or so, I’d guess. Shut up and write better code, is the story for the nonce…
*: Nowadays CPU performance is largely limited by how much cooling you can apply to your processor. So it’s unlikely that I’ll be able to run all of those cores full tilt for extended periods. Sigh.
For most of what I do, Stan is so easy to write (and to read) that I don’t think I’ll switch away from it just for a 2x speedup in computation. Maybe would do it for 4x if I’m doing something that is really slow to run. But… can’t even a current LLM do a good job converting Stan code to whatever other language I might want? If that’s not already the case then I expect it will be within months, not years. Maybe I can treat an LLM as a sort of compiler: I write in Stan, then that gets translated to JAX for me (or whatever).
Bottom line is that I expect to write Stan for the next ten years or the rest of my career, whichever comes first.
For that matter, we may not be too far from the day I can write some equations on a note pad, show it to my computer’s camera, and say (in natural language) how I want the model implemented, much as one might do today with a grad student, and after half an hour of back and forth I’ll have an implementation. I might not even have to know (or care) what language is being run under the hood. At the moment I wouldn’t trust the results without a lot of checking, but on the other hand I also wouldn’t trust a grad student’s results without a lot of checking either.
Hi, Phil. It’s not just scaling up, it’s scaling out. So it’s partly just for dealing with more data in parallel. As such, we’re looking at much bigger speedups in computation than 2X or I agree this wouldn’t be worth the effort. Check out Pavel, Colin, and Matt’s paper (cited above).
I’ve been amazed at how good GPT5 is at math, stats, Stan, C++, and Python. My officemate, Jeremy Magland, just built a RAG-system for Stan (on top of CLINE, I believe) that ingests our user’s guide and a compressed form of the reference manual and it’s quite good at Stan. He’ll roll it out to the public as soon as we can figure out how to get the foundation to pay for it—it costs a few cents per conversation about Stan with all the context even using GPT4.1, which is good in this controlled environment. I was frustrated coding JAX with the 4-series models—it just couldn’t sort out the distinction between what needed to be bound vs. passed in as an argument—I haven’t tried with GPT5. I’ve also heard Claude’s Sonnet is good for this. I find that these systems are great if you can understand the output.
If you can’t understand the code or math produced by an LLM, they can be very good tutors. I spent yesterday and this morning being educated by GPT on lock-free ring buffers implemented in C++20—Steve recommended them as the right way to handle convergence monitoring in parallel but they are very challenging for someone who hasn’t ever tried to write high-performance asynchronous concurrent code. If you look at the documentation, it can feel like they’re pranking you with a logic puzzle. For example, Steve sent me a ring buffer tutorial by Eric Rigtorp, which sent me to the C++ doc for std::atomic, which in turn sent me to std::memory_order. At that point, I needed both Steve’s and GPT’s help to understand all the constructs.
P.S. I’m in the same boat as a sexagenerian—I don’t feel much different than I did 30 years ago, but I don’t know how much I’ll be working 10 yeas from now! For now, I still love learning math, learning code, and building open source tools and am still amazed someone’s willing to pay me to do it every day.
GPT 5 is the first one from Open AI that I have started to have some confidence in. I still find it making mistakes, but significantly less than before.
Scaling out matters more for big data or bigger models. For a lot of problems one CPU is sufficient.
> lock-free ring buffers implemented in C++20—Steve recommended them as
That made me laugh. I think I’ll coast for another decade on the C++ stuff I had to learn to understand what Steve and Tadej were doing.
I was coasting until I had to think about parallelism! The ring buffer stuff is super cool—it’s lock free, which takes some standing on your head to understand. So if you change your mind, there’s a very short and to the point tutorial on ring buffers by Eric Rigtorp.
I stopped coding for the core of Stan because of the complexity added by Steve and Tadej. I tried to unify the approach of simplex and log_simplex, and just couldn’t get anything to compile. The type traits templates and their lack of doc and the lack of good error messages from C++ were my undoing. Now I’m just building standalone things off to the side of Stan because it’s too hard for me to code Stan itself. It’s too bad given that I wrote nearly half of the first version including the autodiff and parsing and now I can’t understand either of those in the current version (Daniel Lee also wrote nearly half of it and Matt wrote NUTS).
One of the examples you give a few posts back of how an LLM explained the math behind how to efficiently parameterized a model where there was a funnel was suspiciously like an explanation and code in McElreath’s book, which I believe is copyrighted. So maybe the explanation was good because it was illegally used to train the LLM. I wonder how many of the other good explanations are similarly obtained.
I think you’ll find all of the explanations of how to efficiently parameterize a hierarchical model to be similar to one another. Our early descriptions used the term “Matt trick” and our later ones “non-centered parameterization” and they closely followed the papers on the subject. There are only so many different ways to talk about standardizing.
It turns out people are really good at plagiarizing without AI. When Mitzi joined the Stan project, the first project she worked on was spatial smoothing for epidemiology. She went through several decades of people’s papers on the ICAR model. Many, if not most, of them cited Besag and then copied his math and often text almost line for line. Of course, I’m not saying it’s OK because people do it without AI.
Legally, at least in the U.S., it is still not clear what is fair use for training and what is illegal. I imagine it’s going to take years to sort this out in court. Like VCRs being used to illegal copy movies, the liability may go to the user of the tool rather than the manufacturer of the tool. In other words, it may turn out not to be illegal to train your AI with a bunch of books and articles.
In the U.S., copyright is automatically assigned to the author of text or code. The author may then assign the copyright to an employer. For example, most U.S. universities require their professors to assign software copyright to the university, but do not require this for books or papers. If you own the copyright, you can release your material under a license for others to use.
Bob:
Regarding the first line of your above post: It’s a Mad, Mad, Mad, Mad World is a film with a memorable title that was aired from time to time when we were kids, but I’d hardly call it a “classic.” Unwatchable is more like it!
Agree, it’s a terrible movie.
I also agree it’s a terrible movie. Nevertheless, I’m going to stand by my assertion that it’s a classic. Nobody said classics had to be good. The largely favorable reviews summarized in the Wikipedia link are surprising.
I suspect that few people need the extra performance that JAX can give. Most applied scientists I’ve met avoid coding and are happy using interfaces like, say, brms or lme4 in R.
Faster code is nice, but my workflow depends more on being able to diagnose and visualize the results of my models. Stan has many available diagnostics (numerical and graphical), and makes it easy to manipulate and visualize its outputs. This is especially simple in R using packages like posterior, bayesplot, tidybayes, and others. Are these advantages available in JAX, or will they be available soon?
Maybe Stan and JAX will split the pie, with Stan being the default for comfort and JAX being used for speed and parallelism.
I don’t see the future of Stan as tied to a specific implementation of autodiff or HMC or whatever. I see Stan as a strongly-typed PPL that can be transpiled, with an active community and strong documentation and testing.
So if it’s a JAX world, let’s have stan2jax :)
JAX isn’t expressive enough for a full stan2jax, but we could map a limited subset of Stan to JAX. We’d have to give up any conditionals on parameters (including iterative algorithms), recursive functions, and probably some things I’m not considering yet. We’d probably want to give up on NUTS for performance on GPU. I think we could map all of Stan to PyTorch—they also use dynamic autodiff.
The bottleneck is developer time and project management/design time. That could perhaps be solved by applying for funding, but that’s hard for software (it’s largely why I moved to Flatiron—it was exhausting raising funds) and I don’t think any of our senior people are applying for funding that would support something like a stan2jax transpolar.
If someone wanted to do this, I’d be happy to help and I’m sure Brain Ward would be happy to help on the language side.
As Andrew (the blog host) says, everyone wants extra performance because of workflow—we rarely know which model we want to fit. He’s always asking us to somehow make Stan go faster. I don’t know that JAX is going to help with making a 2 minute run take 20 seconds, but it might help a 10 minute run take 2 minutes.
Diagnosing and visualizing results is just a matter of grabbing the draws and importing them. You can either stay in Python or do that or save the results and import them into R. Working across languages is always rough, though, starting from the indexing-from-1 in R and from-0 in Python (and every other programming language I know other than R and MATLAB).
The ArviZ toolkit in Python replicates much of what is in posterior and loo (I don’t even know if it supplies plotting. Aki’s been involved making sure all the definitions are right. plotnine is a call-for-call duplicate of ggplot2. pandas is like data frames and some of the reshaping. I write all my own plotting code (GPT writes most of it now), so I have never used bayesplot, tidybayes, etc.
The comfort issue is super important. Most of the statisticians who work in R would rather not move to Python. That makes sense—there’s a steep learning curve. I liked John Cook’s post on what he calls “bicycle skills.” Writing models in JAX means learning some Python. This is one of its great affordances—it just puns numpy and scipy (meaning it uses the same API but with an autodiffable and jittable implementation). As it turns out, the folks who work in Python would rather not move to R or Stan, which is why NumPyro and PyMC are so popular. That means PyMC and NumPyro are more comfortable for all those folks getting their degrees in computer science, data science, engineering, and the physical sciences, all of which are very Python-heavy these days (zero of our 20 postdocs in math and ML and stats use R and even the biologists here have switched to Python). Ditto the Julia folks and Turing.jl, though there are many more Julia users (and even Julia developers) here at Flatiron than R users.
The comfort and speed analogy reminds me of the Howling Wolf song!
I like the metaphor (analogy?) of “bicycle skills”. Do you think that Numpyro and JAX are bicycle skills for Stan coders? Or are they still closer to “just-in-time skills” to learn if the model is too slow (or incompatible) with Stan?
Also, that Howling Wolf song is a banger and got a good chuckle out of me.
Yes, I think all these programming languages are like bicycle skills. Although I think that is a bit dismissive—some of these skills are worth leveling up to near Tour de France levels, depending on your day job.
I finally got to use Stan again at work a couple weeks back (been years!).
Someone was working with a regression model and had some questions. I think a lot of what they were talking about was sort of tied up more in technical contingencies of how the model was originally developed, rather than things that’ll push them forward. I’m hoping Stan gives them some new perspective on the problem.
jax can be really fun to play with, just as a programming exercise. I think I’d choose jax for something if the problem definition is really stable and I can get all my python lined up to support the iteration. Other than that I get lazy trying to wire up whatever optimizer, figure out how to separate data/params, how to monitor progress, blah blah — even with AI!
I agree that JAX isn’t as seamless for MCMC as Stan. The seams are also different. For Stan, it’s two languages (R and Stan, Python and Stan, etc.), whereas for Python, it’s a heterogeneous set of tools. JAX uses the numpy and scipy libraries and uses base Python data structure, which is a huge affordance for Python users over having to learn a new language like Stan. That’s what most of the PyMC and NumPyro users have told me.
What we really need is a good MCMC environment in Python. That is something I’d like to work on over the next year as we build normalizing flow and diffusion approaches to VI. I’m not even convinced MCMC is going to be state of the art for sampling when GPU compute gets cheaper—it’s not state of the art as far as I can tell when compute is no object—we (Justin Domke and Abhinav Agrawal, really—I just pitched problems) could fit normalizing flows to things we could not sample with HMC, but it took a bajillion GPU cycles.
The genius of Stan is in the combination of coding and estimation. I haven’t done a lot of work with JAX, but being able to translate a mathematical model into valid Bayesian inference is not easy. So it seems like JAX/Stan translation would be great to have if JAX is simply going to be more performant.
At the same time, GPU acceleration is of limited importance in some models because (at least from what I’ve seen) GPU acceleration is about matrix algebra–and some models are difficult (though not impossible, of course) to write in matrix form. The feature of Stan that has been life-changing for my work is within-chain parallelization — if JAX implements something similar, then it would be worth taking a look, but if not, the speedups aren’t going to matter much.
Within-chain parallelization is much much easier with JAX. The whole language is largely compositional around parallelization. Even though GPU acceleration is most beneficial when there are massive SIMD operations like matrix multiply, it’s also amazing just for evaluating log densities in parallel if you can keep everything on kernel as JAX can. Then you can mix the largely heterogeneous log density evaluations and the largely neural network-based normalizing flow or diffusion parameterizations. With an H100 GPU, we could do 50K log density evaluations to calculate the ELBO in variational inference in parallel. This kind of parallel back end compute behind log density evals is also what you see in Matt Hoffman, Pavel Sountsov, and Colin Carroll’s tutorial linked above. It’s also another reason people like NumPyro—it lets you build model components that are neural networks.
Bob, great post and thanks for all the detail. You nailed where JAX beats Stan. The schism – and I really do feel like this is a reformation of some kind – comes down to what is Stan good at that JAX or others still lack? An even better question is, however, what would we want in a next generation ppl?
I think you laid out a nice starting list parallelism, flexibility, a robust, fast, and maintained AD backend, and an ecosystem of math functions, all of which JAX checks off. Additional benefits of being in a feature rich programming environment, which limits one in Stan, is having access to dictionary types, ragged arrays, and in-place object modification for memory savings.
Besides the poor documentation, I still see a lack of stylistic clarity for the model and the mathematics. Stan is particularly good for seeing the model laid out in the nude. I’d like to see more principled ways to denote maps/charts between measures. Even better would be a sampler that samples directly from the measure for the constrained variables – right now I don’t think there’s a great sampler that can mix and match parameters on different measures but maybe we’ll get something. I think Stan did the right thing by declaring the constrained variables upfront. Discrete parameters and multivariate discrete parameters, are a pain point. I’ve been custom writing tree based sampling for highly dependent multivariate discrete parameters intertwined with Stan sampled parameters. The model needs to learn the best approximation of the discrete multivariate distribution by decomposing it into smaller conditional marginal distributions that I do have data on. These come out to be graph type problems that being in JAX can really help (though I’m doing it in Julia).
I do hope that many of these get implemented in a future ppl. In the meantime, I think there will be many quality of life improvements to Stan over the next few years. It’s possible we can get much better handling of complex index scenarios through dictionary-like types. Constraint type stuff is getting better with the recent inclusion of the constraining/unconstraining functions. It seems not too far off to get the Nutpie style adaptation with the 2x speedup in. There’s a pr in the works for gaussian copula support. We could add a discrete type using “data augmentation”, that would be cool!
We can get even more than 2x speedup using Nutpie-like techniques. Adrian Seyboldt has been using very conservative settings. If you only need an ESS of 100 or so, then I’m measuring a version of it that’s more like 8x faster. Then Nikolas Siccha has a really nice step size adaptation using Bayesian optimization to maximize expected square jump distance per cost rather than having the user fix a step size. All of this is really independent of Stan, but we can bring in the benefits. Following Adrian’s implementation of Nutpie, we plan to roll out sampler implementations that can work over Stan, NumPy, NumPyro, JAX, etc. models. This is a little different than coding the samplers in JAX—that’s a challenge for our current NUTS-based samplers.
I think we can achieve enough stylistic clarity on the models. At least to the point that there’s any consensus on how to write a Stan model. I think the advantage of tools embedded in Python is that they can be more modular than Stan. I know you’ve been working on a really nifty function library, which is great, but it’s impossible to write a standalone hierarchical prior in Stan, for example. That’s why Maria Gorinova developed Slicstan. If I had an infinite developer budget, we’d be moving to a blockless Stan that would clean up a lot of these issues. As is I think it’ll be easier to write the generative process in something like JAX in a natural order binding data variables.
PyTorch could be mentioned in this context. For example probabilistic programming framework called Pyro build on top of PyTorch. PyTorch project is now Linux foundation project, so it is fully open source.
I mostly hear about NumPyro these days, not Pyro. I don’t know why they switched focus. I’m not even sure who’s managing the project now that Eli Bingham moved to industry. As far as I can tell, it’s picking up users quickly.
I am late to this thread, but the code examples from my book have been available as NumPyro/JAX for a while now. I’ve thought about doing the next edition of Statistical Rethinking with both R/Stan and NumPyro examples, and maybe I will. But at the same time I feel like the most important modeling issues – logically connecting statistical algorithms to scientific models – is essentially unsupported in both environments. I’m not sure what tools to build to fill that gap, nor which environment has users eager to see them. Lots to do still.
It’s challenging to write a book using more than one language (R and Python), because readers are often starting from the tool they know and use. Do you have any idea what the uptake of your book is among Python users? I know the models have also been translated to PyMC, which can also generate JAX code (though I mostly see people using NumPyro for that). I tend to see the more programmer-oriented introductions to Bayesian statistics in Python recommended on the PyMC forums.
There are some nice web site technologies that Brian Ward has used to allow you to see examples in multiple languages under tabs—it’s a problem for us with BridgeStan. I think being able to see the same thing done in different languages is super useful, but also a huge pain.
Did you have any ideas on how to connect statistics to the scientific models? I don’t have any ideas on how to support that better in a PPL, but it’s the key thing we have to teach all the scientists to do. I just worked on a project with a group of biologists who were really really surprised that they could code their forward model and I could just solve the inverse problem for them with Stan.
For me as a user, one of the the main things that’s missing in Stan is modularity. I’d really like to be able to define a time-series prior with all its components including hyperprior variable declarations and hyperpriors. I don’t see many uses of PyMC or NumPyro that use their flexibility this way. That’s a little easier to do in something like PyMC, but I still find it a big clunky with the variable naming and contexts, but that may just be my relative naivete in Python.
Hi Bob — thanks for the great post and discussion.
I’m curious what you think about the Julia + Turing.jl ecosystem as an alternative to JAX. Turing feels very flexible — you can integrate all of pure Julia and its many scientific libraries directly in your models. For example, you could approximate a surface with radial basis functions and estimate its parameters within Turing.
I also find the language itself pleasant: models are readable, and you don’t have to juggle Python environments. Julia’s built-in reproducibility and project-based package management make it easy to keep analyses consistent — something I always found a bit messy with Python or Anaconda.
What I’m less sure about is the AD side. Julia has several backends, and some (like Mooncake.jl or Enzyme.jl) are said to be quite fast, but I haven’t seen systematic comparisons to JAX. Do you have any thoughts on how Julia’s AD and Turing’s approach stack up against what you like about JAX?
I don’t know the state of autodiff in Julia now. The Enzyme approach of autodiffing at the LLVM level (a low-level pre-assembly intermediate representation) is really compelling, but that’s going to need to be mixed with expression autodiff somehow so as not to blow out memory autodiffing through fiddly algorithms like Cholesky decomposition (you don’t want an O(N^4) Jacobian built explicitly there or implicitly through links in reverse mode) or differential equation solvers (you don’t want to have to autodiff the algorithm as opposed to using the implicit function theorem).
Julia itself is orders of magnitude faster than Python or R, but all the heavy lifting in something like JAX is being done on the back end outside of Python, so this doesn’t come into play so much. Same for Stan—R and Python just call Stan’s C++. So I don’t see that using Julia is a route to making things go faster. Has anyone built a competitor to PyTorch or JAX directly in Julia? In some ways JAX is similar in that you just-in-time compile, it’s just that the target is XLA rather than LLVM.
Turing.jl is written entirely in Julia, same for the AD library Mooncake.jl
Even if Turing.jl itself is coded in Julia the way that JAX is coded in Python, the back end is going to call non-Julia code in many cases. If I use Enzyme for autodiff, that’s coded at the LLVM level. I believe Zygote.jl was mostly Julia; I haven’t heard of Mooncake.jl. If you use the standard matrix libraries in Julia, those call out to BLAS and LAPACK codes in Fortran (cf. https://docs.julialang.org/en/v1/stdlib/LinearAlgebra/). Some of the standard ODE solvers, like lsoda, also call out to Fortran, although I don’t imagine those are even available in Turing.jl because of the foreign function interface.
Bob,
Turing is a “Turing macro to Julia” compiler basically. Essentially you write normal Julia code within the @model macro, and have some additional statements like a ~ Distribution(b,c,d) which it converts to calls to logpdf to produce a function that lets you evaluate the logpdf of a model.
Yes, some of the code you use in Julia calls out to Fortran libraries in places under the hood. Mooncake as I understand it is a source-to-source Autodiff that is designed to handle code using mutation in a way that Zygote didn’t. As such you can think of it as a successor to Zygote that supports a broader class of code. Because of this, it also handles code that calls out to Fortran and such. Enzyme also allows you to call out to Fortran code I think. Though raw Julia code has been clocked at very competitive to BLAS type stuff in many cases (for example code auto-written using Tullio.jl macros to do linear algebra calculations)
here’s the mooncake docs: https://chalk-lab.github.io/Mooncake.jl/dev/
For differential equation solving there’s the SciML ecosystem which includes pretty much the top of the line differential equation solvers in the world. If you need to solve ODEs today, regardless of what language you write in, the best way to get it done is to call out to Julia’s DifferentialEquations.jl through foreign function interfaces within your language (in this sense it plays the role of a modern Fortran). That’s how diffeqpy works in Python, and diffeqr package in R.
You’re right though that just differentiating through the operations of the ODE solver is not necessarily a great plan. However I do think it works for a wide variety of problems, and I think Mooncake can in theory do something else (such as use a custom gradient function). Whereas since Enzyme only sees LLVM it probably can’t.
There’s been a lot of effort to make Julia work with GPUs but it’s ongoing.
If you need to work in industry I dont’ think there’s an easy way to get away from Python + Jax type stuff, if you work in some select industry groups who use Julia, or are in academia or have other reasons why Julia is accessible to you, then I think it offers an extremely good mix of flexibility, well thought out language, and functionality.
Where Julia shines is if you’re doing something non-standard and need your custom code to operate very quickly and within the ecosystem of broader stuff. My impression is that if you do all the “standard” stuff (linear and logistic regressions, neural nets, classifiers, whatever) that you can use Jax for you’re doing fine, but as soon as you need Jax to do something it can’t automatically do, you’re just at full halt. Is that fair? I think that’s why you said you probably couldn’t do a general purpose stan 2 jax compiler.
So, Jax is gonna be enough for anyone doing maybe linear and some nonlinear regressions and fitting neural networks and things, but say wouldn’t be good for someone who does something like a time series analysis where each time step comes from the result of a custom simulation of some “physics” (micro-scale interactions of things according to rules, physical, economic, or whatever) or other similar stuff (say, flows of information along nodes in a graph that changes structure in time, or whatever).
My impression is the goal of Jax is to be a toolbox to do everyday engineering fast, while Julia is a tool for experts to push the envelope in what can be done. Julia will remain more niche because of that, but I don’t think it’ll go away.
Thanks, Daniel—that’s super helpful—it’s really hard to keep up with all the activity in Julia! You’re not the only huge fan of Julia’s diff eq solvers. We started in Stan by just autodiffing through a Runge-Kutta solver—it works, but it takes up a huge amount of memory for the expression graph.
How does mooncake autodiff through foreign function calls? I don’t see how that’s possible. I only called out lsoda because it’s the default ODE solver in R. I’m sure you could reimplement in Julia.
I don’t work in industry and am free to use whatever tools I want. I’ve mainly been put off using Julia because of its niche status compared to Python.
The reason we can’t transpile all of Stan to JAX is that Stan allows branching based on parameter values, including general recursive functions. That means you can’t just tape the function once and re-use the way JAX does, but you have to dynamically generate the expression graph the way PyTorch does.
JAX excels at SIMD operations on matrices on GPU. But even if you don’t have large well-structured matrix operations, if you can keep everything on kernel (Stan cannot, by the way, but JAX can), then you can massively parallelize and get big speedups in all of your generic code. This is what we were doing with Justin Domke implementing normalizing flows for very heterogeneously structured models—we could run 50K iterations for each nested KL divergence (ELBO) gradient evaluation using Monte Carlo, which we couldn’t get close to on CPU.
Bob,
Not an expert on Mooncake at all, but my impression is that Mooncake works through construction of differentiation rules objects, and that these can be either hand-written or automatically written. So when Julia calls out to known Fortran routines, there are probably hand-written rules that generate the derivatives by calling the necessary code to calculate the derivative of that piece of Fortran. Then, these compose, so that code that calls code that calls code that calls Fortran can have its derivative computed automatically through the composition of these objects.
So basically I THINK Mooncake gets you derivatives of Fortran by special casing, and then you can get the derivative of anything Julia by rules that differentiate Julia, including all the known Fortran that is used by base-Julia.
I don’t think you get differentiation of user-level calls to Fortran for free, though if you user-level call the same stuff as Julia does in the base, you might get that for free. Most of that stuff has a Julia function wrapping it anyway so I doubt there’s a reason to directly call the Fortran.
Turing also allows branching on parameters, when you do ForwardDiff.jl and I think also with ReverseDiff, but not when you do ReverseDiff and pre-compile the tape. My impression is unfortunately if you do that and precompile the tape you just get wrong answers, and there’s big warnings about it, but it can’t really detect when it’s “not allowed” automatically. I guess Mooncake supports branching always as does Enzyme (I think).
It is really hard to keep up with this stuff, in fact I’d never heard of Mooncake until recently.
When it comes to GPU, I think there’s a bunch of approaches and Julia gives you some access to all of it, but I think it’s a bit fragmented.
https://juliagpu.org/ has some info
I haven’t yet really needed to get into the weeds on GPU and I don’t have any super fast GPU hardware anyway, so I’m not the best person to comment on that.
“Mooncake” lol. This thread reminds me of nothing so much as this YouTube short from Michael Spicer: https://youtube.com/shorts/Pbw5SjxFS8s?si=7x0m3e9f4BUtckHb
I use Stan every day at work. I like the flexibility and the ease of reading and writing it. For my work, most of the time I don’t need any speed ups. What I need is flexibility and expressibility in something relatively easy to read and write. So I like Stan a lot. The JAX example of a linear regression isn’t super compelling to me.
I’m a lot less concerned about the speed and efficiency of algorithms (note, I do want them to fail loudly when they fail), and I don’t care if they run some typical regression, neural net, or whatnot any faster. I’m a lot more concerned about being able to write bespoke models that are closer (where “closer” is to some reasonable resolution) to the narratively generative process, so for me, Stan is a really straightforward tool, and I hope it stays around for a long time.
JD, I’m with you about expressiveness being primary and speed being secondary. Though I’ll always take speedup, getting a wrong answer fast because I used a default model is useless.
All that said, have a look at Turing.jl, its really quite good and incredibly flexible, more so than Stan as pretty much any Julia code can be in the model.
That was my conclusion in my ProbProg presentation a couple years ago. I find Turing.jl to be the current system that’s most like BUGS in just defining a clean joint model and therefore very flexible if you can stick to directed graphical models. PyMC is similar if you stick to the graphical modeling subset of it. I understand both Turing.jl and PyMC can go beyond graphical models, but then pieces of workflow automation start to break.
My understanding from talking to some of the Turing.jl devs and users is that each of Julia’s umpteen autodiff systems handles a slightly different subset of Julia. I could never find any doc for which subset that was, but I never looked very hard. I have the same issue with PyMC and JAX, which advertise “just write Python” but in fact come with numerous restrictions on how that Python must be written and which functions it can use. There I have looked for doc, but can’t find it—I don’t think it exists.
As a rule of thumb, I think you can do ForwardDiff.jl for anything less than around 10 or 20 parameters, and it’ll just work, it doesn’t have any real strong requirements. It’s just not super fast (still WAY faster than plain Python).
For more parameters, if you don’t branch on parameter values, you can do ReverseDiff.jl with precompiled tape and it works unless there’s some stochastic algorithm underlying something you’re doing (deep inside the implementation of some library for example). I’ve done thousands of parameters with this system and it works fine, it happens at basically full machine code speed.
It is *theoretically* possible to get wrong answers because you call some of the vast Julia ecosystem where some algorithm has for example a randomized algorithm that approximates something by doing RNG and therefore the compiled tape doesn’t represent the actual calculation. I’d suggest this is pretty rare for anything similar to what you’d do in Stan, but there’s an entire giant Julia ecosystem available to you, so for example if you did a deterministic sort you’d get the right answer, but if you did a randomized quicksort you’d get the wrong answer… to give an example.
You can test if you get different results *without* tape-compilation and that’s an indicator of a problem. I have never actually seen this problem in practice. I don’t think you’d get this problem doing anything you’d normally do in Stan.
You can always try Enzyme (autodiff of LLVM) or Mooncake (a Julia level approach to autodiff) for things that use branching on parameters, or if you need speedup or if you want to ensure you get similar results and haven’t been led astray by limitations. I think they’re both in-process of being developed, but are usable. Read the docs especially if you’re worried about the branching on parameters aspect.
If you follow that basic script, and you’re someone who’s pretty happy with Stan, then I think you’d also be pretty happy with Turing and get for-free a certain amount of increased expressiveness.
Also, if you want to do sampling algorithm development, it’s pretty easy to take a Turing model and create something that will give you logpdf and gradient calculations, so then you can write your own samplers. I did that like 6 months ago to try to sample tempered distributions when we had a model that had ultra-tight tolerances on some coefficients. It took a few lines of code to create a tempered distribution from a Turing model.
You will have an easier time using Julia if you have experience with one or more of: Common Lisp, Matlab, Fortran, or scheme. Absolutely not required, but could help.
NumPyro is wonderfully flexible and complete, and it has already been adopted in many cutting-edge research projects. It is also low-level, which means it can be plugged into any JAX workflow. It would be an amazing contribution if the Stan team adopted the NumPyro project and contributed to it. Recreating Stan in JAX from scratch, when NumPyro already exists, would seem like a mistake.