Dario Panici
Dario Panici
@dpanici make separate branch with the implementation using JAX's version, and in this PR implement the one based off of `netket`
@kianorr @YigitElma
> If you don't care about jax's native multi-GPU sharding support it should be easy to just vendor our implementation. In that case, you can just vendor our `netket/jax/_chunk_utils.py`, `netket/jax/_scanmap.py`...
Some laptop benchmarking on the ATF jacobian benchmark example   The effect on compute time is more thane expected, I think I will implement something that only automatically chooses...
> > I will re-do these tests on the gpu if I can as well, this was just on cpu on mac > > Can you also add the benchmarks...
based off these scalings for memory usage of jacobian on GPU and CPU for the ATF benchmark, I will have "auto" estimate mem usage based off of dim_x * dim_f...
a conservative estimate for the actual peak memory usage over the estimated peak memory usage according to the above formula, as a function of normalized chunk_size
Putting it together, we have something like chunk_size < (device_mem / estimated_mem - b) / a * dim_x where a ~ 0.8 and b ~ 0.15, and device_mem is the...
Under the hood, `get_profiles` computes `iota`, so if you run your benchmark code you posted on an equilibrum that has a current profile assigned instead of an iota profile (ATF...
just use current Bt scale as |B| scale