spikingjelly
spikingjelly copied to clipboard
使用自定义激活函数后,把ann转换成snn失败
pytorch版本:2.0.0+cpu
#自定义算子
class CRELU(nn.Module): #
@staticmethod
def forward(x):
x = torch.clamp(x, min=0, max=1)
return x
def backward(self, grad_output):
input, = self.saved_tensors
grad_input = grad_output.clone()
if (input < 0) or (input > 1):
grad_input = 0
return grad_input
class Qt(nn.Module):
def forward(self, x):
return torch.where(x < 0, torch.zeros_like(x),torch.where(x > 1, torch.ones_like(x), x))
def backward(self, grad_output):
return grad_output
#模型结构
import torch
import torch.nn as nn
class CNNforArousal(nn.Module):
def __init__(self):
super(CNNforArousal, self).__init__()
self.conv1 = nn.Conv1d(32, 16, kernel_size=5,stride=1)
self.bn1 = nn.BatchNorm1d(16)
self.crelu1 = CRELU()
self.qt1 = Qt()
self.conv2 = nn.Conv1d(16, 32, kernel_size=5,stride=1)
self.bn2 = nn.BatchNorm1d(32)
self.crelu2 = CRELU()
self.qt2 = Qt()
self.maxpool1 = nn.MaxPool1d(kernel_size=2,stride=2)
self.conv3 = nn.Conv1d(in_channels=32,out_channels=32,kernel_size=5,stride=1)
self.bn3 = nn.BatchNorm1d(32)
self.crelu3 = CRELU()
self.qt3 = Qt()
self.maxpool2 = nn.MaxPool1d(kernel_size=2,stride=2)
self.fc1 = nn.Linear(2432, 128)
self.qt3 = Qt()
self.fc2 = nn.Linear(128, 2)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.crelu1(x)
x = self.qt1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.crelu2(x)
x = self.qt2(x)
x = self.maxpool1(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.crelu3(x)
x = self.qt3(x)
x = self.maxpool2(x)
# print(x.view(x.shape[0], -1))
x = self.fc1(x.flatten())
x = self.qt3(x)
x = self.fc2(x)
x = torch.unsqueeze(x,0)
return x
#转换代码
from spikingjelly.activation_based import ann2snn
model_converter = ann2snn.Converter(mode='max', dataloader=train_dataloader)
snn_model = model_converter(model)
#转换代码的输出结果
snn_model
#输出以下结果
CNNforArousal(
(conv1): Conv1d(32, 16, kernel_size=(5,), stride=(1,))
(conv2): Conv1d(16, 32, kernel_size=(5,), stride=(1,))
(maxpool1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv3): Conv1d(32, 32, kernel_size=(5,), stride=(1,))
(maxpool2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(fc1): Linear(in_features=2432, out_features=128, bias=True)
(fc2): Linear(in_features=128, out_features=2, bias=True)
)
你好,现在ANN2SNN转换的代码逻辑是,把ReLU替换为snn tailor。 所以如果您自定义了一个激活函数,就需要手动改写代码了。 把/remote-home/lvliuzh/temp/spikingjelly/spikingjelly/activation_based/ann2snn/converter.py里涉及 nn.ReLU 的类型判断,都改为您自定义的激活函数的类型 @populustremble