This is Bob. And I’d like to know the best way for us to code a bunch of models in JAX to use to evaluate parallel algorithms including normalizing flows. I’m going to dump out my current thinking, but I’m really hoping to get feedback from experts on the best way to do this without starting a flame war in the comments.
Why not Stan? Ask Elizaveta!
The bottom line is that in order to evaluate the parallel algorithms we’re considering, we need fast parallel execution in-kernel on the GPU. Stan has some ability to offload compute to GPU, but not to the extent that we can parallelize entire model evaluations.
Elizaveta Semenova’s words at StanCon are still ringing in my ears—she started her live interview with Alex Andorra by saying, “I don’t use Stan any more.”
Why JAX?
Elizaveta needed to integrate neural networks for the Bayesian optimization she’s doing and for that turned to JAX. (The interview with Elizaveta and Chris Wymant will soon be up on Alex’s podcast, Learn Bayes Stats, along with the interview of Brian Ward and Mitzi Morris in another segment that also took place live at StanCon—the podcast is a ton of fun and both Andrew and I have done interviews).
The real reason for why JAX isn’t that all the cool kids are using it (though everyone I know on the CS side has pretty much switched to JAX, including my own personal bellwether, Matt Hoffman). JAX is beautifully compositional in the same way as Unix. I suppose we could’ve used PyTorch, but JAX just feels much more natural to a computer scientist like me. I just love the way it can compose JIT and autodiff to enable massively parallel differentiable programs. There are really two applications I have in mind, normalizing flows (the main topic of this post) and parallelized MCMC of the form Matt Hoffman’s been propounding lately (Charles Margossian, a former Ph.D. student of Andrew’s and one of our postdocs here, did an internship with Matt at Google working out how to do R-hat in a massively parallel setting with 1000+ chains that communicate with each other to accelerate convergence, after which a single draw is taken from each in the limiting case).
Normalizing flows
I think there is a good chance that normalizing flow-based variational inference will displace MCMC as the go-to method for Bayesian posterior inference as soon as everyone gets access to good GPUs. I’ve been looking into normalizing flows with Gilad Turok, Sifan Liu, Justin Domke, and Abhinav Agrawal. Justin visited Flatiron for five months and during that time, we didn’t manage to program a distribution in JAX that his and Abhinav’s take on realNVP, as coded in the repo vistan, couldn’t fit well. They’re busy writing up a more extensive evaluation in a follow-up paper and the results only look better. Gilad was able to port their vistan code to Blackjax and replicate all their results on our clusters here—he’ll be submitting a PR to Blackjax soon.
My thinking on normalizing flows was inspired by the last model we fit with Justin—a centered parameterization of a hierarchical IRT 2PL model with around 1000 total parameters (this is a nice example due to additive non-identifiability, multiplicative non-identifiability, and funnels from the hierarchical priors). With this parameterization, Stan struggles to the point where I’d say it can’t really fit the model. Justin and Abhinav’s RealNVP fit it quite well—much better than Stan managed. It just took a massive number of flops on a state-of-the-art GPU. One of the things Justin and Abhinav’s approach to flows relies on for convergence is a massive number of evaluations of the log density and gradients for computing the approximate KL-divergence stochastic gradient (i.e., the ELBO). So we needed to code the models in JAX to run entirely on the GPU. So I’m looking for an easier way to do this.
Workflow in JAX
Colin Carroll (Google employee, PyMC dev) just presented a talk about Bayes and JAX at PyData Vermont. He covers the whole workflow in JAX and talks about his bayeaux
repository. Colin talks about Adrian Seyboldt’s new nutpie sampler in Rust, which Adrian also just presented at StanCon. There’s no write-up, but we’re looking into reverse engineering the Rust into C++ for Stan—it works quite well. Adrian’s agreed to come out and give a talk here at Flatiron on his sampler in the new year. But that’s a different topic.
For now, I want to do a lot more evaluations of Justin and Abhinav’s take on realNVP, and we’re trying to figure out how to code up a couple dozen models in JAX. There are many possibilities.
PyMC
PyMC can produce JAX output. The PyMC devs just did a little hackathon and created about ten pull requests in the posteriordb repository for PyMC implementations.
with pm.Model() as hierarchical: eta = pm.Normal("eta", 0, 1, shape=J) mu = pm.Normal("mu", 0, sigma=10) tau = pm.HalfNormal("tau", 10) theta = pm.Deterministic("theta", mu + tau * eta) obs = pm.Normal("obs", theta, sigma=sigma, observed=y)
All of the approaches in Python wind up having to name variables and then provide string-based names. I don’t know if the sigma=sigma
thing is necessary for the scale parameter. I like that the distributions are vectorized here. It’s too bad that there’s an observed=
in the data models—I think that means the models as defined aren’t as flexible as the BUGS models in terms of specifying what’s data at run time. At the same time, Thomas Wiecki was telling me you could use NaN to code the equivalent of R’s NA and do inference, so I think that observed value can have missingness.
Not all of the PyMC models look so much like a graphical model.
NumPyro
NumPyro is the version of Pyro that generates JAX on the back end. NumPyro looks like BUGS (or Turing.jl), which is not necessarily a bad thing. Here’s the NumPyro version of Andrew’s favorite example model, eight schools (the arguments to the top-level function are the data):
def eight_schools(J, sigma, y=None): mu = numpyro.sample('mu', dist.Normal(0, 5)) tau = numpyro.sample('tau', dist.HalfCauchy(5)) with numpyro.plate('J', J): theta = numpyro.sample('theta', dist.Normal(mu, tau)) numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)
pangolin
pangolin can produce JAX output. This is an “early-stage probabilistic inference project” rather than a longstanding embedded PPL like PyMC or NumPyro. Specifically, it’s a graphical modeling language that looks a lot like the others, and it has back ends for Stan, JAGS, and JAX. It’s very experimental and a work in progress, but the models look nice. Python doesn’t let you overload the ~
operator, which is unary arithmetic complement. Here it’s not so clear that y
and stddevs
are the data.
mu = pg.normal(0,10) # μ ~ normal(0,10) tau = pg.exp(pg.normal(5,1)) # τ ~ lognormal(5,1) theta = [pg.normal(mu,tau) for i in range(num_schools)] # θ[i] ~ normal(μ,τ) y = [pg.normal(theta[i],stddevs[i]) for i in range(num_schools)] # y[i] ~ normal(θ[i],stddevs[I])
No names here, but they have to get introduced later if you want to do I/O. The doc also makes it clear how things line up. unlike the other approaches, this uses standard Python comprehensions, which I don’t think are super efficient in JAX judging from the JAX doc I’ve read. But I think there are lots of ways to code in pangolin. The problem is when you release “Hello, World!” code, people read it as what your project does rather than as a simple example.
postjax
We can just code models in JAX. Bernardo Williams (Ph.D. student at U. Helsinki) just coded a bunch of models directly in JAX in his GitHub postjax. I couldn’t find eight schools, but here’s a simple logistic regression model as a class with a method defined as follows.
def logp(self, theta): sqrt_alpha = jnp.sqrt(self.alpha_var) data = self.data X = data["X"] y = data["y"] assert len(theta) == self.D return jnp.sum(jss.norm.logpdf(theta, 0.0, sqrt_alpha)) + jnp.sum( jss.bernoulli.logpmf(y, sigmoid(jnp.dot(X, theta))) )
The variable self.alpha_var
is set as data in the constructor as is the data dictionary data
. I’d have been tempted to put alpha_var
into the data
input.
Other options?
I’d really like to hear about other options for coding statistical models in JAX.
Straight to XLA?
Both JAX and TensorFlow run by compilation down to XLA (stands for “accelerated linear algebra”). Mattijs Vákár, who coded a lot of the Stan parser and code generator, is working on autodiff down at that level. That may be a good eventual target for a compiler, but it’s a lot easier to start in JAX. Similarly, we could have targeted LLVM with Stan rather than C++, but we rely on so much pre-existing C++ infrastructure that would have been challenging. Similarly, I think coding directly at the XLA level would be painful at this stage, not that I’ve ever tried it or even know what it looks like. I just know we’re going to need a lot more than linear algebra.
Stan
For comparison, I really wish we could just use Stan. Here’s what eight schools looks like in Stan. This includes all the data declarations that were implicit in the other programs (which used either a closure or function argument to capture data directly).
data { intJ; vector[J] y; vector [J] sigma; } parameters { real mu; real tau; vector [J] theta; } model { tau ~ cauchy(0, 5); theta ~ normal(mu, tau); y ~ normal(theta, sigma); mu ~ normal(0, 5); }
I’m thinking the way I would code something that follows Stan’s execution logic in JAX directly would be something like this:
class LinearRegression: def __init__(self, data): self._data = data def num_params_unc(self): return 3 def log_density(self, params_unc): reader r = Reader(params_Unc) alpha = r.real() beta = r.real() sigma = r.realLB(lower=0) log_jacobian = r.lp_ log_prior = 0 log_prior += norm.logpdf(alpha, 0, 1) log_prior += norm.logpdf(beta, 0, 1) log_prior += exponential.logpdf(sigma, 1) log_likelihood = 0 log_likelihood_fun = lambda x, y: norm.logpdf(y, alpha + beta * x, sigma) log_likelihood += sum(vmap(log_likelihood)(zip(self._data['x'], self._data['y']))) return log_jacobian + log_prior + log_likelihood
where I’m relying on a Reader
class that follows the reader I first coded for Stan in order to define the log density over a vector. It’s really a deserializer. I’m wondering if I can just lean more on the pytree
construct in JAX to simplify my interfaces, but I’m just getting started with JAX myself.
class Reader: def __init__(self, params): self._params = params self._lj = 0 self._next = 1 def real(self): x = self._params[self._next_] self._next += 1 return x def real_lb(lb): x_unc = self.read_real() self._lj += x_unc return lb + jax.numpy.exp(x_unc) ... other constraining transforms ... def log_jacobian(): return self._lp
The ML libraries are all interesting too. Haiku/Flax/Equinox — I used Haiku the most and like it though it’s been deprecated.
The jax models I’ve written end up looking a little rough to read, but I like that I can use the functions I write outside of the models themselves. It makes it easier to test things and whatnot.
Half of enjoying jax I think is enjoying the numpy way of doing things. For encoding factor variables I like this:
animal_types, animals_as_ints = numpy.unique([“rabbit”, “rabbit”, “fish”], return_inverse=True)
That gives me back animal_types == [‘fish’, ‘rabbit’] and animals_as_ints == [1, 1, 0]
Bob:
I remain impressed by your unbiased takes. You wrote Stan, and Stan is still great, but you’re open to talking about how other programs can be better than Stan in some settings. That’s just the way you are—you can call it how you see it without being constrained, or soft-constrained, by defensiveness. It’s my impression that most researchers don’t have this admirable trait.
Hear, Hear
Although I may not have anything to add, I just wanted to say thanks Bob for being so open about what other languages offer and clear about the possible headaches we might face. Always nice to hear what you and the folks at Flatiron are thinking about.
You probably know this (because nested rhat is already in Stan), but the JAX code from Charles’s visit to Google is available at https://github.com/charlesm93/nested-rhat.
I’m a little unsure what you’re hoping to get by reverse-engineering nutpie in C++ — from the README it’s just a minor wrapper around NUTS, and surely Stan’s builtin NUTS is very good by now?
[edit: tried to escape code, because the blog doesn’t support markdown, but the blog overall is pretty broken and won’t let me commit markdown here.]
Thanks for the nice article, Bob!
In terms of other JAX ppls, I have grown to love the TFP joint distribution syntax (whose biggest drawback is the long name):
import tensorflow_probability.substrates.jax as tfp
tfd = tfp.distributions
@tfd.JointDistributionCoroutineAutoBatched
def tfp_model():
avg_effect = yield tfd.Normal(0., 10., name=’avg_effect’)
avg_stddev = yield tfd.HalfNormal(10., name=’avg_stddev’)
school_effects = yield tfd.Sample(
tfd.Normal(0., 1.), sample_shape=8, name=’school_effects’)
yield tfd.Normal(loc=avg_effect + avg_stddev * school_effects,
scale=treatment_stddevs, name=’observed’)
</pre
The bayeux documentation additionally has examples for defining and fitting
– state space models in dynamax (Kevin Murphy and Scott Linderman were prominent contributors to that library): https://jax-ml.github.io/bayeux/examples/dynamax_and_bayeux/
– bayesian neural networks in oryx (Sharad Vikram started this, and it has some nice functional transforms like `inverse` and `inverse log det jacobian`): https://jax-ml.github.io/bayeux/examples/oryx_and_bayeux/
Just wanted to add you don’t actually need to use pass the data with the `observed` keyword at model creation. You can use `pymc.observe` to “observe” data after the model is created, allowing to reuse the model quite easily. For example
“`
import pymc as pm
y = [28, 8, -3, 7, -1, 1, 18, 12]
sigma = [15, 10, 16, 11, 9, 11, 10, 18]
J = len(y)
with pm.Model() as hierarchical:
eta = pm.Normal(“eta”, 0, 1, shape=J)
mu = pm.Normal(“mu”, 0, sigma=10)
tau = pm.HalfNormal(“tau”, 10)
theta = pm.Deterministic(“theta”, mu + tau * eta)
obs = pm.Normal(“obs”, theta, sigma=sigma)
with pm.observe(hierarchical, {“obs”: y}):
idata = pm.sample()
“`
This is really cool, thanks Tomas.
I’ve tried non-ML stats models in JAX and tbh I’ve never really been that impressed in that context. It’s amazing for what it does and there’s definitely some fun things you can do with parallel chains etc, but the dev and optimizations are so heavily focussed on neural nets and their ilk that I always feel like I’m fighting against its design to make some of my models work.
That said, when you can make everything rectangular enough, you can do some really cool things. We did a paper where we just threw a GPU at cross validation and it was a lovely thing.
But I don’t think it’s either trying to be or succeeding at being a general solution.
Wow—I knew the blog had a lot of reach, but I wasn’t expecting so many experts in this space to write back. Thanks so much!
@Kiran Gauthier: I hardly speak for the folks at Flatiron—people are going in all sorts of different directions for VI, sampling, amortized inference, and surrogate modeling, even among my collaborators.
@rif a. saurous: I didn’t know the JAX code was available for nested R-hat. That might be useful later, but for now, I’m just trying to figure out how to code models efficiently enough in JAX to use for normalizing flow experiments. nutpie substantially changes the way warmup is done compared to NUTS and it seems to lead to much more efficient sampling. End to end, it’s about twice as fast as Stan’s implementation of NUTS when using Stan-supplied log densities and gradients (Adrian built a beautiful interface that lets you use BridgeStan). We only realized recently (D’oh!) that the inverse of the diagonal of the covariance, which is what NUTS is estimating, isn’t a good inverse mass matrix approximation—we want the diagonal of the precision matrix, which is usually different. nutpie’s using things like gradient outer products to estimate that, similarly to L-BFGS, and it converges to what looks like a better solution faster. Adrian did a presentation at the recent StanCon in Oxford. But there’s no writeup at the pseudocode level of the algorithm!
@Colin Carroll: Thanks for the video and bayeaux and for the link to dynamax. Hopefully that’ll make it a bit easier to and more efficient to implement HMMs than my having to code the forward algorithm in JAX. As a newbie, I’m having a bit of a hard time seeing how all the pieces fit together—for instance, bayeaux has a bunch of pages about using it with other PPLs. I’d like to have some kind of simple solution that doesn’t depend on half a dozen different packages.
@Tomas Capretto: Is there doc for PyMC somewhere that goes over all these patterns? I was also told you can pass in NaN values in a data array and it’ll treat it like a parameter the same way BUGS does.
@Dan Simpson: What specifically are you fighting against in the design? Lack of loops and associated control flow like breaks? Lack of ability to branch on parameters? Otherwise, I would thin you could just about everything in JAX that you can do in Stan. I have been curious about how much speedup there is to be had on GPU when you’re not doing a lot of matrix multiplies. For the normalizing flow application, I just need a bunch of SIMD executions of the log density and gradient functions to crank out a lot of iterations for the ELBO estimation.
Indirect indexing is SLOW! And loops over ragged structures doesn’t work in the lax.scan: you need to pad everything.
Basically, things that don’t turn into matrix multiplies don’t speed up for single evals. For SIMD evals it depends on stuff.
Unfortunately, I think this is the only place where you can see both `pm.do()` and `pm.observe()` used together, although we use them a lot in client projects, https://www.pymc.io/welcome.html#example-from-linear-regression.
I’ll try to find other examples.
I can’t find them in the online documentation, but the docstrings in the code for `observe` and `do` are both pretty readable:
https://github.com/pymc-devs/pymc/blob/main/pymc/model/transform/conditioning.py
Here there’s a post I wrote about these PyMC features. I think we’ll incorporate it (or a version of it) into the PyMC docs.
https://tomicapretto.com/posts/2024-11-01_pymc-data-simulation/
Since StanCon I’ve experimented a lot with combining normalizing flows with the ideas from the nutpie mass matrix adaptation. It is still early, but it seems to be working way better than I would ever have imagined.
It can for instance sample a correlated funnel like this:
“`
with pm.Model() as model:
mu = pm.Normal(“mu”)
log_sigma = pm.Normal(“log_sigma”, sigma=1)
pm.Normal(“x”, sigma=np.exp(log_sigma), mu=mu, shape=100)
“`
without any trouble, giving plenty of effective samples, and it takes only ~40_000 gradient evaluations per chain. (Which is much less than I ever managed with VI, but I wouldn’t be surprised if I’m not using it the way I should).
The optimization of the normalizing flow is a bit expensive, it takes about 10 minutes with my horrible implementation. But at least it doesn’t need any logp evaluations of its own, it just uses the gradient evals from MCMC.
I hope I can get an experimental version released pretty soon, so that people can give it a go and tell me of all the different ways it immediately breaks on real problems…
I feel much the same way as Dan Simpson when it comes to applications that feel more statistics-y than machine learning. The multi-array focused design obviously helps to efficiently pack SIMD operations for GPU memory, but becomes a pain when dimensionality is not necessarily known ahead of time. As a concrete example, timeseries analysis where the series varies in length will trigger XLA recompilation whenever you get an unknown series length. You can specify a maximum length and pad or unpad, or do some vmapping and other acrobatics to restructure the problem, but you end up expressing things in a really unnatural way. If the problem also isn’t so big that a GPU helps or has some kind of intrinsic recursive structure, I end up finding it a lot faster to do in Stan. Though, the CSV file based interface to CmdStan can also end up biting me in the ass; having direct memory access to the results is also very nice. Tradeoffs.
I feel a little mixed about the direction of scientific programming lately. Obviously things are getting better and we can solve more problems, but it feels like maybe a couple of wrong turns have been made. One thing is the GPU stuff; while GPUs and big models are amazing, the dominance of the nVidia CUDA backend makes me really uncomfortable. OpenCL sucks and Intel/AMD apparently don’t want their stock prices to go up, so if you want massively parallel compute you’re basically tied to the proprietary nvidia stack with no competition even threatening to emerge. The popular libraries have made an effort to make MacOS metal based builds for Apple’s amazing Arm chips (though without taking advantage of the unified memory, they just pointlessly copies data around), but that’s coupled to an even more vertically integrated proprietary stack.
Jax makes me a little uncomfortable for a similar reason. Google is great and does great work, but its eternal alpha status and Google’s habit of just abruptly abandoning or forking projects when they stop being interesting makes me wary of committing fully to it, especially when they’re maintaining tensorflow 2 at the same time.
Often the issues I run into with in these sorts of things also make me wonder about the choice to use library interfaces on high level languages at all. When I’m digging through some kind of torchinductor-dynamo/XLA error message to get to an open github issue from 2 years ago, or dispatching work into a short lived subprocess to sidestep a memory leak in the traced computation graph, it starts to feel no easier than programming in C++ did. It’s certainly harder for me than Stan’s very literal syntax. Is it really worth the overhead of thousands of cycles of python jit interpretation between minibatches?
Come on over to the dark side of Julia and Turing.jl, one of these days soon Enzyme will be stable released, in the meantime it’s actually usable as is, and ReverseDiff and ForwardDiff both work well too. There’s an excellent profiler for code that’s easy to use, and you have access to essentially all of the language from within Turing.
Julia The language has a huge huge advantage in such a way that even though it’s a smaller community it’s had massive growth in capabilities in the 5 yrs I’ve been using it. in 2019 it was dynamically compiled only, starting up Julia and loading data analysis libraries took something like a minute of startup time every time. Today that code is cached, startup time is in milliseconds and the juliac compiler will be released in its first soon allowing you to build small binaries and distribute them.
It’s easy to call R or Python from Julia to do work where specialized libraries are needed. And Julia has good integration with Jupyter, Quarto, VSCode/VSCodium, DuckDB, Arrow, etc etc
Daniel:
Under the let’s-not-let-the-blog-overflow-with-repeat-comments-principle, you are allowed at most one plug of Julia per month!
Noted! If only there were a julia plug wordpress plugin to keep track of that. Will try to keep that enthusiasm to a dull roar in the 1 per month range at least.
>My thinking on normalizing flows was inspired by the last model we fit with Justin—a centered parameterization of a hierarchical IRT 2PL model with around 1000 total parameters (this is a nice example due to additive non-identifiability, multiplicative non-identifiability, and funnels from the hierarchical priors). With this parameterization, Stan struggles to the point where I’d say it can’t really fit the model.
To a very helpful extent from the modeler’s perspective, the loud complaints of a flailing HMC are a feature not a bug. It can rather helpfully suggest parameterization changes that reflect an improved model (improved in the sense of closer to the narratively generative process, not just in the sense of the estimation algorithm). In some cases, I would wonder if HMC cannot fit the model then if you really want to fit that model anyway.
Jd:
You’re talking about !
We are working on improving Stan’s warmup so that it reveals this sort of problem faster. Better to run 10 iterations and see the problem than have to run 1000 iterations and only then give up. Especially given that, with current default settings, Stan runs slower per iterations for models that have problems.
There are some errors that are like that and some that aren’t. For things like non-identifiability leading to an improper posterior, I think it’s great that HMC fails fast. For a centered parameterization of a hierarchical model, I’d rather the sampler just works. The only improvement you get out of a reparameterization is the ability to run on a sampler that’s not robust to the original parameterization—it’s not changing the posterior you’re trying to fit. So when HMC fails on a centered parameterization of a hierarchical model, I don’t blame the model, I blame HMC. Switching to a non-centered parameterization isn’t changing the inferences, just coercing the model into a form that HMC can handle. I’d much rather write a better sampler.
There are also lots of reasonable models that we don’t know how to reparameterize so that HMC can fit them efficiently and a lot more where the reparameterization is just a huge pain (like non-centered multivariate priors). That is, we still haven’t solved Andrew’s unfolding flower problem.
what does “state of the art GPU” mean here? I’d be shocked if you need H100s or the new H200s to do a lot of useful statistics. and there is an enormous dropoff in cost for even slightly old models. plus it sounds like for this application (low parameter count, huge “batch size” so to speak) it may be viable to scale horizontally. Also to degrade gracefully — if your GPU isn’t big enough, that slows you down, but the computation is not impossible.
autograd GPU libraries are cool and fun, so hopefully lots more people will have the opportunity to play around with them sooner rather than later.
By “state of the art”, I’m mostly talking about something with enough GPU memory to run a lot of parallel density and gradient evals. The only GPUs I have free access to are the H100s in our cluster at Flatiron. I guess I could buy a desktop machine with a more consumer-grade GPU, but I’m more thinking about designing for me and the future rather than for Stan users in 2024.
Very nice post!
A few years ago we built a transpiler from Stan to NumPyro and used it to try NumPyro’s autoguides on Stan models in posteriorDB (including AutoBNAFNormal AutoIAFNormal which are parameterized by normalizing flows).
https://arxiv.org/abs/2110.11790
Unfortunately, the project is now outdated both w.r.t. Stan and NumPyro, but, maybe this approach (keeping Stan’s frontend and static analyses) can be useful to quickly try new models.
Numpyro is such a beautiful program. Flexible, fast, and lightweight; it plays nice with the rest of the JAX ecosystem. The functionality under `numpyro.handlers` allows one to do precise “surgery” on a model, be it conditioning certain sampling sites, apply a do-operator, mask some terms from the log-probability function, etc. I love(!) Stan, and what a bounty of fun it was as my introduction to probabilistic programming, but once I found Numpyro, I never left. I now work in industry, where I work with very big datasets and have access to multi-GPU compute. Numpyro just works there with minimal changes.
Thanks for the feedback. Unfortunately, I’m not sure what any of these are or what the use cases are: conditioning certain sampling sites, apply a do-operator, mask some terms from the log-probability function. The first I have no clue about—I don’t even know what a sampling site is. I think “do” is some Pearl thing for causation? Masking is for dropout in neural nets?
Are you working on Bayesian inference in industry or something else? My goal is to be able to express the kinds of models we can express in Stan. I’m having trouble understanding the doc in all these packages and what can and can’t be expressed. For NumPyro, there’s a ton of API doc, but I can’t find anything like a “getting started” tutorial or a reference manual—just endless examples. Am I just not looking in the right place?
Is there a plan to make vistan available for use in R?
Terrance:
I don’t know—you should try asking on the Stan forums!