feat: diagonal sparse Pijk + discussion on merging limb-dark and starry
This PR is an attempt to understand what can be done to merge the limb-darkening light curve implementation with starry.
Motivation
The idea is that, if only limb-darkening is present on a map defined with starry, JAX should figure out which part of the different matrices are necessary to compute the starry light-curve, and that these should be pretty similar in number to the ones involved in core.limb_dark.light_curve. Hence performances of the starry light curves should be similar to the limb-dark one, i.e. experimental.starry.light_curves.surface_light_curve with
surface = Surface(y=None, u=u)
Description of the current modifs
In the Ylm and Pijk interfaces, we defined a diagonal attribute which indicates if the non-radial coefficients of the spherical harmonics basis, i.e. not contributing to limb-darkening, are all zeros. The first commit in this PR fixes that for the Pijk basis, and allows to pass a diagonal kwarg to the Pijk.to_sparse method, so that the computation
# see end of jaxoplanet/experimental/starry/light_curves.py
p_y.tosparse(diagonal=only_u) @ design_matrix_p
takes into account the sparsity of p_y.
Results
For now this is not working. The starry light curve is ~ 4 times slower than the limb-dark one.
starry
import jax
from jaxoplanet.experimental.starry.surface import Surface
from jaxoplanet.experimental.starry.light_curves import surface_light_curve
import numpy as np
r = 0.1
u = (0.1, 0.2)
b = np.linspace(0, 1 + r, 1000)
order = 20
surface = Surface(u=u)
function = jax.jit(jax.vmap(lambda b: surface_light_curve(surface, r, z=10.0, y=b, order=order)))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))
439 μs ± 2.82 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
limb-dark
from jaxoplanet.core.limb_dark import light_curve
function = jax.jit(jax.vmap(lambda b: light_curve(u, b, r, order=order)))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))
115 μs ± 2.04 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
The evaluation time is widely dominated by the call to solution_vector on both ends. But I wonder how to write the starry light curve function so that only relevant parts of the solution vectors are computed if only limb-darkening is present.
Something I noticed is that with a scalar b (no vmap) we get 9.42 μs vs. 7.49 μs.
I added a diagonal kwarg to the solution_vector of starry, which lead to skipping the computation for m != 0. It works pretty well but I'd like to understand why the vmap is still that slow and if there is anything we can do.
Here is the new benchmark:
import jax
from jaxoplanet.experimental.starry.surface import Surface
from jaxoplanet.experimental.starry.light_curves import surface_light_curve
import numpy as np
r = 0.1
u = (0.1, 0.2)
b = np.linspace(0, 1 + r, 1000)
order = 20
surface = Surface(u=u)
function = jax.jit(jax.vmap(lambda b: surface_light_curve(surface, r, z=10.0, y=b, order=order)))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))
from jaxoplanet.core.limb_dark import light_curve
function = jax.jit(lambda b: light_curve(u, b, r, order=order))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))
206 μs ± 4.74 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
118 μs ± 2.28 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
But for a single b:
8.26 μs ± 56.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
7.33 μs ± 11 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
I was playing around yesterday and noticed that I got a roughly 20% speed increase when I switched the order of the jit and vmap! So vmapping on the jitted function.
Thanks @soichiro-hattori!
Also, my approach is very wrong... the solution vector is in Green's basis so I cannot do what I did the way I did it. I'll work on this.
Never mind I lied @lgrcia!
I think I managed to skip computation of the non-diagonal terms. The green's basis is actually pretty similar to the polynomial basis, so I computed which indices $(l,m)$ correspond to the off-diagonal terms (in the polynomial basis) and skipped the computation of these terms in solution_vector (both in p_integral and q_integral).
Here are some benchmark:
import jax
jax.config.update("jax_enable_x64", True)
from jaxoplanet.experimental.starry.solution import solution_vector
r = 0.1
b = 0.1
order = 20
function = jax.jit(solution_vector(2, order, diagonal=False))
jax.block_until_ready(function(b, r))
print("all terms computed")
%timeit jax.block_until_ready(function(b, r))
function = jax.jit(solution_vector(2, order, diagonal=True))
jax.block_until_ready(function(b, r))
print("\nOnly diagonal terms")
%timeit jax.block_until_ready(function(b, r))
from jaxoplanet.core.limb_dark import solution_vector
function = jax.jit(solution_vector(2, order))
print("\ncore.limb_dark version")
jax.block_until_ready(function(b, r))
%timeit jax.block_until_ready(function(b, r))
all terms computed
6.44 μs ± 21.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
Only diagonal terms
6.05 μs ± 24.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
core.limb_dark version
5.46 μs ± 36.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
But if we consider a vmaped version:
import jax
jax.config.update("jax_enable_x64", True)
from jaxoplanet.experimental.starry.solution import solution_vector
r = 0.1
b = jax.numpy.linspace(0, 1 + r, 1000)
order = 20
function = jax.jit(jax.vmap(jax.jit(solution_vector(2, order, diagonal=False)), (0, None)))
jax.block_until_ready(function(b, r))
print("all terms computed")
%timeit jax.block_until_ready(function(b, r))
function = jax.jit(jax.vmap(jax.jit(solution_vector(2, order, diagonal=True)), (0, None)))
jax.block_until_ready(function(b, r))
print("\nOnly diagonal terms")
%timeit jax.block_until_ready(function(b, r))
from jaxoplanet.core.limb_dark import solution_vector
function = jax.jit(jax.vmap(jax.jit(solution_vector(2, order)), (0, None)))
print("\ncore.limb_dark version")
jax.block_until_ready(function(b, r))
%timeit jax.block_until_ready(function(b, r))
all terms computed
295 μs ± 1.04 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Only diagonal terms
190 μs ± 4.2 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
core.limb_dark version
74.5 μs ± 439 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
It'll be nice to know where the difference is coming from.
I'm starting to take a look at this and we'll see how far I get. To begin with, I'm looking at the jaxpr for each calculation. I started with just the zeroth order computation, which should be identical in both cases, and added jax.make_jaxpr(function)(b, r) after each benchmark. Here are the two jaxprs:
starry
let _where = { lambda ; a:bool[] b:f64[] c:i64[]. let
d:f64[] = convert_element_type[new_dtype=float64 weak_type=True] c
e:f64[] = select_n a d b
in (e,) } in
let _where1 = { lambda ; f:bool[] g:i64[] h:f64[]. let
i:f64[] = convert_element_type[new_dtype=float64 weak_type=True] g
j:f64[] = select_n f h i
in (j,) } in
{ lambda ; k:f64[] l:f64[]. let
m:f64[1] = pjit[
name=impl
jaxpr={ lambda ; n:f64[] o:f64[]. let
p:f64[1] = pjit[
name=impl
jaxpr={ lambda q:f64[20] r:f64[1,20] s:i64[1]; t:f64[] u:f64[]. let
v:f64[] = abs t
w:f64[] = abs u
x:f64[] = integer_pow[y=2] v
y:f64[] = sub w 1.0
z:f64[] = add w 1.0
ba:f64[] = mul y z
bb:f64[] = sub 1.0 w
bc:f64[] = abs bb
bd:bool[] = gt v bc
be:f64[] = add 1.0 w
bf:bool[] = lt v be
bg:bool[] = convert_element_type[new_dtype=bool weak_type=False] bd
bh:bool[] = convert_element_type[new_dtype=bool weak_type=False] bf
bi:bool[] = and bg bh
bj:f64[] = pjit[name=_where jaxpr=_where] bi v 1
bk:f64[] = min w bj
bl:f64[] = max w bj
bm:f64[] = min bl 1.0
bn:f64[] = max bl 1.0
bo:f64[] = min bk bm
bp:f64[] = max bk bm
bq:f64[] = add bp bn
br:f64[] = add bo bq
bs:f64[] = sub bo bp
bt:f64[] = sub bn bs
bu:f64[] = mul br bt
bv:f64[] = sub bo bp
bw:f64[] = add bn bv
bx:f64[] = mul bu bw
by:f64[] = sub bp bn
bz:f64[] = add bo by
ca:f64[] = mul bx bz
cb:f64[] = max 0.0 ca
cc:f64[] = custom_jvp_call[
call_jaxpr={ lambda ; cd:f64[]. let
ce:f64[] = sqrt cd
in (ce,) }
jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x176fe2fc0>
num_consts=0
symbolic_zeros=False
] cb
cf:f64[] = pjit[name=_where jaxpr=_where] bi cc 0
cg:f64[] = add x ba
ch:f64[] = atan2 cf cg
ci:f64[] = sub x ba
cj:f64[] = atan2 cf ci
ck:f64[] = integer_pow[y=2] v
cl:f64[] = integer_pow[y=2] w
cm:f64[] = mul 4.0 v
cn:f64[] = mul cm w
co:f64[] = convert_element_type[
new_dtype=float64
weak_type=False
] cn
cp:bool[] = lt co 2.220446049250313e-15
cq:f64[] = pjit[name=_where jaxpr=_where1] cp 1 cn
cr:f64[] = sub 1.0 cl
cs:f64[] = sub cr ck
ct:f64[] = mul 2.0 v
cu:f64[] = mul ct w
cv:f64[] = add cs cu
cw:f64[] = div cv cq
cx:f64[] = max 0.0 cw
cy:f64[] = convert_element_type[
new_dtype=float64
weak_type=False
] w
cz:bool[] = lt cy 2.220446049250313e-15
da:f64[] = sub v w
db:f64[] = pjit[name=_where jaxpr=_where1] cz 1 w
dc:f64[] = mul 2.0 db
dd:f64[] = div da dc
de:f64[] = mul 0.5 ch
df:f64[] = convert_element_type[
new_dtype=float64
weak_type=False
] de
dg:f64[20] = mul df q
dh:f64[] = mul 0.5 ch
di:f64[] = convert_element_type[
new_dtype=float64
weak_type=False
] dh
dj:f64[20] = add dg di
_:f64[20] = cos dj
dk:f64[20] = sin dg
dl:f64[20] = integer_pow[y=2] dk
dm:f64[] = sub 1.0 cl
dn:f64[] = convert_element_type[
new_dtype=float64
weak_type=False
] cx
do:f64[20] = sub dn dl
dp:f64[] = convert_element_type[
new_dtype=float64
weak_type=False
] cq
dq:f64[20] = mul dp do
dr:f64[20] = pjit[
name=_where
jaxpr={ lambda ; ds:bool[] dt:f64[] du:f64[20]. let
dv:f64[] = convert_element_type[
new_dtype=float64
weak_type=False
] dt
dw:f64[20] = broadcast_in_dim[
broadcast_dimensions=()
shape=(20,)
] dv
dx:f64[20] = select_n ds du dw
in (dx,) }
] cp dm dq
dy:f64[20] = max 0.0 dr
dz:f64[20] = pow dy 1.5
ea:f64[20] = integer_pow[y=2] dl
eb:f64[20] = sub dl ea
ec:f64[] = convert_element_type[
new_dtype=float64
weak_type=False
] dd
ed:f64[20] = add ec dl
ee:f64[20] = pjit[
name=_where
jaxpr={ lambda ; ef:bool[] eg:i64[] eh:f64[20]. let
ei:f64[] = convert_element_type[
new_dtype=float64
weak_type=False
] eg
ej:f64[20] = broadcast_in_dim[
broadcast_dimensions=()
shape=(20,)
] ei
ek:f64[20] = select_n ef eh ej
in (ek,) }
] cz 0 ed
el:f64[20] = mul 2.0 dl
_:f64[20] = sub 1.0 el
em:f64[] = mul 2.0 w
en:f64[] = integer_pow[y=-1] em
eo:f64[] = convert_element_type[
new_dtype=float64
weak_type=False
] en
_:f64[20] = mul eo dz
ep:f64[] = mul 2.0 w
eq:f64[] = integer_pow[y=2] ep
er:f64[] = mul 2.0 eq
es:f64[20] = pow eb 1.0
et:f64[] = convert_element_type[
new_dtype=float64
weak_type=False
] er
eu:f64[20] = mul et es
ev:f64[20] = pow ee 0.0
ew:f64[20] = mul eu ev
ex:f64[1,20] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 20)
] ew
ey:f64[1,20] = mul ex r
ez:f64[1] = reduce_sum[axes=(1,)] ey
fa:f64[] = convert_element_type[
new_dtype=float64
weak_type=False
] de
fb:f64[1] = mul fa ez
fc:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0.0
fd:i64[1] = device_put[devices=[None] srcs=[None]] s
fe:bool[1] = lt fd 0
ff:i64[1] = add fd 1
fg:i64[1] = select_n fe fd ff
fh:i32[1] = convert_element_type[
new_dtype=int32
weak_type=False
] fg
fi:i32[1,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(1, 1)
] fh
fj:f64[1] = scatter[
dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))
indices_are_sorted=False
mode=GatherScatterMode.FILL_OR_DROP
unique_indices=False
update_consts=()
update_jaxpr=None
] fc fi fb
fk:f64[] = sub 1.5707963267948966 cj
fl:f64[] = cos fk
fm:f64[] = sin fk
fn:f64[] = mul 2.0 fk
fo:f64[] = add fn 3.141592653589793
_:f64[] = mul -2.0 fl
fp:f64[] = integer_pow[y=1] fl
fq:f64[] = mul 2.0 fp
fr:f64[] = integer_pow[y=1] fm
fs:f64[] = mul fq fr
ft:f64[] = mul 1.0 fo
fu:f64[] = add fs ft
fv:f64[] = div fu 2.0
fw:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] fv
fx:f64[1] = convert_element_type[
new_dtype=float64
weak_type=False
] fw
fy:f64[1] = sub fx fj
in (fy,) }
] n o
in (p,) }
] k l
in (m,) }
limb_dark
let _where = { lambda ; a:bool[] b:f64[] c:f64[]. let
d:f64[] = select_n a c b
in (d,) } in
let _where1 = { lambda ; e:bool[] f:f64[] g:f64[]. let
h:f64[] = select_n e g f
in (h,) } in
{ lambda ; i:f64[] j:f64[]. let
k:f64[1] = pjit[
name=impl
jaxpr={ lambda ; l:f64[] m:f64[]. let
n:f64[] = abs l
o:f64[] = abs m
p:f64[] = integer_pow[y=2] n
q:f64[] = sub o 1.0
r:f64[] = add o 1.0
s:f64[] = mul q r
t:f64[] = sub 1.0 o
u:f64[] = abs t
v:bool[] = gt n u
w:f64[] = add 1.0 o
x:bool[] = lt n w
y:bool[] = convert_element_type[new_dtype=bool weak_type=False] v
z:bool[] = convert_element_type[new_dtype=bool weak_type=False] x
ba:bool[] = and y z
bb:f64[] = pjit[name=_where jaxpr=_where] ba n 1.0
bc:f64[] = min o bb
bd:f64[] = max o bb
be:f64[] = min bd 1.0
bf:f64[] = max bd 1.0
bg:f64[] = min bc be
bh:f64[] = max bc be
bi:f64[] = add bh bf
bj:f64[] = add bg bi
bk:f64[] = sub bg bh
bl:f64[] = sub bf bk
bm:f64[] = mul bj bl
bn:f64[] = sub bg bh
bo:f64[] = add bf bn
bp:f64[] = mul bm bo
bq:f64[] = sub bh bf
br:f64[] = add bg bq
bs:f64[] = mul bp br
bt:f64[] = max 0.0 bs
bu:f64[] = custom_jvp_call[
call_jaxpr={ lambda ; bv:f64[]. let bw:f64[] = sqrt bv in (bw,) }
jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x176fe28e0>
num_consts=0
symbolic_zeros=False
] bt
bx:f64[] = pjit[name=_where jaxpr=_where] ba bu 0.0
by:f64[] = add p s
bz:f64[] = atan2 bx by
ca:f64[] = sub p s
cb:f64[] = atan2 bx ca
cc:f64[] = add 1.0 o
cd:bool[] = ge n cc
ce:f64[] = add 1.0 n
cf:bool[] = le ce o
cg:bool[] = convert_element_type[new_dtype=bool weak_type=False] cd
ch:bool[] = convert_element_type[new_dtype=bool weak_type=False] cf
ci:bool[] = or cg ch
cj:f64[] = pjit[name=_where jaxpr=_where] ci 1.0 n
ck:f64[] = integer_pow[y=2] cj
cl:f64[] = integer_pow[y=2] o
cm:f64[] = add cj o
cn:f64[] = add 1.0 cm
co:f64[] = sub 1.0 cm
cp:f64[] = mul cn co
cq:f64[] = mul 0.5 cl
cr:f64[] = mul 2.0 ck
cs:f64[] = add cl cr
ct:f64[] = mul cq cs
cu:f64[] = sub 1.0 cl
cv:f64[] = mul 3.141592653589793 cu
cw:f64[] = mul 2.0 cv
cx:f64[] = sub ct 0.5
cy:f64[] = mul 12.566370614359172 cx
cz:f64[] = add cw cy
da:f64[] = mul cl bz
db:f64[] = add cb da
dc:f64[] = mul bx 0.5
dd:f64[] = sub db dc
de:f64[] = sub 3.141592653589793 dd
df:f64[] = mul 2.0 de
dg:f64[] = sub 3.141592653589793 cb
dh:f64[] = neg dg
di:f64[] = mul 2.0 ct
dj:f64[] = mul di bz
dk:f64[] = add dh dj
dl:f64[] = mul 0.25 bx
dm:f64[] = mul 5.0 cl
dn:f64[] = add 1.0 dm
do:f64[] = add dn ck
dp:f64[] = mul dl do
dq:f64[] = sub dk dp
dr:f64[] = mul 2.0 dq
ds:f64[] = add df dr
dt:f64[] = mul 4.0 cj
du:f64[] = mul dt o
dv:f64[] = add cp du
dw:bool[] = gt dv du
dx:f64[] = pjit[name=_where jaxpr=_where1] dw cv de
dy:f64[] = pjit[name=_where jaxpr=_where1] dw cz ds
dz:f64[] = pjit[name=_where jaxpr=_where1] cd 3.141592653589793 dx
ea:f64[] = pjit[name=_where jaxpr=_where1] cf 0.0 dz
_:f64[] = pjit[name=_where jaxpr=_where] ci 0.0 dy
eb:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] ea
in (eb,) }
] i j
in (k,) }
They seem to start off the same, but in the middle I'm seeing a lot more device_puts and scatters in the starry version, so I'm going to see if I can track those down!
Oh - I think I know what it is! We should probably just merge the two implementations: we should use the closed form solution for s0 and s2, and then keep the numerical solutions for the others. I still think we might be able to optimize the implementation in starry for the numerical solutions too though. Where are those scatters coming from?!
See https://github.com/exoplanet-dev/jaxoplanet/pull/204
The matrices involved in the light curve computation from Agol 2019 (polynomial limb-darkening) and Luger 2019 (more general) are inherently of different sizes, even if Luger matrices are reduced to the minimal case of a limb-darkened surface (see https://github.com/exoplanet-dev/jaxoplanet/pull/204#issuecomment-2323046728).
So I suspect we will never really have the same performances between the two. But maybe I'm wrong.
One thing I did to bridge the gap is to compute the change of basis matrix from Agol's Green's basis to Luger's polynomial basis, so that the smaller solution vector from Agol can be used directly in the limb-darkened case of a starry surface. Combined with excluding the non-diagonal values in Luger's matrices, performances are much better. But again, not sure if we can push it further.
With these changes, processing times for single values of b and r are equal! But I still don't get why the vmapped version of starry acts differently (although having now closer performances to the limbdark one).
Benchmark of the new version
vmapped
import jax
jax.config.update("jax_enable_x64", True)
from jaxoplanet.experimental.starry.surface import Surface
from jaxoplanet.experimental.starry.light_curves import surface_light_curve
import numpy as np
r = 1.
u = (1.0, 1.0)
b = np.linspace(0, 1 + r, 1000)
order = 20
surface = Surface(u=u)
function = jax.jit(jax.vmap(lambda b: surface_light_curve(surface, r, z=10.0, y=b, order=order)))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))
from jaxoplanet.core.limb_dark import light_curve
function = jax.jit(jax.vmap(lambda b: light_curve(u, b, r, order=order)))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))
207 μs ± 10.8 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
132 μs ± 6.56 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
This time difference is the worst we will get, as the processing time gets increasingly more comparable for increasing values of order and increased degree of polynomial limb-darkening law.
Single b
import jax
from jaxoplanet.experimental.starry.surface import Surface
from jaxoplanet.experimental.starry.light_curves import surface_light_curve
import numpy as np
r = 0.1
u = (0.1, 0.2)
b = 0.1
order = 20
surface = Surface(u=u)
function = jax.jit(lambda b: surface_light_curve(surface, r, z=10.0, y=b, order=order))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))
from jaxoplanet.core.limb_dark import light_curve
function = jax.jit(lambda b: light_curve(u, b, r, order=order))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))
7.27 μs ± 68.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
7.42 μs ± 80.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
This time, the evaluation for a single b is equal, and the fact that we don't get that on the vmapped version is still a mystery...
I think these changes contain the idea of #204, since we are working with the minimal size of starry matrices possible (without the zeros). Looking forward to hearing your ideas on this.
I will close this PR for now as we decided that both implementations should remain separated.