recurrent-interface-network-pytorch
recurrent-interface-network-pytorch copied to clipboard
Inconsistent interpretation of model output between self-conditioning step and prediction step
I think there is a bug in the interface self-conditioning in rin_pytorch.py.
The model output is interpreted differently during the self-conditioning stage compared to the prediction stage.
Currently we have (pseudocode):
self_cond = x0_to_target_modification(model_output) # Treat model prediction as x0 convert it to x0, eps, or v
pred = self.model(..., self_cond, ...) # Self-condition on prediction for x0, eps, or v and predict x0, eps, or v
target = x0_to_target_modification(x0)
loss = F.mse_loss(pred, target)
In the current implementation, the interface prediction is interpreted as x0 during self-conditioning, but as the target (x0, eps, or v) at the prediction step.
I see two ways that we could do interface self-conditioning that would be consistent.
We could either:
- make the model always predict x0. Then we would have
self_cond = model_output # Where self_cond is a prediction for x0
pred = self.model(..., self_cond, ...) # Self-condition on x0 prediction and predict x0
pred = x0_to_target_modification(pred). # Convert x0 prediction into prediction for x0, eps, or v
target = x0_to_target_modification(x0)
loss = F.mse_loss(pred, target)
or
- make the model always predict what it is intended to predict (x0, eps, or v). Then we would have
self_cond = model_output # Where self_cond is a prediction for the target (x0, eps, or v)
pred = self.model(..., self_cond, ...) # Self-condition on prediction for x0, eps, or v and predict x0, eps, or v
target = x0_to_target_modification(x0)
loss = F.mse_loss(pred, target)
In contrast to the current implementation, in my two proposals, the interpretation of the interface prediction is the same between the self-conditioning step and the prediction step. Would you agree that there is inconsistency here and that either of these proposals solves it?
Here is the current code:
if random() < self.train_prob_self_cond:
with torch.no_grad():
model_output, self_latents = self.model(noised_img, times, return_latents = True)
self_latents = self_latents.detach()
if self.objective == 'x0':
self_cond = model_output
elif self.objective == 'eps':
self_cond = safe_div(noised_img - sigma * model_output, alpha)
elif self.objective == 'v':
self_cond = alpha * noised_img - sigma * model_output
self_cond.clamp_(-1., 1.)
self_cond = self_cond.detach()
# predict and take gradient step
pred = self.model(noised_img, times, self_cond, self_latents)
...
loss = F.mse_loss(pred, target, reduction = 'none')
loss = reduce(loss, 'b ... -> b', 'mean')
@jsternabsci i have only seen self conditioning done with the predicted x0 (correct me if i'm wrong)
there's nothing inconsistent as during inference, self conditioning is also done with the predicted x0
however, i get what you are saying. i could offer both options, if you are running experiments to see which way is better?