mobilehand
mobilehand copied to clipboard
The loss of hand axis angle pose will make the effect worse
I reproduced the hand training module and found that the loss of hand axis angle pose may make the effect worse. The data verification is correct. After the loss of axis angle is added, the hand often turns forward and backward.
import torch
import torch.nn as nn
class ManoLoss:
def __init__(
self,
lambda_pose=100.0,
lambda_shape=100.0,
lambda_joint3d=1.0,
lambda_kp2d=1.0,
):
self.lambda_pose = lambda_pose
self.lambda_shape = lambda_shape
self.lambda_joint3d = lambda_joint3d
self.lambda_kp2d = lambda_kp2d
self.criterion_pose = nn.MSELoss().cuda()
self.criterion_shape = nn.MSELoss().cuda()
self.criterion_joint3d = nn.MSELoss().cuda()
self.criterion_kp2d = nn.MSELoss().cuda()
def compute_loss(self, preds, targs, infos):
inp_res = infos['inp_res']
root_id = infos['root_id']
batch_size = infos['batch_size']
flag = targs['flag_3d']
batch_3d_size = flag.sum()
flag = flag.bool()
total_loss = torch.Tensor([0]).cuda()
mano_losses = {}
gt_pose = targs['pose']
gt_shape = targs['shape'].float()
gt_kp2d = targs['kp2d'].float()
gt_joint3d = targs['joint'] * 1000.0
gt_joint3d = gt_joint3d - gt_joint3d[:, root_id:root_id+1, :]
for idx, pred in enumerate(preds):
pred_pose = pred['pose']
pred_shape = pred['shape']
pred_kp2d = pred['kp2d']
pred_joint3d = pred['joint']
pred_joint3d = pred_joint3d - pred_joint3d[:, root_id:root_id + 1, :]
total_loss = torch.Tensor([0]).cuda()
if self.lambda_pose:
pose_loss = self.criterion_pose(pred_pose, gt_pose) * self.lambda_pose
mano_losses['pose_%d' % idx] = pose_loss
total_loss += pose_loss
if self.lambda_shape:
shape_loss = self.criterion_pose(pred_shape, gt_shape) * self.lambda_shape
#shape_loss = self.criterion_pose(pred_shape, torch.zeros_like(pred_shape)) * self.lambda_shape
mano_losses['shape_%d' % idx] = shape_loss
total_loss += shape_loss
if self.lambda_joint3d:
joint3d_loss = self.criterion_pose(pred_joint3d, gt_joint3d) * self.lambda_joint3d
mano_losses['joint3d_%d' % idx] = joint3d_loss
total_loss += joint3d_loss
if self.lambda_kp2d:
kp2d_loss = self.criterion_pose(pred_kp2d, gt_kp2d) * self.lambda_kp2d
mano_losses['kp2d_%d' % idx] = kp2d_loss
total_loss += kp2d_loss
mano_losses["total"] = total_loss
return total_loss, mano_losses, batch_3d_size
loginfo: (1000/1018) d: 0.03s | b: 0.31s | s: 72.8770745 | p: 66.2028805 | j: 255.2186516 | k: 137.7673337 | t: 406.3150635 | loginfo: (1001/1018) d: 0.03s | b: 0.31s | s: 72.8693185 | p: 66.2040951 | j: 255.2534850 | k: 137.7815070 | t: 574.5736694 | loginfo: (1002/1018) d: 0.03s | b: 0.31s | s: 72.8742214 | p: 66.2062452 | j: 255.1808178 | k: 137.7479424 | t: 432.7313232 | loginfo: (1003/1018) d: 0.03s | b: 0.31s | s: 72.8717776 | p: 66.2096950 | j: 255.2022860 | k: 137.7690050 | t: 575.6766357 | loginfo: (1004/1018) d: 0.03s | b: 0.31s | s: 72.8674182 | p: 66.2144529 | j: 255.2273901 | k: 137.7747007 | t: 563.3758545 | loginfo: (1005/1018) d: 0.03s | b: 0.31s | s: 72.8567293 | p: 66.2034177 | j: 255.2001193 | k: 137.7657703 | t: 473.8689270 | loginfo: (1006/1018) d: 0.03s | b: 0.31s | s: 72.8619979 | p: 66.2114522 | j: 255.1318979 | k: 137.7335525 | t: 444.3671875 | loginfo: (1007/1018) d: 0.03s | b: 0.31s | s: 72.8579450 | p: 66.2035864 | j: 255.2208746 | k: 137.7956344 | t: 672.0527344 | loginfo: (1008/1018) d: 0.03s | b: 0.31s | s: 72.8569219 | p: 66.2097032 | j: 255.2702999 | k: 137.8076349 | t: 599.1296997 | loginfo: (1009/1018) d: 0.03s | b: 0.31s | s: 72.8681490 | p: 66.2060013 | j: 255.2751612 | k: 137.7991216 | t: 536.0526733 | loginfo: (1010/1018) d: 0.03s | b: 0.31s | s: 72.8743189 | p: 66.2180742 | j: 255.2236392 | k: 137.7840679 | t: 483.3320618 | loginfo: (1011/1018) d: 0.03s | b: 0.31s | s: 72.8830146 | p: 66.2242202 | j: 255.2460789 | k: 137.8080345 | t: 594.0219727 | loginfo: (1012/1018) d: 0.03s | b: 0.31s | s: 72.8847830 | p: 66.2241230 | j: 255.2646785 | k: 137.7912989 | t: 535.7387695 | loginfo: (1013/1018) d: 0.03s | b: 0.31s | s: 72.8771316 | p: 66.2204524 | j: 255.2239141 | k: 137.7670858 | t: 454.8735657 | loginfo: (1014/1018) d: 0.03s | b: 0.31s | s: 72.8792146 | p: 66.2185865 | j: 255.1972219 | k: 137.7473550 | t: 485.2358704 | loginfo: (1015/1018) d: 0.03s | b: 0.31s | s: 72.8804192 | p: 66.2231351 | j: 255.2193977 | k: 137.7601394 | t: 573.3665161 | loginfo: (1016/1018) d: 0.03s | b: 0.31s | s: 72.8631759 | p: 66.2301331 | j: 255.1524250 | k: 137.7305885 | t: 423.6058655 | loginfo: (1017/1018) d: 0.03s | b: 0.31s | s: 72.8613670 | p: 66.2278808 | j: 255.2338836 | k: 137.7320333 | t: 612.1587524 | loginfo: (1018/1018) d: 0.03s | b: 0.31s | s: 72.8622015 | p: 66.2263912 | j: 255.2558645 | k: 137.7465916 | t: 605.0795898 |
could I get your reproduced code for study? Thanks a lot! @youngstu
你好 能分享下你复现的训练代码吗 感谢~