PiPPy icon indicating copy to clipboard operation
PiPPy copied to clipboard

Fix backward implementation and remove setting grad in forward()

Open H-Huang opened this issue 1 year ago • 0 comments

""" fwd_outputs all forced to have 'requires_grad=True' -- why? what's our design here? freqs_cis could be passed from stage0 to stage1 but is an input value from dataloader and should not require grads.

backward isn't implement correctly afaiu. see rewrite in whc/pp branch, fixes (a) .grad() wont set .grad on W's but .backward will; (b) funny issues with requires-gradness on inputs, disappeared after i simplified """

Backward implementation is incorrect as it does not update the gradients of the parameters. Furthermore in forward we should not explicitly set require_grads to true for the inputs.

H-Huang avatar Feb 14 '24 19:02 H-Huang