The issue that made me search for an alternate
. My work entails taking fashions of the Universe – darkish power equations of state, modified gravity, tachyonic fields – and asking: what do the information really say in regards to the parameters? The device for that query is Bayesian inference. I often run dynesty nested sampling for a couple of thousand to a couple hundred thousand probability evaluations relying upon the complexity of the mannequin.
For many of my PhD, I didn’t suppose a lot in regards to the ODE solver contained in the probability as solve_ivp labored. It was dependable. Therefore I used it and moved on.
Then I began engaged on a tachyonic DBI darkish power mannequin the place the darkish power subject is ruled by a non-standard kinetic time period, and the background and perturbation equations are a coupled stiff-ish system. Every probability name solved these ODEs, computed the comoving distance, and evaluated the gap modulus on the redshifts of 30 supernovae.
I profiled it. The ODE resolve alone was taking 0.4 ms per name. In a nested sampling run with 10⁵ evaluations, that’s 40 seconds — simply in ODE calls, earlier than you rely any bookkeeping. And for a 10-parameter mannequin, getting a gradient through central finite variations prices 20 additional ahead solves, turning these 0.4 ms into 8 ms per gradient. That’s 300 seconds, or about 5 minutes, only for the gradients. For a single nested sampling run.
One thing needed to change.

What I discovered: diffrax
After a day of looking, I landed on diffrax [1], a library of numerical ODE solvers written fully in JAX. Not a neural surrogate. Not an approximation. The identical embedded Runge–Kutta algorithms I already use in scipy — Tsit5 as an alternative of RK45, however the identical household of strategies — simply compiled, differentiable, and vectorisable.
Three properties come from the “written fully in JAX” design:
JIT compilation – The whole adaptive-stepping loop compiles to a single XLA kernel. Zero Python overhead after the primary name.
Autodiff – As a result of each operation contained in the solver is a JAX primitive, jax.grad propagates gradients by means of the resolve. Actual gradients. One backward cross. No matter what number of parameters.
vmap – A whole batch of parameter vectors might be solved in parallel with jax.vmap. Important for nested sampling.
Putting in it takes 10 seconds:pip set up jax diffrax
The check drawback: flat ΛCDM from supernovae
To make the comparability concrete, let me present the precise drawback I used to be working with. In a flat ΛCDM universe, the comoving distance satisfies:
The gap modulus follows: μ(z) = 5 log₁₀[(1+z)χ(z) / 10 pc]. I need to infer (Ωₘ, H₀) from 30 mock SNIa distance-modulus observations.
from scipy.combine import solve_ivp
import numpy as np
C_KMS = 299792.458 # velocity of sunshine [km/s]
def rhs(z, chi, Om, H0):
return C_KMS / (H0 * np.sqrt(Om*(1+z)**3 + (1-Om)))
def forward_scipy(Om, H0, z_obs):
sol = solve_ivp(rhs, t_span=(0, z_obs[-1]),
y0=[0.0], t_eval=z_obs,
args=(Om, H0), technique="RK45",
rtol=1e-8, atol=1e-10)
chi = sol.y[0]
return 5 * np.log10((1 + z_obs) * chi * 1e5) # distance modulus
The outdated means: SciPy
from scipy.combine import solve_ivp
import numpy as np
C_KMS = 299792.458 # velocity of sunshine [km/s]
def rhs(z, chi, Om, H0):
return C_KMS / (H0 * np.sqrt(Om*(1+z)**3 + (1-Om)))
def forward_scipy(Om, H0, z_obs):
sol = solve_ivp(rhs, t_span=(0, z_obs[-1]),
y0=[0.0], t_eval=z_obs,
args=(Om, H0), technique="RK45",
rtol=1e-8, atol=1e-10)
chi = sol.y[0]
return 5 * np.log10((1 + z_obs) * chi * 1e5) # distance modulus
The brand new means: Diffrax
import jax, jax.numpy as jnp
import diffrax as dfx
# Non-negotiable: allow 64-bit (extra on this beneath)
jax.config.replace("jax_enable_x64", True)
def H_jax(z, Om, H0):
return H0 * jnp.sqrt(Om*(1+z)**3 + (1-Om))
@jax.jit # compile as soon as, name quick eternally
def forward_diffrax(theta, z_obs):
Om, H0 = theta[0], theta[1]
sol = dfx.diffeqsolve(
dfx.ODETerm(lambda z, chi, a: C_KMS / H_jax(z, a[0], a[1])),
dfx.Tsit5(),
t0=0.0, t1=float(z_obs[-1]), # preliminary and ultimate worth
dt0=1e-3, # preliminary step-size
y0=jnp.array(0.0), # preliminary situation
args=(Om, H0),
saveat=dfx.SaveAt(ts=z_obs),
stepsize_controller=dfx.PIDController(rtol=1e-8, atol=1e-10),
max_steps=10_000,
)
chi = sol.ys
return 5 * jnp.log10((1 + z_obs) * chi * 1e5)
The physics is similar. The solver algorithm is sort of similar (Tsit5 is similar to RK45). The one structural variations are @jax.jit and the diffrax API. Allow us to take a look at what these two modifications purchase.
Shock 1: the velocity
solve_ivp: 404 μs per name. diffrax post-JIT: 59 μs per name. That’s 07× quicker.
I stared at this quantity for a couple of seconds the primary time I noticed it. Let me be trustworthy about the place the speedup really comes from, as a result of it isn’t magic.
In solve_ivp, Python re-enters the C/Cython backend on each name. Reminiscence is allotted recent. The adaptive while-loop goes by means of the Python interpreter asking: “is the native error too massive? reject; else develop the step; repeat.” For a 12-step resolve, that’s 12 rounds of Python dispatch, 12 allocations, 12 error-estimate computations sitting behind the interpreter lock.
In diffrax, the primary @jax.jit name traces the whole computation – together with the adaptive while-loop, which is lowered to a lax.while_loop and palms it to XLA to compile right into a machine-code kernel. Each subsequent name executes that kernel instantly. Subsequently no Python, no want for allocation and no dispatch.

For 100,000 probability evaluations, 404 μs vs 59 μs interprets to 40.4 seconds vs 5.9 seconds. That’s the distinction that get enhanced when the mannequin complexity will increase.
Shock 2: gradients grow to be free
This was the half that modified not simply my workflow however how I take into consideration inference. With scipy, getting one gradient of the log-likelihood with respect to 2 parameters (Ωₘ, H₀) prices 4 ahead solves (central finite variations). When you begin turning the dial up, it will get costly quick: 10 parameters means 20 ahead solves, 50 parameters means 100. The invoice grows linearly with the variety of parameters.
With diffrax, I write:
def loss(theta):
mu_pred = forward_diffrax(theta, z_obs)
return 0.5 * jnp.sum(((mu_pred - mu_obs) / sigma_mu)**2)
grad_fn = jax.jit(jax.grad(loss)) # that's the total change
g = grad_fn(jnp.array([0.3, 70.0])) # actual gradient
Underneath the hood, JAX’s reverse-mode autodiff integrates the adjoint equations [2] backward by means of the ODE resolve – however I by no means have to write down these equations. The result’s an actual gradient in time comparable to 1 ahead cross, impartial of the variety of parameters.

How to decide on a solver
On the subject of selecting a solver, it’s important to be somewhat cautious. I defaulted to Tsit5 for nearly every little thing, and it dealt with about 95% of my issues with out grievance. If you’d like the entire resolution course of, right here it’s:
- Non-stiff ODE (most cosmological issues) →
dfx.Tsit5()← begin right here - Very tight tolerances (< 10⁻⁹) →
dfx.Dopri8() - Stiff ODE (many steps, solver appears sluggish) →
dfx.Kvaerno5() - Stiff + non-stiff phrases (IMEX) →
dfx.KenCarp4() - SDE →
dfx.EulerHeun()ordfx.SPaRK()
A fast approach to inform in case your drawback is stiff: print sol.stats["num_steps"]. Whether it is 10–100× greater than you count on, the issue is stiff and also you want an implicit solver.
The payoff: cosmological inference end-to-end
Now, let me present the complete inference comparability. I begin each pipelines from the identical unhealthy preliminary guess (Ωₘ, H₀) = (0.10, 60), properly away from the reality (0.30, 70), and run 350 gradient steps.
- scipy pipeline: gradient from central finite variations, easy gradient descent, fastened studying charge.
- diffrax pipeline: gradient from autodiff, Adam optimiser with a cosine-decay learning-rate schedule.
import optax # optimisers for JAX
# Scale parameters so Adam can deal with them equally
# Om ~ 0.3, h = H0/100 ~ 0.7 -- each O(1) now
def loss_scaled(theta_s):
theta = jnp.array([theta_s[0], 100.0 * theta_s[1]])
return loss(theta)
grad_scaled = jax.jit(jax.grad(loss_scaled))
schedule = optax.cosine_decay_schedule(
init_value=0.05, decay_steps=350, alpha=0.04)
choose = optax.adam(schedule)
theta = jnp.array([0.10, 0.60]) # begin removed from reality
state = choose.init(theta)
for step in vary(350):
g = grad_scaled(theta)
updates, state = choose.replace(g, state)
theta = optax.apply_updates(theta, updates)
if (step + 1) % 50 == 0:
print(f"Step {step+1}: Om={theta[0]:.3f} H0={100*theta[1]:.2f}")

Whereas the diffrax pipeline recovers bodily smart parameters, the scipy pipeline can not concurrently transfer each parameters – a textbook failure of gradient descent on poorly-scaled issues. Adam handles this robotically by means of its per-parameter adaptive studying charges, however Adam is just obtainable as a result of autodiff offers me actual gradients to feed it.
Three issues I bought incorrect (so that you do not need to)

Caveat 1: forgetting 64-bit precision. JAX defaults to 32-bit floats. For those who push the tolerances (rtol < 10⁻⁷), that may result in some very odd outcomes: on my ODE the solver wants 69 steps in 32-bit, however solely 12 in 64-bit. Tighten the tolerances additional and it may possibly fail outright. The repair is easy — allow 64-bit earlier than you do the rest:
jax.config.replace("jax_enable_x64", True) # have to be first
Caveat 2: benchmarking with out warming up. The primary name to any @jax.jit-decorated perform features a one-off compilation hit of about 90–100 ms. For those who embrace that in your timings, diffrax will look slower than scipy for the incorrect cause. The repair is to heat up as soon as and throw away that first run:
_ = forward_diffrax(theta, z_obs).block_until_ready() # compile
# NOW benchmark -- that is the actual velocity
Additionally: JAX dispatches asynchronously. At all times name .block_until_ready() in timing loops otherwise you measure the time to submit work, not end it.
Caveat 3: the argument-order entice. scipy.odeint expects f(y, t) (state first, time second). Virtually every little thing else (solve_ivp, diffrax) expects f(t, y). For those who port outdated odeint code to diffrax with out swapping the arguments, you find yourself fixing a unique ODE and also you often gained’t get an error. You’ll simply get the incorrect reply.
Must you make the change?
The trustworthy reply is that this: if you happen to’re fixing a one-off ODE and also you don’t want gradients, solve_ivp is completely nice — there’s no must be taught a brand new API. However if you happen to’re doing inference (repeated probability evaluations, parameter gradients, or batched solves), the change is well worth the effort.
| State of affairs | solve_ivp | odeint | diffrax |
|---|---|---|---|
| One-off resolve, no inference | ✓ | ✓ | nice too |
| Nested sampling / MCMC | sluggish | sluggish | YES |
| Want gradients | FD solely | FD solely | actual, free |
| Batch over parameter grid | for-loop | for-loop | vmap |
| Stiff system | Radau | auto (LSODA) | Kvaerno5 |
| SDE or Neural ODE | no | no | YES |
| GPU/TPU | no | no | YES |
The migration itself is small. The ahead mannequin modifications by about six strains. The gradient seems by including yet another line. The remainder of the inference code stays similar.
One factor we should point out right here, diffrax will not be “ML-based” within the sense of utilizing a neural community. It’s the identical classical Runge–Kutta arithmetic, written in JAX. The “ML acceleration” comes from JIT compilation and autodiff – each infrastructure instruments from the ML world utilized to a classical numerical solver. The one genuinely ML-based method can be a neural surrogate that learns θ → μ(z) from coaching information – a separate and extra superior matter.
The entire working code
The whole lot above in a single self-contained script (pip set up jax diffrax optax):
"""
flat_lcdm_inference.py
Infer (Omega_m, H0) from 30 mock supernovae utilizing diffrax + Adam.
pip set up jax diffrax optax
"""
import jax, jax.numpy as jnp, numpy as np
import diffrax as dfx, optax
from scipy.combine import solve_ivp # just for producing mock information
jax.config.replace("jax_enable_x64", True)
# -- Constants and information -----------------------------------------------
C_KMS = 299792.458
z_obs = jnp.linspace(0.05, 1.5, 30)
SIGMA = 0.10
# Mock information at reality (Om=0.30, H0=70)
def chi_np(Om, H0):
sol = solve_ivp(lambda z, y: C_KMS/(H0*np.sqrt(Om*(1+z)**3+(1-Om))),
(0, 1.5), [0.], t_eval=np.array(z_obs), rtol=1e-10)
return sol.y[0]
mu_true = 5*np.log10((1+np.array(z_obs))*chi_np(0.3, 70.)*1e5)
mu_obs = jnp.array(mu_true + 0.10*np.random.default_rng(42).standard_normal(30))
# -- diffrax ahead mannequin --------------------------------------------
@jax.jit
def ahead(theta):
Om, H0 = theta[0], theta[1]
sol = dfx.diffeqsolve(
dfx.ODETerm(lambda z, chi, a:
C_KMS/(a[1]*jnp.sqrt(a[0]*(1+z)**3+(1-a[0])))),
dfx.Tsit5(),
t0=0., t1=1.5, dt0=1e-3, y0=jnp.array(0.),
args=(Om, H0),
saveat=dfx.SaveAt(ts=z_obs),
stepsize_controller=dfx.PIDController(rtol=1e-8, atol=1e-10),
max_steps=10_000,
).ys
return 5*jnp.log10((1+z_obs)*sol*1e5)
# -- Loss and gradient ------------------------------------------------
def loss(th_s): # optimise in scaled coords (Om, h=H0/100)
mu = ahead(jnp.array([th_s[0], 100.*th_s[1]]))
return 0.5*jnp.sum(((mu - mu_obs)/SIGMA)**2)
grad_fn = jax.jit(jax.grad(loss))
# Heat up the JIT compiler
theta_init = jnp.array([0.10, 0.60])
_ = ahead(jnp.array([0.3, 0.7])).block_until_ready()
_ = grad_fn(theta_init).block_until_ready()
# -- Adam optimiser with cosine LR schedule ---------------------------
sched = optax.cosine_decay_schedule(init_value=0.05, decay_steps=350, alpha=0.04)
choose = optax.adam(sched)
theta = theta_init
state = choose.init(theta)
print(f"{'Step':>5} {'Om':>7} {'H0':>7} {'Loss':>8}")
for step in vary(350):
g = grad_fn(theta)
upd, state = choose.replace(g, state)
theta = optax.apply_updates(theta, upd)
if (step + 1) % 70 == 0 or step == 0:
L = float(loss(theta))
print(f"{step+1:5d} {float(theta[0]):7.4f} {100*float(theta[1]):7.3f} {L:8.2f}")
Om_fit, H0_fit = float(theta[0]), 100*float(theta[1])
print(f"nFinal: Om = {Om_fit:.3f} H0 = {H0_fit:.2f}")
print(f"Fact: Om = 0.300 H0 = 70.00")
Numbers at a look
| Measurement | scipy | diffrax | Speedup |
|---|---|---|---|
| Single ahead name | 0.4 ms | 57 μs | ~07× |
| Gradient (2 params) | 1.62 ms | 195 μs | ~08× |
| 10⁵ ahead calls | 40 s | 5.9 s | ~07× |
| 10⁵ gradient calls | ~98 s | ~19.6 s | ~05× |
| Last Ωₘ (350 steps) | 0.652 (incorrect) | 0.270 | — |
| Last H₀ (350 steps) | 60.10 (caught) | 70.94 | — |
The “incorrect” scipy consequence will not be a solver failure – it displays that straightforward gradient descent with finite-difference gradients can not deal with the 200× scale mismatch between Ωₘ and H₀.
Last thought
Switching my ahead mannequin to diffrax didn’t change the physics or the inference technique. It modified the sensible feasibility of doing that inference in any respect. A nested-sampling run that was heading towards a big time forward-model finances turned a lower than a minutes one. The gradients that have been going to value 20 additional solves per step turned basically free.
The educational curve was about one afternoon. The debugging was principally the 64-bit caveat and the JIT warmup confusion. The payoff has been actual and speedy.
If you’re a physicist utilizing scipy for repeated probability evaluations and you haven’t checked out diffrax but, I hope this offers you a cause to.
A notice on reproducibility: the precise timings you see will differ in your machine and even between runs on the identical machine. On my Mac (Macbook Air M3 Base Mannequin), the diffrax ahead name diversified between 55 µs and 62 µs throughout periods, and scipy diversified between 400 µs and 407 µs. That is regular – CPU thermal state, OS scheduling, and reminiscence cache circumstances all shift absolutely the numbers by 10–15%. What stays steady is the ratio: diffrax is persistently 07–08× quicker than scipy on this drawback. The ratio, not absolutely the time, is the quantity to remove.
The Python code that generated each determine on this article is accessible at: github.com/Samit1424/ODE_solver_comparison
Be aware : Excluding the featured picture, which was produced utilizing AI device, all illustrations are of creator’s authentic work.
References
[1] P. Kidger, On Neural Differential Equations, DPhil thesis, College of Oxford, 2021. docs.kidger.web site/diffrax/
[2] R. T. Q. Chen, Y. Rubanova, J. Bettencourt, D. Duvenaud, Neural Odd Differential Equations, NeurIPS 2018.















