DESC icon indicating copy to clipboard operation
DESC copied to clipboard

Use "scan trick" to reduce compilation times for loops over functions

Open f0uriest opened this issue 2 years ago • 2 comments

Something like https://github.com/patrick-kidger/equinox/blob/6846a2b86ff9f049ea332da543ae0b73b04daf52/equinox/internal/_misc.py#L49

f0uriest avatar Nov 17 '23 21:11 f0uriest

I use it here

Can you give me some examples where we can we use this in DESC?

rahulgaur104 avatar Jan 22 '24 18:01 rahulgaur104

That looks more like a regular scan of a function over an array. The scan trick I was referring to combines scan + select to apply a sequence of functions to a given input like f1(f2(f3(x))) without having to unroll the loop which is slow in jax. The main place I was thinking of using it is the logic in desc.compute.compute where we loop over dependencies to compute stuff, but I'm not sure it will work there since the outputs can be difference sizes.

f0uriest avatar Jan 23 '24 01:01 f0uriest

tested in #1147 and too slow/wasteful

dpanici avatar Aug 20 '24 19:08 dpanici