pykan icon indicating copy to clipboard operation
pykan copied to clipboard

torch.OutOfMemoryError: CUDA out of memory

Open dwz92 opened this issue 5 months ago • 2 comments

Hi all,

I am trying to train the KAN model on a GPU server and I am getting torch.OutOfMemoryError: CUDA out of memory .

So far, the largest model I can train on the GPU without running into this error is:

# From KAN Example 4
def train_acc():
    return torch.mean((torch.round(model(dataset['train_input'])[:,0]) == dataset['train_label'][:,0]).type(dtype))

def test_acc():
    return torch.mean((torch.round(model(dataset['test_input'])[:,0]) == dataset['test_label'][:,0]).type(dtype))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = KAN(width= [12,84,42,2], grid= 11, k=3, device=device )

model.update_grid_from_samples(torch.cat((dataset['train_input'], dataset['test_input'])))

# Training set size = 30000, test set size = 20000
dataset = {}
dataset['train_input'] = x_train.to(device)
dataset['test_input'] = x_test.to(device)
dataset['train_label'] = y_train.to(device)
dataset['test_label'] = y_test.to(device)

results = model.fit(dataset, opt="Adam", steps=3000, metrics=(train_acc, test_acc))

Any modification including increasing the width/depth, decreasing batch size, etc. will result in torch.OutOfMemoryError: CUDA out of memory . Even in this example that I have, I am allocating ~33.1Gb of memory to train the model.

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A   1447175      C   python                          33174MiB |

Any tips to decrease the amount of memory being allocated to the training process so I can train a larger model?

dwz92 avatar Sep 10 '24 19:09 dwz92