pykan
pykan copied to clipboard
CUDA device for training
Hello, I'm trying to train a kan net for OCR. In the process I had to do a few tweaks for being able to use the GPU.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = 'cpu'
print(device)
print(torch.cuda.is_available())
torch.set_default_dtype(torch.float64) inside __init__.py
otherwise I get
RuntimeError: expected scalar type Double but found Float
also this:
model = KAN(width=[16*16, 128, 1], grid=3, device=device, k=3)
model.to(device)
results = model.train(dataset, opt="LBFGS", steps=20, device=device)
and all data
dataset['train_input'] = torch.from_numpy(train_images).to(device)
dataset['test_input'] = torch.from_numpy(test_images).to(device)
dataset['train_label'] = torch.from_numpy(train_labels).to(device)
dataset['test_label'] = torch.from_numpy(test_labels).to(device)
this is the entire script:
from kan import KAN, create_dataset
import matplotlib.pyplot as plt
import matplotlib
from sklearn.datasets import make_moons
import torch
import numpy as np
import os
import utilities as ut
import zipfile
matplotlib.use('TkAgg')
TARGET_WIDTH = 16
TARGET_HEIGHT = 16
TARGET_DEPTH = 1
lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','abs']
dataset = {}
#read label.txt file each row is a label
labels = ut.readLabels("ocr_mix/labels.txt")
print("labels: ", labels)
#read zip file and read each image from each subfolder named as label index
train_images = []
train_labels = []
test_images = []
test_labels = []
#insert path to zip file
found = False
archive = zipfile.ZipFile("./ocr_mix/ocr_mix.zip", 'r')
print("Reading images...")
(train_images, train_labels, test_images, test_labels) = ut.readImagesArchive(archive, len(labels), 500, 500, 20)
print("Preparing dataset...")
train_images = ut.prepareDataset(train_images, TARGET_WIDTH, TARGET_HEIGHT, TARGET_DEPTH)
test_images = ut.prepareDataset(test_images, TARGET_WIDTH, TARGET_HEIGHT, TARGET_DEPTH)
#flatten dimensions
train_images = train_images.reshape(train_images.shape[0], -1)
test_images = test_images.reshape(test_images.shape[0], -1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = 'cpu'
print(device)
print(torch.cuda.is_available())
#add second dimension
train_labels = np.expand_dims(train_labels, axis=1)
test_labels = np.expand_dims(test_labels, axis=1)
train_labels = np.asarray(train_labels)
test_labels = np.asarray(test_labels)
dataset['train_input'] = torch.from_numpy(train_images).to(device)
dataset['test_input'] = torch.from_numpy(test_images).to(device)
dataset['train_label'] = torch.from_numpy(train_labels).to(device)
dataset['test_label'] = torch.from_numpy(test_labels).to(device)
model = KAN(width=[16*16, 128, 1], grid=3, device=device, k=3)
model.to(device)
results = model.train(dataset, opt="LBFGS", steps=20, device=device)# metrics=(train_acc, test_acc))
model.save_ckpt('ckpt1')
model.plot()
After that the GPU goes on even if it is still very slow...
This should be fixed with https://github.com/KindXiaoming/pykan/pull/98 (please try)
I'll give you some news: RuntimeError: expected scalar type Double but found Float
was my fault because the normalized images were saved in dtype=float64
About the device choice for training, now CUDA works fine without model.to(device)
but forcing CPU gives an error:
Traceback (most recent call last):
File "/root/Progetti_GIT/OCR/kanocr.py", line 86, in <module>
results = model.train(dataset, opt="Adam", steps=3, save_fig_freq=0, batch=16, device=device)# metrics=(train_acc, test_acc))
File "/root/miniconda3/envs/tf/lib/python3.9/site-packages/kan/KAN.py", line 898, in train
self.update_grid_from_samples(dataset['train_input'][train_id].to(device))
File "/root/miniconda3/envs/tf/lib/python3.9/site-packages/kan/KAN.py", line 243, in update_grid_from_samples
self.forward(x)
File "/root/miniconda3/envs/tf/lib/python3.9/site-packages/kan/KAN.py", line 311, in forward
x_numerical, preacts, postacts_numerical, postspline = self.act_fun[l](x)
File "/root/miniconda3/envs/tf/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/envs/tf/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/root/miniconda3/envs/tf/lib/python3.9/site-packages/kan/KANLayer.py", line 176, in forward
y = self.scale_base.unsqueeze(dim=0) * base + self.scale_sp.unsqueeze(dim=0) * y
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
The problem is in line 126 in KANLayer.py:
if isinstance(scale_base, float):
self.scale_base = torch.nn.Parameter(torch.ones(size, device=device) * scale_base).requires_grad_(sb_trainable) # make scale trainable
else:
self.scale_base = torch.nn.Parameter(torch.FloatTensor(scale_base).cuda()).requires_grad_(sb_trainable)
It is forcing .cuda() even if you are using CPU. A temporary solution could be this:
if isinstance(scale_base, float):
self.scale_base = torch.nn.Parameter(torch.ones(size, device=device) * scale_base).requires_grad_(sb_trainable) # make scale trainable
else:
self.scale_base = torch.nn.Parameter(torch.tensor(scale_base, device=device)).requires_grad_(sb_trainable)
Could you please take the time to review https://github.com/KindXiaoming/pykan/pull/98? Because it does exactly what you mentioned and more. Probably you could have saved some migraine!
Could you please take the time to review #98? Because it does exactly what you mentioned and more. Probably you could have saved some migraine!
Wooops, sorry XD. Yes, it does work!
Side note, the CPU somehow outperforms CUDA in performance, maybe it is like that becuase I'm using a small dataset?
CPU:
train loss: 1.85e-01 | test loss: 1.90e+04 | reg: 1.41e+05 : 33%|█▎ | 1/3 [00:18<00:36, 18.19s/it]
CUDA:
train loss: 1.83e-01 | test loss: 2.18e+04 | reg: 1.42e+05 : 33%|█▎ | 1/3 [00:44<01:28, 44.34s/it]
Other than this, it works just fine.
Yeah we already noticed the slowness of CUDA for small datasets. I didn't have time to test it on big datasets yet, but I think it's actually due to a part of the code where memory is passed forcefully to CPU for performing lstgs. I will open another issue/PR as soon as I have more info and time to dedicate :)