DTPP icon indicating copy to clipboard operation
DTPP copied to clipboard

Questions about the ego regularization

Open AlonZhao opened this issue 10 months ago • 0 comments

Hello Zhiyu, What is the meaning of ego regularization, which is the output of ego_traj_decoder:

    agents_trajecotries = torch.stack(agents_trajecotries, dim=2)
    scores, weights = self.scorer(ego_traj_inputs, encoding[:, 0], agents_trajecotries, current_states, timesteps)
    ego_traj_regularization = self.ego_traj_decoder(encoding[:, 0])
    return agents_trajecotries, scores, ego_traj_regularization, weights

and in calc loss function it decrease the ego_gt :

def calc_loss(neighbors, ego, ego_regularization, scores, weights, ego_gt, neighbors_gt, neighbors_valid):
    mask = torch.ne(ego.sum(-1), 0)
    neighbors = neighbors[:, 0] * neighbors_valid 
    cmp_loss = F.smooth_l1_loss(neighbors, neighbors_gt, reduction='none') # beta == 1.0 ?
    cmp_loss = cmp_loss * mask[:, 0, None, :, None]
    cmp_loss = cmp_loss.sum() / mask[:, 0].sum()
    regularization_loss = F.smooth_l1_loss(ego_regularization.view(ego_gt.shape), ego_gt, reduction='none')
    regularization_loss = regularization_loss * mask[:, 0, :, None]
    regularization_loss = regularization_loss.sum() / mask[:, 0].sum()
    label = torch.zeros(scores.shape[0], dtype=torch.long).to(scores.device) # why zeros?
    irl_loss = F.cross_entropy(scores, label)
    weights_regularization = torch.square(weights).mean()
    loss = cmp_loss + irl_loss + 0.1 * regularization_loss + 0.01 * weights_regularization
    return loss

it is the ego planned trajectory decoded from model? but I think the ego planned trajectory is the "second_stage trjectory" from the tree. so what is the meaning of ego_traj_regularization from ego_traj_decoder? thanks

AlonZhao avatar Feb 25 '25 02:02 AlonZhao