ibot
ibot copied to clipboard
some debug about use torch.utils.checkpoint.checkpoint
When I try to use torch.utils.checkpoint.checkpoint as follows, and use apex to train the model, I found that the loss is so small as 0.4, but the normal loss is 2.x.
So, do you have some idea about this question?
for blk in self.blocks:
# x = blk(x)
x = torch.utils.checkpoint.checkpoint(blk, x)