c-swm icon indicating copy to clipboard operation
c-swm copied to clipboard

Discrepancy between the loss mentioned in the paper and GitHub

Open bhattg opened this issue 4 years ago • 4 comments

According to the paper, the negative component of the contrastive loss is the difference between the negative states (randomly sampled from embedding at timestamp t, (z_{t}~)) and the ground truth state (z_{t+1}).

However, as per the line 113 of modules.py, given no trans, you are effectively taking the difference between randomly sampled from embedding at timestamp t (z_{t}~) and z_{t} (rather than z_{t+1}).

` def contrastive_loss(self, obs, action, next_obs):

    objs = self.obj_extractor(obs)
    next_objs = self.obj_extractor(next_obs)

    state = self.obj_encoder(objs)
    next_state = self.obj_encoder(next_objs)

    # Sample negative state across episodes at random
    batch_size = state.size(0)
    perm = np.random.permutation(batch_size)
    neg_state = state[perm]

    self.pos_loss = self.energy(state, action, next_state)
    zeros = torch.zeros_like(self.pos_loss)
    
    self.pos_loss = self.pos_loss.mean()
    self.neg_loss = torch.max(
        zeros, self.hinge - self.energy(
            state, action, neg_state, no_trans=True)).mean()

    loss = self.pos_loss + self.neg_loss

    return loss

` Thus, I feel instead of the state as the first argument of the energy function, next_state should have been the argument. Please let me know if I am misconstruing at any point.

Thanks.

bhattg avatar Apr 10 '20 18:04 bhattg

Good catch. I wonder if this would fix the issue described in Figure 4b.

AugustKarlstedt avatar Apr 18 '20 05:04 AugustKarlstedt

Hey, do you mean this in the caption of Figure 4 - "One trajectory (in the center) strongly deviates from typical trajectories seen during training, and the model struggles to predict the correct transition." ??

bhattg avatar Apr 18 '20 05:04 bhattg

Yes, exactly.

AugustKarlstedt avatar Apr 18 '20 06:04 AugustKarlstedt

I also wander why not apply transition_model to negative state

BenchengY avatar Jul 16 '20 06:07 BenchengY