pykan icon indicating copy to clipboard operation
pykan copied to clipboard

How do I compute gradient of loss function in KAN Model

Open Chao-Chen-2004 opened this issue 6 months ago • 0 comments

Hi,

I was wondering if there is a built-in way to compute the gradient of loss function like in other NN using backward(). As well as if we can use it to update the weights for re-training. Something similar to this:

import torch

# Define a model, loss function, and optimizer
model = MyNeuralNetwork()
loss_function = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Input data and ground truth
input_data = torch.randn((batch_size, input_size))
ground_truth = torch.randn((batch_size, output_size))

# Forward pass
predictions = model(input_data)

# Loss calculation
loss = loss_function(predictions, ground_truth)

# Backward pass (compute gradients)
loss.backward()

# Update parameters
optimizer.step()

# Clear the gradients for the next iteration
optimizer.zero_grad()

Chao-Chen-2004 avatar Aug 08 '24 09:08 Chao-Chen-2004