If param is not used in loss calculation, `Optimizer::update()` causes runtime panic
Contrived example:
fn main() {
let mut model: (Linear<5, 5>, Linear<5, 4>) = Default::default();
let y = model.1.forward(Tensor1D::zeros().traced());
let gradients = y.square().mean().backward();
let mut opt: Sgd = Default::default();
opt.update(&mut model, gradients);
}
Slightly worse is that you can pass a model that wasn't even used in gradients calculation into Sgd::update.
fn main() {
let mut model: (Linear<5, 5>, Linear<5, 4>) = Default::default();
let y = model.forward(Tensor1D::zeros().traced());
let gradients = y.square().mean().backward();
let mut opt: Sgd = Default::default();
let mut model2: (Linear<5, 5>, Linear<5, 4>) = Default::default();
// or even: `let mut model2: Linear<5, 5> = Default::default();`
// or even: `let mut model2: ReLU = Default::default();`
opt.update(&mut model2, gradients);
}
Another example with SplitInto:
fn main() {
let mut model: SplitInto<(Linear<5, 4>, Linear<5, 2>)> = Default::default();
let (_a, b) = model.forward(Tensor1D::zeros().trace());
let loss = b.square().mean(); // NOTE: a is not used
let gradients = loss.backwards();
let mut opt: Sgd = Default::default();
opt.update(&mut model, gradients);
}
Idea:
- Make OwnedTape generic over Model
- Make Gradients also generic over model
e.g. backward would look like:
pub fn backward<M, T: Tensor<Dtype = f32, Tape = OwnedTape<M>>>(t: T) -> Gradients<M> { ... }
Then you could have Optimizer:
pub trait Optimizer<M: CanUpdateWithGradients> {
fn update(&mut self, module: &mut M, gradients: Gradients<M>);
}
Making Gradients generic over M would end up requiring making OwnedTape generic over M. Making tape generic over model pollutes everything. Not sure its worth it?
Perhaps Gradients -> GeneralGradients, and then add Gradients as struct Gradients<M>(GeneralGradients)?
In a theory each operation should actually change the type of the Tape. e.g. ReLU would change type from Tape<M> to Tape<(M, ReLU)>. Then backwards would take Tape<M> and return Gradients<M>. This would enable us to compare the ops the tape has recorded vs the model that's using the tape.
Then Optimizer could be:
pub trait Optimizer<M: CanUpdateWithGradients> {
fn update<Ops>(&mut self, module: &mut M, gradients: Gradients<Ops>)
where (M::Ops, Ops): SameType;
}
The root of what we really want to know for a given Gradients object is: "Does Gradients have something for each of my model parameters?"
The SameType check of Module::Ops and Ops gets at this in a very restrictive way. The Gradients object will have gradients for many other things aside from Module parameters, and we really want to query just for the model parameters. We don't care what other parameters exist.
Another fuzzy direction: Make the tape take ownership of the parameter, and then backward could reconstruct the model somehow? Something to take advantage of ownership system. Instead of trying to verify that the gradients can be used with the model, producing the gradients should require the model so they are linked somehow
Normal use case
fn main() {
type Model = (Linear<5, 5>, Linear<5, 4>);
let mut model: Model = Default::default();
let mut opt: Sgd<Model> = Default::default();
let y = model.forward(Tensor1D::zeros().traced());
let gradients = y.square().mean().backward();
// would fail because model.forward() takes ownership of model
// println!("{:?}", model);
model = opt.update(gradients);
Catching unused param error:
fn main() {
type Model = (Linear<5, 5>, Linear<5, 4>);
let model: Model = Default::default();
// only use half of the model during forward
let y = model.1.forward(Tensor1D::zeros().traced());
let gradients = y.square().mean().backward();
let mut opt: Sgd<Model> = Default::default();
// NOTE: this should .update() fail, because gradients only captured half of the model
let updated_model = opt.update(gradients);
}
FWIW pytorch silently just updates only the parameters that were involved in the computation. First need to decide what dfdx should do... fail or continue? if fail, runtime or compile time?
I think runtime failing is a reasonable thing to do here (and it seems very difficult to catch at compile at this point).
To do this better than panic!() in Gradients from unwrap, GradientProvider should return Result, and CanUpdateWithGrads should return a Result as well. While unwinding up the module update stack, a location string could be updated (e.g. location = "0.linear.weight" if the weight tensor isn't present). Then all the optimizers could do an .expect on the result for a useful error message.
Interesting, although PyTorch itself silently updates only the parameters involved in the computation, I know frameworks on top of it that produce a warning about unused parameters in your model.
Returning a Result feels like a good solution since I agree this seems tricky to do in Compile time.
@vikigenius thanks for that feedback! Yeah and I think rust people will be very comfortable with results. Will probably be merging #107 before next release, happy for feedback there as well if you want