pytorch-tutorial
pytorch-tutorial copied to clipboard
I have problems with similar methods using Neural transform style
I used the same method as your Neural Style Transform style to generate the Distilled Network, I used the pytorch's hook ' register_forward_hook' to get the feature map of the middle layer output. The loss calculation uses the feature map.Here is my loss function: `class StudentLoss(nn.Module): def init(self): super(StudentLoss, self).init()
def forward(self,tfmaps_list, sfmaps_list, open_conv):
tfmaps_list_new = []
## get specific feature maps
for idl in range(0, len(sfmaps_list)):
sfmaps = sfmaps_list[idl]
tfmaps = tfmaps_list[idl]
tfmaps_new = torch.Tensor(sfmaps_list[idl].shape[0], sfmaps_list[idl].shape[1],
sfmaps_list[idl].shape[2], sfmaps_list[idl].shape[3]).zero_()
for idx in range(0, sfmaps.shape[0]):
for idy in range(0, sfmaps.shape[1]):
tfmaps_new[idx][idy] = tfmaps[idx][open_conv[2][idy]]
tfmaps_list_new.append(tfmaps_new.detach())
fmaps_losses = None
for idl in range(0, len(sfmaps_list)):
sfmaps = sfmaps_list[idl]
tfmaps = tfmaps_list_new[idl]
fmaps_loss = F.mse_loss(sfmaps.view(sfmaps.shape[0], -1), tfmaps.view(tfmaps.shape[0], -1).cuda())[0]
if fmaps_losses is None:
fmaps_losses = fmaps_loss
else:
fmaps_losses += fmaps_loss
return fmaps_losses
And the part of my main function is here:
# Init model #
tmodel = MDNet(opts['model_path'], K)
smodel = SMDNet(opts['init_model_path'], K)
if opts['use_gpu']:
tmodel = tmodel.cuda()
smodel = smodel.cuda()
smodel.set_learnable_params(['conv'])
tmodel.set_learnable_params([])
#Set teacher and student network's hook to get feature maps #
tConv = torch.nn.ModuleList()
for lname, list in tmodel.named_children():
if lname in ['layers']:
for name, module in list.named_children():
if name in ['conv1','conv2','conv3']:
for cname, cmodule in module.named_children():
if cname == '0':
tConv.append(cmodule)
thandle_feat_conv3 = tConv[2].register_forward_hook(get_tfeature_hook) # conv2
sConv = torch.nn.ModuleList()
for lname, list in smodel.named_children():
if lname in ['layers']:
for name, module in list.named_children():
if name in ['conv1', 'conv2', 'conv3']:
for cname, cmodule in module.named_children():
if cname == '0':
sConv.append(cmodule)
shandle_feat_conv3 = sConv[2].register_forward_hook(get_sfeature_hook) # conv3
# Init criterion and optimizer #
criterion = StudentLoss()
optimizer = set_optimizer(smodel, opts['lr'])
best_loss = 9999
for i in range(opts['n_cycles']):
print "==== Start Cycle %d ====" % (i)
k_list = np.random.permutation(K)
losses = np.zeros(K)
for j, k in enumerate(k_list):
tic = time.time()
pos_regions, neg_regions = dataset[k].next()
pos_regions = Variable(pos_regions)
neg_regions = Variable(neg_regions)
if opts['use_gpu']:
pos_regions = pos_regions.cuda()
neg_regions = neg_regions.cuda()
smodel(pos_regions, k, out_layer='conv3')
smodel(neg_regions, k, out_layer='conv3')
tmodel(pos_regions, k, out_layer='conv3')
tmodel(neg_regions, k, out_layer='conv3')
# remove hook
thandle_feat_conv3.remove()
shandle_feat_conv3.remove()
loss = criterion(result_tfeature, result_sfeature, smodel.open_conv)
smodel.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(smodel.parameters(), opts['grad_clip'])
optimizer.step()
`
but I encountered an error during the loss backward:
Traceback (most recent call last): File "/home/liuchenfeng/code/py-SMDNet/pretrain/train_smdnet.py", line 181, in <module> train_smdnet() File "/home/liuchenfeng/code/py-SMDNet/pretrain/train_smdnet.py", line 155, in train_smdnet loss.backward() File "/usr/local/lib/python2.7/dist-packages/torch/tensor.py", line 93, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph) File "/usr/local/lib/python2.7/dist-packages/torch/autograd/__init__.py", line 89, in backward allow_unreachable=True) # allow_unreachable flag RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
I don't understand why I'm reporting this error. I didn't use loss.backward() twice.