NeuralPDE.jl icon indicating copy to clipboard operation
NeuralPDE.jl copied to clipboard

Neural Tangent Kernel Adaptive Loss

Open zoemcc opened this issue 2 years ago • 6 comments

Implementing the Neural Tangent Kernel adaptive loss method proposed in the "When and Why PINNs Fail to Train: A Neural Tangent Kernel Perspective" paper by Sifan Wang, Xinling Yu, Paris Perdikaris. There is a github repo that should guide implementation.

The algorithm is Algorithm 1 in the paper. The algorithm should be implemented as a concrete subtype of AbstractAdaptiveLoss so that it fits within our pre-existing code gen infrastructure in the discretize_inner_functions function. The definition of the K kernels is in Lemma 3.1.

(i.e.)

struct NeuralTangentKernelAdaptiveLoss <: AbstractAdaptiveLoss
...
end

This paper is slightly harder than some of the other adaptive loss methods to implement in our system, but not that much harder. The definition of K requires a selection of points from each domain, and so that could be generated via a grid or stochastic or quasi-random. The implementation provided on their github seems to have used a Grid strategy, but I don't see why that must always be the case for this quantity (it seems arbitrary). Thus, most of the difficulty in implementation is just figuring out the best way to have this own type maintain its own samples that are possibly different from the main PDE domain samplers for the toplevel PINN, and then calculating the kernel quantities using those points and the internal generated PDE functions. There is a ton of interesting theory in this paper but the implementing the algorithm mainly relies on understanding how to compute the K kernels.

zoemcc avatar Mar 21 '22 23:03 zoemcc

I have been going through the code written in the `discretize_inner_functions' and there seems to be some additional code for each adaptive loss method specifically of the form,


if adaloss isa GradientScaleAdaptiveLoss
        weight_change_inertia = discretization.adaptive_loss.weight_change_inertia
        function run_loss_gradients_adaptive_loss(θ)
            if iteration[1] % adaloss.reweight_every == 0
                # the paper assumes a single pde loss function, so here we grab the maximum of the maximums of each pde loss function
                pde_grads_maxes = [maximum(abs.(Zygote.gradient(pde_loss_function, θ)[1])) for pde_loss_function in pde_loss_functions]
                pde_grads_max = maximum(pde_grads_maxes)
                bc_grads_mean = [mean(abs.(Zygote.gradient(bc_loss_function, θ)[1])) for bc_loss_function in bc_loss_functions]

                nonzero_divisor_eps =  adaloss_T isa Float64 ? Float64(1e-11) : convert(adaloss_T, 1e-7)
                bc_loss_weights_proposed = pde_grads_max ./ (bc_grads_mean .+ nonzero_divisor_eps)
                adaloss.bc_loss_weights .= weight_change_inertia .* adaloss.bc_loss_weights .+ (1 .- weight_change_inertia) .* bc_loss_weights_proposed
                logscalar(logger, pde_grads_max, "adaptive_loss/pde_grad_max", iteration[1])
                logvector(logger, pde_grads_maxes, "adaptive_loss/pde_grad_maxes", iteration[1])
                logvector(logger, bc_grads_mean, "adaptive_loss/bc_grad_mean", iteration[1])
                logvector(logger, adaloss.bc_loss_weights, "adaptive_loss/bc_loss_weights", iteration[1])
            end
            nothing
        end

For the implementation of the NTK loss, is it only the struct that needs to be defined or some of how the loss is to be propogated as well? Or in fact am I mistaken and that code is purely for logging and the calculations are done purely in the struct through the lines of the form, new(convert(Int64, reweight_every), convert(PDE_OPT, pde_max_optimiser), convert(BC_OPT, bc_max_optimiser), vectorify(pde_loss_weights, T), vectorify(bc_loss_weights, T), vectorify(additional_loss_weights, T))?

I am going to use the adaptive losses that are already defined as a starting point, but it only seems possible to me to define another chain/function to compute the NTK loss that is called upon by the struct. Let me know if I am approaching it from the right perspective.

Parvfect avatar Mar 29 '22 16:03 Parvfect

The authors had used Jacobians for predicting the weights ( Fig. attached ) but as you can see from the algo - as we need only the trace of K matrix; a simple self-dot product over the gradients of loss functions should suffice ( in fact it would be even faster ) Matrix K Neural Tangent Kernel Loss algo Con someone pls concur if my understanding is correct or there is something that I am missing. Thanks ! @zoemcc @ChrisRackauckas

sphinx-tech avatar Apr 02 '23 06:04 sphinx-tech

Yeah, we only need the diagonal and the formula for the diagonal is much easier since it's just square of the derivative. There's no need to compute the whole matrix.

ChrisRackauckas avatar Apr 02 '23 11:04 ChrisRackauckas

Hi, I want to work on this but making sure if this is open or if anyone is working towards it.

ayushinav avatar Feb 18 '24 08:02 ayushinav

No one is working on it.

ChrisRackauckas avatar Feb 18 '24 08:02 ChrisRackauckas

There are couple of PRs related to it - https://github.com/SciML/NeuralPDE.jl/pull/673 & https://github.com/SciML/NeuralPDE.jl/pull/506. Not sure if it is still relevant.

sathvikbhagavan avatar Feb 18 '24 09:02 sathvikbhagavan