DESC
DESC copied to clipboard
Dp/jacobian batched vmap
This works well, this is LMN18 equilibrium solve with 1.5 oversampled grid and maxiter=10
memory trace vs time on GPU, where we get 4x memory decrease with negligible runtime increase:
Currently uses netket
package for its chunked_vmap
function, we don't want this as a dependency though so will try to implement a lighter weight version ourselves.
TODO
- [ ] re-implement without relying on
netket
- [ ] change chunk_size to a better default value (something like 100 would be fine, maybe can dynamically choose based off of size of
dim_x
) - [ ] Add
chunk_size
argument to every Objective class - [ ] Add
"chunked"
as a deriv_mode toDerivative
(or, just as an argument toDerivative
to be used when"batched"
is used) - [ ] add to singular integral calculation as well
Resolves #826