snntorch icon indicating copy to clipboard operation
snntorch copied to clipboard

"Expected all tensors to be on the same device" after loading and moving model

Open ej159 opened this issue 1 year ago • 4 comments

  • snntorch version: 0.6.4
  • Python version: 3.10.6
  • Operating System: Ubuntu 22.04

Description

I trained a model, saved the model (using dill to pickle), loaded the model, moved the model from torch.device('cuda') to torch.device('cpu') and got the error: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!. I poked around and found that it's mem in Leaky that is causing the problem when_base_state_function_hidden() is called. This doesn't feel like it is working as expected.

What I Did

import dill  # Needed to allow saving
import snntorch as snn
import torch
import torch.nn as nn
from snntorch import surrogate, utils

spike_grad = surrogate.fast_sigmoid() # surrogate gradient

# Define Network
class Net(nn.Module):
   def __init__(self):
        super().__init__()
        # initialize layers
        self.model = torch.nn.Sequential(
        nn.Linear(10,10),
        snn.Leaky(beta=0.9, init_hidden=True, output=True, reset_mechanism='zero', spike_grad=spike_grad)
        )

   def forward(self, x):
        spike_recording = [] # record spikes over time
        utils.reset(self.model) # reset/initialize hidden states for all neurons
        # for module in self.model:
        #     if type(module) is snn.Leaky:
        #         module.init_leaky()
        #         module.mem = module.mem.to(device=torch.device('cpu'))

        for step in range(100): # loop over time
            spike, state = self.model(x[...,step]) # one time step of forward-pass
            spike_recording.append(spike) # record spikes in list

        return torch.sum(torch.stack(spike_recording), dim=0)


#Set up running on GPU      
device = torch.device('cuda')

net = Net().to(device)

input_example = torch.rand((10,10,100)).to(device)

output = net(input_example)
print(output)

torch.save(net, 'example', pickle_module=dill) # Use dill for pickling
net = torch.load('example', pickle_module=dill)

device = torch.device('cpu')

net.eval()
net = net.to(device) #Should put everything "on the CPU"
input_example = input_example.to(device)

output= net(input_example) # This causes an error because the membrane voltage of the leak is on the GPU

print(output)
Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
  File "/home/user/git/RIID_PyTorch/.venv/lib/python3.10/site-packages/snntorch/_neurons/leaky.py", line 227, in _base_state_function_hidden
    base_fn = self.beta.clamp(0, 1) * self.mem + input_
  File "/home/user/git/RIID_PyTorch/.venv/lib/python3.10/site-packages/snntorch/_neurons/leaky.py", line 238, in _build_state_function_hidden
    state_fn = self._base_state_function_hidden(input_)
  File "/home/user/git/RIID_PyTorch/.venv/lib/python3.10/site-packages/snntorch/_neurons/leaky.py", line 194, in forward
    self.mem = self._build_state_function_hidden(input_)
  File "/home/user/git/RIID_PyTorch/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/git/RIID_PyTorch/.venv/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/home/user/git/RIID_PyTorch/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/git/RIID_PyTorch/bug_minimal_example.py", line 28, in forward
    spike, state = self.model(x[...,step]) # one time step of forward-pass
  File "/home/user/git/RIID_PyTorch/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/git/RIID_PyTorch/bug_minimal_example.py", line 53, in <module>
    output= net(input_example) # This causes an error because the membrane voltage of the leak is on the GPU
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main (Current frame)
    return _run_code(code, main_globals, None,
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

ej159 avatar Jul 12 '23 13:07 ej159

Thanks for catching this issue; it seems like the initialized state is missing a device check for each forward-pass. Will try to push a fix for this within the next week or so.

In the meantime, I managed to bypass this by calling utils.reset(net) just before the forward-pass. Are you able to test if this works?

jeshraghian avatar Jul 28 '23 14:07 jeshraghian

Any updates, @ej159 ?

ahenkes1 avatar Aug 14 '23 11:08 ahenkes1

I've tried peppering the script with utils.reset(net) (there's one in there already too) and toggling between them but still get the same error. The commented out code in the provided example is the only workaround that I've found.

ej159 avatar Aug 14 '23 13:08 ej159

Ok, thank you for your feedback!

ahenkes1 avatar Aug 15 '23 11:08 ahenkes1