kfac-jax
kfac-jax copied to clipboard
Second Order Optimization and Curvature Estimation with K-FAC in JAX.
Silence type errors generated by new pytype features.
In my application, I need to jointly optimize two probabilistic models. They contribute to two different terms in the final loss function. I am wondering what would be the recommended...
Correct buffer donation. Buffer donation is only valid if the shape and type of an input buffer matches an output. Buffer donation only works with positional arguments, not keyword arguments.
* Simplifying type annotations in curvature_blocks.py * Adding more general functions for Kronecker products multiplications.
Silence some pytype errors.
Convert estimation modes to enums.
Bumps [numpy](https://github.com/numpy/numpy) from 1.21 to 1.22.0. Release notes Sourced from numpy's releases. v1.22.0 NumPy 1.22.0 Release Notes NumPy 1.22.0 is a big release featuring the work of 153 contributors spread...
Bumps [ipython](https://github.com/ipython/ipython) from 7.16.1 to 8.10.0. Release notes Sourced from ipython's releases. See https://pypi.org/project/ipython/ We do not use GitHub release anymore. Please see PyPI https://pypi.org/project/ipython/ Commits 15ea1ed release 8.10.0 560ad10...
Added pytype None checks to accumulators.py.
Hey, Thank you for the implementation. From the guide, I saw that I have to register loss functions to be able to use K-FAC. For my specific case, the loss...