snntorch
snntorch copied to clipboard
"Expected all tensors to be on the same device" after loading and moving model
- snntorch version: 0.6.4
- Python version: 3.10.6
- Operating System: Ubuntu 22.04
Description
I trained a model, saved the model (using dill to pickle), loaded the model, moved the model from torch.device('cuda') to torch.device('cpu') and got the error: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
. I poked around and found that it's mem
in Leaky that is causing the problem when_base_state_function_hidden()
is called. This doesn't feel like it is working as expected.
What I Did
import dill # Needed to allow saving
import snntorch as snn
import torch
import torch.nn as nn
from snntorch import surrogate, utils
spike_grad = surrogate.fast_sigmoid() # surrogate gradient
# Define Network
class Net(nn.Module):
def __init__(self):
super().__init__()
# initialize layers
self.model = torch.nn.Sequential(
nn.Linear(10,10),
snn.Leaky(beta=0.9, init_hidden=True, output=True, reset_mechanism='zero', spike_grad=spike_grad)
)
def forward(self, x):
spike_recording = [] # record spikes over time
utils.reset(self.model) # reset/initialize hidden states for all neurons
# for module in self.model:
# if type(module) is snn.Leaky:
# module.init_leaky()
# module.mem = module.mem.to(device=torch.device('cpu'))
for step in range(100): # loop over time
spike, state = self.model(x[...,step]) # one time step of forward-pass
spike_recording.append(spike) # record spikes in list
return torch.sum(torch.stack(spike_recording), dim=0)
#Set up running on GPU
device = torch.device('cuda')
net = Net().to(device)
input_example = torch.rand((10,10,100)).to(device)
output = net(input_example)
print(output)
torch.save(net, 'example', pickle_module=dill) # Use dill for pickling
net = torch.load('example', pickle_module=dill)
device = torch.device('cpu')
net.eval()
net = net.to(device) #Should put everything "on the CPU"
input_example = input_example.to(device)
output= net(input_example) # This causes an error because the membrane voltage of the leak is on the GPU
print(output)
Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
File "/home/user/git/RIID_PyTorch/.venv/lib/python3.10/site-packages/snntorch/_neurons/leaky.py", line 227, in _base_state_function_hidden
base_fn = self.beta.clamp(0, 1) * self.mem + input_
File "/home/user/git/RIID_PyTorch/.venv/lib/python3.10/site-packages/snntorch/_neurons/leaky.py", line 238, in _build_state_function_hidden
state_fn = self._base_state_function_hidden(input_)
File "/home/user/git/RIID_PyTorch/.venv/lib/python3.10/site-packages/snntorch/_neurons/leaky.py", line 194, in forward
self.mem = self._build_state_function_hidden(input_)
File "/home/user/git/RIID_PyTorch/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/user/git/RIID_PyTorch/.venv/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
File "/home/user/git/RIID_PyTorch/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/user/git/RIID_PyTorch/bug_minimal_example.py", line 28, in forward
spike, state = self.model(x[...,step]) # one time step of forward-pass
File "/home/user/git/RIID_PyTorch/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/user/git/RIID_PyTorch/bug_minimal_example.py", line 53, in <module>
output= net(input_example) # This causes an error because the membrane voltage of the leak is on the GPU
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main (Current frame)
return _run_code(code, main_globals, None,
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
Thanks for catching this issue; it seems like the initialized state is missing a device check for each forward-pass. Will try to push a fix for this within the next week or so.
In the meantime, I managed to bypass this by calling utils.reset(net)
just before the forward-pass.
Are you able to test if this works?
Any updates, @ej159 ?
I've tried peppering the script with utils.reset(net) (there's one in there already too) and toggling between them but still get the same error. The commented out code in the provided example is the only workaround that I've found.
Ok, thank you for your feedback!