MemTorch icon indicating copy to clipboard operation
MemTorch copied to clipboard

Memtorch with CUDA support cannot be runned on CPU

Open YBRua opened this issue 3 years ago • 4 comments

I followed the README and installed MemTorch with CUDA support using the pip package manager.

pip install memtorch # Supports CUDA and normal operation

The README file said that this version should support both CUDA and normal operations. So I assumed it could run on both my GPU and my CPU.

It did work fine with CUDA. However, it raised an error when I tried to patch a torch model and evaluate it on my CPU. A simplified but reproducible version of my code looked like this

import copy
import torch
import torch.nn as nn

import memtorch
from memtorch.mn.Module import patch_model
from memtorch.map.Input import naive_scale
from memtorch.map.Parameter import naive_map


class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 10)

    def forward(self, x: torch.Tensor):
        return self.fc1(x)


# seed for reproducibility
torch.manual_seed(1337)

# expects the evaluation to run on CPU
device = torch.device('cpu')

# create a model and its input, both on CPU
model = MyNet().to(device)
net_input = torch.ones((10, 10)).to(device)  # B, N

# patching the model with memtorch
reference_memristor = memtorch.bh.memristor.VTEAM
reference_memristor_params = {'time_series_resolution': 1e-10}
patched_model = patch_model(
    copy.deepcopy(model),
    memristor_model=reference_memristor,
    memristor_model_params=reference_memristor_params,
    module_parameters_to_patch=[nn.Linear],
    mapping_routine=naive_map,
    transistor=True,
    programming_routine=None,
    tile_shape=(128, 128),
    max_input_voltage=0.3,
    scaling_routine=naive_scale,
    ADC_resolution=8,
    ADC_overflow_rate=0.,
    quant_method='linear')
patched_model.tune_()

# memtorch_output would be on CUDA
memtorch_output = patched_model(net_input)
print(memtorch_output.device)  # cuda:0
print(memtorch_output.max())  # tensor(1.0530, device='cuda:0')

# error if it is involved in an operation on CPU
print(memtorch_output + net_input)  # ERROR here

After checking the source code I assume the error is caused by this line

# line 94 in memtorch/mn/Linear.py
self.device = torch.device("cpu" if "cpu" in memtorch.__version__ else "cuda")

It determines the device of the mn.Linear module by checking whether the current version of memtorch supports cuda or not, regardless of which device the original nn.Linear module is on.

I am not sure if this is intended. I also tried changing this line into self.device = torch.device('cpu') and forcing the patched model to run on device('cpu'). The error would be suppressed, but the evaluation results on CPU varied greatly from the results on CUDA (Sorry I am unable to provide a small enough reproducer for this, I'm using an MLP on the MNIST dataset. The patched model using CUDA have a test acc of 0.8849, but it only have a test acc of 0.7907 if it is forced to run on CPU using the aforementioned way).

I am using PyTorch 1.9.0+cu102 and Memtorch 1.1.6.

YBRua avatar May 26 '22 14:05 YBRua

Mmetorch is now not actively developed. I am also facing a bug where changing the number of finite conductance states does not affect anything.

matifali avatar Jun 08 '22 09:06 matifali

Hi @YBRua and @matifali,

My apologies for my late reply- I'll look into this later this week and implement a fix for this behavior. Will a static method which can be used to change the "operation mode" between CPU/GPU, i.e., memtorch.set_device('cpu') suffice for your use cases?

Kind Regards,

Corey.

coreylammie avatar Jun 15 '22 06:06 coreylammie

Hi @coreylammie ,

Thanks for your reply! And, yes, I think the static method would fix my problem. It would of great help. Thanks in advance!

YBRua avatar Jun 25 '22 04:06 YBRua

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

stale[bot] avatar Aug 31 '22 01:08 stale[bot]