Performance issue with SDE solver
Hello,
When solving the (trivial) SDE $d y_t = -y_t\ dt + 0.2\ dW_t$, the Diffrax Euler solver is ~200x slower than a naive for loop. Am I doing something wrong? The speed difference is consistent across various SDEs, solvers, time steps dt, and number of trajectories, and it appears to be specific to SDE solvers.
import diffrax as dx
import jax
import jax.numpy as jnp
from matplotlib import pyplot as plt
# === simulation parameters
key = jax.random.PRNGKey(42)
t0 = 0
t1 = 1
y0 = 1.0
ndt = 101
dt = (t1 - t0) / (ndt - 1)
drift = lambda t, y, args: -y
diffusion = lambda t, y, args: 0.2
# === diffrax euler
brownian_motion = dx.VirtualBrownianTree(t0, t1, tol=1e-3, shape=(), key=key)
solver = dx.Euler()
terms = dx.MultiTerm(dx.ODETerm(drift), dx.ControlTerm(diffusion, brownian_motion))
saveat = dx.SaveAt(ts=jnp.linspace(t0, t1, ndt))
@jax.jit
def diffrax_simu():
return dx.diffeqsolve(terms, solver, t0, t1, dt0=dt, y0=y0, saveat=saveat).ys
# === homemade euler
@jax.jit
def homemade_simu():
dWs = jnp.sqrt(dt) * jax.random.normal(key, (ndt,))
def step(y, dW):
dy = drift(None, y, None) * dt + diffusion(None, y, None) * dW
return y + dy, y
return jax.lax.scan(step, 1.0, dWs)[-1]
# === plot a single trajectory
y = diffrax_simu()
plt.plot(y)
y = homemade_simu()
plt.plot(y)
# === benchmark
%timeit diffrax_simu().block_until_ready()
%timeit homemade_simu().block_until_ready()
5.39 ms ± 261 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
19.7 μs ± 899 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
I get them to be a lot closer by using UnsafeBrownianPath, which has less overhead than VBT. Diffrax is still a bit slower with this change on my machine, but the difference is smaller (and probably due to other overheads that diffrax does to enable more features).
There's also some risky (but often useful) changes to UBP we've made internally that I've been meaning to put in the fork, so you can definitely do a fair amount with modifications to UBP (being able to get through all 3 stated requirements).
Yup, VBT is often the cause of poor SDE performance. Really we need some kind of LRU caching to make it behave properly, but that doesn't seem to be easy in JAX -- I'm pretty sure it'd require both a new primitive ('cached_call_p') and a new transform. That's a fairly advanced project for someone to take on!
In the meantime I recommend UBP as the go-to for these kinds of normal 'just solve an SDE' applications.
I think a lot of people get turned off by the Unsafe in the name, maybe worth adding a sentence like this to the docs ("In the meantime I recommend UBP as the go-to for these kinds of normal 'just solve an SDE' applications.").
Thanks. Indeed using UBP does help but I understand it's quite restricted in terms of usage.
Diffrax is still a bit slower with this change on my machine, but the difference is smaller (and probably due to other overheads that diffrax does to enable more features).
It seems there is still a factor ~10-20 difference (irrespective of number of time steps) between the homemade solver and diffrax with UBP. I would have naively thought that any irrelevant computation would be jitted away. Could you elaborate on what diffrax with UBP does compared to the naive solver?
Diffrax (VBT): 7.51 ms ± 18.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Diffrax (UBP): 637 µs ± 2.23 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Naive: 28.5 µs ± 147 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Diffrax has a lot more checking/shaping/logging than the default implementation. You can see it reflected in the jaxprs:
diffrax
let _where = { lambda ; a:bool[] b:i32[] c:i32[]. let
d:i32[] = select_n a c b
in (d,) } in
let _where1 = { lambda ; e:bool[] f:f32[] g:f32[]. let
h:f32[] = select_n e g f
in (h,) } in
let _where2 = { lambda ; i:bool[] j:f32[] k:f32[]. let
l:f32[] = select_n i k j
in (l,) } in
let _where3 = { lambda ; m:bool[] n:i32[] o:f32[]. let
p:f32[] = convert_element_type[new_dtype=float32 weak_type=False] n
q:f32[] = select_n m o p
in (q,) } in
{ lambda ; . let
r:f32[4096] = pjit[
name=diffrax_simu
jaxpr={ lambda s:u32[2]; . let
_:i32[] = add 1 1
_:i32[] _:f32[] _:f32[] _:f32[4096] t:f32[4096] _:i32[] _:i32[] _:i32[]
_:i32[] = pjit[
name=diffeqsolve
jaxpr={ lambda u:bool[] v:bool[] w:bool[] x:bool[]; y:u32[2]. let
_:i32[] = add 1 1
_:i32[] = pjit[
name=branched_error_if_impl
jaxpr={ lambda ; z:f32[]. let in (0,) }
] 0.009999999776482582
ba:bool[] = lt 0.0 1.0
bb:i32[] = pjit[name=_where jaxpr=_where] ba 1 -1
bc:f32[] = convert_element_type[
new_dtype=float32
weak_type=False
] bb
bd:f32[] = mul 0.0 bc
be:f32[] = convert_element_type[
new_dtype=float32
weak_type=False
] bb
bf:f32[] = mul 1.0 be
bg:f32[] = convert_element_type[
new_dtype=float32
weak_type=False
] bb
bh:f32[] = mul 0.009999999776482582 bg
bi:f32[] = add bd bh
bj:f32[] = min bi bf
bk:f32[] = convert_element_type[
new_dtype=float32
weak_type=True
] bb
bl:f32[] = mul bk inf
bm:f32[] = convert_element_type[
new_dtype=float32
weak_type=False
] bl
bn:f32[4096] = broadcast_in_dim[
broadcast_dimensions=()
shape=(4096,)
] bm
bo:f32[4096] = broadcast_in_dim[
broadcast_dimensions=()
shape=(4096,)
] inf
bp:f32[] = copy 1.0
bq:f32[] = copy bd
br:f32[] = copy bj
bs:f32[] = copy bh
bt:f32[4096] = copy bn
bu:f32[4096] = copy bo
bv:bool[] = lt bq bf
bw:bool[] = and bv u
bx:bool[] = copy bw
_:i32[] _:bool[] _:bool[] by:f32[] bz:f32[] ca:f32[] _:bool[] cb:f32[]
cc:i32[] cd:i32[] ce:i32[] cf:i32[] cg:i32[] ch:f32[4096] ci:f32[4096]
cj:i32[] = while[
body_jaxpr={ lambda ; ck:i32[] cl:u32[2] cm:f32[] cn:f32[] co:bool[]
cp:i32[] cq:bool[] cr:bool[] cs:f32[] ct:f32[] cu:f32[] cv:bool[]
cw:f32[] cx:i32[] cy:i32[] cz:i32[] da:i32[] db:i32[] dc:f32[4096]
dd:f32[4096] de:i32[]. let
df:bool[] = eq ck 1
dg:f32[] = neg cu
dh:f32[] = pjit[name=_where jaxpr=_where1] df ct dg
di:bool[] = eq ck 1
dj:f32[] = neg ct
dk:f32[] = pjit[name=_where jaxpr=_where1] di cu dj
dl:f32[] = sub dk dh
dm:f32[] = convert_element_type[
new_dtype=float32
weak_type=False
] ck
dn:f32[] = mul dm dl
do:bool[] = eq ck 1
dp:f32[] = neg cu
dq:f32[] = pjit[name=_where jaxpr=_where1] do ct dp
dr:bool[] = eq ck 1
ds:f32[] = neg ct
dt:f32[] = pjit[name=_where jaxpr=_where1] dr cu ds
_:i32[] = add 1 1
_:i32[] du:f32[] = pjit[
name=evaluate
jaxpr={ lambda ; dv:u32[2] dw:f32[] dx:f32[]. let
dy:f32[] = custom_jvp_call[
call_jaxpr={ lambda ; dz:f32[]. let in (dz,) }
jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7ef9d5f12440>
num_consts=0
symbolic_zeros=False
] dw
ea:f32[] = custom_jvp_call[
call_jaxpr={ lambda ; eb:f32[]. let in (eb,) }
jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7ef9d5f125f0>
num_consts=0
symbolic_zeros=False
] dx
ec:i32[] = bitcast_convert_type[new_dtype=int32] dy
ed:i32[] = bitcast_convert_type[new_dtype=int32] ea
ee:key<fry>[] = random_wrap[impl=fry] dv
ef:u32[] = convert_element_type[
new_dtype=uint32
weak_type=False
] ec
eg:key<fry>[] = random_fold_in ee ef
eh:u32[2] = random_unwrap eg
ei:key<fry>[] = random_wrap[impl=fry] eh
ej:u32[] = convert_element_type[
new_dtype=uint32
weak_type=False
] ed
ek:key<fry>[] = random_fold_in ei ej
el:u32[2] = random_unwrap ek
em:key<fry>[] = random_wrap[impl=fry] el
en:key<fry>[1] = random_split[shape=(1,)] em
eo:u32[1,2] = random_unwrap en
ep:u32[1,2] = slice[
limit_indices=(1, 2)
start_indices=(0, 0)
strides=(1, 1)
] eo
eq:u32[2] = squeeze[dimensions=(0,)] ep
er:f32[] = sub ea dy
es:f32[] = sqrt er
_:f32[] = sub ea dy
et:key<fry>[] = random_wrap[impl=fry] eq
eu:f32[] = pjit[
name=_normal
jaxpr={ lambda ; ev:key<fry>[]. let
ew:f32[] = pjit[
name=_normal_real
jaxpr={ lambda ; ex:key<fry>[]. let
ey:f32[] = pjit[
name=_uniform
jaxpr={ lambda ; ez:key<fry>[] fa:f32[]
fb:f32[]. let
fc:u32[] = random_bits[
bit_width=32
shape=()
] ez
fd:u32[] = shift_right_logical fc 9
fe:u32[] = or fd 1065353216
ff:f32[] = bitcast_convert_type[
new_dtype=float32
] fe
fg:f32[] = sub ff 1.0
fh:f32[] = sub fb fa
fi:f32[] = mul fg fh
fj:f32[] = add fi fa
fk:f32[] = reshape[
dimensions=None
new_sizes=()
] fj
fl:f32[] = max fa fk
in (fl,) }
] ex -0.9999999403953552 1.0
fm:f32[] = erf_inv ey
fn:f32[] = mul 1.4142135381698608 fm
in (fn,) }
] ev
in (ew,) }
] et
fo:f32[] = mul eu es
in (0, fo) }
] cl dq dt
fp:f32[] = convert_element_type[
new_dtype=float32
weak_type=False
] ck
fq:f32[] = mul fp du
fr:f32[] = convert_element_type[
new_dtype=float32
weak_type=False
] ck
_:f32[] = mul ct fr
fs:f32[] = neg cs
ft:f32[] = mul dn fs
fu:f32[] = convert_element_type[
new_dtype=float32
weak_type=False
] ck
_:f32[] = mul ct fu
fv:f32[] = dot_general[
dimension_numbers=(([], []), ([], []))
preferred_element_type=float32
] 0.20000000298023224 fq
fw:f32[] = add ft fv
fx:f32[] = add cs fw
fy:f32[] = add cu cw
fz:f32[] = min cu cm
ga:f32[] = sub cm 9.999999974752427e-07
gb:bool[] = gt fy ga
gc:f32[] = sub cm fz
gd:f32[] = mul 0.5 gc
ge:f32[] = add fz gd
gf:f32[] = pjit[name=_where jaxpr=_where2] True cm ge
gg:f32[] = pjit[name=_where jaxpr=_where2] gb gf fy
gh:bool[] = eq cn cm
gi:f32[] = sub fz cn
gj:f32[] = pjit[name=_where jaxpr=_where3] gh 0 gi
gk:f32[] = sub cm cn
gl:f32[] = pjit[name=_where jaxpr=_where3] gh 1 gk
_:f32[] = div gj gl
gm:f32[] = pjit[name=_where jaxpr=_where2] True fx cs
gn:i32[] = add cy 1
go:i32[] = pjit[name=_where jaxpr=_where] True 1 0
gp:i32[] = add cz go
gq:i32[] = pjit[name=_where jaxpr=_where] True 0 1
gr:i32[] = add da gq
gs:bool[] = and True cq
gt:f32[] = copy fz
gu:f32[4096] = maybe_set[
i_static=None
i_treedef=PyTreeDef(*)
kwargs={}
makes_false_steps=False
] gs dc gt de
gv:bool[] = and True cq
gw:f32[] = copy gm
gx:f32[4096] = maybe_set[
i_static=None
i_treedef=PyTreeDef(*)
kwargs={}
makes_false_steps=False
] gv dd gw de
gy:i32[] = pjit[name=_where jaxpr=_where] True 1 0
gz:i32[] = add de gy
ha:f32[] = copy gm
hb:f32[] = copy fz
hc:f32[] = copy gg
hd:f32[] = copy cw
he:i32[] = copy gn
hf:i32[] = copy gp
hg:i32[] = copy gr
hh:i32[] = copy db
hi:f32[4096] = copy gu
_:bool[] = copy cq
hj:f32[4096] = copy gx
_:bool[] = copy cq
hk:i32[] = copy gz
hl:bool[] = copy cq
hm:f32[] = copy ha
hn:f32[] = copy cs
ho:f32[] = select_if_vmap hl hm hn
hp:bool[] = copy cq
hq:f32[] = copy hb
hr:f32[] = copy ct
hs:f32[] = select_if_vmap hp hq hr
ht:bool[] = copy cq
hu:f32[] = copy hc
hv:f32[] = copy cu
hw:f32[] = select_if_vmap ht hu hv
hx:bool[] = copy cq
hy:bool[] = copy False
hz:bool[] = copy cv
ia:bool[] = select_if_vmap hx hy hz
ib:bool[] = copy cq
ic:f32[] = copy hd
id:f32[] = copy cw
ie:f32[] = select_if_vmap ib ic id
if:bool[] = copy cq
ig:i32[] = copy 0
ih:i32[] = copy cx
ii:i32[] = select_if_vmap if ig ih
ij:bool[] = copy cq
ik:i32[] = copy he
il:i32[] = copy cy
im:i32[] = select_if_vmap ij ik il
in:bool[] = copy cq
io:i32[] = copy hf
ip:i32[] = copy cz
iq:i32[] = select_if_vmap in io ip
ir:bool[] = copy cq
is:i32[] = copy hg
it:i32[] = copy da
iu:i32[] = select_if_vmap ir is it
iv:bool[] = copy cq
iw:i32[] = copy hh
ix:i32[] = copy db
iy:i32[] = select_if_vmap iv iw ix
iz:bool[] = copy cq
ja:i32[] = copy hk
jb:i32[] = copy de
jc:i32[] = select_if_vmap iz ja jb
jd:i32[] = add cp 1
je:bool[] = lt hb cm
jf:bool[] = and je co
jg:bool[] = and cq jf
in (jd, jg, cq, ho, hs, hw, ia, ie, ii, im, iq, iu, iy, hi, hj,
jc) }
body_nconsts=5
cond_jaxpr={ lambda ; jh:f32[] ji:bool[] jj:i32[] jk:bool[] jl:bool[]
jm:f32[] jn:f32[] jo:f32[] jp:bool[] jq:f32[] jr:i32[] js:i32[]
jt:i32[] ju:i32[] jv:i32[] jw:f32[4096] jx:f32[4096] jy:i32[]. let
jz:bool[] = lt jn jh
ka:bool[] = and jz ji
kb:bool[] = unvmap_any ka
kc:bool[] = lt jj 4096
kd:bool[] = convert_element_type[
new_dtype=bool
weak_type=False
] kc
ke:bool[] = and kb kd
kf:bool[] = nonbatchable[
allow_constant_across_batch=True
msg=Nonconstant batch. `equinox.internal.while_loop` has received a batch of values that were expected to be constant. This is probably an internal error in the library you are using.
] ke
in (kf,) }
cond_nconsts=2
] bf v bb y bf bd w 0 bx True bp bq br False bs 0 0 0 0 0 bt bu 0
kg:bool[] = lt bz bf
kh:bool[] = and kg x
ki:i32[] = pjit[
name=_where
jaxpr={ lambda ; kj:bool[] kk:i32[] kl:i32[]. let
km:i32[] = select_n kj kl kk
in (km,) }
] kh 1 cc
_:f32[] = nondifferentiable_backward[
msg=Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`.
symbolic=True
] by
_:f32[] = nondifferentiable_backward[
msg=Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`.
symbolic=True
] bz
_:f32[] = nondifferentiable_backward[
msg=Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`.
symbolic=True
] ca
_:f32[] = nondifferentiable_backward[
msg=Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`.
symbolic=True
] cb
kn:i32[] = nondifferentiable_backward[
msg=Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`.
symbolic=True
] ki
ko:i32[] = nondifferentiable_backward[
msg=Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`.
symbolic=True
] cd
kp:i32[] = nondifferentiable_backward[
msg=Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`.
symbolic=True
] ce
kq:i32[] = nondifferentiable_backward[
msg=Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`.
symbolic=True
] cf
_:i32[] = nondifferentiable_backward[
msg=Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`.
symbolic=True
] cg
kr:f32[4096] = nondifferentiable_backward[
msg=Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`.
symbolic=True
] ch
ks:f32[4096] = nondifferentiable_backward[
msg=Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`.
symbolic=True
] ci
_:i32[] = nondifferentiable_backward[
msg=Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`.
symbolic=True
] cj
kt:f32[] = convert_element_type[
new_dtype=float32
weak_type=False
] bb
ku:f32[4096] = mul kr kt
kv:f32[] = convert_element_type[
new_dtype=float32
weak_type=False
] bb
kw:f32[] = mul bd kv
kx:f32[] = convert_element_type[
new_dtype=float32
weak_type=False
] bb
ky:f32[] = mul bf kx
kz:bool[] = eq kn 0
la:bool[] = eq kn 8
lb:bool[] = or kz la
lc:bool[] = not lb
_:i32[] = add 1 1
_:i32[] ld:f32[] le:f32[] lf:f32[4096] lg:f32[4096] lh:i32[] li:i32[]
lj:i32[] lk:i32[] = pjit[
name=branched_error_if_impl
jaxpr={ lambda ; ll:f32[] lm:f32[] ln:f32[4096] lo:f32[4096] lp:i32[]
lq:i32[] lr:i32[] ls:i32[] lt:bool[] lu:i32[]. let
lv:bool[] = unvmap_any lt
lw:i32[] = unvmap_max lu
lx:f32[] ly:f32[] lz:f32[4096] ma:f32[4096] mb:i32[] mc:i32[]
md:i32[] me:i32[] = custom_jvp_call[
call_jaxpr={ lambda ; mf:f32[] mg:f32[] mh:f32[4096] mi:f32[4096]
mj:i32[] mk:i32[] ml:i32[] mm:i32[] mn:bool[] mo:i32[]. let
mp:i32[] = convert_element_type[
new_dtype=int32
weak_type=False
] mn
mq:f32[] mr:f32[] ms:f32[4096] mt:f32[4096] mu:i32[]
mv:i32[] mw:i32[] mx:i32[] = cond[
branches=(
{ lambda ; my_:i32[] mz:f32[] na:f32[] nb:f32[4096]
nc:f32[4096] nd:i32[] ne:i32[] nf:i32[] ng:i32[]. let
in (mz, na, nb, nc, nd, ne, nf, ng) }
{ lambda ; nh:i32[] ni_:f32[] nj_:f32[] nk_:f32[4096]
nl_:f32[4096] nm_:i32[] nn_:i32[] no_:i32[] np_:i32[]. let
nq:f32[] nr:f32[] ns:f32[4096] nt:f32[4096] nu:i32[]
nv:i32[] nw:i32[] nx:i32[] = pure_callback[
callback=_FlatCallback(callback_func=<function _error.<locals>.raises at 0x7ef9d5b48160>, in_tree=PyTreeDef(((*,), {})))
result_avals=(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[4096]), ShapedArray(float32[4096]), ShapedArray(int32[]), ShapedArray(int32[]), ShapedArray(int32[]), ShapedArray(int32[]))
sharding=None
vectorized=False
] nh
ny:f32[] nz:f32[] oa:f32[4096] ob:f32[4096] oc:i32[]
od:i32[] oe:i32[] of:i32[] = pure_callback[
callback=_FlatCallback(callback_func=<function _error.<locals>.tpu_msg at 0x7ef9d5b481f0>, in_tree=PyTreeDef(((CustomNode(Solution[('t0', 't1', 'ts', 'ys', 'interpolation', 'stats', 'result', 'solver_state', 'controller_state', 'made_jump', 'event_mask'), (), ()], [*, *, *, *, None, {'max_steps': None, 'num_accepted_steps': *, 'num_rejected_steps': *, 'num_steps': *}, CustomNode(EnumerationItem[('_value',), ('_enumeration',), (<class 'diffrax._solution.RESULTS'>,)], [*]), None, None, None, None]), *), {})))
result_avals=(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[4096]), ShapedArray(float32[4096]), ShapedArray(int32[]), ShapedArray(int32[]), ShapedArray(int32[]), ShapedArray(int32[]))
sharding=None
vectorized=False
] nq nr ns nt nu nv nw nx nh
in (ny, nz, oa, ob, oc, od, oe, of) }
)
] mp mo mf mg mh mi mj mk ml mm
in (mq, mr, ms, mt, mu, mv, mw, mx) }
jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7ef9d5b48790>
num_consts=0
symbolic_zeros=True
] ll lm ln lo lp lq lr ls lv lw
in (0, lx, ly, lz, ma, mb, mc, md, me) }
] kw ky ku ks kp kq ko kn lc kn
in (0, ld, le, lf, lg, lh, li, lj, lk) }
] s
in (t,) }
]
in (r,) }
pure jax
{ lambda ; . let
a:f32[101] = pjit[
name=homemade_simu
jaxpr={ lambda b:u32[2]; . let
c:f32[] = sqrt 0.01
d:key<fry>[] = random_wrap[impl=fry] b
e:f32[101] = pjit[
name=_normal
jaxpr={ lambda ; f:key<fry>[]. let
g:f32[101] = pjit[
name=_normal_real
jaxpr={ lambda ; h:key<fry>[]. let
i:f32[101] = pjit[
name=_uniform
jaxpr={ lambda ; j:key<fry>[] k:f32[] l:f32[]. let
m:f32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
] k
n:f32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
] l
o:u32[101] = random_bits[bit_width=32 shape=(101,)] j
p:u32[101] = shift_right_logical o 9
q:u32[101] = or p 1065353216
r:f32[101] = bitcast_convert_type[new_dtype=float32] q
s:f32[101] = sub r 1.0
t:f32[1] = sub n m
u:f32[101] = mul s t
v:f32[101] = add u m
w:f32[101] = max m v
in (w,) }
] h -0.9999999403953552 1.0
x:f32[101] = erf_inv i
y:f32[101] = mul 1.4142135381698608 x
in (y,) }
] f
in (g,) }
] d
z:f32[] = convert_element_type[new_dtype=float32 weak_type=False] c
ba:f32[101] = mul z e
_:f32[] bb:f32[101] = scan[
_split_transpose=False
jaxpr={ lambda ; bc:f32[] bd:f32[]. let
be:f32[] = neg bc
bf:f32[] = mul be 0.01
bg:f32[] = mul 0.20000000298023224 bd
bh:f32[] = convert_element_type[
new_dtype=float32
weak_type=False
] bf
bi:f32[] = add bh bg
bj:f32[] = convert_element_type[
new_dtype=float32
weak_type=False
] bc
bk:f32[] = add bj bi
in (bk, bc) }
length=101
linear=(False, False)
num_carry=1
num_consts=0
reverse=False
unroll=1
] 1.0 ba
in (bb,) }
]
in (a,) }
I believe most of this comes from the UBP, since if I do
@jax.jit
def homemade_simu():
ts = jnp.linspace(t0, t1, ndt)
def step(y, t):
dw = brownian_motion.evaluate(t, t + dt)
dy = drift(None, y, None) * dt + diffusion(None, y, None) * dw
return y + dy, y
return jax.lax.scan(step, 1.0, ts)[-1]
I see the times are pretty much the same. Perhaps this does indicate that there is room for cutting down the speed costs of the UBP related overhead.
FWIW I think the speed difference here does seem unacceptably large. This seems like it should be improved.
Starting with the low-hanging fruit to be sure we're doing more of an equal comparison: can you try setting EQX_ON_ERROR=nan and diffeqsolve(throw=False), to disable all error checks. Those are fairly slow.
Also, can you try using stepsize_controller=StepTo(...). By default Diffrax does not recompile if the number of steps changes (e.g. because t1 changes), but a lax.scan implementation does. Diffrax pays a small amount of runtime cost for this generality. Using StepTo instead bakes in the discretisation in the same way as a lax.scan.
With throw=False, EQX_ERROR=NAN and step to, this is what I see
code
import os
os.environ["EQX_ON_ERROR"] = "nan"
import diffrax as dx
import jax
import jax.numpy as jnp
from matplotlib import pyplot as plt
# === simulation parameters
key = jax.random.PRNGKey(42)
t0 = 0
t1 = 1
y0 = 1.0
ndt = 101
dt = (t1 - t0) / (ndt - 1)
drift = lambda t, y, args: -y
diffusion = lambda t, y, args: 0.2
steps = jnp.linspace(t0, t1, ndt)
brownian_motion = dx.UnsafeBrownianPath(shape=(), key=key)
solver = dx.Euler()
terms = dx.MultiTerm(dx.ODETerm(drift), dx.ControlTerm(diffusion, brownian_motion))
saveat = dx.SaveAt(steps=True)
@jax.jit
def diffrax_simu():
return dx.diffeqsolve(terms, solver, t0, t1, dt0=None, y0=y0, saveat=saveat, adjoint=dx.DirectAdjoint(), throw=False, stepsize_controller=dx.StepTo(ts=steps)).ys
@jax.jit
def homemade_simu():
dWs = jnp.sqrt(dt) * jax.random.normal(key, (ndt,))
def step(y, dW):
dy = drift(None, y, None) * dt + diffusion(None, y, None) * dW
return y + dy, y
return jax.lax.scan(step, 1.0, dWs)[-1]
y = diffrax_simu().block_until_ready()
plt.plot(y)
y = homemade_simu().block_until_ready()
plt.plot(y)
plt.show()
%timeit _ = diffrax_simu().block_until_ready()
%timeit _ = homemade_simu().block_until_ready()
(diffrax top, custom bottom)
2.18 ms ± 351 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 109 µs ± 25.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
(without any of those things I had):
2.43 ms ± 666 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 110 µs ± 15.6 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
(all on CPU, just a slower CPU, but the 20-30x slowdown seems of the same scale)
So you definitely don't want DirectAdjoint: this is actually really slow and should be avoided if possible. (It exists to handle some autodiff edge cases, I'd love to remove it sometime...) Use the default instead.
Make sure you include an argument (say y0) to both jitted functions -- XLA may have different behavior around constant folding.
I'd also try with and without SaveAt(steps=True). (And adjusting the scan appropriately.) I think this should be equivalent either way but I'm not 100% certain.
With all of the above in, then at that point there shouldn't actually be that much difference between the two implementations. (And if there is then we should figure out what.)
The default actually errors with UBP which is why I changed to direct adjoint
ValueError: `adjoint=RecursiveCheckpointAdjoint()` does not support `UnsafeBrownianPath`. Consider using `adjoint=DirectAdjoint()` instead.
Ah, right. I've just checked and in the case of an unsafe SDE we do actually arrange for DirectAdjoint to do a scan so that should be fine:
https://github.com/patrick-kidger/diffrax/blob/ada5229c46ed041b30090b969f9943082e5300d6/diffrax/_adjoint.py#L352
(In retrospect I think we could have arranged for the default adjoint to also do the same thing, that might be a small usability improvement.)
Anyway, that's everything off the top of my head -- I might be forgetting something but with these settings then I think Diffrax should be doing something similar to the simple lax.scan implementation. But clearly we're missing something!
(EDIT: we still have one discrepancy I have just noticed: generating the Brownian samples in advance vs on-the-fly.)
If you'd like to dig into this then it might be time to stare at some jaxprs or HLO for the two programs. If you want to do this at the jaxpr level then you might find eqxi.finalise_jaxpr(and friends) to be a useful set of tools here:
https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_finalise_jaxpr.py
Many primitives exist just to add e.g. an autodiff rule, so we can simplify our jaxprs down to what actually gets lowered by ignoring that and tracing through their impl rules instead.
DirectAjoint does slow things down, but not all the way. If I switch to a branch that allows for UBP + recursive adjoint, it's faster but still around ~4x gap. If I account for the fact that UBP has to split keys but the other doesn't, I get the gap to be around ~1.1-1.2 (which maybe isn't ideal, but seems much more reasonable to me given there's probably some other if statements/logging that might exist).
x = Timer(lambda : diffrax_simu(y0).block_until_ready())
print(x.timeit(number=100))
x = Timer(lambda : homemade_simu(y0).block_until_ready())
print(x.timeit(number=100))
with (above things, NAN, steps, function input, stepto, max steps, etc. all that) and direct adjoint: 0.002462916076183319 0.0005935421213507652
w/ checkpoint adjoint (on an internal branch that had some UBP changes to work with checkpoint): 0.002062791958451271 0.0005716248415410519
w/ both splitting keys: 0.0019747079350054264 0.001669874880462885
(code changed to:
@jax.jit
def homemade_simu(yy):
def step(y1, dW):
y, k = y1
k, subkey = jax.random.split(k)
dw = jnp.sqrt(dt) * jax.random.normal(subkey)
dy = drift(None, y, None) * dt + diffusion(None, y, None) * dw
return (y + dy, k), y
return jax.lax.scan(step, (yy, key), steps)[-1]
)
Aha, interesting! Good to have more-or-less gotten to the bottom of the cause of this.
So:
- I'd be curious to see what your version of
RecursiveCheckpointAdjointdoes, and how that compares to the unsafe-SDE-branch ofDirectAdjoint. - I suppose generating the Brownian samples in advance, rather than on-the-fly, is very plausibly much faster. (Although I note that it will be more memory-intensive.) Off the top of my head I'm not immediately sure how to arrange it so that the case of using a constant step size controller and an
UnsafeBrownianPathcould make it possible to precompute things.
On point 2, I suspect the solution may require allowing the control to have additional state. (Which is also what we'd need to make VBT faster.) Perhaps it's time to bite that bullet and allow for that to happen. Happy to hear suggestions on this one!
- That is something I want to investigate as well (and also organize more of it pushed to a fork for others to check), admittedly will take a little bit for me to get to
- Would it be possible to add a "precompute" flag (or something to that effect) to UBP? Which would generate the noise ahead of time (and the size is just determined by the max steps or user input), without requiring a stateful approach. This might(?, if the dt multiplication is still done in the loop) also be compatible with adaptive solver that don't reject steps ("previsible" I think James calls them).
- I am in general an advocate of stateful controls (also discussed in #490), although I haven't thought much more on it since the discussion in that issue (which is very similar to how my stateful UBP is implemented).
- Okay, lmk what you find.
- I'm not sure. The way the controls are called at the moment is with the
t, not the step index. We'd also have to have a way to pass the number of steps etc to the control. FWIW I'd probably lean towards not having a flag and just always doing this when possible. - I think to do this 'properly' we might need to have
AbstractSolver.stepalso accept the control state, and then pipe it through appropriately. Then also return the updated state. Unfortunately I think we're looking at a hard break to both the control and the solver APIs here, but c'est la vie.
I'm not sure. The way the controls are called at the moment is with the t, not the step index. We'd also have to have a way to pass the number of steps etc to the control. FWIW I'd probably lean towards not having a flag and just always doing this when possible.
Yes, looking at it more, this would probably have to be change/addon to support passing the "step" counter around. If this is an acceptable change, I don't think it would be too much for me to get a PR up.
I think to do this 'properly' we might need to have AbstractSolver.step also accept the control state, and then pipe it through appropriately. Then also return the updated state. Unfortunately I think we're looking at a hard break to both the control and the solver APIs here, but c'est la vie.
This was my conclusion as well, and I started drafting a branch for this, but figured it would require a pretty noticeable breaking change (at least internally), and I figured diffrax was more fait accompli than c'est la vie when it came to this level of breaking changes.
It's true, I try to avoid breaking changes where possible! They're no fun for anyone. But the performance issues discussed here genuinely are quite severe, so I think they're actually strong enough to motivate making a breaking change of this nature.
If this is an acceptable change, I don't think it would be too much for me to get a PR up.
Awesome, I'm looking forward to it! Let's see if we can get the stateful controls done at the same time? I'd like to contain the breaking changes to a single release, ideally.
Awesome, I'm looking forward to it! Let's see if we can get the stateful controls done at the same time? I'd like to contain the breaking changes to a single release, ideally.
Sure I can work on them simultaneously.
A design question in advance (since I have now gone back and forth on it). Should AbstractPath be stateful (pros: most general, highlight level abstract class, cons: requires all the interpolations and everything else that inherits from it to now for the most part have another empty input) or just the AbstractBrownianPath (pros: less overall impact, more focused on what the statefulness is actually for, cons: breaks the inheritance pattern w/ evaluate returning more information now causing other effects as well). My first reaction was to do it at the AbstractPath level to be maximally general, but going through every interpolation and adding a path_state: Optional[_PathState] = None to everything made it feel a little forced, so I figured I'd just check before committing too much to one way.
I think at the AbstractPath level makes sense to me.
Perhaps we could tackle two birds with one stone here: we could put the new stateful API under __call__ (#160), to mostly-preserve backward compatibility.
@pierreguilmin feel free to try out the draft I have currently: https://github.com/patrick-kidger/diffrax/pull/559. This (under some situations) allow a 100x speedup over VBT. For small problems with very small numbers of steps <100, scan is still faster (due to diffrax overhead), but for longer problems diffrax can be faster.
See the benchmark file where I get
Results on Mac M1 CPU:
VBT: 0.184882
Old UBP: 0.016347
New UBP: 0.013731
New UBP + Precompute: 0.002430
Pure Jax: 0.002799
Hello @lockwo, thanks for the great work, that's amazing news! I'll try to take a look at your branch this week to benchmark it on the problems we really care about. FYI, we're planning to use SDE solvers to simulate continuous quantum measurements, and add these solvers to our library dynamiqs.
Sounds good, let me know if you encounter any issues. I will try to get it into a more final state by EoW so Patrick can start reviewing (though I suspect it will take some time to get to main, since there's also Andraz's PR, and the fact that there are many breaking changes in my PR that will likely require some back and forth).