PiPPy
PiPPy copied to clipboard
Fix backward implementation and remove setting grad in forward()
""" fwd_outputs all forced to have 'requires_grad=True' -- why? what's our design here? freqs_cis could be passed from stage0 to stage1 but is an input value from dataloader and should not require grads.
backward isn't implement correctly afaiu. see rewrite in whc/pp branch, fixes (a) .grad() wont set .grad on W's but .backward will; (b) funny issues with requires-gradness on inputs, disappeared after i simplified """
Backward implementation is incorrect as it does not update the gradients of the parameters. Furthermore in forward we should not explicitly set require_grads to true for the inputs.