recurrent-interface-network-pytorch icon indicating copy to clipboard operation
recurrent-interface-network-pytorch copied to clipboard

Inconsistent interpretation of model output between self-conditioning step and prediction step

Open jsternabsci opened this issue 1 year ago • 1 comments

I think there is a bug in the interface self-conditioning in rin_pytorch.py.

The model output is interpreted differently during the self-conditioning stage compared to the prediction stage.

Currently we have (pseudocode):

self_cond = x0_to_target_modification(model_output)           # Treat model prediction as x0 convert it to x0, eps, or v
pred = self.model(..., self_cond, ...)                        # Self-condition on prediction for x0, eps, or v and predict x0, eps, or v
target = x0_to_target_modification(x0)
loss = F.mse_loss(pred, target)

In the current implementation, the interface prediction is interpreted as x0 during self-conditioning, but as the target (x0, eps, or v) at the prediction step.

I see two ways that we could do interface self-conditioning that would be consistent.

We could either:

  • make the model always predict x0. Then we would have
self_cond = model_output                                      # Where self_cond is a prediction for x0
pred = self.model(..., self_cond, ...)                        # Self-condition on x0 prediction and predict x0
pred = x0_to_target_modification(pred).                       # Convert x0 prediction into prediction for x0, eps, or v
target = x0_to_target_modification(x0)
loss = F.mse_loss(pred, target)

or

  • make the model always predict what it is intended to predict (x0, eps, or v). Then we would have
self_cond = model_output                                      # Where self_cond is a prediction for the target (x0, eps, or v)
pred = self.model(..., self_cond, ...)                        # Self-condition on prediction for x0, eps, or v and predict x0, eps, or v
target = x0_to_target_modification(x0)
loss = F.mse_loss(pred, target)

In contrast to the current implementation, in my two proposals, the interpretation of the interface prediction is the same between the self-conditioning step and the prediction step. Would you agree that there is inconsistency here and that either of these proposals solves it?

Here is the current code:

if random() < self.train_prob_self_cond:
    with torch.no_grad():
        model_output, self_latents = self.model(noised_img, times, return_latents = True)
        self_latents = self_latents.detach()

        if self.objective == 'x0':
            self_cond = model_output

        elif self.objective == 'eps':
            self_cond = safe_div(noised_img - sigma * model_output, alpha)

        elif self.objective == 'v':
            self_cond = alpha * noised_img - sigma * model_output

        self_cond.clamp_(-1., 1.)
        self_cond = self_cond.detach()

# predict and take gradient step

pred = self.model(noised_img, times, self_cond, self_latents)

...

loss = F.mse_loss(pred, target, reduction = 'none')
loss = reduce(loss, 'b ... -> b', 'mean')

jsternabsci avatar Jul 28 '23 06:07 jsternabsci

@jsternabsci i have only seen self conditioning done with the predicted x0 (correct me if i'm wrong)

there's nothing inconsistent as during inference, self conditioning is also done with the predicted x0

however, i get what you are saying. i could offer both options, if you are running experiments to see which way is better?

lucidrains avatar Aug 01 '23 18:08 lucidrains