DTPP
DTPP copied to clipboard
Questions about the ego regularization
Hello Zhiyu, What is the meaning of ego regularization, which is the output of ego_traj_decoder:
agents_trajecotries = torch.stack(agents_trajecotries, dim=2)
scores, weights = self.scorer(ego_traj_inputs, encoding[:, 0], agents_trajecotries, current_states, timesteps)
ego_traj_regularization = self.ego_traj_decoder(encoding[:, 0])
return agents_trajecotries, scores, ego_traj_regularization, weights
and in calc loss function it decrease the ego_gt :
def calc_loss(neighbors, ego, ego_regularization, scores, weights, ego_gt, neighbors_gt, neighbors_valid):
mask = torch.ne(ego.sum(-1), 0)
neighbors = neighbors[:, 0] * neighbors_valid
cmp_loss = F.smooth_l1_loss(neighbors, neighbors_gt, reduction='none') # beta == 1.0 ?
cmp_loss = cmp_loss * mask[:, 0, None, :, None]
cmp_loss = cmp_loss.sum() / mask[:, 0].sum()
regularization_loss = F.smooth_l1_loss(ego_regularization.view(ego_gt.shape), ego_gt, reduction='none')
regularization_loss = regularization_loss * mask[:, 0, :, None]
regularization_loss = regularization_loss.sum() / mask[:, 0].sum()
label = torch.zeros(scores.shape[0], dtype=torch.long).to(scores.device) # why zeros?
irl_loss = F.cross_entropy(scores, label)
weights_regularization = torch.square(weights).mean()
loss = cmp_loss + irl_loss + 0.1 * regularization_loss + 0.01 * weights_regularization
return loss
it is the ego planned trajectory decoded from model? but I think the ego planned trajectory is the "second_stage trjectory" from the tree. so what is the meaning of ego_traj_regularization from ego_traj_decoder? thanks