kfac-jax icon indicating copy to clipboard operation
kfac-jax copied to clipboard

Second Order Optimization and Curvature Estimation with K-FAC in JAX.

Results 12 kfac-jax issues
Sort by recently updated
recently updated
newest added

Silence type errors generated by new pytype features.

In my application, I need to jointly optimize two probabilistic models. They contribute to two different terms in the final loss function. I am wondering what would be the recommended...

Correct buffer donation. Buffer donation is only valid if the shape and type of an input buffer matches an output. Buffer donation only works with positional arguments, not keyword arguments.

* Simplifying type annotations in curvature_blocks.py * Adding more general functions for Kronecker products multiplications.

Bumps [numpy](https://github.com/numpy/numpy) from 1.21 to 1.22.0. Release notes Sourced from numpy's releases. v1.22.0 NumPy 1.22.0 Release Notes NumPy 1.22.0 is a big release featuring the work of 153 contributors spread...

dependencies

Bumps [ipython](https://github.com/ipython/ipython) from 7.16.1 to 8.10.0. Release notes Sourced from ipython's releases. See https://pypi.org/project/ipython/ We do not use GitHub release anymore. Please see PyPI https://pypi.org/project/ipython/ Commits 15ea1ed release 8.10.0 560ad10...

dependencies

Added pytype None checks to accumulators.py.

Hey, Thank you for the implementation. From the guide, I saw that I have to register loss functions to be able to use K-FAC. For my specific case, the loss...