candle icon indicating copy to clipboard operation
candle copied to clipboard

Extend `GradStore` public functionality

Open agerasev opened this issue 1 year ago • 5 comments
trafficstars

This PR allows user to merge two GradStores together (and also to create an empty one). It is helpful for collecting gradients from multiple different backward passes (e.g. with different graph but sharing some weights) before making optimizer step.

agerasev avatar Dec 26 '23 04:12 agerasev

Could you provide some minimal example of how you would want to use this? Maybe it's something that could be done via some helper functions rather than having to modify the GradStore themselves.

LaurentMazare avatar Dec 26 '23 08:12 LaurentMazare

I'm trying to implement some kind of checkpointing by splitting the whole network by segments and running forward and backward pass separately for each segment to avoid storing all activations of the whole network in the backprop graph at the same time.

let segments: Vec<Box<dyn Module>> = ...;
let mut opt = ...;

// Forward pass
let mut xs = ...;
let mut checkpoints: Vec<Tensor> = vec![];
for seg in segments.iter() {
    checkpoints.push(xs.clone());
    xs = seg.forward(&xs)?.detach()?;
}
let ys = xs;

// Backward pass
let mut grad = ys.ones_like()?;
let mut grad_store = GradStore::default();
for (seg, xs) in segments.iter().zip(checkpoints).rev() {
    let ys = seg.forward(&xs)?;
    let mut gs = ys.mul(&grad)?.backward()?;
    gs.remove(&grad);
    grad = gs.remove(&xs).unwrap().detach()?;
    grad_store.extend(gs)?;
}

// Optimizer step
opt.step(&grad_store)?;

It also could be done without this PR by storing a Vec<GradStore> instead of single grad_store and performing optimizer step for each of them. But the problem of such approach is when segments share the same weights (e.g. as in recurrent network) then we perform multiple optimizer steps on the single weight, but some optimizers are not linear (i.e. opt.step(a); opt.step(b); is not the same as opt.step(a + b);) so it would provide different result.

agerasev avatar Dec 27 '23 03:12 agerasev

Thanks for the details. Maybe in that case you could have some specific optimizer that would perform the aggregation and for which step would take multiple grad-store? The optimizer implementation is very straightforward and you can probably get it to work around existing optimizers. This way the complexity could be externalised to a specific crate for checkpointing.

LaurentMazare avatar Dec 27 '23 12:12 LaurentMazare

Maybe in that case you could have some specific optimizer that would perform the aggregation and for which step would take multiple grad-store?

I didn't consider customizing optimizer, thanks for idea. But this approach doesn't eliminate the problem when multiple GradStores contain gradients for the same variable (for example, in case of RNN) unnecessarily consuming extra memory. With GradStore::extend they can be collapsed together during backward pass.

agerasev avatar Dec 28 '23 02:12 agerasev

@LaurentMazare, what do you think about this PR? Has it a chance to be merged? Or maybe it's better to simply expose GradStore internal storage like it's done in VarMap.data()?

agerasev avatar Feb 17 '24 02:02 agerasev