twm
twm copied to clipboard
DynamicsLoss does not use `z_pred_loss` while training
When compute_dynamics_loss is called in world_model.py, it is fed in an argument called preds
returned from dyn_model.predict(). For z
prediction, preds have keys z_dist
and z_hat_probs
and there is no key z
.
When computing the dynamics loss, specifically the loss for z
, the code checks for whether z
is in the list of keys preds
which it never is (because the keys are z_dist
and z_hat_probs
). For this reason the z_pred_loss
is never computed & used for training.
A fix should be made here, something like :
if 'z_dist' in preds: ...