DESC
DESC copied to clipboard
Use "scan trick" to reduce compilation times for loops over functions
Something like https://github.com/patrick-kidger/equinox/blob/6846a2b86ff9f049ea332da543ae0b73b04daf52/equinox/internal/_misc.py#L49
That looks more like a regular scan of a function over an array. The scan trick I was referring to combines scan + select to apply a sequence of functions to a given input like f1(f2(f3(x))) without having to unroll the loop which is slow in jax. The main place I was thinking of using it is the logic in desc.compute.compute where we loop over dependencies to compute stuff, but I'm not sure it will work there since the outputs can be difference sizes.
tested in #1147 and too slow/wasteful