andrewdipper
andrewdipper
That makes sense, I looked into the initialization of running_* but I don't think it was exactly Adam-style. I'll run some tests and get back
I looked into using an Adam style correction - it helps a bit but there is still instability. I also looked into initializing at 0, doing a running average, and...
For sure, I'll check other implementations and come up with something. Multiple strategies seems like a good way to go
All good - I got caught up in other things myself! From my tests the replication is exact now. It added another approach that is very similar to `"ema"` but...
Looks like test failures are due to an older version of jax and a recent blackJAX PR to fix an argument deprecation to jnp.clip: https://github.com/blackjax-devs/blackjax/pull/664 https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.clip.html
Also added in a bugfix for pm.sample not respecting compute_convergence_checks with numpyro/blackjax sampler
I added the fix for reducing memory from the blackjax window_adaptation. But that depends on https://github.com/blackjax-devs/blackjax/pull/674 which was just merged earlier today. So blackjax would have to be up to...
I made some more changes for my own experiments to offload on lighter hardware as follows: - Enable sampling in chunks to make memory independent of number of samples (if...
Just removed the blackjax window adaptation memory fix that is stalled by no conda jaxlib updates. It's now covered by https://github.com/pymc-devs/pymc/pull/7407. The remaining changes help with the sampling memory and...