dfdx icon indicating copy to clipboard operation
dfdx copied to clipboard

If param is not used in loss calculation, `Optimizer::update()` causes runtime panic

Open chelsea0x3b opened this issue 3 years ago • 11 comments

Contrived example:

fn main() {
    let mut model: (Linear<5, 5>, Linear<5, 4>) = Default::default();
    let y = model.1.forward(Tensor1D::zeros().traced());
    let gradients = y.square().mean().backward();
    let mut opt: Sgd = Default::default();
    opt.update(&mut model, gradients);
}

chelsea0x3b avatar Jul 11 '22 16:07 chelsea0x3b

Slightly worse is that you can pass a model that wasn't even used in gradients calculation into Sgd::update.

fn main() {
    let mut model: (Linear<5, 5>, Linear<5, 4>) = Default::default();
    let y = model.forward(Tensor1D::zeros().traced());
    let gradients = y.square().mean().backward();
    let mut opt: Sgd = Default::default();
    let mut model2: (Linear<5, 5>, Linear<5, 4>) = Default::default();
    // or even: `let mut model2: Linear<5, 5> = Default::default();`
    // or even: `let mut model2: ReLU = Default::default();`
    opt.update(&mut model2, gradients);
}

chelsea0x3b avatar Jul 11 '22 16:07 chelsea0x3b

Another example with SplitInto:

fn main() {
    let mut model: SplitInto<(Linear<5, 4>, Linear<5, 2>)> = Default::default();
    let (_a, b) = model.forward(Tensor1D::zeros().trace());
    let loss = b.square().mean(); // NOTE: a is not used
    let gradients = loss.backwards();
    let mut opt: Sgd = Default::default();
    opt.update(&mut model, gradients);
}

chelsea0x3b avatar Jul 11 '22 16:07 chelsea0x3b

Idea:

  1. Make OwnedTape generic over Model
  2. Make Gradients also generic over model

e.g. backward would look like:

pub fn backward<M, T: Tensor<Dtype = f32, Tape = OwnedTape<M>>>(t: T) -> Gradients<M> { ... }

Then you could have Optimizer:

pub trait Optimizer<M: CanUpdateWithGradients> {
    fn update(&mut self, module: &mut M, gradients: Gradients<M>);
}

chelsea0x3b avatar Jul 12 '22 13:07 chelsea0x3b

Making Gradients generic over M would end up requiring making OwnedTape generic over M. Making tape generic over model pollutes everything. Not sure its worth it?

Perhaps Gradients -> GeneralGradients, and then add Gradients as struct Gradients<M>(GeneralGradients)?

chelsea0x3b avatar Jul 12 '22 13:07 chelsea0x3b

In a theory each operation should actually change the type of the Tape. e.g. ReLU would change type from Tape<M> to Tape<(M, ReLU)>. Then backwards would take Tape<M> and return Gradients<M>. This would enable us to compare the ops the tape has recorded vs the model that's using the tape.

Then Optimizer could be:

pub trait Optimizer<M: CanUpdateWithGradients> {
    fn update<Ops>(&mut self, module: &mut M, gradients: Gradients<Ops>)
        where (M::Ops, Ops): SameType;
}

chelsea0x3b avatar Jul 12 '22 14:07 chelsea0x3b

The root of what we really want to know for a given Gradients object is: "Does Gradients have something for each of my model parameters?"

The SameType check of Module::Ops and Ops gets at this in a very restrictive way. The Gradients object will have gradients for many other things aside from Module parameters, and we really want to query just for the model parameters. We don't care what other parameters exist.

chelsea0x3b avatar Jul 12 '22 14:07 chelsea0x3b

Another fuzzy direction: Make the tape take ownership of the parameter, and then backward could reconstruct the model somehow? Something to take advantage of ownership system. Instead of trying to verify that the gradients can be used with the model, producing the gradients should require the model so they are linked somehow

Normal use case

fn main() {
    type Model = (Linear<5, 5>, Linear<5, 4>);
    let mut model: Model = Default::default();
    let mut opt: Sgd<Model> = Default::default();

    let y = model.forward(Tensor1D::zeros().traced());
    let gradients = y.square().mean().backward();

    // would fail because model.forward() takes ownership of model
    // println!("{:?}", model);

    model = opt.update(gradients);

Catching unused param error:

fn main() {
    type Model = (Linear<5, 5>, Linear<5, 4>);
    let model: Model = Default::default();

    // only use half of the model during forward
    let y = model.1.forward(Tensor1D::zeros().traced());
    let gradients = y.square().mean().backward();

    let mut opt: Sgd<Model> = Default::default();

    // NOTE: this should .update() fail, because gradients only captured half of the model
    let updated_model = opt.update(gradients);
}

chelsea0x3b avatar Jul 12 '22 15:07 chelsea0x3b

FWIW pytorch silently just updates only the parameters that were involved in the computation. First need to decide what dfdx should do... fail or continue? if fail, runtime or compile time?

chelsea0x3b avatar Jul 12 '22 22:07 chelsea0x3b

I think runtime failing is a reasonable thing to do here (and it seems very difficult to catch at compile at this point).

To do this better than panic!() in Gradients from unwrap, GradientProvider should return Result, and CanUpdateWithGrads should return a Result as well. While unwinding up the module update stack, a location string could be updated (e.g. location = "0.linear.weight" if the weight tensor isn't present). Then all the optimizers could do an .expect on the result for a useful error message.

chelsea0x3b avatar Jul 14 '22 15:07 chelsea0x3b

Interesting, although PyTorch itself silently updates only the parameters involved in the computation, I know frameworks on top of it that produce a warning about unused parameters in your model.

Returning a Result feels like a good solution since I agree this seems tricky to do in Compile time.

vikigenius avatar Jul 21 '22 16:07 vikigenius

@vikigenius thanks for that feedback! Yeah and I think rust people will be very comfortable with results. Will probably be merging #107 before next release, happy for feedback there as well if you want

chelsea0x3b avatar Jul 21 '22 21:07 chelsea0x3b