AdaptiveWingLoss icon indicating copy to clipboard operation
AdaptiveWingLoss copied to clipboard

Training code implementation

Open HassanAbbas92 opened this issue 5 years ago • 8 comments

`
import matplotlib.pyplot as plt import cv2 import sys import os from PIL import Image, ImageDraw from utils.utils import fan_NME, show_landmarks, get_preds_fromhm import numpy as np from skimage import io import shutil from torch.autograd import Variable import time import copy from torch import nn import torch import math import matplotlib matplotlib.use('Agg')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class AdaptiveWingLoss(nn.Module):
    def __init__(self, omega=14, theta=0.5, epsilon=1, alpha=2.1):
        super(AdaptiveWingLoss, self).__init__()
        self.omega = omega
        self.theta = theta
        self.epsilon = epsilon
        self.alpha = alpha

    def forward(self, pred, weight_map, target):
        y = target
        y_hat = pred
        delta_y = (y - y_hat).abs()
        delta_y1 = delta_y[delta_y < self.theta]
        delta_y2 = delta_y[delta_y >= self.theta]
        y1 = y[delta_y < self.theta]
        y2 = y[delta_y >= self.theta]
        loss1 = self.omega * torch.log(1 + torch.pow(
            delta_y1 / self.omega, self.alpha - y1)) * weight_map[delta_y < self.theta]
        A = self.omega * (1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - y2))) * (self.alpha - y2) * (
            torch.pow(self.theta / self.epsilon, self.alpha - y2 - 1)) * (1 / self.epsilon)
        C = self.theta * A - self.omega * \
            torch.log(1 + torch.pow(self.theta / self.epsilon, self.alpha - y2))
        loss2 = (A * delta_y2 - C) * weight_map[delta_y >= self.theta]
        return (loss1.sum() + loss2.sum()) / (len(loss1) + len(loss2))


def train_model(model, dataloaders, dataset_sizes, use_gpu=True, epoches=5,
                save_path='./', num_landmarks=68, start_epoch=0):
    best_acc = 100
    optimizer = torch.optim.RMSprop(
        model.parameters(), lr=0.0000001, weight_decay=0)
    loss_AW = AdaptiveWingLoss()
    for epoch in range(start_epoch, epoches + start_epoch):
        running_loss = 0
        step = 0
        total_nme = 0
        total_count = 0
        fail_count = 0
        nmes = []
        # running_corrects = 0
        step_start = time.time()

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            # Iterate over data.
            # with torch.set_grad_enabled(True):
            for data in dataloaders[phase]:
                optimizer.zero_grad()
                total_runtime = 0
                run_count = 0

                step += 1
                # get the inputs
                inputs = data['image'].type(torch.FloatTensor)
                labels_heatmap = data['heatmap'].type(torch.FloatTensor)
                labels_boundary = data['boundary'].type(torch.FloatTensor)
                gt_landmarks = data['landmarks'].type(torch.FloatTensor)
                loss_weight_map = data['weight_map'].type(torch.FloatTensor)
                # wrap them in Variable
                if use_gpu:
                    inputs = inputs.to(device)
                    labels_heatmap = labels_heatmap.to(device)
                    labels_boundary = labels_boundary.to(device)
                    loss_weight_map = loss_weight_map.to(device)
                else:
                    inputs, labels_heatmap = Variable(
                        inputs), Variable(labels_heatmap)
                    labels_boundary = Variable(labels_boundary)
                labels = torch.cat((labels_heatmap, labels_boundary), 1)
                single_start = time.time()
                with torch.set_grad_enabled(phase == 'train'):
                    outputs, boundary_channels = model(inputs)
                    pred_labels = torch.cat(
                        (outputs[-1][:, :-1, :, :], boundary_channels[-1][:, :-1, :, :]), 1)
                    ###
                    loss_total = loss_AW(
                        pred_labels, loss_weight_map * 10 + 1, labels)
                    ###
                    #print("Batch Loss: {:.6f}".format(loss.item()))
                    if phase == 'train':
                        loss_total.backward()
                        optimizer.step()
                batch_nme = fan_NME(
                    outputs[-1][:, :-1, :, :].detach().cpu(), gt_landmarks, num_landmarks)
                #print("Batch NME: {:.6f}".format(batch_nme))
                # batch_nme = 0
                total_nme += batch_nme
            epoch_nme = total_nme / dataset_sizes[phase]
            step_end = time.time()
            print(phase + ' NME: {:.6f}'.format(epoch_nme))
            if phase == 'val' and epoch_nme < best_acc:
                state = {
                    'next_epoch': epoch+1,
                    'epoch_total_nme': epoch_nme,
                    'state_dict': model.state_dict(),
                    # 'scheduler' : scheduler.state_dict(),
                    'optimizer': optimizer.state_dict()
                }
                torch.save(state, save_path+'{:02d}'.format(epoch)+'.pth')
        #nme_save_path = os.path.join(save_path, 'nme_log.npy')
        #np.save(nme_save_path, np.array(nmes))
        #print('NME: {:.6f} Failure Rate: {:.6f} Total Count: {:.6f} Fail Count: {:.6f}'.format(epoch_nme, fail_count/total_count, total_count, fail_count))

    #print('Everage runtime for a single batch: {:.6f}'.format(total_runtime/run_count))
    return model

` @protossw512 code you please check if my training implementation is correct

HassanAbbas92 avatar Feb 18 '20 09:02 HassanAbbas92

@HassanAbbas92 Thanks for posting this, I was looking for something to get started with training. I found a couple pretty small issues:

  • In the Adaptive Wing Loss calculation, there is one equation that should be changed to match the paper: loss1 = self.omega * torch.log(1 + torch.pow( delta_y1 / self.omega, self.alpha - y1)) * weight_map[delta_y < self.theta] should be loss1 = self.omega * torch.log(1 + torch.pow( delta_y1 / self.epsilon, self.alpha - y1)) * weight_map[delta_y < self.theta]
  • In the main training loop, total_nme is reset at the start of the epoch, not at the start of each phase, so your train_nme will be correct but your val_nme will be off because it didn't zero out the train error.
  • Some of the hyperparameters are different than what is listed in the paper, but I think it will still train somewhat with what is present...

And thanks @protossw512 for sharing this great project.

mustangchavez avatar Feb 27 '20 21:02 mustangchavez

@mustangchavez Thanks for correction, did you start training from scratch and achieve same results in the papper on WFLW dataset?

HassanAbbas92 avatar Mar 04 '20 11:03 HassanAbbas92

Sorry, I haven't tried to reproduce the results as I am actually working on a different problem domain.

I found what seems to be another correction, I believe the correct boundary map to use is already part of the first output and not the "boundary_channels" output. So

pred_labels = torch.cat(
                        (outputs[-1][:, :-1, :, :], boundary_channels[-1][:, :-1, :, :]), 1)

should be

pred_labels = outputs[-1]

mustangchavez avatar Mar 05 '20 15:03 mustangchavez

@mustangchavez are you sure? because i thought @protossw512 talk about it in https://github.com/protossw512/AdaptiveWingLoss/issues/12

HassanAbbas92 avatar Mar 13 '20 13:03 HassanAbbas92

I'm not sure, no! Haha.

But I think that each "boundary_channels" is intended to be passed into the next HourGlass module, whereas the "outputs" contains the final boundary heatmap prediction. That's why the shape of the last tensor tensor in the "outputs" list is [batch_size, num_landmarks + 1, heatmap_size, heatmap_size]

mustangchavez avatar Mar 13 '20 16:03 mustangchavez

You are correct @mustangchavez But also i will try to visualize both of them to see the difference

HassanAbbas92 avatar Mar 15 '20 13:03 HassanAbbas92

Hello @HassanAbbas92 , Thanks for sharing training code imlementation, how is it going? I am reading the paper and just inference the result and think about training from scratch. Have u managed to reproduce the result?

vuthede avatar Jun 17 '20 14:06 vuthede

` import matplotlib.pyplot as plt import cv2 import sys import os from PIL import Image, ImageDraw from utils.utils import fan_NME, show_landmarks, get_preds_fromhm import numpy as np from skimage import io import shutil from torch.autograd import Variable import time import copy from torch import nn import torch import math import matplotlib matplotlib.use('Agg')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class AdaptiveWingLoss(nn.Module):
    def __init__(self, omega=14, theta=0.5, epsilon=1, alpha=2.1):
        super(AdaptiveWingLoss, self).__init__()
        self.omega = omega
        self.theta = theta
        self.epsilon = epsilon
        self.alpha = alpha

    def forward(self, pred, weight_map, target):
        y = target
        y_hat = pred
        delta_y = (y - y_hat).abs()
        delta_y1 = delta_y[delta_y < self.theta]
        delta_y2 = delta_y[delta_y >= self.theta]
        y1 = y[delta_y < self.theta]
        y2 = y[delta_y >= self.theta]
        loss1 = self.omega * torch.log(1 + torch.pow(
            delta_y1 / self.omega, self.alpha - y1)) * weight_map[delta_y < self.theta]
        A = self.omega * (1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - y2))) * (self.alpha - y2) * (
            torch.pow(self.theta / self.epsilon, self.alpha - y2 - 1)) * (1 / self.epsilon)
        C = self.theta * A - self.omega * \
            torch.log(1 + torch.pow(self.theta / self.epsilon, self.alpha - y2))
        loss2 = (A * delta_y2 - C) * weight_map[delta_y >= self.theta]
        return (loss1.sum() + loss2.sum()) / (len(loss1) + len(loss2))


def train_model(model, dataloaders, dataset_sizes, use_gpu=True, epoches=5,
                save_path='./', num_landmarks=68, start_epoch=0):
    best_acc = 100
    optimizer = torch.optim.RMSprop(
        model.parameters(), lr=0.0000001, weight_decay=0)
    loss_AW = AdaptiveWingLoss()
    for epoch in range(start_epoch, epoches + start_epoch):
        running_loss = 0
        step = 0
        total_nme = 0
        total_count = 0
        fail_count = 0
        nmes = []
        # running_corrects = 0
        step_start = time.time()

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            # Iterate over data.
            # with torch.set_grad_enabled(True):
            for data in dataloaders[phase]:
                optimizer.zero_grad()
                total_runtime = 0
                run_count = 0

                step += 1
                # get the inputs
                inputs = data['image'].type(torch.FloatTensor)
                labels_heatmap = data['heatmap'].type(torch.FloatTensor)
                labels_boundary = data['boundary'].type(torch.FloatTensor)
                gt_landmarks = data['landmarks'].type(torch.FloatTensor)
                loss_weight_map = data['weight_map'].type(torch.FloatTensor)
                # wrap them in Variable
                if use_gpu:
                    inputs = inputs.to(device)
                    labels_heatmap = labels_heatmap.to(device)
                    labels_boundary = labels_boundary.to(device)
                    loss_weight_map = loss_weight_map.to(device)
                else:
                    inputs, labels_heatmap = Variable(
                        inputs), Variable(labels_heatmap)
                    labels_boundary = Variable(labels_boundary)
                labels = torch.cat((labels_heatmap, labels_boundary), 1)
                single_start = time.time()
                with torch.set_grad_enabled(phase == 'train'):
                    outputs, boundary_channels = model(inputs)
                    pred_labels = torch.cat(
                        (outputs[-1][:, :-1, :, :], boundary_channels[-1][:, :-1, :, :]), 1)
                    ###
                    loss_total = loss_AW(
                        pred_labels, loss_weight_map * 10 + 1, labels)
                    ###
                    #print("Batch Loss: {:.6f}".format(loss.item()))
                    if phase == 'train':
                        loss_total.backward()
                        optimizer.step()
                batch_nme = fan_NME(
                    outputs[-1][:, :-1, :, :].detach().cpu(), gt_landmarks, num_landmarks)
                #print("Batch NME: {:.6f}".format(batch_nme))
                # batch_nme = 0
                total_nme += batch_nme
            epoch_nme = total_nme / dataset_sizes[phase]
            step_end = time.time()
            print(phase + ' NME: {:.6f}'.format(epoch_nme))
            if phase == 'val' and epoch_nme < best_acc:
                state = {
                    'next_epoch': epoch+1,
                    'epoch_total_nme': epoch_nme,
                    'state_dict': model.state_dict(),
                    # 'scheduler' : scheduler.state_dict(),
                    'optimizer': optimizer.state_dict()
                }
                torch.save(state, save_path+'{:02d}'.format(epoch)+'.pth')
        #nme_save_path = os.path.join(save_path, 'nme_log.npy')
        #np.save(nme_save_path, np.array(nmes))
        #print('NME: {:.6f} Failure Rate: {:.6f} Total Count: {:.6f} Fail Count: {:.6f}'.format(epoch_nme, fail_count/total_count, total_count, fail_count))

    #print('Everage runtime for a single batch: {:.6f}'.format(total_runtime/run_count))
    return model

` @protossw512 code you please check if my training implementation is correct

Hi. Thanks for your code and idea about the loss function. I have a question. Is there correct about the calculation method of loss1 and loss2. More specifically, * weight_map[delta_y < self.theta] is correct? Thanks. I think the torch.where is better?

switch626 avatar Jun 10 '21 12:06 switch626