mobilehand icon indicating copy to clipboard operation
mobilehand copied to clipboard

The loss of hand axis angle pose will make the effect worse

Open youngstu opened this issue 3 years ago • 2 comments

I reproduced the hand training module and found that the loss of hand axis angle pose may make the effect worse. The data verification is correct. After the loss of axis angle is added, the hand often turns forward and backward.

import torch
import torch.nn as nn

class ManoLoss:
    def __init__(
            self,
            lambda_pose=100.0,
            lambda_shape=100.0,
            lambda_joint3d=1.0,
            lambda_kp2d=1.0,
    ):
        self.lambda_pose = lambda_pose
        self.lambda_shape = lambda_shape
        self.lambda_joint3d = lambda_joint3d
        self.lambda_kp2d = lambda_kp2d

        self.criterion_pose = nn.MSELoss().cuda()
        self.criterion_shape = nn.MSELoss().cuda()
        self.criterion_joint3d = nn.MSELoss().cuda()
        self.criterion_kp2d = nn.MSELoss().cuda()

    def compute_loss(self, preds, targs, infos):

        inp_res = infos['inp_res']
        root_id = infos['root_id']
        batch_size = infos['batch_size']
        flag = targs['flag_3d']
        batch_3d_size = flag.sum()

        flag = flag.bool()

        total_loss = torch.Tensor([0]).cuda()
        mano_losses = {}

        gt_pose = targs['pose']
        gt_shape = targs['shape'].float()
        gt_kp2d = targs['kp2d'].float()
        gt_joint3d = targs['joint'] * 1000.0
        gt_joint3d = gt_joint3d - gt_joint3d[:, root_id:root_id+1, :]
   
        for idx, pred in enumerate(preds):

            pred_pose = pred['pose']
            pred_shape = pred['shape']
            pred_kp2d = pred['kp2d']
            pred_joint3d = pred['joint']
            pred_joint3d = pred_joint3d - pred_joint3d[:, root_id:root_id + 1, :]

            total_loss = torch.Tensor([0]).cuda()
            if self.lambda_pose:
                pose_loss = self.criterion_pose(pred_pose, gt_pose) * self.lambda_pose
                mano_losses['pose_%d' % idx] = pose_loss
                total_loss += pose_loss

            if self.lambda_shape:
                shape_loss = self.criterion_pose(pred_shape, gt_shape) * self.lambda_shape
                #shape_loss = self.criterion_pose(pred_shape, torch.zeros_like(pred_shape)) * self.lambda_shape
                mano_losses['shape_%d' % idx] = shape_loss
                total_loss += shape_loss

            if self.lambda_joint3d:
                joint3d_loss = self.criterion_pose(pred_joint3d, gt_joint3d) * self.lambda_joint3d
                mano_losses['joint3d_%d' % idx] = joint3d_loss
                total_loss += joint3d_loss

            if self.lambda_kp2d:
                kp2d_loss = self.criterion_pose(pred_kp2d, gt_kp2d) * self.lambda_kp2d
                mano_losses['kp2d_%d' % idx] = kp2d_loss
                total_loss += kp2d_loss

        mano_losses["total"] = total_loss

        return total_loss, mano_losses, batch_3d_size

loginfo: (1000/1018) d: 0.03s | b: 0.31s | s: 72.8770745 | p: 66.2028805 | j: 255.2186516 | k: 137.7673337 | t: 406.3150635 | loginfo: (1001/1018) d: 0.03s | b: 0.31s | s: 72.8693185 | p: 66.2040951 | j: 255.2534850 | k: 137.7815070 | t: 574.5736694 | loginfo: (1002/1018) d: 0.03s | b: 0.31s | s: 72.8742214 | p: 66.2062452 | j: 255.1808178 | k: 137.7479424 | t: 432.7313232 | loginfo: (1003/1018) d: 0.03s | b: 0.31s | s: 72.8717776 | p: 66.2096950 | j: 255.2022860 | k: 137.7690050 | t: 575.6766357 | loginfo: (1004/1018) d: 0.03s | b: 0.31s | s: 72.8674182 | p: 66.2144529 | j: 255.2273901 | k: 137.7747007 | t: 563.3758545 | loginfo: (1005/1018) d: 0.03s | b: 0.31s | s: 72.8567293 | p: 66.2034177 | j: 255.2001193 | k: 137.7657703 | t: 473.8689270 | loginfo: (1006/1018) d: 0.03s | b: 0.31s | s: 72.8619979 | p: 66.2114522 | j: 255.1318979 | k: 137.7335525 | t: 444.3671875 | loginfo: (1007/1018) d: 0.03s | b: 0.31s | s: 72.8579450 | p: 66.2035864 | j: 255.2208746 | k: 137.7956344 | t: 672.0527344 | loginfo: (1008/1018) d: 0.03s | b: 0.31s | s: 72.8569219 | p: 66.2097032 | j: 255.2702999 | k: 137.8076349 | t: 599.1296997 | loginfo: (1009/1018) d: 0.03s | b: 0.31s | s: 72.8681490 | p: 66.2060013 | j: 255.2751612 | k: 137.7991216 | t: 536.0526733 | loginfo: (1010/1018) d: 0.03s | b: 0.31s | s: 72.8743189 | p: 66.2180742 | j: 255.2236392 | k: 137.7840679 | t: 483.3320618 | loginfo: (1011/1018) d: 0.03s | b: 0.31s | s: 72.8830146 | p: 66.2242202 | j: 255.2460789 | k: 137.8080345 | t: 594.0219727 | loginfo: (1012/1018) d: 0.03s | b: 0.31s | s: 72.8847830 | p: 66.2241230 | j: 255.2646785 | k: 137.7912989 | t: 535.7387695 | loginfo: (1013/1018) d: 0.03s | b: 0.31s | s: 72.8771316 | p: 66.2204524 | j: 255.2239141 | k: 137.7670858 | t: 454.8735657 | loginfo: (1014/1018) d: 0.03s | b: 0.31s | s: 72.8792146 | p: 66.2185865 | j: 255.1972219 | k: 137.7473550 | t: 485.2358704 | loginfo: (1015/1018) d: 0.03s | b: 0.31s | s: 72.8804192 | p: 66.2231351 | j: 255.2193977 | k: 137.7601394 | t: 573.3665161 | loginfo: (1016/1018) d: 0.03s | b: 0.31s | s: 72.8631759 | p: 66.2301331 | j: 255.1524250 | k: 137.7305885 | t: 423.6058655 | loginfo: (1017/1018) d: 0.03s | b: 0.31s | s: 72.8613670 | p: 66.2278808 | j: 255.2338836 | k: 137.7320333 | t: 612.1587524 | loginfo: (1018/1018) d: 0.03s | b: 0.31s | s: 72.8622015 | p: 66.2263912 | j: 255.2558645 | k: 137.7465916 | t: 605.0795898 |

youngstu avatar Jun 17 '21 03:06 youngstu

could I get your reproduced code for study? Thanks a lot! @youngstu

Rookienovice avatar Dec 13 '21 14:12 Rookienovice

你好 能分享下你复现的训练代码吗 感谢~

lvZic avatar Aug 03 '22 01:08 lvZic