kfac-jax
kfac-jax copied to clipboard
Using K-FAC with physics-based losses
Hey,
Thank you for the implementation.
From the guide, I saw that I have to register loss functions to be able to use K-FAC. For my specific case, the loss function is a FEM simulation on the outputs of the network along with some other functions (postprocessing, filtering etc).
Will it be possible to use K-FAC?