c-swm
c-swm copied to clipboard
Discrepancy between the loss mentioned in the paper and GitHub
According to the paper, the negative component of the contrastive loss is the difference between the negative states (randomly sampled from embedding at timestamp t, (z_{t}~)) and the ground truth state (z_{t+1}).
However, as per the line 113 of modules.py, given no trans, you are effectively taking the difference between randomly sampled from embedding at timestamp t (z_{t}~) and z_{t} (rather than z_{t+1}).
` def contrastive_loss(self, obs, action, next_obs):
objs = self.obj_extractor(obs)
next_objs = self.obj_extractor(next_obs)
state = self.obj_encoder(objs)
next_state = self.obj_encoder(next_objs)
# Sample negative state across episodes at random
batch_size = state.size(0)
perm = np.random.permutation(batch_size)
neg_state = state[perm]
self.pos_loss = self.energy(state, action, next_state)
zeros = torch.zeros_like(self.pos_loss)
self.pos_loss = self.pos_loss.mean()
self.neg_loss = torch.max(
zeros, self.hinge - self.energy(
state, action, neg_state, no_trans=True)).mean()
loss = self.pos_loss + self.neg_loss
return loss
` Thus, I feel instead of the state as the first argument of the energy function, next_state should have been the argument. Please let me know if I am misconstruing at any point.
Thanks.
Good catch. I wonder if this would fix the issue described in Figure 4b.
Hey, do you mean this in the caption of Figure 4 - "One trajectory (in the center) strongly deviates from typical trajectories seen during training, and the model struggles to predict the correct transition." ??
Yes, exactly.
I also wander why not apply transition_model to negative state