avalanche icon indicating copy to clipboard operation
avalanche copied to clipboard

LwF gradient backpropagation

Open matkowski-voy opened this issue 1 year ago • 4 comments

Hi,

I think that in LwF plugin in the penalty computation the gradient doesn't flow correctly in some cases


            with torch.no_grad():
                if isinstance(self.prev_model, MultiTaskModule):
                    # output from previous output heads.
                    y_prev = avalanche_forward(self.prev_model, x, None)
                    # in a multitask scenario we need to compute the output
                    # from all the heads, so we need to call forward again.
                    # TODO: can we avoid this?
                    y_curr = avalanche_forward(curr_model, x, None)
                else:  # no task labels
                    y_prev = {"0": self.prev_model(x)}
                    y_curr = {"0": out}


  • it seems to be ok for else: # no task labels because out is already computed but if I would do something like y_curr = {"0": curr_model(x)} then it wouldn't work.
  • so this is actually the case for if isinstance(self.prev_model, MultiTaskModule):, where y_curr['0'].requires_grad would return False

Please let me know if there is something that I am missing here and actually the computation is correct.

Cheers, Woj

matkowski-voy avatar Sep 02 '22 07:09 matkowski-voy

you are correct, this is a bug

AntonioCarta avatar Sep 02 '22 07:09 AntonioCarta

I changed that part of code with the following one, but the final accuracy is still lower than expected.

if isinstance(self.prev_model, MultiTaskModule):
    # output from previous output heads.
    with torch.no_grad():
        y_prev = avalanche_forward(self.prev_model, x, None)
    # in a multitask scenario we need to compute the output
    # from all the heads, so we need to call forward again.
    # TODO: can we avoid this?
    y_curr = avalanche_forward(curr_model, x, None)
else:  # no task labels
    with torch.no_grad():
        y_prev = {"0": self.prev_model(x)}
        y_curr = {"0": out}

AndreaCossu avatar Sep 02 '22 07:09 AndreaCossu

Oh right, the change above applies only to MultiTaskModule, so I guess that is still a useful fix. There may be another problem for single-headed models, too.

AndreaCossu avatar Sep 02 '22 07:09 AndreaCossu

another thing is whether the self.prev_model should be .train() or .eval() because currently it is in the training mode so it would affect batch norm and dropouts for example and thus knowledge distilled will be different.

matkowski-voy avatar Sep 02 '22 09:09 matkowski-voy