twm icon indicating copy to clipboard operation
twm copied to clipboard

DynamicsLoss does not use `z_pred_loss` while training

Open famishedrover opened this issue 1 year ago • 0 comments

When compute_dynamics_loss is called in world_model.py, it is fed in an argument called preds returned from dyn_model.predict(). For z prediction, preds have keys z_dist and z_hat_probs and there is no key z.

When computing the dynamics loss, specifically the loss for z, the code checks for whether z is in the list of keys preds which it never is (because the keys are z_dist and z_hat_probs). For this reason the z_pred_loss is never computed & used for training.

A fix should be made here, something like :

if 'z_dist' in preds: ...

famishedrover avatar Jun 29 '23 15:06 famishedrover