pykan
pykan copied to clipboard
improve the compatibility with non-CUDA environments
Many thanks to the author for proposing this amazing KAN. I modified the KANLayer.py file for better compatibility in non-CUDA environments. BTW, it is my first pull request. I would appreciate it if you could accept this request. @KindXiaoming
Before:
line 126 in KANLayer.py
self.scale_base = torch.nn.Parameter(torch.FloatTensor(scale_base)).requires_grad_(sb_trainable)
After:
if torch.cuda.is_available(): self.scale_base = torch.nn.Parameter(torch.FloatTensor(scale_base).cuda()).requires_grad_(sb_trainable) else: self.scale_base = torch.nn.Parameter(torch.FloatTensor(scale_base)).requires_grad_(sb_trainable)
I do agree with @Jim137. BTW, if you already set device in the calling function there's no need to do so. If you want it to be transparent, even if I don't agree with this approach, you should modify EVERY method that regards cuda vs cpu.
Was gonna make a pull request about this and then saw this - 100% agree!