dfdx icon indicating copy to clipboard operation
dfdx copied to clipboard

Input Wrapper

Open swfsql opened this issue 2 years ago • 1 comments

This is a draft, closes #878.
Note: If this design ends up being useful, this could be implemented as a separated library (there are only code additions and they don't conflict with anything), but for perhaps feedback, it's better and more straight-forward to currently have this as a draft PR.

  • Add #[input_wrapper].
    • Add the heck dep to convert from CamelCase into snake_case.
  • Add layers.
    • Id, which just forwards the input.
    • On, applies some Module into an input wrapper field.
      • Contains a test demonstrating it's usage.
    • Add, which calls try_add for the inputs.

This is how it gets used: https://github.com/coreylowman/dfdx/blob/52b62649d38482ceea432027cabed08ceca12f52/dfdx/src/nn/layers/on.rs#L41-L64

This is what gets generated from the above:

rust code

pub struct Split1<Forward, Skip> {
    pub forward: Forward,
    pub skip: Skip,
}
/// Automatically generated by `input_wrapper`. The containing items are visible on your project's documentation.
pub mod split1 {
    use super::Split1;
    /// Indicates the [`Split1::forward`] field.  \nThis field is the `0` value (`0`-based index).
    #[allow(non_camel_case_types)]
    #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
    pub struct forward;

    /// Indicates the [`Split1::skip`] field.  \nThis field is the `1` value (`0`-based index).
    #[allow(non_camel_case_types)]
    #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
    pub struct skip;

    /// Indicates a conversion from a (Forward, Skip) tuple into a `Split1<Forward, Skip>`.
    #[derive(
        Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, crate::prelude::CustomModule,
    )]
    pub struct FromTuple;

    /// Indicates a conversion from a `Split1<Forward, Skip>` into a (Forward, Skip) tuple.
    #[derive(
        Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, crate::prelude::CustomModule,
    )]
    pub struct IntoTuple;

    /// Conversion of a tuple into a [`Split1`].
    impl<Forward, Skip> From<(Forward, Skip)> for Split1<Forward, Skip> {
        fn from(x: (Forward, Skip)) -> Self {
            Split1 {
                forward: x.0,
                skip: x.1,
            }
        }
    }
    /// Conversion of a [`Split1`] into a tuple.
    impl<Forward, Skip> From<Split1<Forward, Skip>> for (Forward, Skip) {
        fn from(x: Split1<Forward, Skip>) -> Self {
            (x.forward, x.skip)
        }
    }
    /// Module to convert a tuple into a [`Split1`].
    impl<Forward, Skip> crate::prelude::Module<(Forward, Skip)> for FromTuple {
        type Output = Split1<Forward, Skip>;
        fn try_forward(&self, x: (Forward, Skip)) -> Result<Self::Output, crate::prelude::Error> {
            Ok(x.into())
        }
    }
    /// Module to convert a [`Split1`] into a tuple.
    impl<Forward, Skip> crate::prelude::Module<Split1<Forward, Skip>> for IntoTuple {
        type Output = (Forward, Skip);
        fn try_forward(
            &self,
            x: Split1<Forward, Skip>,
        ) -> Result<Self::Output, crate::prelude::Error> {
            Ok(x.into())
        }
    }
    /// Module that access [`Split1::forward`] and then applies Module `M` on it.
    impl<M: crate::prelude::Module<Forward>, Forward, Skip>
        crate::prelude::Module<Split1<Forward, Skip>> for crate::prelude::On<forward, M>
    {
        type Output = Split1<<M as crate::prelude::Module<Forward>>::Output, Skip>;
        fn try_forward(
            &self,
            x: Split1<Forward, Skip>,
        ) -> Result<Self::Output, crate::prelude::Error> {
            let x0 = x.forward;
            let x1 = x.skip;
            let x0 = self.t.try_forward(x0)?;
            let x = Split1 {
                forward: x0,
                skip: x1,
            };
            Ok(x)
        }
        fn try_forward_mut(
            &mut self,
            x: Split1<Forward, Skip>,
        ) -> Result<Self::Output, crate::prelude::Error> {
            let x0 = x.forward;
            let x1 = x.skip;
            let x0 = self.t.try_forward_mut(x0)?;
            let x = Split1 {
                forward: x0,
                skip: x1,
            };
            Ok(x)
        }
    }
    /// Module that access [`Split1::skip`] and then applies Module `M` on it.
    impl<M: crate::prelude::Module<Skip>, Forward, Skip>
        crate::prelude::Module<Split1<Forward, Skip>> for crate::prelude::On<skip, M>
    {
        type Output = Split1<Forward, <M as crate::prelude::Module<Skip>>::Output>;
        fn try_forward(
            &self,
            x: Split1<Forward, Skip>,
        ) -> Result<Self::Output, crate::prelude::Error> {
            let x0 = x.forward;
            let x1 = x.skip;
            let x1 = self.t.try_forward(x1)?;
            let x = Split1 {
                forward: x0,
                skip: x1,
            };
            Ok(x)
        }
        fn try_forward_mut(
            &mut self,
            x: Split1<Forward, Skip>,
        ) -> Result<Self::Output, crate::prelude::Error> {
            let x0 = x.forward;
            let x1 = x.skip;
            let x1 = self.t.try_forward_mut(x1)?;
            let x = Split1 {
                forward: x0,
                skip: x1,
            };
            Ok(x)
        }
    }
}

swfsql avatar Nov 06 '23 01:11 swfsql

To add info on this, this is how I was able to define a unet:

(Note: in this case I was using a version of dfdx that had some other local changes, specially experimental ones.)

rust code

#[input_wrapper]
#[derive(Clone, Debug)]
pub struct Split<Forward, Skip> {
    pub forward: Forward,
    pub skip: Skip,
}

impl<Forward, Skip, const AXIS: isize> TryConcatTensorAlong<Axis<AXIS>> for Split<Forward, Skip>
where
    (Forward, Skip): TryConcatTensorAlong<Axis<AXIS>>,
{
    type Output = <(Forward, Skip) as TryConcatTensorAlong<Axis<AXIS>>>::Output;
    fn try_concat_tensor_along(self, ax: Axis<AXIS>) -> Result<Self::Output, Error> {
        let (forward, skip) = self.into();
        (forward, skip).try_concat_tensor_along(ax)
    }
}

/// From:
/// ```ignore
/// batch * CH_IN * height * width
/// ```
///
/// To:
/// ```ignore
/// batch * CH_OUT * height * width
/// ```
#[derive(Debug, Clone, Default, dfdx::Sequential)]
#[built(ConvBlock)]
pub struct ConvBlockConfig<const CH_IN: usize, const CH_OUT: usize> {
    pub conv_1: Conv2DConstConfig<CH_IN, CH_OUT, 3, 1, 1>,
    pub norm_1: BatchNorm2DConstConfig<CH_OUT>,
    pub a_1: ReLU,
    //
    pub conv_2: Conv2DConstConfig<CH_OUT, CH_OUT, 3, 1, 1>,
    pub norm_2: BatchNorm2DConstConfig<CH_OUT>,
    pub a_2: ReLU,
}

/// From:
/// ```ignore
/// batch * CH_IN * height * width
/// ```
///
/// To:
/// ```ignore
/// Split {
///     forward: batch * CH_OUT * height/2 * width/2,
///     skip: batch * CH_OUT * height * width,
/// }
/// ```
#[derive(Debug, Clone, Default, dfdx::Sequential)]
#[built(DownBlock)]
pub struct DownBlockConfig<const CH_IN: usize, const CH_OUT: usize> {
    pub conv: ConvBlockConfig<CH_IN, CH_OUT>,
    //
    pub split: SplitInto<(Id, Id)>,
    pub wrapper: split::FromTuple,
    pub pool: On<split::forward, MaxPool2DConst<2, 2, 0>>,
}

/// From:
/// ```ignore
/// Split {
///     forward: batch * CH_INF * height/2 * width/2,
///     skip: batch * CH_INS * height * width,
/// }
/// ```
///
/// To:
/// ```ignore
/// batch * CH_OUT * height * width
/// ```
///
/// Notes:
/// - `CH_INF` refers to the #channels from [`Split::forward`].
/// - `CH_INS` refers to the #channels from [`Split::skip`], but this parameter is not directly passed to this structure.
/// - `CH_CONCAT` is supposed to be `CH_OUT + CH_INS`.
#[derive(Debug, Clone, Default, dfdx::Sequential)]
#[built(UpBlock)]
pub struct UpBlockConfig<const CH_INF: usize, const CH_OUT: usize, const CH_CONCAT: usize> {
    // for keras' padding='same', the PADDING value must be set to:
    // ((kernel-1) * dilation + 1) // 2 = ((3-1) * 1 + 1 // 2) = 1
    pub conv_trans:
        On<split::forward, ConvTrans2DConstConfig<CH_INF, CH_OUT, 3, 2, 1, 1, 1, 1>>,
    pub bias: On<split::forward, Bias2DConstConfig<CH_OUT>>,
    // concat "skip" and "forward" along channels
    pub tuple: split::IntoTuple,
    pub concat: ops::ConcatTensorAlong<Axis<1>>,
    pub conv: ConvBlockConfig<CH_CONCAT, CH_OUT>,
}

/// Just applies `M`.
type Onc0<M> = M;
/// Access `F` and then applies `M`.
type Onc1<F, M> = On<F, M>;
/// Access `F` consecutively 2 times and then applies `M`.
type Onc2<F, M> = On<F, On<F, M>>;
/// Access `F` consecutively 3 times and then applies `M`.
type Onc3<F, M> = On<F, Onc2<F, M>>;
/// Access `F` consecutively 4 times and then applies `M`.
type Onc4<F, M> = On<F, Onc3<F, M>>;
/// Access `F` consecutively 5 times and then applies `M`.
type Onc5<F, M> = On<F, Onc4<F, M>>;

#[derive(Debug, Clone, Default, dfdx::Sequential)]
#[built(Model)]
pub struct ModelConfig {
    // encoder
    pub down_block_0: Onc0<DownBlockConfig<3, 32>>,
    pub down_block_1: Onc1<split::forward, DownBlockConfig<32, 64>>,
    pub down_block_2: Onc2<split::forward, DownBlockConfig<64, 128>>,
    pub down_block_3: Onc3<split::forward, DownBlockConfig<128, 256>>,

    // bottleneck
    // note: this increases channels but does not reduces the height nor the width
    pub conv_bottle: Onc4<split::forward, ConvBlockConfig<256, 512>>,

    // decoder
    pub up_block_4: Onc3<split::forward, UpBlockConfig<512, 256, 512>>,
    pub up_block_3: Onc2<split::forward, UpBlockConfig<256, 128, 256>>,
    pub up_block_2: Onc1<split::forward, UpBlockConfig<128, 64, 128>>,
    pub up_block_1: Onc0<UpBlockConfig<64, 32, 64>>,

    // yclass channel conversion
    pub conv_2: Conv2DConstConfig<32, 13, 1, 1, 0>,
    pub bias_2: Bias2DConstConfig<13>,
}

And this is how I defined and trained a simple RNN (based on this exercise). The trained model was able to generate dino-like names.

rust code

pub mod model {
    use super::*;

    #[input_wrapper]
    #[derive(Clone, Debug)]
    pub struct Input<A, X> {
        pub a_prev: A,
        pub x: X,
    }

    impl<A, X, const AXIS: isize> TryConcatTensorAlong<Axis<AXIS>> for Input<A, X>
    where
        (A, X): TryConcatTensorAlong<Axis<AXIS>>,
    {
        type Output = <(A, X) as TryConcatTensorAlong<Axis<AXIS>>>::Output;
        fn try_concat_tensor_along(self, ax: Axis<AXIS>) -> Result<Self::Output, Error> {
            (self.a_prev, self.x).try_concat_tensor_along(ax)
        }
    }

    #[input_wrapper]
    #[derive(Clone, Debug)]
    pub struct Output<A, Y> {
        pub a: A,
        pub y: Y,
    }

    impl<AS: Shape, YS: Shape, E: Dtype, D: Device<E>>
        Output<Tensor<AS, E, D, OwnedTape<E, D>>, Tensor<YS, E, D, OwnedTape<E, D>>>
    {
        pub fn merge_tapes_on_y(self) -> Self {
            let (a, at) = self.a.split_tape();
            let (y, yt) = self.y.split_tape();
            Self {
                a: a.leaky_traced(),
                y: y.put_tape(at.merge(yt)),
            }
        }
    }

    //
    /// Input:
    /// ```ignore
    /// Input {
    ///     a_prev: A,
    ///     x: X,
    /// }
    /// ```
    ///
    /// Output:
    /// ```ignore
    /// Output {
    ///     a: A,
    ///     y: Y,
    /// }
    /// ```
    #[derive(Debug, Clone, Default, dfdx::Sequential)]
    #[built(Cell)]
    pub struct CellConstConfig<
        const NA: usize,
        const NX: usize,
        const NY: usize,
        const CONCAT_AXIS: isize,
        const NAPNX: usize = { NA + NX },
    > {
        // doing concat(a_prev, x) dot concat(wa^t, wb^t) + b is the same as
        // doing a_prev dot wa^t + x dot wb^t + b
        //
        // pub amul: On<input::a_prev, MatMulConstConfig<NA, NA>>,
        // pub xmul: On<input::x, MatMulConstConfig<NX, NA>>,
        // pub ax_tuple: input::IntoTuple,
        // pub ax_add: ops::Add,
        // pub bias: Bias1DConstConfig<NA>,
        //
        pub concat_input: ops::ConcatTensorAlong<Axis<CONCAT_AXIS>>,
        pub ax_linear: LinearConstConfig<NAPNX, NA>,
        //
        pub g1: Tanh,
        pub ay_tuple: SplitInto<(Id, Id)>,
        pub ay: output::FromTuple,
        pub ylinear: On<output::y, LinearConstConfig<NA, NY>>,
    }
}

#[test]
fn test_rnn() -> anyhow::Result<()> {
    let dev = Cuda::try_build(0, 0)?;

    const T_: usize = 2;
    const BATCH: usize = 3;
    // for unbatched (1D) tensors, the concat axis is 0
    // for batched (2D) tensors, the concat axis is 1
    const CONCAT_AXIS: isize = 1;
    const NA: usize = 2;
    const NX: usize = 3;
    const NY: usize = 3;
    type XT<T = NoneTape> = Tensor<Rank2<BATCH, NX>, f32, Device_, T>;
    type AT<T = NoneTape> = Tensor<Rank2<BATCH, NA>, f32, Device_, T>;
    type YT<T = NoneTape> = Tensor<Rank2<BATCH, NY>, f32, Device_, T>;

    let mut model =
        dev.build_module::<f32>(model::CellConstConfig::<NA, NX, NY, CONCAT_AXIS>::default());
    let mut grads = model.alloc_grads();

    let mut opt = dfdx::prelude::optim::Adam::new(
        &model,
        AdamConfig {
            lr: 1e-4,
            ..Default::default() // weight_decay: Some(dfdx::nn::optim::WeightDecay::L2(0.001)),
        },
    );

    const EPOCHS: usize = 2;
    for e in 0..EPOCHS {
        let a_prev: AT = dev.zeros();
        let mut x: XT = dev.zeros();

        let mut a_prev_t: AT<_> = a_prev.leaky_traced();

        let mut batch_loss = 0f32;

        for _t in 0..T_ {
            let y_t: YT = dev.sample_uniform();

            let x_t: XT<OwnedTape<f32, Device_>> = x.traced(grads);

            let input = model::Input {
                a_prev: a_prev_t,
                x: x_t,
            };

            let prediction = model.forward_mut(input);
            let prediction = prediction.merge_tapes_on_y();
            let loss_t =
                dfdx::losses::cross_entropy_with_logits_loss(prediction.y, y_t.clone());
            batch_loss += loss_t.array();

            // Note:
            // Running backprop and model update for each timestep t.
            // A different approach would be to run backprop at the last timestep and update once.
            // Or yet do something in between.
            grads = loss_t.backward();
            opt.update(&mut model, &grads).unwrap();

            x = y_t;
            a_prev_t = prediction.a;
        }
        println!("epoch: {}; loss: {}", e, batch_loss);
        // grads.drop_non_leafs();
        model.zero_grads(&mut grads);
    }

    Ok(())
}

swfsql avatar Nov 25 '23 04:11 swfsql

I'll prioritize moving this experiment to a separate crate, but feel free to ping in case anyone have some question or suggestion.

swfsql avatar Mar 01 '24 15:03 swfsql