Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"Expected all tensors to be on the same device" after loading and moving model #225

Open
ej159 opened this issue Jul 12, 2023 · 4 comments
Labels
bug Something isn't working

Comments

@ej159
Copy link

ej159 commented Jul 12, 2023

  • snntorch version: 0.6.4
  • Python version: 3.10.6
  • Operating System: Ubuntu 22.04

Description

I trained a model, saved the model (using dill to pickle), loaded the model, moved the model from torch.device('cuda') to torch.device('cpu') and got the error: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!. I poked around and found that it's mem in Leaky that is causing the problem when_base_state_function_hidden() is called. This doesn't feel like it is working as expected.

What I Did

import dill  # Needed to allow saving
import snntorch as snn
import torch
import torch.nn as nn
from snntorch import surrogate, utils

spike_grad = surrogate.fast_sigmoid() # surrogate gradient

# Define Network
class Net(nn.Module):
   def __init__(self):
        super().__init__()
        # initialize layers
        self.model = torch.nn.Sequential(
        nn.Linear(10,10),
        snn.Leaky(beta=0.9, init_hidden=True, output=True, reset_mechanism='zero', spike_grad=spike_grad)
        )

   def forward(self, x):
        spike_recording = [] # record spikes over time
        utils.reset(self.model) # reset/initialize hidden states for all neurons
        # for module in self.model:
        #     if type(module) is snn.Leaky:
        #         module.init_leaky()
        #         module.mem = module.mem.to(device=torch.device('cpu'))

        for step in range(100): # loop over time
            spike, state = self.model(x[...,step]) # one time step of forward-pass
            spike_recording.append(spike) # record spikes in list

        return torch.sum(torch.stack(spike_recording), dim=0)


#Set up running on GPU      
device = torch.device('cuda')

net = Net().to(device)

input_example = torch.rand((10,10,100)).to(device)

output = net(input_example)
print(output)

torch.save(net, 'example', pickle_module=dill) # Use dill for pickling
net = torch.load('example', pickle_module=dill)

device = torch.device('cpu')

net.eval()
net = net.to(device) #Should put everything "on the CPU"
input_example = input_example.to(device)

output= net(input_example) # This causes an error because the membrane voltage of the leak is on the GPU

print(output)
Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
  File "/home/user/git/RIID_PyTorch/.venv/lib/python3.10/site-packages/snntorch/_neurons/leaky.py", line 227, in _base_state_function_hidden
    base_fn = self.beta.clamp(0, 1) * self.mem + input_
  File "/home/user/git/RIID_PyTorch/.venv/lib/python3.10/site-packages/snntorch/_neurons/leaky.py", line 238, in _build_state_function_hidden
    state_fn = self._base_state_function_hidden(input_)
  File "/home/user/git/RIID_PyTorch/.venv/lib/python3.10/site-packages/snntorch/_neurons/leaky.py", line 194, in forward
    self.mem = self._build_state_function_hidden(input_)
  File "/home/user/git/RIID_PyTorch/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/git/RIID_PyTorch/.venv/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/home/user/git/RIID_PyTorch/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/git/RIID_PyTorch/bug_minimal_example.py", line 28, in forward
    spike, state = self.model(x[...,step]) # one time step of forward-pass
  File "/home/user/git/RIID_PyTorch/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/git/RIID_PyTorch/bug_minimal_example.py", line 53, in <module>
    output= net(input_example) # This causes an error because the membrane voltage of the leak is on the GPU
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main (Current frame)
    return _run_code(code, main_globals, None,
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
@jeshraghian
Copy link
Owner

Thanks for catching this issue; it seems like the initialized state is missing a device check for each forward-pass.
Will try to push a fix for this within the next week or so.

In the meantime, I managed to bypass this by calling utils.reset(net) just before the forward-pass.
Are you able to test if this works?

@ahenkes1
Copy link
Collaborator

Any updates, @ej159 ?

@ej159
Copy link
Author

ej159 commented Aug 14, 2023

I've tried peppering the script with utils.reset(net) (there's one in there already too) and toggling between them but still get the same error. The commented out code in the provided example is the only workaround that I've found.

@ahenkes1 ahenkes1 added the bug Something isn't working label Aug 15, 2023
@ahenkes1
Copy link
Collaborator

Ok, thank you for your feedback!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants