Dario Panici

Results 260 comments of Dario Panici
trafficstars

@dpanici make separate branch with the implementation using JAX's version, and in this PR implement the one based off of `netket`

@kianorr @YigitElma

> If you don't care about jax's native multi-GPU sharding support it should be easy to just vendor our implementation. In that case, you can just vendor our `netket/jax/_chunk_utils.py`, `netket/jax/_scanmap.py`...

Some laptop benchmarking on the ATF jacobian benchmark example ![image](https://github.com/user-attachments/assets/301ace72-481a-4a35-9694-03e87dff94d0) ![image](https://github.com/user-attachments/assets/d76a3ae5-0692-4b5f-8424-d5da13dfb9fe) The effect on compute time is more thane expected, I think I will implement something that only automatically chooses...

> > I will re-do these tests on the gpu if I can as well, this was just on cpu on mac > > Can you also add the benchmarks...

based off these scalings for memory usage of jacobian on GPU and CPU for the ATF benchmark, I will have "auto" estimate mem usage based off of dim_x * dim_f...

a conservative estimate for the actual peak memory usage over the estimated peak memory usage according to the above formula, as a function of normalized chunk_size

Putting it together, we have something like chunk_size < (device_mem / estimated_mem - b) / a * dim_x where a ~ 0.8 and b ~ 0.15, and device_mem is the...

Under the hood, `get_profiles` computes `iota`, so if you run your benchmark code you posted on an equilibrum that has a current profile assigned instead of an iota profile (ATF...

just use current Bt scale as |B| scale