dfdx icon indicating copy to clipboard operation
dfdx copied to clipboard

Adding initializers

Open cBournhonesque opened this issue 2 years ago • 2 comments

We currently have the trait ResetParams for modules, and Randomize for tensors.

I'm trying to see how we can make them more user-friendly, by having something similar to https://keras.io/api/layers/initializers/#randomnormal-class

  1. Would it be worthwhile to have the ResetParams trait take a Distribution object that the user could provide? So that the user has more control over how the params of the networks are reset. (this could also be another trait)

  2. Would it be useful to have a module initializers similar to Keras? Note that some initializers (Xavier) make use of the shape of the module that calls them.

cBournhonesque avatar Nov 01 '22 04:11 cBournhonesque

Great question. Before ResetParams they just implemented the Randomize that tensors implement, but then you couldn't have things that depend on shape as you mentioned.

I wonder if it's possible to have some sort of fallback impl for initializers? Thinking along the lines of this:

// NOTE: use like: `StandardInitializers.reset_params(&mut model)`
pub struct StandardInitializers;
impl ResetParams<Linear<I, O>> for StandardInitializers { }

Then a user could override with:

pub struct MyLinearInitializers;
impl ResetParams<Linear<I, O>> for MyLinearInitializers { ... }
impl<M> ResetParams<M> for MyLinearInitializers where StandardInitializers: ResetParams<M> { ... }

However I think rust would error out with conflicting implementation error (since the StandardInitializers implements something for Linear<I, O> already). There's probably some way to disambiguate them, would have to think more about it though

coreylowman avatar Nov 01 '22 17:11 coreylowman

This might be do-able with TensorCollection as it is now, especially since ResetParams was moved to doing this. Check out the existing implementation of ResetParams:

struct Resetter;
impl<E: Dtype, D: DeviceStorage> TensorVisitor<E, D> for Resetter {
    type Viewer = ViewTensorMut;
    type Err = D::Err;

    fn visit<S: Shape>(
        &mut self,
        _: String,
        opts: TensorOptions<S, E, D>,
        t: &mut Tensor<S, E, D>,
    ) -> Result<(), D::Err> {
        (opts.reset)(t)
    }
}
pub trait ResetParams<E: Dtype, D: DeviceStorage>: TensorCollection<E, D> {
    fn reset_params(&mut self) {
        self.try_reset_params().unwrap();
    }
    fn try_reset_params(&mut self) -> Result<(), D::Err> {
        Self::iter_tensors(&mut RecursiveWalker {
            m: self,
            f: &mut Resetter,
            path: &mut Vec::new(),
        })
    }
}
impl<E: Dtype, D: DeviceStorage, M: TensorCollection<E, D>> ResetParams<E, D> for M {}

So if you wanted to do a custom initialization you could copy the above, and then change the visit method do something like:

struct MyCustomInit;
impl<E: Dtype, D: DeviceStorage> TensorVisitor<E, D> for MyCustomInit {
    type Viewer = ViewTensorMut;
    type Err = D::Err;

    fn visit<S: Shape>(
        &mut self,
        path: String,
        opts: TensorOptions<S, E, D>,
        t: &mut Tensor<S, E, D>,
    ) -> Result<(), D::Err> {
        if S::NUM_DIMS == 2 { ... }
        else if S::NUM_DIMS == 4 { ... }
       else if path.contains("weight") { ... }
        else { (opts.reset)(t)
    }
}

However I'm not sure we have the correct pub exports for this to work properly, and we should add an example of how to do this before closing this ticket.

coreylowman avatar Mar 08 '23 19:03 coreylowman