DESC
DESC copied to clipboard
A few small improvements to reduce compile time/memory
- integer and bool arrays are no longer treated as static values when doing JAX transformations (jit, etc). This avoids having them baked into the jitted function, which seems to avoid some of the "constant folding" issues
- The
Objective.compute
method now uses the lower leveldesc.compute.utils._compute
directly instead of the outer one, which avoids some largely unnecessary runtime checks in cases where we know we're passing in the correct transforms etc.