AdaptiveWingLoss
AdaptiveWingLoss copied to clipboard
Training code implementation
`
import matplotlib.pyplot as plt
import cv2
import sys
import os
from PIL import Image, ImageDraw
from utils.utils import fan_NME, show_landmarks, get_preds_fromhm
import numpy as np
from skimage import io
import shutil
from torch.autograd import Variable
import time
import copy
from torch import nn
import torch
import math
import matplotlib
matplotlib.use('Agg')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class AdaptiveWingLoss(nn.Module):
def __init__(self, omega=14, theta=0.5, epsilon=1, alpha=2.1):
super(AdaptiveWingLoss, self).__init__()
self.omega = omega
self.theta = theta
self.epsilon = epsilon
self.alpha = alpha
def forward(self, pred, weight_map, target):
y = target
y_hat = pred
delta_y = (y - y_hat).abs()
delta_y1 = delta_y[delta_y < self.theta]
delta_y2 = delta_y[delta_y >= self.theta]
y1 = y[delta_y < self.theta]
y2 = y[delta_y >= self.theta]
loss1 = self.omega * torch.log(1 + torch.pow(
delta_y1 / self.omega, self.alpha - y1)) * weight_map[delta_y < self.theta]
A = self.omega * (1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - y2))) * (self.alpha - y2) * (
torch.pow(self.theta / self.epsilon, self.alpha - y2 - 1)) * (1 / self.epsilon)
C = self.theta * A - self.omega * \
torch.log(1 + torch.pow(self.theta / self.epsilon, self.alpha - y2))
loss2 = (A * delta_y2 - C) * weight_map[delta_y >= self.theta]
return (loss1.sum() + loss2.sum()) / (len(loss1) + len(loss2))
def train_model(model, dataloaders, dataset_sizes, use_gpu=True, epoches=5,
save_path='./', num_landmarks=68, start_epoch=0):
best_acc = 100
optimizer = torch.optim.RMSprop(
model.parameters(), lr=0.0000001, weight_decay=0)
loss_AW = AdaptiveWingLoss()
for epoch in range(start_epoch, epoches + start_epoch):
running_loss = 0
step = 0
total_nme = 0
total_count = 0
fail_count = 0
nmes = []
# running_corrects = 0
step_start = time.time()
for phase in ['train', 'val']:
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
# Iterate over data.
# with torch.set_grad_enabled(True):
for data in dataloaders[phase]:
optimizer.zero_grad()
total_runtime = 0
run_count = 0
step += 1
# get the inputs
inputs = data['image'].type(torch.FloatTensor)
labels_heatmap = data['heatmap'].type(torch.FloatTensor)
labels_boundary = data['boundary'].type(torch.FloatTensor)
gt_landmarks = data['landmarks'].type(torch.FloatTensor)
loss_weight_map = data['weight_map'].type(torch.FloatTensor)
# wrap them in Variable
if use_gpu:
inputs = inputs.to(device)
labels_heatmap = labels_heatmap.to(device)
labels_boundary = labels_boundary.to(device)
loss_weight_map = loss_weight_map.to(device)
else:
inputs, labels_heatmap = Variable(
inputs), Variable(labels_heatmap)
labels_boundary = Variable(labels_boundary)
labels = torch.cat((labels_heatmap, labels_boundary), 1)
single_start = time.time()
with torch.set_grad_enabled(phase == 'train'):
outputs, boundary_channels = model(inputs)
pred_labels = torch.cat(
(outputs[-1][:, :-1, :, :], boundary_channels[-1][:, :-1, :, :]), 1)
###
loss_total = loss_AW(
pred_labels, loss_weight_map * 10 + 1, labels)
###
#print("Batch Loss: {:.6f}".format(loss.item()))
if phase == 'train':
loss_total.backward()
optimizer.step()
batch_nme = fan_NME(
outputs[-1][:, :-1, :, :].detach().cpu(), gt_landmarks, num_landmarks)
#print("Batch NME: {:.6f}".format(batch_nme))
# batch_nme = 0
total_nme += batch_nme
epoch_nme = total_nme / dataset_sizes[phase]
step_end = time.time()
print(phase + ' NME: {:.6f}'.format(epoch_nme))
if phase == 'val' and epoch_nme < best_acc:
state = {
'next_epoch': epoch+1,
'epoch_total_nme': epoch_nme,
'state_dict': model.state_dict(),
# 'scheduler' : scheduler.state_dict(),
'optimizer': optimizer.state_dict()
}
torch.save(state, save_path+'{:02d}'.format(epoch)+'.pth')
#nme_save_path = os.path.join(save_path, 'nme_log.npy')
#np.save(nme_save_path, np.array(nmes))
#print('NME: {:.6f} Failure Rate: {:.6f} Total Count: {:.6f} Fail Count: {:.6f}'.format(epoch_nme, fail_count/total_count, total_count, fail_count))
#print('Everage runtime for a single batch: {:.6f}'.format(total_runtime/run_count))
return model
` @protossw512 code you please check if my training implementation is correct
@HassanAbbas92 Thanks for posting this, I was looking for something to get started with training. I found a couple pretty small issues:
- In the Adaptive Wing Loss calculation, there is one equation that should be changed to match the paper:
loss1 = self.omega * torch.log(1 + torch.pow( delta_y1 / self.omega, self.alpha - y1)) * weight_map[delta_y < self.theta]
should beloss1 = self.omega * torch.log(1 + torch.pow( delta_y1 / self.epsilon, self.alpha - y1)) * weight_map[delta_y < self.theta]
- In the main training loop, total_nme is reset at the start of the epoch, not at the start of each phase, so your train_nme will be correct but your val_nme will be off because it didn't zero out the train error.
- Some of the hyperparameters are different than what is listed in the paper, but I think it will still train somewhat with what is present...
And thanks @protossw512 for sharing this great project.
@mustangchavez Thanks for correction, did you start training from scratch and achieve same results in the papper on WFLW dataset?
Sorry, I haven't tried to reproduce the results as I am actually working on a different problem domain.
I found what seems to be another correction, I believe the correct boundary map to use is already part of the first output and not the "boundary_channels" output. So
pred_labels = torch.cat(
(outputs[-1][:, :-1, :, :], boundary_channels[-1][:, :-1, :, :]), 1)
should be
pred_labels = outputs[-1]
@mustangchavez are you sure?
because i thought @protossw512 talk about it in https://github.com/protossw512/AdaptiveWingLoss/issues/12
I'm not sure, no! Haha.
But I think that each "boundary_channels" is intended to be passed into the next HourGlass module, whereas the "outputs" contains the final boundary heatmap prediction. That's why the shape of the last tensor tensor in the "outputs" list is [batch_size, num_landmarks + 1, heatmap_size, heatmap_size]
You are correct @mustangchavez But also i will try to visualize both of them to see the difference
Hello @HassanAbbas92 , Thanks for sharing training code imlementation, how is it going? I am reading the paper and just inference the result and think about training from scratch. Have u managed to reproduce the result?
` import matplotlib.pyplot as plt import cv2 import sys import os from PIL import Image, ImageDraw from utils.utils import fan_NME, show_landmarks, get_preds_fromhm import numpy as np from skimage import io import shutil from torch.autograd import Variable import time import copy from torch import nn import torch import math import matplotlib matplotlib.use('Agg')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") class AdaptiveWingLoss(nn.Module): def __init__(self, omega=14, theta=0.5, epsilon=1, alpha=2.1): super(AdaptiveWingLoss, self).__init__() self.omega = omega self.theta = theta self.epsilon = epsilon self.alpha = alpha def forward(self, pred, weight_map, target): y = target y_hat = pred delta_y = (y - y_hat).abs() delta_y1 = delta_y[delta_y < self.theta] delta_y2 = delta_y[delta_y >= self.theta] y1 = y[delta_y < self.theta] y2 = y[delta_y >= self.theta] loss1 = self.omega * torch.log(1 + torch.pow( delta_y1 / self.omega, self.alpha - y1)) * weight_map[delta_y < self.theta] A = self.omega * (1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - y2))) * (self.alpha - y2) * ( torch.pow(self.theta / self.epsilon, self.alpha - y2 - 1)) * (1 / self.epsilon) C = self.theta * A - self.omega * \ torch.log(1 + torch.pow(self.theta / self.epsilon, self.alpha - y2)) loss2 = (A * delta_y2 - C) * weight_map[delta_y >= self.theta] return (loss1.sum() + loss2.sum()) / (len(loss1) + len(loss2)) def train_model(model, dataloaders, dataset_sizes, use_gpu=True, epoches=5, save_path='./', num_landmarks=68, start_epoch=0): best_acc = 100 optimizer = torch.optim.RMSprop( model.parameters(), lr=0.0000001, weight_decay=0) loss_AW = AdaptiveWingLoss() for epoch in range(start_epoch, epoches + start_epoch): running_loss = 0 step = 0 total_nme = 0 total_count = 0 fail_count = 0 nmes = [] # running_corrects = 0 step_start = time.time() for phase in ['train', 'val']: if phase == 'train': model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode # Iterate over data. # with torch.set_grad_enabled(True): for data in dataloaders[phase]: optimizer.zero_grad() total_runtime = 0 run_count = 0 step += 1 # get the inputs inputs = data['image'].type(torch.FloatTensor) labels_heatmap = data['heatmap'].type(torch.FloatTensor) labels_boundary = data['boundary'].type(torch.FloatTensor) gt_landmarks = data['landmarks'].type(torch.FloatTensor) loss_weight_map = data['weight_map'].type(torch.FloatTensor) # wrap them in Variable if use_gpu: inputs = inputs.to(device) labels_heatmap = labels_heatmap.to(device) labels_boundary = labels_boundary.to(device) loss_weight_map = loss_weight_map.to(device) else: inputs, labels_heatmap = Variable( inputs), Variable(labels_heatmap) labels_boundary = Variable(labels_boundary) labels = torch.cat((labels_heatmap, labels_boundary), 1) single_start = time.time() with torch.set_grad_enabled(phase == 'train'): outputs, boundary_channels = model(inputs) pred_labels = torch.cat( (outputs[-1][:, :-1, :, :], boundary_channels[-1][:, :-1, :, :]), 1) ### loss_total = loss_AW( pred_labels, loss_weight_map * 10 + 1, labels) ### #print("Batch Loss: {:.6f}".format(loss.item())) if phase == 'train': loss_total.backward() optimizer.step() batch_nme = fan_NME( outputs[-1][:, :-1, :, :].detach().cpu(), gt_landmarks, num_landmarks) #print("Batch NME: {:.6f}".format(batch_nme)) # batch_nme = 0 total_nme += batch_nme epoch_nme = total_nme / dataset_sizes[phase] step_end = time.time() print(phase + ' NME: {:.6f}'.format(epoch_nme)) if phase == 'val' and epoch_nme < best_acc: state = { 'next_epoch': epoch+1, 'epoch_total_nme': epoch_nme, 'state_dict': model.state_dict(), # 'scheduler' : scheduler.state_dict(), 'optimizer': optimizer.state_dict() } torch.save(state, save_path+'{:02d}'.format(epoch)+'.pth') #nme_save_path = os.path.join(save_path, 'nme_log.npy') #np.save(nme_save_path, np.array(nmes)) #print('NME: {:.6f} Failure Rate: {:.6f} Total Count: {:.6f} Fail Count: {:.6f}'.format(epoch_nme, fail_count/total_count, total_count, fail_count)) #print('Everage runtime for a single batch: {:.6f}'.format(total_runtime/run_count)) return model
` @protossw512 code you please check if my training implementation is correct
Hi. Thanks for your code and idea about the loss function. I have a question. Is there correct about the calculation method of loss1 and loss2. More specifically, * weight_map[delta_y < self.theta] is correct? Thanks. I think the torch.where is better?