jaxoplanet icon indicating copy to clipboard operation
jaxoplanet copied to clipboard

feat: diagonal sparse Pijk + discussion on merging limb-dark and starry

Open lgrcia opened this issue 1 year ago • 9 comments

This PR is an attempt to understand what can be done to merge the limb-darkening light curve implementation with starry.

Motivation

The idea is that, if only limb-darkening is present on a map defined with starry, JAX should figure out which part of the different matrices are necessary to compute the starry light-curve, and that these should be pretty similar in number to the ones involved in core.limb_dark.light_curve. Hence performances of the starry light curves should be similar to the limb-dark one, i.e. experimental.starry.light_curves.surface_light_curve with

surface = Surface(y=None, u=u)

Description of the current modifs

In the Ylm and Pijk interfaces, we defined a diagonal attribute which indicates if the non-radial coefficients of the spherical harmonics basis, i.e. not contributing to limb-darkening, are all zeros. The first commit in this PR fixes that for the Pijk basis, and allows to pass a diagonal kwarg to the Pijk.to_sparse method, so that the computation

# see end of jaxoplanet/experimental/starry/light_curves.py

p_y.tosparse(diagonal=only_u) @ design_matrix_p

takes into account the sparsity of p_y.

Results

For now this is not working. The starry light curve is ~ 4 times slower than the limb-dark one.

starry

import jax
from jaxoplanet.experimental.starry.surface import Surface
from jaxoplanet.experimental.starry.light_curves import surface_light_curve
import numpy as np

r = 0.1
u = (0.1, 0.2)
b = np.linspace(0, 1 + r, 1000)
order = 20

surface = Surface(u=u)
function = jax.jit(jax.vmap(lambda b: surface_light_curve(surface, r, z=10.0, y=b, order=order)))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))
439 μs ± 2.82 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

limb-dark

from jaxoplanet.core.limb_dark import light_curve

function = jax.jit(jax.vmap(lambda b: light_curve(u, b, r, order=order)))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))
115 μs ± 2.04 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

The evaluation time is widely dominated by the call to solution_vector on both ends. But I wonder how to write the starry light curve function so that only relevant parts of the solution vectors are computed if only limb-darkening is present.

Something I noticed is that with a scalar b (no vmap) we get 9.42 μs vs. 7.49 μs.

lgrcia avatar Aug 14 '24 17:08 lgrcia

I added a diagonal kwarg to the solution_vector of starry, which lead to skipping the computation for m != 0. It works pretty well but I'd like to understand why the vmap is still that slow and if there is anything we can do.

Here is the new benchmark:

import jax
from jaxoplanet.experimental.starry.surface import Surface
from jaxoplanet.experimental.starry.light_curves import surface_light_curve
import numpy as np

r = 0.1
u = (0.1, 0.2)
b = np.linspace(0, 1 + r, 1000)
order = 20

surface = Surface(u=u)
function = jax.jit(jax.vmap(lambda b: surface_light_curve(surface, r, z=10.0, y=b, order=order)))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))

from jaxoplanet.core.limb_dark import light_curve

function = jax.jit(lambda b: light_curve(u, b, r, order=order))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))
206 μs ± 4.74 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
118 μs ± 2.28 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

But for a single b:

8.26 μs ± 56.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
7.33 μs ± 11 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

lgrcia avatar Aug 16 '24 18:08 lgrcia

I was playing around yesterday and noticed that I got a roughly 20% speed increase when I switched the order of the jit and vmap! So vmapping on the jitted function.

soichiro-hattori avatar Aug 16 '24 18:08 soichiro-hattori

Thanks @soichiro-hattori!

Also, my approach is very wrong... the solution vector is in Green's basis so I cannot do what I did the way I did it. I'll work on this.

lgrcia avatar Aug 16 '24 18:08 lgrcia

Never mind I lied @lgrcia!

soichiro-hattori avatar Aug 16 '24 20:08 soichiro-hattori

I think I managed to skip computation of the non-diagonal terms. The green's basis is actually pretty similar to the polynomial basis, so I computed which indices $(l,m)$ correspond to the off-diagonal terms (in the polynomial basis) and skipped the computation of these terms in solution_vector (both in p_integral and q_integral).

Here are some benchmark:

import jax
jax.config.update("jax_enable_x64", True)
from jaxoplanet.experimental.starry.solution import solution_vector

r = 0.1
b = 0.1
order = 20

function = jax.jit(solution_vector(2, order, diagonal=False))
jax.block_until_ready(function(b, r))
print("all terms computed")
%timeit jax.block_until_ready(function(b, r))

function = jax.jit(solution_vector(2, order, diagonal=True))
jax.block_until_ready(function(b, r))
print("\nOnly diagonal terms")
%timeit jax.block_until_ready(function(b, r))

from jaxoplanet.core.limb_dark import solution_vector

function = jax.jit(solution_vector(2, order))
print("\ncore.limb_dark version")
jax.block_until_ready(function(b, r))
%timeit jax.block_until_ready(function(b, r))
all terms computed
6.44 μs ± 21.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

Only diagonal terms
6.05 μs ± 24.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

core.limb_dark version
5.46 μs ± 36.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

But if we consider a vmaped version:

import jax
jax.config.update("jax_enable_x64", True)
from jaxoplanet.experimental.starry.solution import solution_vector

r = 0.1
b = jax.numpy.linspace(0, 1 + r, 1000)
order = 20

function = jax.jit(jax.vmap(jax.jit(solution_vector(2, order, diagonal=False)), (0, None)))
jax.block_until_ready(function(b, r))
print("all terms computed")
%timeit jax.block_until_ready(function(b, r))

function = jax.jit(jax.vmap(jax.jit(solution_vector(2, order, diagonal=True)), (0, None)))
jax.block_until_ready(function(b, r))
print("\nOnly diagonal terms")
%timeit jax.block_until_ready(function(b, r))

from jaxoplanet.core.limb_dark import solution_vector

function = jax.jit(jax.vmap(jax.jit(solution_vector(2, order)), (0, None)))
print("\ncore.limb_dark version")
jax.block_until_ready(function(b, r))
%timeit jax.block_until_ready(function(b, r))
all terms computed
295 μs ± 1.04 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Only diagonal terms
190 μs ± 4.2 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

core.limb_dark version
74.5 μs ± 439 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

It'll be nice to know where the difference is coming from.

lgrcia avatar Aug 16 '24 20:08 lgrcia

I'm starting to take a look at this and we'll see how far I get. To begin with, I'm looking at the jaxpr for each calculation. I started with just the zeroth order computation, which should be identical in both cases, and added jax.make_jaxpr(function)(b, r) after each benchmark. Here are the two jaxprs:

starry
let _where = { lambda ; a:bool[] b:f64[] c:i64[]. let
    d:f64[] = convert_element_type[new_dtype=float64 weak_type=True] c
    e:f64[] = select_n a d b
  in (e,) } in
let _where1 = { lambda ; f:bool[] g:i64[] h:f64[]. let
    i:f64[] = convert_element_type[new_dtype=float64 weak_type=True] g
    j:f64[] = select_n f h i
  in (j,) } in
{ lambda ; k:f64[] l:f64[]. let
    m:f64[1] = pjit[
      name=impl
      jaxpr={ lambda ; n:f64[] o:f64[]. let
          p:f64[1] = pjit[
            name=impl
            jaxpr={ lambda q:f64[20] r:f64[1,20] s:i64[1]; t:f64[] u:f64[]. let
                v:f64[] = abs t
                w:f64[] = abs u
                x:f64[] = integer_pow[y=2] v
                y:f64[] = sub w 1.0
                z:f64[] = add w 1.0
                ba:f64[] = mul y z
                bb:f64[] = sub 1.0 w
                bc:f64[] = abs bb
                bd:bool[] = gt v bc
                be:f64[] = add 1.0 w
                bf:bool[] = lt v be
                bg:bool[] = convert_element_type[new_dtype=bool weak_type=False] bd
                bh:bool[] = convert_element_type[new_dtype=bool weak_type=False] bf
                bi:bool[] = and bg bh
                bj:f64[] = pjit[name=_where jaxpr=_where] bi v 1
                bk:f64[] = min w bj
                bl:f64[] = max w bj
                bm:f64[] = min bl 1.0
                bn:f64[] = max bl 1.0
                bo:f64[] = min bk bm
                bp:f64[] = max bk bm
                bq:f64[] = add bp bn
                br:f64[] = add bo bq
                bs:f64[] = sub bo bp
                bt:f64[] = sub bn bs
                bu:f64[] = mul br bt
                bv:f64[] = sub bo bp
                bw:f64[] = add bn bv
                bx:f64[] = mul bu bw
                by:f64[] = sub bp bn
                bz:f64[] = add bo by
                ca:f64[] = mul bx bz
                cb:f64[] = max 0.0 ca
                cc:f64[] = custom_jvp_call[
                  call_jaxpr={ lambda ; cd:f64[]. let
                      ce:f64[] = sqrt cd
                    in (ce,) }
                  jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x176fe2fc0>
                  num_consts=0
                  symbolic_zeros=False
                ] cb
                cf:f64[] = pjit[name=_where jaxpr=_where] bi cc 0
                cg:f64[] = add x ba
                ch:f64[] = atan2 cf cg
                ci:f64[] = sub x ba
                cj:f64[] = atan2 cf ci
                ck:f64[] = integer_pow[y=2] v
                cl:f64[] = integer_pow[y=2] w
                cm:f64[] = mul 4.0 v
                cn:f64[] = mul cm w
                co:f64[] = convert_element_type[
                  new_dtype=float64
                  weak_type=False
                ] cn
                cp:bool[] = lt co 2.220446049250313e-15
                cq:f64[] = pjit[name=_where jaxpr=_where1] cp 1 cn
                cr:f64[] = sub 1.0 cl
                cs:f64[] = sub cr ck
                ct:f64[] = mul 2.0 v
                cu:f64[] = mul ct w
                cv:f64[] = add cs cu
                cw:f64[] = div cv cq
                cx:f64[] = max 0.0 cw
                cy:f64[] = convert_element_type[
                  new_dtype=float64
                  weak_type=False
                ] w
                cz:bool[] = lt cy 2.220446049250313e-15
                da:f64[] = sub v w
                db:f64[] = pjit[name=_where jaxpr=_where1] cz 1 w
                dc:f64[] = mul 2.0 db
                dd:f64[] = div da dc
                de:f64[] = mul 0.5 ch
                df:f64[] = convert_element_type[
                  new_dtype=float64
                  weak_type=False
                ] de
                dg:f64[20] = mul df q
                dh:f64[] = mul 0.5 ch
                di:f64[] = convert_element_type[
                  new_dtype=float64
                  weak_type=False
                ] dh
                dj:f64[20] = add dg di
                _:f64[20] = cos dj
                dk:f64[20] = sin dg
                dl:f64[20] = integer_pow[y=2] dk
                dm:f64[] = sub 1.0 cl
                dn:f64[] = convert_element_type[
                  new_dtype=float64
                  weak_type=False
                ] cx
                do:f64[20] = sub dn dl
                dp:f64[] = convert_element_type[
                  new_dtype=float64
                  weak_type=False
                ] cq
                dq:f64[20] = mul dp do
                dr:f64[20] = pjit[
                  name=_where
                  jaxpr={ lambda ; ds:bool[] dt:f64[] du:f64[20]. let
                      dv:f64[] = convert_element_type[
                        new_dtype=float64
                        weak_type=False
                      ] dt
                      dw:f64[20] = broadcast_in_dim[
                        broadcast_dimensions=()
                        shape=(20,)
                      ] dv
                      dx:f64[20] = select_n ds du dw
                    in (dx,) }
                ] cp dm dq
                dy:f64[20] = max 0.0 dr
                dz:f64[20] = pow dy 1.5
                ea:f64[20] = integer_pow[y=2] dl
                eb:f64[20] = sub dl ea
                ec:f64[] = convert_element_type[
                  new_dtype=float64
                  weak_type=False
                ] dd
                ed:f64[20] = add ec dl
                ee:f64[20] = pjit[
                  name=_where
                  jaxpr={ lambda ; ef:bool[] eg:i64[] eh:f64[20]. let
                      ei:f64[] = convert_element_type[
                        new_dtype=float64
                        weak_type=False
                      ] eg
                      ej:f64[20] = broadcast_in_dim[
                        broadcast_dimensions=()
                        shape=(20,)
                      ] ei
                      ek:f64[20] = select_n ef eh ej
                    in (ek,) }
                ] cz 0 ed
                el:f64[20] = mul 2.0 dl
                _:f64[20] = sub 1.0 el
                em:f64[] = mul 2.0 w
                en:f64[] = integer_pow[y=-1] em
                eo:f64[] = convert_element_type[
                  new_dtype=float64
                  weak_type=False
                ] en
                _:f64[20] = mul eo dz
                ep:f64[] = mul 2.0 w
                eq:f64[] = integer_pow[y=2] ep
                er:f64[] = mul 2.0 eq
                es:f64[20] = pow eb 1.0
                et:f64[] = convert_element_type[
                  new_dtype=float64
                  weak_type=False
                ] er
                eu:f64[20] = mul et es
                ev:f64[20] = pow ee 0.0
                ew:f64[20] = mul eu ev
                ex:f64[1,20] = broadcast_in_dim[
                  broadcast_dimensions=(1,)
                  shape=(1, 20)
                ] ew
                ey:f64[1,20] = mul ex r
                ez:f64[1] = reduce_sum[axes=(1,)] ey
                fa:f64[] = convert_element_type[
                  new_dtype=float64
                  weak_type=False
                ] de
                fb:f64[1] = mul fa ez
                fc:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0.0
                fd:i64[1] = device_put[devices=[None] srcs=[None]] s
                fe:bool[1] = lt fd 0
                ff:i64[1] = add fd 1
                fg:i64[1] = select_n fe fd ff
                fh:i32[1] = convert_element_type[
                  new_dtype=int32
                  weak_type=False
                ] fg
                fi:i32[1,1] = broadcast_in_dim[
                  broadcast_dimensions=(0,)
                  shape=(1, 1)
                ] fh
                fj:f64[1] = scatter[
                  dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))
                  indices_are_sorted=False
                  mode=GatherScatterMode.FILL_OR_DROP
                  unique_indices=False
                  update_consts=()
                  update_jaxpr=None
                ] fc fi fb
                fk:f64[] = sub 1.5707963267948966 cj
                fl:f64[] = cos fk
                fm:f64[] = sin fk
                fn:f64[] = mul 2.0 fk
                fo:f64[] = add fn 3.141592653589793
                _:f64[] = mul -2.0 fl
                fp:f64[] = integer_pow[y=1] fl
                fq:f64[] = mul 2.0 fp
                fr:f64[] = integer_pow[y=1] fm
                fs:f64[] = mul fq fr
                ft:f64[] = mul 1.0 fo
                fu:f64[] = add fs ft
                fv:f64[] = div fu 2.0
                fw:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] fv
                fx:f64[1] = convert_element_type[
                  new_dtype=float64
                  weak_type=False
                ] fw
                fy:f64[1] = sub fx fj
              in (fy,) }
          ] n o
        in (p,) }
    ] k l
  in (m,) }
limb_dark
let _where = { lambda ; a:bool[] b:f64[] c:f64[]. let
    d:f64[] = select_n a c b
  in (d,) } in
let _where1 = { lambda ; e:bool[] f:f64[] g:f64[]. let
    h:f64[] = select_n e g f
  in (h,) } in
{ lambda ; i:f64[] j:f64[]. let
    k:f64[1] = pjit[
      name=impl
      jaxpr={ lambda ; l:f64[] m:f64[]. let
          n:f64[] = abs l
          o:f64[] = abs m
          p:f64[] = integer_pow[y=2] n
          q:f64[] = sub o 1.0
          r:f64[] = add o 1.0
          s:f64[] = mul q r
          t:f64[] = sub 1.0 o
          u:f64[] = abs t
          v:bool[] = gt n u
          w:f64[] = add 1.0 o
          x:bool[] = lt n w
          y:bool[] = convert_element_type[new_dtype=bool weak_type=False] v
          z:bool[] = convert_element_type[new_dtype=bool weak_type=False] x
          ba:bool[] = and y z
          bb:f64[] = pjit[name=_where jaxpr=_where] ba n 1.0
          bc:f64[] = min o bb
          bd:f64[] = max o bb
          be:f64[] = min bd 1.0
          bf:f64[] = max bd 1.0
          bg:f64[] = min bc be
          bh:f64[] = max bc be
          bi:f64[] = add bh bf
          bj:f64[] = add bg bi
          bk:f64[] = sub bg bh
          bl:f64[] = sub bf bk
          bm:f64[] = mul bj bl
          bn:f64[] = sub bg bh
          bo:f64[] = add bf bn
          bp:f64[] = mul bm bo
          bq:f64[] = sub bh bf
          br:f64[] = add bg bq
          bs:f64[] = mul bp br
          bt:f64[] = max 0.0 bs
          bu:f64[] = custom_jvp_call[
            call_jaxpr={ lambda ; bv:f64[]. let bw:f64[] = sqrt bv in (bw,) }
            jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x176fe28e0>
            num_consts=0
            symbolic_zeros=False
          ] bt
          bx:f64[] = pjit[name=_where jaxpr=_where] ba bu 0.0
          by:f64[] = add p s
          bz:f64[] = atan2 bx by
          ca:f64[] = sub p s
          cb:f64[] = atan2 bx ca
          cc:f64[] = add 1.0 o
          cd:bool[] = ge n cc
          ce:f64[] = add 1.0 n
          cf:bool[] = le ce o
          cg:bool[] = convert_element_type[new_dtype=bool weak_type=False] cd
          ch:bool[] = convert_element_type[new_dtype=bool weak_type=False] cf
          ci:bool[] = or cg ch
          cj:f64[] = pjit[name=_where jaxpr=_where] ci 1.0 n
          ck:f64[] = integer_pow[y=2] cj
          cl:f64[] = integer_pow[y=2] o
          cm:f64[] = add cj o
          cn:f64[] = add 1.0 cm
          co:f64[] = sub 1.0 cm
          cp:f64[] = mul cn co
          cq:f64[] = mul 0.5 cl
          cr:f64[] = mul 2.0 ck
          cs:f64[] = add cl cr
          ct:f64[] = mul cq cs
          cu:f64[] = sub 1.0 cl
          cv:f64[] = mul 3.141592653589793 cu
          cw:f64[] = mul 2.0 cv
          cx:f64[] = sub ct 0.5
          cy:f64[] = mul 12.566370614359172 cx
          cz:f64[] = add cw cy
          da:f64[] = mul cl bz
          db:f64[] = add cb da
          dc:f64[] = mul bx 0.5
          dd:f64[] = sub db dc
          de:f64[] = sub 3.141592653589793 dd
          df:f64[] = mul 2.0 de
          dg:f64[] = sub 3.141592653589793 cb
          dh:f64[] = neg dg
          di:f64[] = mul 2.0 ct
          dj:f64[] = mul di bz
          dk:f64[] = add dh dj
          dl:f64[] = mul 0.25 bx
          dm:f64[] = mul 5.0 cl
          dn:f64[] = add 1.0 dm
          do:f64[] = add dn ck
          dp:f64[] = mul dl do
          dq:f64[] = sub dk dp
          dr:f64[] = mul 2.0 dq
          ds:f64[] = add df dr
          dt:f64[] = mul 4.0 cj
          du:f64[] = mul dt o
          dv:f64[] = add cp du
          dw:bool[] = gt dv du
          dx:f64[] = pjit[name=_where jaxpr=_where1] dw cv de
          dy:f64[] = pjit[name=_where jaxpr=_where1] dw cz ds
          dz:f64[] = pjit[name=_where jaxpr=_where1] cd 3.141592653589793 dx
          ea:f64[] = pjit[name=_where jaxpr=_where1] cf 0.0 dz
          _:f64[] = pjit[name=_where jaxpr=_where] ci 0.0 dy
          eb:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] ea
        in (eb,) }
    ] i j
  in (k,) }

They seem to start off the same, but in the middle I'm seeing a lot more device_puts and scatters in the starry version, so I'm going to see if I can track those down!

dfm avatar Aug 17 '24 14:08 dfm

Oh - I think I know what it is! We should probably just merge the two implementations: we should use the closed form solution for s0 and s2, and then keep the numerical solutions for the others. I still think we might be able to optimize the implementation in starry for the numerical solutions too though. Where are those scatters coming from?!

dfm avatar Aug 17 '24 14:08 dfm

See https://github.com/exoplanet-dev/jaxoplanet/pull/204

dfm avatar Aug 17 '24 14:08 dfm

The matrices involved in the light curve computation from Agol 2019 (polynomial limb-darkening) and Luger 2019 (more general) are inherently of different sizes, even if Luger matrices are reduced to the minimal case of a limb-darkened surface (see https://github.com/exoplanet-dev/jaxoplanet/pull/204#issuecomment-2323046728).

So I suspect we will never really have the same performances between the two. But maybe I'm wrong.

One thing I did to bridge the gap is to compute the change of basis matrix from Agol's Green's basis to Luger's polynomial basis, so that the smaller solution vector from Agol can be used directly in the limb-darkened case of a starry surface. Combined with excluding the non-diagonal values in Luger's matrices, performances are much better. But again, not sure if we can push it further.

With these changes, processing times for single values of b and r are equal! But I still don't get why the vmapped version of starry acts differently (although having now closer performances to the limbdark one).

Benchmark of the new version

vmapped

import jax

jax.config.update("jax_enable_x64", True)
from jaxoplanet.experimental.starry.surface import Surface
from jaxoplanet.experimental.starry.light_curves import surface_light_curve
import numpy as np

r = 1.
u = (1.0, 1.0)
b = np.linspace(0, 1 + r, 1000)
order = 20

surface = Surface(u=u)
function = jax.jit(jax.vmap(lambda b: surface_light_curve(surface, r, z=10.0, y=b, order=order)))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))

from jaxoplanet.core.limb_dark import light_curve

function = jax.jit(jax.vmap(lambda b: light_curve(u, b, r, order=order)))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))
207 μs ± 10.8 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
132 μs ± 6.56 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

This time difference is the worst we will get, as the processing time gets increasingly more comparable for increasing values of order and increased degree of polynomial limb-darkening law.

Single b

import jax
from jaxoplanet.experimental.starry.surface import Surface
from jaxoplanet.experimental.starry.light_curves import surface_light_curve
import numpy as np

r = 0.1
u = (0.1, 0.2)
b = 0.1
order = 20

surface = Surface(u=u)
function = jax.jit(lambda b: surface_light_curve(surface, r, z=10.0, y=b, order=order))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))

from jaxoplanet.core.limb_dark import light_curve

function = jax.jit(lambda b: light_curve(u, b, r, order=order))
jax.block_until_ready(function(b))
%timeit function(jax.block_until_ready(function(b)))
7.27 μs ± 68.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
7.42 μs ± 80.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

This time, the evaluation for a single b is equal, and the fact that we don't get that on the vmapped version is still a mystery...

I think these changes contain the idea of #204, since we are working with the minimal size of starry matrices possible (without the zeros). Looking forward to hearing your ideas on this.

lgrcia avatar Sep 03 '24 15:09 lgrcia

I will close this PR for now as we decided that both implementations should remain separated.

lgrcia avatar Oct 08 '24 20:10 lgrcia