horde-ad icon indicating copy to clipboard operation
horde-ad copied to clipboard

Implement checkpointing

Open Mikolaj opened this issue 1 year ago • 0 comments

Try to implement checkpointing (inserting recomputation to trade-off computation vs memory use) and then automatic checkpointing, which is what pytorch/JAX users now reportedly need and can't get.

We have an old discussion starting with @tomjaguarpaw sketching an extension of the POPL paper with checkpointing https://github.com/Mikolaj/mostly-harmless/discussions/20. We also had two variants of (things related to) checkpoint implemented at some point due to a peak of popular interest, but it bit-rotted before anybody found it interesting again and before any benchmarks for it were written and was removed when horde-ad got simplified.

I wonder if in the current mode of operation where we do reverse differentiation symbolically instead on using the real inputs, the memory leaks problems posed in the discussion are gone. More generally, I wonder how checkpointing in the current mode would differ from what Tom describes and whether pytorch/JAX do checkpointing in both modes of operation.

I'd advise against implementing it again before we have an interest proven by tests and benchmarks written by the interested parties.

Mikolaj avatar Sep 02 '23 19:09 Mikolaj